/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.cp;

import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.LocalVariableMap;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.FederatedPSControlThread;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionLocalScheme;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.FederatedDataPartitioner;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.LocalDataPartitioner;
import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;

public class ParamservBuiltinCPInstruction
extends ParameterizedBuiltinCPInstruction {
    private static final Log LOG = LogFactory.getLog((String)ParamservBuiltinCPInstruction.class.getName());
    public static final int DEFAULT_BATCH_SIZE = 64;
    private static final Statement.PSFrequency DEFAULT_UPDATE_FREQUENCY = Statement.PSFrequency.EPOCH;
    private static final Statement.PSScheme DEFAULT_SCHEME = Statement.PSScheme.DISJOINT_CONTIGUOUS;
    private static final Statement.PSRuntimeBalancing DEFAULT_RUNTIME_BALANCING = Statement.PSRuntimeBalancing.NONE;
    private static final Statement.FederatedPSScheme DEFAULT_FEDERATED_SCHEME = Statement.FederatedPSScheme.KEEP_DATA_ON_WORKER;
    private static final Statement.PSModeType DEFAULT_MODE = Statement.PSModeType.LOCAL;
    private static final Statement.PSUpdateType DEFAULT_TYPE = Statement.PSUpdateType.ASP;
    public static final int DEFAULT_NBATCHES = 1;
    private static final Boolean DEFAULT_MODELAVG = false;

    public ParamservBuiltinCPInstruction(Operator op, LinkedHashMap<String, String> paramsMap, CPOperand out, String opcode, String istr) {
        super(op, paramsMap, out, opcode, istr);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        if (ec.getMatrixObject(this.getParam("features")).isFederated() || ec.getMatrixObject(this.getParam("labels")).isFederated()) {
            this.runFederated(ec);
        } else {
            Statement.PSModeType mode = this.getPSMode();
            switch (mode) {
                case LOCAL: {
                    this.runLocally(ec, mode);
                    break;
                }
                case REMOTE_SPARK: {
                    this.runOnSpark((SparkExecutionContext)ec, mode);
                    break;
                }
                default: {
                    throw new DMLRuntimeException(String.format("Paramserv func: not support mode %s", new Object[]{mode}));
                }
            }
        }
    }

    private void runFederated(ExecutionContext ec) {
        if (DMLScript.STATISTICS) {
            Statistics.getPSExecutionTimer().start();
        }
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
        LOG.info((Object)"PARAMETER SERVER");
        LOG.info((Object)"[+] Running in federated mode");
        String updFunc = this.getParam("upd");
        String aggFunc = this.getParam("agg");
        Statement.PSUpdateType updateType = this.getUpdateType();
        Statement.PSFrequency freq = this.getFrequency();
        Statement.FederatedPSScheme federatedPSScheme = this.getFederatedScheme();
        Statement.PSRuntimeBalancing runtimeBalancing = this.getRuntimeBalancing();
        boolean weighting = this.getWeighting();
        int seed = this.getSeed();
        int nbatches = this.getNbatches();
        if (LOG.isInfoEnabled()) {
            LOG.info((Object)("[+] Update Type: " + (Object)((Object)updateType)));
            LOG.info((Object)("[+] Frequency: " + (Object)((Object)freq)));
            LOG.info((Object)("[+] Data Partitioning: " + (Object)((Object)federatedPSScheme)));
            LOG.info((Object)("[+] Runtime Balancing: " + (Object)((Object)runtimeBalancing)));
            LOG.info((Object)("[+] Weighting: " + weighting));
            LOG.info((Object)("[+] Seed: " + seed));
        }
        if (tSetup != null) {
            Statistics.accPSSetupTime((long)tSetup.stop());
        }
        Timing tDataPartitioning = DMLScript.STATISTICS ? new Timing(true) : null;
        DataPartitionFederatedScheme.Result result = new FederatedDataPartitioner(federatedPSScheme, seed).doPartitioning(ec.getMatrixObject(this.getParam("features")), ec.getMatrixObject(this.getParam("labels")));
        int workerNum = result._workerNum;
        if (DMLScript.STATISTICS) {
            Statistics.accFedPSDataPartitioningTime((long)tDataPartitioning.stop());
        }
        if (DMLScript.STATISTICS) {
            tSetup.start();
        }
        BasicThreadFactory factory = new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d").build();
        ExecutorService es = Executors.newFixedThreadPool(workerNum, (ThreadFactory)factory);
        LocalVariableMap newVarsMap = this.createVarsMap(ec);
        ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, -1, true);
        List<ExecutionContext> federatedWorkerECs = ParamservUtils.copyExecutionContext(newEC, workerNum);
        ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
        ListObject model = ec.getListObject(this.getParam("model"));
        MatrixObject val_features = this.getParam("val_features") != null ? ec.getMatrixObject(this.getParam("val_features")) : null;
        MatrixObject val_labels = this.getParam("val_labels") != null ? ec.getMatrixObject(this.getParam("val_labels")) : null;
        boolean modelAvg = Boolean.parseBoolean(this.getParam("modelAvg"));
        ParamServer ps = ParamservBuiltinCPInstruction.createPS(Statement.PSModeType.FEDERATED, aggFunc, updateType, freq, workerNum, model, aggServiceEC, this.getValFunction(), this.getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics), val_features, val_labels, nbatches, modelAvg);
        int finalNumBatchesPerEpoch = this.getNumBatchesPerEpoch(runtimeBalancing, result._balanceMetrics);
        List threads = IntStream.range(0, workerNum).mapToObj(i -> new FederatedPSControlThread(i, updFunc, freq, runtimeBalancing, weighting, this.getEpochs(), this.getBatchSize(), finalNumBatchesPerEpoch, (ExecutionContext)federatedWorkerECs.get(i), ps, nbatches, modelAvg)).collect(Collectors.toList());
        if (workerNum != threads.size()) {
            throw new DMLRuntimeException("ParamservBuiltinCPInstruction: Federated data partitioning does not match threads!");
        }
        for (int i2 = 0; i2 < threads.size(); ++i2) {
            ((FederatedPSControlThread)threads.get(i2)).setFeatures(result._pFeatures.get(i2));
            ((FederatedPSControlThread)threads.get(i2)).setLabels(result._pLabels.get(i2));
            ((FederatedPSControlThread)threads.get(i2)).setup(result._weightingFactors.get(i2));
        }
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long)tSetup.stop());
        }
        try {
            for (Future ret : es.invokeAll(threads)) {
                ret.get();
            }
            ec.setVariable(this.output.getName(), ps.getResult());
            if (DMLScript.STATISTICS) {
                Statistics.accPSExecutionTime((long)Statistics.getPSExecutionTimer().stop());
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException("ParamservBuiltinCPInstruction: unknown error: ", e);
        }
        finally {
            es.shutdownNow();
        }
    }

    private void runOnSpark(SparkExecutionContext sec, Statement.PSModeType mode) {
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
        int workerNum = this.getWorkerNum(mode);
        String updFunc = this.getParam("upd");
        String aggFunc = this.getParam("agg");
        int nbatches = this.getNbatches();
        boolean modelAvg = Boolean.parseBoolean(this.getParam("modelAvg"));
        LocalVariableMap newVarsMap = this.createVarsMap(sec);
        ExecutionContext newEC = ParamservUtils.createExecutionContext(sec, newVarsMap, updFunc, aggFunc, 1);
        ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
        ListObject model = sec.getListObject(this.getParam("model"));
        ParamServer ps = ParamservBuiltinCPInstruction.createPS(mode, aggFunc, this.getUpdateType(), this.getFrequency(), workerNum, model, aggServiceEC, nbatches, modelAvg);
        String host = sec.getSparkContext().getConf().get("spark.driver.host");
        TransportServer server = PSRpcFactory.createServer(sec.getSparkContext().getConf(), (LocalParamServer)ps, host);
        Recompiler.recompileProgramBlockHierarchy2Forced(newEC.getProgram().getProgramBlocks(), 0L, new HashSet<String>(), Types.ExecType.CP);
        SparkPSBody body = new SparkPSBody(newEC);
        HashMap<String, byte[]> clsMap = new HashMap<String, byte[]>();
        String program = ProgramConverter.serializeSparkPSBody(body, clsMap);
        LongAccumulator aSetup = sec.getSparkContext().sc().longAccumulator("setup");
        LongAccumulator aWorker = sec.getSparkContext().sc().longAccumulator("workersNum");
        LongAccumulator aUpdate = sec.getSparkContext().sc().longAccumulator("modelUpdate");
        LongAccumulator aIndex = sec.getSparkContext().sc().longAccumulator("batchIndex");
        LongAccumulator aGrad = sec.getSparkContext().sc().longAccumulator("gradCompute");
        LongAccumulator aRPC = sec.getSparkContext().sc().longAccumulator("rpcRequest");
        LongAccumulator aBatch = sec.getSparkContext().sc().longAccumulator("numBatches");
        LongAccumulator aEpoch = sec.getSparkContext().sc().longAccumulator("numEpochs");
        SparkPSWorker worker = new SparkPSWorker(this.getParam("upd"), this.getParam("agg"), this.getFrequency(), this.getEpochs(), this.getBatchSize(), program, clsMap, sec.getSparkContext().getConf(), server.getPort(), aSetup, aWorker, aUpdate, aIndex, aGrad, aRPC, aBatch, aEpoch, nbatches, modelAvg);
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long)tSetup.stop());
        }
        MatrixObject features = sec.getMatrixObject(this.getParam("features"));
        MatrixObject labels = sec.getMatrixObject(this.getParam("labels"));
        try {
            ParamservUtils.doPartitionOnSpark(sec, features, labels, this.getScheme(), workerNum).foreach((VoidFunction)worker);
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Paramserv function failed: ", e);
        }
        finally {
            server.close();
        }
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime(aSetup.value());
            Statistics.incWorkerNumber(aWorker.value());
            Statistics.accPSLocalModelUpdateTime(aUpdate.value());
            Statistics.accPSBatchIndexingTime(aIndex.value());
            Statistics.accPSGradientComputeTime(aGrad.value());
            Statistics.accPSRpcRequestTime(aRPC.value());
        }
        sec.setVariable(this.output.getName(), ps.getResult());
    }

    private void runLocally(ExecutionContext ec, Statement.PSModeType mode) {
        if (DMLScript.STATISTICS) {
            Statistics.getPSExecutionTimer().start();
        }
        Timing tSetup = DMLScript.STATISTICS ? new Timing(true) : null;
        int workerNum = this.getWorkerNum(mode);
        BasicThreadFactory factory = new BasicThreadFactory.Builder().namingPattern("workers-pool-thread-%d").build();
        ExecutorService es = Executors.newFixedThreadPool(workerNum, (ThreadFactory)factory);
        String updFunc = this.getParam("upd");
        String aggFunc = this.getParam("agg");
        LocalVariableMap newVarsMap = this.createVarsMap(ec);
        ExecutionContext newEC = ParamservUtils.createExecutionContext(ec, newVarsMap, updFunc, aggFunc, ParamservBuiltinCPInstruction.getParLevel(workerNum));
        List<ExecutionContext> workerECs = ParamservUtils.copyExecutionContext(newEC, workerNum);
        ExecutionContext aggServiceEC = ParamservUtils.copyExecutionContext(newEC, 1).get(0);
        Statement.PSFrequency freq = this.getFrequency();
        Statement.PSUpdateType updateType = this.getUpdateType();
        double rows_per_worker = Math.ceil((float)ec.getMatrixObject(this.getParam("features")).getNumRows() / (float)workerNum);
        int num_batches_per_epoch = (int)Math.ceil(rows_per_worker / (double)this.getBatchSize());
        int nbatches = this.getNbatches();
        ListObject model = ec.getListObject(this.getParam("model"));
        MatrixObject val_features = this.getParam("val_features") != null ? ec.getMatrixObject(this.getParam("val_features")) : null;
        MatrixObject val_labels = this.getParam("val_labels") != null ? ec.getMatrixObject(this.getParam("val_labels")) : null;
        boolean modelAvg = this.getModelAvg();
        ParamServer ps = ParamservBuiltinCPInstruction.createPS(mode, aggFunc, updateType, freq, workerNum, model, aggServiceEC, this.getValFunction(), num_batches_per_epoch, val_features, val_labels, nbatches, modelAvg);
        List<LocalPSWorker> workers = IntStream.range(0, workerNum).mapToObj(i -> new LocalPSWorker(i, updFunc, freq, this.getEpochs(), this.getBatchSize(), (ExecutionContext)workerECs.get(i), ps, nbatches, modelAvg)).collect(Collectors.toList());
        Statement.PSScheme scheme = this.getScheme();
        this.partitionLocally(scheme, ec, workers);
        if (DMLScript.STATISTICS) {
            Statistics.accPSSetupTime((long)tSetup.stop());
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug((Object)String.format("\nConfiguration of paramserv func: \nmode: %s \nworkerNum: %d \nupdate frequency: %s \nstrategy: %s \ndata partitioner: %s", new Object[]{mode, workerNum, freq, updateType, scheme}));
        }
        try {
            for (Future ret : es.invokeAll(workers)) {
                ret.get();
            }
            ec.setVariable(this.output.getName(), ps.getResult());
            if (DMLScript.STATISTICS) {
                Statistics.accPSExecutionTime((long)Statistics.getPSExecutionTimer().stop());
            }
        }
        catch (InterruptedException | ExecutionException e) {
            throw new DMLRuntimeException("ParamservBuiltinCPInstruction: some error occurred: ", e);
        }
        finally {
            es.shutdownNow();
        }
    }

    private LocalVariableMap createVarsMap(ExecutionContext ec) {
        LocalVariableMap varsMap = new LocalVariableMap();
        ListObject hyperParams = this.getHyperParams(ec);
        if (hyperParams != null) {
            varsMap.put("hyperparams", hyperParams);
        }
        return varsMap;
    }

    private Statement.PSModeType getPSMode() {
        Statement.PSModeType mode;
        if (!this.getParameterMap().containsKey("mode")) {
            return DEFAULT_MODE;
        }
        try {
            mode = Statement.PSModeType.valueOf(this.getParam("mode"));
        }
        catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support ps execution mode '%s'", this.getParam("mode")));
        }
        return mode;
    }

    private int getEpochs() {
        int epochs = Integer.valueOf(this.getParam("epochs"));
        if (epochs <= 0) {
            throw new DMLRuntimeException(String.format("Paramserv function: The argument '%s' could not be less than or equal to 0.", "epochs"));
        }
        return epochs;
    }

    private static int getParLevel(int workerNum) {
        return Math.max((int)Math.ceil((double)ParamservBuiltinCPInstruction.getRemainingCores() / (double)workerNum), 1);
    }

    private Statement.PSUpdateType getUpdateType() {
        Statement.PSUpdateType updType;
        if (!this.getParameterMap().containsKey("utype")) {
            return DEFAULT_TYPE;
        }
        try {
            updType = Statement.PSUpdateType.valueOf(this.getParam("utype"));
        }
        catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support update type '%s'.", this.getParam("utype")));
        }
        if (updType == Statement.PSUpdateType.SSP) {
            throw new DMLRuntimeException("Paramserv function: Not support update type SSP.");
        }
        return updType;
    }

    private Statement.PSFrequency getFrequency() {
        if (!this.getParameterMap().containsKey("freq")) {
            return DEFAULT_UPDATE_FREQUENCY;
        }
        try {
            return Statement.PSFrequency.valueOf(this.getParam("freq"));
        }
        catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support '%s' update frequency.", this.getParam("freq")));
        }
    }

    private Statement.PSRuntimeBalancing getRuntimeBalancing() {
        if (!this.getParameterMap().containsKey("runtime_balancing")) {
            return DEFAULT_RUNTIME_BALANCING;
        }
        try {
            return Statement.PSRuntimeBalancing.valueOf(this.getParam("runtime_balancing"));
        }
        catch (IllegalArgumentException e) {
            throw new DMLRuntimeException(String.format("Paramserv function: not support '%s' runtime balancing.", this.getParam("runtime_balancing")));
        }
    }

    private static int getRemainingCores() {
        return InfrastructureAnalyzer.getLocalParallelism();
    }

    private int getWorkerNum(Statement.PSModeType mode) {
        switch (mode) {
            case LOCAL: {
                return this.getParameterMap().containsKey("k") ? Integer.valueOf(this.getParam("k")) : ParamservBuiltinCPInstruction.getRemainingCores();
            }
            case REMOTE_SPARK: {
                return this.getParameterMap().containsKey("k") ? Integer.valueOf(this.getParam("k")) : SparkExecutionContext.getDefaultParallelism(true);
            }
        }
        throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
    }

    private static ParamServer createPS(Statement.PSModeType mode, String aggFunc, Statement.PSUpdateType updateType, Statement.PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, int nbatches, boolean modelAvg) {
        return ParamservBuiltinCPInstruction.createPS(mode, aggFunc, updateType, freq, workerNum, model, ec, null, -1, null, null, nbatches, modelAvg);
    }

    private static ParamServer createPS(Statement.PSModeType mode, String aggFunc, Statement.PSUpdateType updateType, Statement.PSFrequency freq, int workerNum, ListObject model, ExecutionContext ec, String valFunc, int numBatchesPerEpoch, MatrixObject valFeatures, MatrixObject valLabels, int nbatches, boolean modelAvg) {
        switch (mode) {
            case LOCAL: 
            case REMOTE_SPARK: 
            case FEDERATED: {
                return LocalParamServer.create(model, aggFunc, updateType, freq, ec, workerNum, valFunc, numBatchesPerEpoch, valFeatures, valLabels, nbatches, modelAvg);
            }
        }
        throw new DMLRuntimeException("Unsupported parameter server: " + mode.name());
    }

    private long getBatchSize() {
        if (!this.getParameterMap().containsKey("batchsize")) {
            return 64L;
        }
        long batchSize = Integer.valueOf(this.getParam("batchsize")).intValue();
        if (batchSize <= 0L) {
            throw new DMLRuntimeException(String.format("Paramserv function: the number of argument '%s' could not be less than or equal to 0.", "batchsize"));
        }
        return batchSize;
    }

    private ListObject getHyperParams(ExecutionContext ec) {
        ListObject hyperparams = null;
        if (this.getParameterMap().containsKey("hyperparams")) {
            hyperparams = ec.getListObject(this.getParam("hyperparams"));
        }
        return hyperparams;
    }

    private void partitionLocally(Statement.PSScheme scheme, ExecutionContext ec, List<LocalPSWorker> workers) {
        MatrixObject features = ec.getMatrixObject(this.getParam("features"));
        MatrixObject labels = ec.getMatrixObject(this.getParam("labels"));
        DataPartitionLocalScheme.Result result = new LocalDataPartitioner(scheme).doPartitioning(workers.size(), (MatrixBlock)features.acquireReadAndRelease(), (MatrixBlock)labels.acquireReadAndRelease());
        List<MatrixObject> pfs = result.pFeatures;
        List<MatrixObject> pls = result.pLabels;
        if (pfs.size() < workers.size()) {
            if (LOG.isWarnEnabled()) {
                LOG.warn((Object)String.format("There is only %d batches of data but has %d workers. Hence, reset the number of workers with %d.", pfs.size(), workers.size(), pfs.size()));
            }
            workers = workers.subList(0, pfs.size());
        }
        for (int i = 0; i < workers.size(); ++i) {
            workers.get(i).setFeatures(pfs.get(i));
            workers.get(i).setLabels(pls.get(i));
        }
    }

    private Statement.PSScheme getScheme() {
        Statement.PSScheme scheme = DEFAULT_SCHEME;
        if (this.getParameterMap().containsKey("scheme")) {
            try {
                scheme = Statement.PSScheme.valueOf(this.getParam("scheme"));
            }
            catch (IllegalArgumentException e) {
                throw new DMLRuntimeException(String.format("Paramserv function: not support data partition scheme '%s'", this.getParam("scheme")));
            }
        }
        return scheme;
    }

    private Statement.FederatedPSScheme getFederatedScheme() {
        Statement.FederatedPSScheme federated_scheme = DEFAULT_FEDERATED_SCHEME;
        if (this.getParameterMap().containsKey("scheme")) {
            try {
                federated_scheme = Statement.FederatedPSScheme.valueOf(this.getParam("scheme"));
            }
            catch (IllegalArgumentException e) {
                throw new DMLRuntimeException(String.format("Paramserv function in federated mode: not support data partition scheme '%s'", this.getParam("scheme")));
            }
        }
        return federated_scheme;
    }

    private int getNumBatchesPerEpoch(Statement.PSRuntimeBalancing runtimeBalancing, DataPartitionFederatedScheme.BalanceMetrics balanceMetrics) {
        int numBatchesPerEpoch = runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MIN || runtimeBalancing == Statement.PSRuntimeBalancing.BASELINE ? (int)Math.ceil((float)balanceMetrics._minRows / (float)this.getBatchSize()) : (runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_AVG || runtimeBalancing == Statement.PSRuntimeBalancing.SCALE_BATCH ? (int)Math.ceil((float)balanceMetrics._avgRows / (float)this.getBatchSize()) : (runtimeBalancing == Statement.PSRuntimeBalancing.CYCLE_MAX ? (int)Math.ceil((float)balanceMetrics._maxRows / (float)this.getBatchSize()) : (int)Math.ceil((float)balanceMetrics._avgRows / (float)this.getBatchSize())));
        return numBatchesPerEpoch;
    }

    private boolean getWeighting() {
        return this.getParameterMap().containsKey("weighting") && Boolean.parseBoolean(this.getParam("weighting"));
    }

    private String getValFunction() {
        if (this.getParameterMap().containsKey("val")) {
            return this.getParam("val");
        }
        return null;
    }

    private int getSeed() {
        return this.getParameterMap().containsKey("seed") ? Integer.parseInt(this.getParam("seed")) : (int)System.currentTimeMillis();
    }

    private boolean getModelAvg() {
        if (!this.getParameterMap().containsKey("modelAvg")) {
            return DEFAULT_MODELAVG;
        }
        return Boolean.parseBoolean(this.getParam("modelAvg"));
    }

    private int getNbatches() {
        if (!this.getParameterMap().containsKey("nbatches")) {
            return 1;
        }
        return Integer.parseInt(this.getParam("nbatches"));
    }
}

