/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.tugraz.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.tugraz.sysds.runtime.controlprogram.paramserv.dp.DataPartitionSparkScheme;
import org.tugraz.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import scala.Tuple2;

public class DRSparkScheme
extends DataPartitionSparkScheme {
    private static final long serialVersionUID = -7655310624144544544L;

    protected DRSparkScheme() {
    }

    @Override
    public DataPartitionSparkScheme.Result doPartitioning(int numWorkers, int rblkID, MatrixBlock features, MatrixBlock labels) {
        List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pfs = this.partition(rblkID, features);
        List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> pls = this.partition(rblkID, labels);
        return new DataPartitionSparkScheme.Result(pfs, pls);
    }

    private List<Tuple2<Integer, Tuple2<Long, MatrixBlock>>> partition(int rblkID, MatrixBlock mb) {
        MatrixBlock partialPerm = (MatrixBlock)((PartitionedBroadcast)this._globalPerms.get(0)).getBlock(rblkID, 1);
        return IntStream.range(0, mb.getNumRows()).mapToObj(r -> {
            MatrixBlock rowMB = ParamservUtils.sliceMatrixBlock(mb, r + 1, r + 1);
            long shiftedPosition = (long)partialPerm.getValue(r, 0);
            int shiftedBlkID = (int)(shiftedPosition / 1000L + 1L);
            MatrixBlock indicator = (MatrixBlock)this._workerIndicator.getBlock(shiftedBlkID, 1);
            int workerID = (int)indicator.getValue((int)shiftedPosition / 1000, 0);
            return new Tuple2((Object)workerID, (Object)new Tuple2((Object)shiftedPosition, (Object)rowMB));
        }).collect(Collectors.toList());
    }
}

