/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.controlprogram.paramserv.dp;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.tugraz.sysds.parser.Statement;
import org.tugraz.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.tugraz.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.DCSparkScheme;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.DRRSparkScheme;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.DRSparkScheme;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.DataPartitionSparkScheme;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.ORSparkScheme;
import org.tugraz.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.util.DataConverter;

public class SparkDataPartitioner
implements Serializable {
    private static final long serialVersionUID = 6841548626711057448L;
    private DataPartitionSparkScheme _scheme;

    protected SparkDataPartitioner(Statement.PSScheme scheme, SparkExecutionContext sec, int numEntries, int numWorkers) {
        switch (scheme) {
            case DISJOINT_CONTIGUOUS: {
                this._scheme = new DCSparkScheme();
                this.createDCIndicator(sec, numWorkers, numEntries);
                break;
            }
            case DISJOINT_ROUND_ROBIN: {
                this._scheme = new DRRSparkScheme();
                this.createDRIndicator(sec, numWorkers, numEntries);
                break;
            }
            case DISJOINT_RANDOM: {
                this._scheme = new DRSparkScheme();
                this.createGlobalPermutations(sec, numEntries, 1);
                this.createDCIndicator(sec, numWorkers, numEntries);
                break;
            }
            case OVERLAP_RESHUFFLE: {
                this._scheme = new ORSparkScheme();
                this.createGlobalPermutations(sec, numEntries, numWorkers);
            }
        }
    }

    private void createDRIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
        double[] vector = IntStream.range(0, numEntries).mapToDouble(n -> n % numWorkers).toArray();
        MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
        this._scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
    }

    private void createDCIndicator(SparkExecutionContext sec, int numWorkers, int numEntries) {
        double[] vector = new double[numEntries];
        int batchSize = (int)Math.ceil((double)numEntries / (double)numWorkers);
        for (int i = 1; i < numWorkers; ++i) {
            int begin = batchSize * i;
            int end = Math.min(begin + batchSize, numEntries);
            Arrays.fill(vector, begin, end, (double)i);
        }
        MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
        this._scheme.setWorkerIndicator(sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB)));
    }

    private void createGlobalPermutations(SparkExecutionContext sec, int numEntries, int numPerm) {
        List<PartitionedBroadcast<MatrixBlock>> perms = IntStream.range(0, numPerm).mapToObj(i -> {
            MatrixBlock perm = MatrixBlock.sampleOperations(numEntries, numEntries, false, ParamservUtils.SEED + (long)i);
            double[] vector = new double[numEntries];
            for (int j = 0; j < perm.getDenseBlockValues().length; ++j) {
                vector[(int)perm.getDenseBlockValues()[j] - 1] = j;
            }
            MatrixBlock vectorMB = DataConverter.convertToMatrixBlock(vector, true);
            return sec.getBroadcastForMatrixObject(ParamservUtils.newMatrixObject(vectorMB));
        }).collect(Collectors.toList());
        this._scheme.setGlobalPermutation(perms);
    }

    public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, MatrixBlock features, MatrixBlock labels, long rowID) {
        return this._scheme.doPartitioning(numWorkers, (int)rowID, features, labels);
    }
}

