/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.MutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.caching.MatrixObject;
import org.tugraz.sysds.runtime.controlprogram.context.ExecutionContext;
import org.tugraz.sysds.runtime.controlprogram.federated.FederatedData;
import org.tugraz.sysds.runtime.controlprogram.federated.FederatedRange;
import org.tugraz.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.tugraz.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.tugraz.sysds.runtime.functionobjects.Multiply;
import org.tugraz.sysds.runtime.functionobjects.Plus;
import org.tugraz.sysds.runtime.instructions.InstructionUtils;
import org.tugraz.sysds.runtime.instructions.cp.CPOperand;
import org.tugraz.sysds.runtime.instructions.fed.BinaryFEDInstruction;
import org.tugraz.sysds.runtime.instructions.fed.FEDInstruction;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.tugraz.sysds.runtime.matrix.operators.Operator;
import org.tugraz.sysds.runtime.util.CommonThreadPool;

public class AggregateBinaryFEDInstruction
extends BinaryFEDInstruction {
    public AggregateBinaryFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.AggregateBinary, op, in1, in2, out, opcode, istr);
    }

    public static AggregateBinaryFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (!opcode.equalsIgnoreCase("ba+*")) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + opcode);
        }
        InstructionUtils.checkNumFields(parts, 4);
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand out = new CPOperand(parts[3]);
        int k = Integer.parseInt(parts[4]);
        return new AggregateBinaryFEDInstruction(InstructionUtils.getMatMultOperator(k), in1, in2, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        MatrixObject mo1 = ec.getMatrixObject(this.input1.getName());
        MatrixObject mo2 = ec.getMatrixObject(this.input2.getName());
        MatrixObject out = ec.getMatrixObject(this.output.getName());
        AggregateBinaryOperator ab_op = (AggregateBinaryOperator)this._optr;
        if (mo1.isFederated() && mo2.getNumColumns() == 1L) {
            MatrixBlock vector = mo2.acquireRead();
            AggregateBinaryFEDInstruction.federatedAggregateBinaryMV(mo1, vector, out, ab_op, true);
            mo2.release();
        } else if (mo2.isFederated() && mo1.getNumRows() == 1L) {
            MatrixBlock vector = mo1.acquireRead();
            AggregateBinaryFEDInstruction.federatedAggregateBinaryMV(mo2, vector, out, ab_op, false);
            mo1.release();
        } else {
            AggregateBinaryFEDInstruction.federatedAggregateBinary(mo1, mo2, out);
        }
    }

    private static void federatedAggregateBinary(MatrixObject mo1, MatrixObject mo2, MatrixObject out) {
        boolean distributeCols = false;
        if (mo1.isFederated() && mo2.isFederated()) {
            distributeCols = mo2.getNumColumns() * mo2.getNumRows() < mo1.getNumColumns() * mo1.getNumRows();
        } else if (mo2.isFederated() && !mo1.isFederated()) {
            distributeCols = true;
        }
        Map<FederatedRange, FederatedData> mapping = distributeCols ? mo1.getFedMapping() : mo2.getFedMapping();
        MatrixBlock matrixBlock = distributeCols ? mo2.acquireRead() : mo1.acquireRead();
        ExecutorService pool = CommonThreadPool.get(mapping.size());
        ArrayList<Pair<FederatedRange, MatrixBlock>> results = new ArrayList<Pair<FederatedRange, MatrixBlock>>();
        ArrayList<FederatedMMTask> tasks = new ArrayList<FederatedMMTask>();
        for (Map.Entry<FederatedRange, FederatedData> fedMap : mapping.entrySet()) {
            MutablePair resultPair = new MutablePair();
            tasks.add(new FederatedMMTask(fedMap.getKey(), fedMap.getValue(), (MutablePair<FederatedRange, MatrixBlock>)resultPair, matrixBlock, distributeCols));
            results.add((Pair<FederatedRange, MatrixBlock>)resultPair);
        }
        CommonThreadPool.invokeAndShutdown(pool, tasks);
        (distributeCols ? mo2 : mo1).release();
        if (mo1.getNumRows() > Integer.MAX_VALUE || mo2.getNumColumns() > Integer.MAX_VALUE) {
            throw new DMLRuntimeException("Federated matrix is too large for federated distribution");
        }
        out.acquireModify(AggregateBinaryFEDInstruction.combinePartialMMResults(results, (int)mo1.getNumRows(), (int)mo2.getNumColumns()));
        out.release();
    }

    private static MatrixBlock combinePartialMMResults(ArrayList<Pair<FederatedRange, MatrixBlock>> results, int rows, int cols) {
        MatrixBlock resultBlock = new MatrixBlock(rows, cols, false);
        for (Pair<FederatedRange, MatrixBlock> partialResult : results) {
            FederatedRange range = (FederatedRange)partialResult.getLeft();
            MatrixBlock partialBlock = (MatrixBlock)partialResult.getRight();
            int[] dimsLower = range.getBeginDimsInt();
            int[] dimsUpper = range.getEndDimsInt();
            resultBlock.copy(dimsLower[0], dimsUpper[0] - 1, dimsLower[1], dimsUpper[1] - 1, partialBlock, false);
        }
        resultBlock.recomputeNonZeros();
        return resultBlock;
    }

    public static void federatedAggregateBinaryMV(MatrixObject fedMo, MatrixBlock vector, MatrixObject output, AggregateBinaryOperator op, boolean matrixVectorOp) {
        FederatedRange range;
        MatrixBlock resultBlock;
        if (!(op.binaryFn instanceof Multiply) || !(op.aggOp.increOp.fn instanceof Plus)) {
            throw new DMLRuntimeException("Only matrix-vector is supported for federated binary aggregation");
        }
        if (!matrixVectorOp) {
            output.getDataCharacteristics().setRows(1L).setCols(fedMo.getNumColumns());
            resultBlock = new MatrixBlock(1, (int)fedMo.getNumColumns(), false);
        } else {
            output.getDataCharacteristics().setRows(fedMo.getNumRows()).setCols(1L);
            resultBlock = new MatrixBlock((int)fedMo.getNumRows(), 1, false);
        }
        ArrayList<ImmutablePair> idResponsePairs = new ArrayList<ImmutablePair>();
        for (Map.Entry<FederatedRange, FederatedData> entry : fedMo.getFedMapping().entrySet()) {
            range = entry.getKey();
            FederatedData fedData = entry.getValue();
            Future<FederatedResponse> future = AggregateBinaryFEDInstruction.executeMVMultiply(range, fedData, vector, matrixVectorOp);
            idResponsePairs.add(new ImmutablePair((Object)range, future));
        }
        try {
            for (Pair pair : idResponsePairs) {
                range = (FederatedRange)pair.getLeft();
                FederatedResponse federatedResponse = (FederatedResponse)((Future)pair.getRight()).get();
                AggregateBinaryFEDInstruction.combinePartialMVResults(range, federatedResponse, resultBlock, matrixVectorOp);
            }
        }
        catch (Exception e) {
            throw new DMLRuntimeException("Federated binary aggregation failed", e);
        }
        long nnz = resultBlock.recomputeNonZeros();
        output.acquireModify(resultBlock);
        output.getDataCharacteristics().setNonZeros(nnz);
        output.release();
    }

    private static void combinePartialMVResults(FederatedRange range, FederatedResponse federatedResponse, MatrixBlock resultBlock, boolean matrixVectorOp) {
        int[] beginDims = range.getBeginDimsInt();
        MatrixBlock mb = (MatrixBlock)federatedResponse.getData();
        for (int r = 0; r < mb.getNumRows(); ++r) {
            for (int c = 0; c < mb.getNumColumns(); ++c) {
                int resultRow = r + (!matrixVectorOp ? 0 : beginDims[0]);
                int resultColumn = c + (!matrixVectorOp ? beginDims[1] : 0);
                resultBlock.quickSetValue(resultRow, resultColumn, resultBlock.quickGetValue(resultRow, resultColumn) + mb.quickGetValue(r, c));
            }
        }
    }

    private static Future<FederatedResponse> executeMVMultiply(FederatedRange range, FederatedData fedData, MatrixBlock vector, boolean matrixVectorOp) {
        MatrixBlock vectorSlice;
        if (!fedData.isInitialized()) {
            throw new DMLRuntimeException("Not all FederatedData was initialized for federated matrix");
        }
        int[] beginDimsInt = range.getBeginDimsInt();
        int[] endDimsInt = range.getEndDimsInt();
        ArrayList<Object> params = new ArrayList<Object>();
        if (!matrixVectorOp) {
            int length = endDimsInt[0] - beginDimsInt[0];
            if (vector.getNumColumns() == length) {
                vectorSlice = vector;
            } else {
                vectorSlice = new MatrixBlock(1, length, false);
                vector.slice(0, 0, beginDimsInt[0], endDimsInt[0] - 1, vectorSlice);
            }
        } else {
            int length = endDimsInt[1] - beginDimsInt[1];
            if (vector.getNumRows() == length) {
                vectorSlice = vector;
            } else {
                vectorSlice = new MatrixBlock(length, 1, false);
                vector.slice(beginDimsInt[1], endDimsInt[1] - 1, 0, 0, vectorSlice);
            }
        }
        params.add(vectorSlice);
        params.add(matrixVectorOp);
        return fedData.executeFederatedOperation(new FederatedRequest(FederatedRequest.FedMethod.MATVECMULT, params), true);
    }

    private static class FederatedMMTask
    implements Callable<Void> {
        private FederatedRange _range;
        private FederatedData _data;
        private MutablePair<FederatedRange, MatrixBlock> _result;
        private MatrixBlock _otherMatrix;
        private boolean _distributeCols;

        public FederatedMMTask(FederatedRange range, FederatedData fedData, MutablePair<FederatedRange, MatrixBlock> result, MatrixBlock otherMatrix, boolean distributeCols) {
            this._range = range;
            this._data = fedData;
            this._result = result;
            this._otherMatrix = otherMatrix;
            this._distributeCols = distributeCols;
        }

        @Override
        public Void call() throws Exception {
            if (this._distributeCols) {
                this.executeColWiseMVMultiplication();
            } else {
                this.executeRowWiseVMMultiplications();
            }
            return null;
        }

        private void executeRowWiseVMMultiplications() throws InterruptedException, ExecutionException {
            int[] beginDims = this._range.getBeginDimsInt();
            int[] endDims = this._range.getEndDimsInt();
            int rowsBeginOtherBlock = 0;
            int colsBeginOtherBlock = beginDims[0];
            int rowsEndOtherBlock = this._otherMatrix.getNumRows();
            int colsEndOtherBlock = endDims[0];
            MatrixBlock result = new MatrixBlock(rowsEndOtherBlock - rowsBeginOtherBlock, endDims[1] - beginDims[1], false);
            this._result.setLeft((Object)new FederatedRange(new long[]{rowsBeginOtherBlock, beginDims[1]}, new long[]{rowsEndOtherBlock, endDims[1]}));
            MatrixBlock vec = new MatrixBlock(1, colsEndOtherBlock - colsBeginOtherBlock, false);
            for (int r = rowsBeginOtherBlock; r < rowsEndOtherBlock; ++r) {
                this._otherMatrix.slice(r, r, colsBeginOtherBlock, colsEndOtherBlock - 1, vec);
                FederatedResponse response = (FederatedResponse)AggregateBinaryFEDInstruction.executeMVMultiply(this._range, this._data, vec, this._distributeCols).get();
                if (!response.isSuccessful()) {
                    throw new DMLRuntimeException("Federated Matrix-Matrix Multiplication failed: " + response.getErrorMessage());
                }
                result.copy(r, r, 0, endDims[1] - beginDims[1] - 1, (MatrixBlock)response.getData(), true);
            }
            this._result.setRight((Object)result);
        }

        private void executeColWiseMVMultiplication() throws InterruptedException, ExecutionException {
            int[] beginDims = this._range.getBeginDimsInt();
            int[] endDims = this._range.getEndDimsInt();
            int rowsBeginOtherBlock = beginDims[1];
            int colsBeginOtherBlock = 0;
            int rowsEndOtherBlock = endDims[1];
            int colsEndOtherBlock = this._otherMatrix.getNumColumns();
            MatrixBlock result = new MatrixBlock(endDims[0] - beginDims[0], colsEndOtherBlock - colsBeginOtherBlock, false);
            this._result.setLeft((Object)new FederatedRange(new long[]{beginDims[0], colsBeginOtherBlock}, new long[]{endDims[0], colsEndOtherBlock}));
            MatrixBlock vec = new MatrixBlock(rowsEndOtherBlock - rowsBeginOtherBlock, 1, false);
            for (int c = colsBeginOtherBlock; c < colsEndOtherBlock; ++c) {
                this._otherMatrix.slice(rowsBeginOtherBlock, rowsEndOtherBlock - 1, c, c, vec);
                FederatedResponse response = (FederatedResponse)AggregateBinaryFEDInstruction.executeMVMultiply(this._range, this._data, vec, this._distributeCols).get();
                if (!response.isSuccessful()) {
                    throw new DMLRuntimeException("Federated Matrix-Matrix Multiplication failed: " + response.getErrorMessage());
                }
                result.copy(0, endDims[0] - beginDims[0] - 1, c, c, (MatrixBlock)response.getData(), true);
            }
            this._result.setRight((Object)result);
        }
    }
}

