/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.controlprogram.paramserv;

import java.util.concurrent.Callable;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.parser.Statement;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.caching.MatrixObject;
import org.tugraz.sysds.runtime.controlprogram.context.ExecutionContext;
import org.tugraz.sysds.runtime.controlprogram.paramserv.PSWorker;
import org.tugraz.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.tugraz.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.tugraz.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.tugraz.sysds.runtime.instructions.cp.ListObject;
import org.tugraz.sysds.utils.Statistics;

public class LocalPSWorker
extends PSWorker
implements Callable<Void> {
    protected static final Log LOG = LogFactory.getLog((String)LocalPSWorker.class.getName());
    private static final long serialVersionUID = 5195390748495357295L;

    protected LocalPSWorker() {
    }

    public LocalPSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps) {
        super(workerID, updFunc, freq, epochs, batchSize, ec, ps);
    }

    @Override
    public String getWorkerName() {
        return String.format("Local worker_%d", this._workerID);
    }

    @Override
    public Void call() throws Exception {
        this.incWorkerNumber();
        try {
            long dataSize = this._features.getNumRows();
            int batchIter = (int)Math.ceil((double)dataSize / (double)this._batchSize);
            switch (this._freq) {
                case BATCH: {
                    this.computeBatch(dataSize, batchIter);
                    break;
                }
                case EPOCH: {
                    this.computeEpoch(dataSize, batchIter);
                    break;
                }
                default: {
                    throw new DMLRuntimeException(String.format("%s not support update frequency %s", new Object[]{this.getWorkerName(), this._freq}));
                }
            }
            if (LOG.isDebugEnabled()) {
                LOG.debug((Object)String.format("%s: job finished.", this.getWorkerName()));
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException(String.format("%s failed", this.getWorkerName()), e);
        }
        return null;
    }

    private void computeEpoch(long dataSize, int batchIter) {
        for (int i = 0; i < this._epochs; ++i) {
            ListObject params = this.pullModel();
            ListObject accGradients = null;
            for (int j = 0; j < batchIter; ++j) {
                ListObject gradients = this.computeGradients(params, dataSize, batchIter, i, j);
                boolean localUpdate = j < batchIter - 1;
                accGradients = ParamservUtils.accrueGradients(accGradients, gradients, !localUpdate);
                if (localUpdate) {
                    params = this.updateModel(params, gradients, i, j, batchIter);
                }
                this.accNumBatches(1);
            }
            this.pushGradients(accGradients);
            ParamservUtils.cleanupListObject(this._ec, "model");
            this.accNumEpochs(1);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug((Object)String.format("%s: finished %d epoch.", this.getWorkerName(), i + 1));
        }
    }

    private ListObject updateModel(ListObject globalParams, ListObject gradients, int i, int j, int batchIter) {
        Timing tUpd = DMLScript.STATISTICS ? new Timing(true) : null;
        globalParams = this._ps.updateLocalModel(this._ec, gradients, globalParams);
        this.accLocalModelUpdateTime(tUpd);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: local global parameter [size:%d kb] updated. [Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", this.getWorkerName(), globalParams.getDataSize(), i + 1, this._epochs, j + 1, batchIter));
        }
        return globalParams;
    }

    private void computeBatch(long dataSize, int totalIter) {
        for (int i = 0; i < this._epochs; ++i) {
            for (int j = 0; j < totalIter; ++j) {
                ListObject globalParams = this.pullModel();
                ListObject gradients = this.computeGradients(globalParams, dataSize, totalIter, i, j);
                this.pushGradients(gradients);
                ParamservUtils.cleanupListObject(this._ec, "model");
                this.accNumBatches(1);
            }
            this.accNumEpochs(1);
            if (!LOG.isDebugEnabled()) continue;
            LOG.debug((Object)String.format("%s: finished %d epoch.", this.getWorkerName(), i + 1));
        }
    }

    private ListObject pullModel() {
        ListObject globalParams = this._ps.pull(this._workerID);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: successfully pull the global parameters [size:%d kb] from ps.", this.getWorkerName(), globalParams.getDataSize() / 1024L));
        }
        return globalParams;
    }

    private void pushGradients(ListObject gradients) {
        this._ps.push(this._workerID, gradients);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: successfully push the gradients [size:%d kb] to ps.", this.getWorkerName(), gradients.getDataSize() / 1024L));
        }
    }

    private ListObject computeGradients(ListObject params, long dataSize, int batchIter, int i, int j) {
        this._ec.setVariable("model", params);
        long begin = (long)j * this._batchSize + 1L;
        long end = Math.min((long)(j + 1) * this._batchSize, dataSize);
        Timing tSlic = DMLScript.STATISTICS ? new Timing(true) : null;
        MatrixObject bFeatures = ParamservUtils.sliceMatrix(this._features, begin, end);
        MatrixObject bLabels = ParamservUtils.sliceMatrix(this._labels, begin, end);
        this.accBatchIndexingTime(tSlic);
        this._ec.setVariable("features", bFeatures);
        this._ec.setVariable("labels", bLabels);
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("%s: got batch data [size:%d kb] of index from %d to %d [last index: %d]. [Epoch:%d  Total epoch:%d  Iteration:%d  Total iteration:%d]", this.getWorkerName(), bFeatures.getDataSize() / 1024L + bLabels.getDataSize() / 1024L, begin, end, dataSize, i + 1, this._epochs, j + 1, batchIter));
        }
        Timing tGrad = DMLScript.STATISTICS ? new Timing(true) : null;
        this._inst.processInstruction(this._ec);
        this.accGradientComputeTime(tGrad);
        ListObject gradients = this._ec.getListObject(this._output.getName());
        ParamservUtils.cleanupData(this._ec, "features");
        ParamservUtils.cleanupData(this._ec, "labels");
        return gradients;
    }

    @Override
    protected void incWorkerNumber() {
        if (DMLScript.STATISTICS) {
            Statistics.incWorkerNumber();
        }
    }

    @Override
    protected void accLocalModelUpdateTime(Timing time) {
        if (DMLScript.STATISTICS) {
            Statistics.accPSLocalModelUpdateTime((long)time.stop());
        }
    }

    @Override
    protected void accBatchIndexingTime(Timing time) {
        if (DMLScript.STATISTICS) {
            Statistics.accPSBatchIndexingTime((long)time.stop());
        }
    }

    @Override
    protected void accGradientComputeTime(Timing time) {
        if (DMLScript.STATISTICS) {
            Statistics.accPSGradientComputeTime((long)time.stop());
        }
    }
}

