/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.instructions.cp;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.util.LongAccumulator;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.hops.recompile.Recompiler;
import org.tugraz.sysds.lops.LopProperties;
import org.tugraz.sysds.parser.Statement;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.LocalVariableMap;
import org.tugraz.sysds.runtime.controlprogram.caching.MatrixObject;
import org.tugraz.sysds.runtime.controlprogram.context.ExecutionContext;
import org.tugraz.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.tugraz.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
import org.tugraz.sysds.runtime.controlprogram.paramserv.LocalParamServer;
import org.tugraz.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.tugraz.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.tugraz.sysds.runtime.controlprogram.paramserv.SparkPSBody;
import org.tugraz.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.tugraz.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.tugraz.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.tugraz.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.tugraz.sysds.runtime.instructions.cp.CPOperand;
import org.tugraz.sysds.runtime.instructions.cp.ListObject;
import org.tugraz.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.matrix.operators.Operator;
import org.tugraz.sysds.runtime.util.ProgramConverter;
import org.tugraz.sysds.utils.Statistics;

public class ParamservBuiltinCPInstruction
extends ParameterizedBuiltinCPInstruction {
    private static final int DEFAULT_BATCH_SIZE = 64;
    private static final Statement.PSFrequency DEFAULT_UPDATE_FREQUENCY = Statement.PSFrequency.EPOCH;
    private static final Statement.PSScheme DEFAULT_SCHEME = Statement.PSScheme.DISJOINT_CONTIGUOUS;
    private static final Statement.PSModeType DEFAULT_MODE = Statement.PSModeType.LOCAL;
    private static final Statement.PSUpdateType DEFAULT_TYPE = Statement.PSUpdateType.ASP;
    private static final boolean LDEBUG = false;

    public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) {
        super(op, paramsMap, out, opcode, istr);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        Statement.PSModeType mode = this.getPSMode();
        switch (mode) {
            case LOCAL: {
                this.runLocally(ec, mode);
                break;
            }
            case REMOTE_SPARK: {
                this.runOnSpark((SparkExecutionContext)ec, mode);
                break;
            }
            default: {
                throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", new Object[]{mode}));
            }
        }
    }

    private void runOnSpark(SparkExecutionContext sec, Statement.PSModeType mode) {
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
        int workerNum = this.getWorkerNum(mode);
        String updFunc = this.getParam("upd");
        String aggFunc = this.getParam("agg");
        LocalVariableMap newVarsMap = this.createVarsMap(sec);
        ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1);
        ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
        ListObject model = sec.getListObject(this.getParam("model"));
        ParamServer ps = ParamservBuiltinCPInstruction.createPS(mode, aggFunc, this.getUpdateType(), workerNum, model, aggServiceEC);
        String host = sec.getSparkContext().getConf().get("spark.driver.host");
        TransportServer server = PSRpcFactory.createServer(sec.getSparkContext().getConf(), (LocalParamServer)ps, host);
        Recompiler.recompileProgramBlockHierarchy2Forced(newEC.getProgram().getProgramBlocks(), 0L, new HashSet<String>(), LopProperties.ExecType.CP);
        SparkPSBody body = new SparkPSBody(newEC);
        HashMap<String, byte[]> clsMap = new HashMap<String, byte[]>();
        String program = ProgramConverter.serializeSparkPSBody(body, clsMap);
        LongAccumulator aSetup = sec.getSparkContext().sc().longAccumulator("setup");
        LongAccumulator aWorker = sec.getSparkContext().sc().longAccumulator("workersNum");
        LongAccumulator aUpdate = sec.getSparkContext().sc().longAccumulator("modelUpdate");
        LongAccumulator aIndex = sec.getSparkContext().sc().longAccumulator("batchIndex");
        LongAccumulator aGrad = sec.getSparkContext().sc().longAccumulator("gradCompute");
        LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest");
        LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches");
        LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs");
        SparkPSWorker worker = new SparkPSWorker(this.getParam("upd"), this.getParam("agg"), this.getFrequency(), this.getEpochs(), this.getBatchSize(), program, clsMap, sec.getSparkContext().getConf(), server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch);
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long)tSetup.stop());
        }
        MatrixObject features = sec.getMatrixObject(this.getParam("features"));
        MatrixObject labels = sec.getMatrixObject(this.getParam("labels"));
        try {
            ParamservUtils.doPartitionOnSpark(sec, features, labels, this.getScheme(), workerNum).foreach((VoidFunction)worker);
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Paramserv function failed: ", e);
        }
        finally {
            server.close();
        }
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime(aSetup.value());
            Statistics.incWorkerNumber(aWorker.value());
            Statistics.accPSLocalModelUpdateTime(aUpdate.value());
            Statistics.accPSBatchIndexingTime(aIndex.value());
            Statistics.accPSGradientComputeTime(aGrad.value());
            Statistics.accPSRpcRequestTime(aRPC.value());
        }
        sec.setVariable(this.output.getName(), ps.getResult());
    }

    private void runLocally(ExecutionContext ec, Statement.PSModeType mode) {
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
        int workerNum = this.getWorkerNum(mode);
        BasicThreadFactory factory = new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d").build();
        ExecutorService es = Executors.newFixedThreadPool(workerNum, (ThreadFactory)factory);
        String updFunc = this.getParam("upd");
        String aggFunc = this.getParam("agg");
        LocalVariableMap newVarsMap = this.createVarsMap(ec);
        ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, ParamservBuiltinCPInstruction.getParLevel(workerNum));
        List<ExecutionContext> workerECs = ParamservUtils.copyExecutionContext(newEC, workerNum);
        ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
        Statement.PSFrequency freq = this.getFrequency();
        Statement.PSUpdateType updateType = this.getUpdateType();
        ListObject model = ec.getListObject(this.getParam("model"));
        ParamServer ps = ParamservBuiltinCPInstruction.createPS(mode, aggFunc, updateType, workerNum, model, aggServiceEC);
        List<LocalPSWorker> workers = IntStream.range(0, workerNum).mapToObj(i -> new LocalPSWorker(i, updFunc, freq, this.getEpochs(), this.getBatchSize(), (ExecutionContext)workerECs.get(i), ps)).collect(Collectors.toList());
        Statement.PSScheme scheme = this.getScheme();
        this.partitionLocally(scheme, ec, workers);
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long)tSetup.stop());
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("\nConfiguration of paramserv func: \nmode: %s \nworkerNum: %d \nupdate frequency: %s \nstrategy: %s \ndata partitioner: %s", new Object[]{mode, workerNum, freq, updateType, scheme}));
        }
        try {
            for (Future ret : es.invokeAll(workers)) {
                ret.get();
            }
            ec.setVariable(this.output.getName(), ps.getResult());
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
        }
        finally {
            es.shutdownNow();
        }
    }

    private LocalVariableMap createVarsMap(ExecutionContext ec) {
        LocalVariableMap varsMap = new LocalVariableMap();
        ListObject hyperParams = this.getHyperParams(ec);
        if (hyperParams != null) {
            varsMap.put("hyperparams", hyperParams);
        }
        return varsMap;
    }

    private Statement.PSModeType getPSMode() {
        Statement.PSModeType mode;
        if (!this.getParameterMap().containsKey("mode")) {
            return DEFAULT_MODE;
        }
        try {
            mode = Statement.PSModeType.valueOf(this.getParam("mode"));
        }
        catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support ps execution mode '%s'", this.getParam("mode")));
        }
        return mode;
    }

    private int getEpochs() {
        int epochs = Integer.valueOf(this.getParam("epochs"));
        if (epochs <= 0) {
            throw new DMLRuntimeException(String.format("Paramserv function: The argument '%s' could not be less than or equal to 0.", "epochs"));
        }
        return epochs;
    }

    private static int getParLevel(int workerNum) {
        return Math.max((int)Math.ceil((double)ParamservBuiltinCPInstruction.getRemainingCores() / (double)workerNum), 1);
    }

    private Statement.PSUpdateType getUpdateType() {
        Statement.PSUpdateType updType;
        if (!this.getParameterMap().containsKey("utype")) {
            return DEFAULT_TYPE;
        }
        try {
            updType = Statement.PSUpdateType.valueOf(this.getParam("utype"));
        }
        catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support update type '%s'.", this.getParam("utype")));
        }
        if (updType == Statement.PSUpdateType.SSP) {
            throw new DMLRuntimeException("Paramserv function: Not support update type SSP.");
        }
        return updType;
    }

    private Statement.PSFrequency getFrequency() {
        if (!this.getParameterMap().containsKey("freq")) {
            return DEFAULT_UPDATE_FREQUENCY;
        }
        try {
            return Statement.PSFrequency.valueOf(this.getParam("freq"));
        }
        catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support '%s' update frequency.", this.getParam("freq")));
        }
    }

    private static int getRemainingCores() {
        return InfrastructureAnalyzer.getLocalParallelism();
    }

    private int getWorkerNum(Statement.PSModeType mode) {
        switch (mode) {
            case LOCAL: {
                return this.getParameterMap().containsKey("k") ? Integer.valueOf(this.getParam("k")) : ParamservBuiltinCPInstruction.getRemainingCores();
            }
            case REMOTE_SPARK: {
                return this.getParameterMap().containsKey("k") ? Integer.valueOf(this.getParam("k")) : SparkExecutionContext.getDefaultParallelism(true);
            }
        }
        throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
    }

    private static ParamServer createPS(Statement.PSModeType mode, String aggFunc, Statement.PSUpdateType updateType, int workerNum, ListObject model, ExecutionContext ec) {
        switch (mode) {
            case LOCAL: 
            case REMOTE_SPARK: {
                return LocalParamServer.create(model, aggFunc, updateType, ec, workerNum);
            }
        }
        throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
    }

    private long getBatchSize() {
        if (!this.getParameterMap().containsKey("batchsize")) {
            return 64L;
        }
        long batchSize = Integer.valueOf(this.getParam("batchsize")).intValue();
        if (batchSize <= 0L) {
            throw new DMLRuntimeException(String.format("Paramserv function: the number of argument '%s' could not be less than or equal to 0.", "batchsize"));
        }
        return batchSize;
    }

    private ListObject getHyperParams(ExecutionContext ec) {
        ListObject hyperparams = null;
        if (this.getParameterMap().containsKey("hyperparams")) {
            hyperparams = ec.getListObject(this.getParam("hyperparams"));
        }
        return hyperparams;
    }

    private void partitionLocally(Statement.PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) {
        MatrixObject features = ec.getMatrixObject(this.getParam("features"));
        MatrixObject labels = ec.getMatrixObject(this.getParam("labels"));
        DataPartitionLocalScheme.Result result = new LocalDataPartitioner(scheme).doPartitioning(workers.size(), (MatrixBlock)features.acquireReadAndRelease(), (MatrixBlock)labels.acquireReadAndRelease());
        List<MatrixObject> pfs = result.pFeatures;
        List<MatrixObject> pls = result.pLabels;
        if (pfs.size() < workers.size()) {
            if (LOG.isWarnEnabled()) {
                LOG.warn((Object)String.format("There is only %d batches of data but has %d workers. Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size()));
            }
            workers = workers.subList(0, pfs.size());
        }
        for (int i = 0; i < workers.size(); ++i) {
            workers.get(i).setFeatures(pfs.get(i));
            workers.get(i).setLabels(pls.get(i));
        }
    }

    private Statement.PSScheme getScheme() {
        Statement.PSScheme scheme = DEFAULT_SCHEME;
        if (this.getParameterMap().containsKey("scheme")) {
            try {
                scheme = Statement.PSScheme.valueOf(this.getParam("scheme"));
            }
            catch (IllegalArgumentException e) {
                throw new DMLRuntimeException(String.format("Paramserv function: not support data partition scheme '%s'", this.getParam("scheme")));
            }
        }
        return scheme;
    }
}

