/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.Arrays;
import java.util.stream.Collectors;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.BooleanObject;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.UnaryFEDInstruction;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class ReshapeFEDInstruction
extends UnaryFEDInstruction {
    private final CPOperand _opRows;
    private final CPOperand _opCols;
    private final CPOperand _opDims;
    private final CPOperand _opByRow;

    private ReshapeFEDInstruction(Operator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand out, String opcode, String istr) {
        super(FEDInstruction.FEDType.Reshape, op, in1, out, opcode, istr);
        this._opRows = in2;
        this._opCols = in3;
        this._opDims = in4;
        this._opByRow = in5;
    }

    public static ReshapeFEDInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 6, 7);
        String opcode = parts[0];
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand in2 = new CPOperand(parts[2]);
        CPOperand in3 = new CPOperand(parts[3]);
        CPOperand in4 = new CPOperand(parts[4]);
        CPOperand in5 = new CPOperand(parts[5]);
        CPOperand out = new CPOperand(parts[6]);
        if (!opcode.equalsIgnoreCase("rshape")) {
            throw new DMLRuntimeException("Unknown opcode while parsing an ReshapeInstruction: " + str);
        }
        return new ReshapeFEDInstruction(new Operator(true), in1, in2, in3, in4, in5, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        FederationMap reshapedFedMap;
        FederatedRequest[] fr1;
        int cols;
        int rows;
        MatrixObject mo1;
        if (this.output.getDataType() == Types.DataType.MATRIX) {
            mo1 = ec.getMatrixObject(this.input1);
            BooleanObject byRow = (BooleanObject)ec.getScalarInput(this._opByRow.getName(), Types.ValueType.BOOLEAN, this._opByRow.isLiteral());
            rows = (int)ec.getScalarInput(this._opRows).getLongValue();
            cols = (int)ec.getScalarInput(this._opCols).getLongValue();
            if (!mo1.isFederated()) {
                throw new DMLRuntimeException("Federated Rshape: Federated input expected, but invoked w/ " + mo1.isFederated());
            }
            if (mo1.getNumColumns() * mo1.getNumRows() != (long)(rows * cols)) {
                throw new DMLRuntimeException("Reshape matrix requires consistent numbers of input/output cells (" + mo1.getNumRows() + ":" + mo1.getNumColumns() + ", " + rows + ":" + cols + ").");
            }
            boolean isNotAligned = Arrays.stream(mo1.getFedMapping().getFederatedRanges()).map(e -> e.getSize() % (long)(byRow.getBooleanValue() ? cols : rows) == 0L).collect(Collectors.toList()).contains(false);
            if (isNotAligned) {
                throw new DMLRuntimeException("Reshape matrix requires consistent numbers of input/output cells for each worker.");
            }
            String[] newInstString = ReshapeFEDInstruction.getNewInstString(mo1, this.instString, rows, cols, byRow.getBooleanValue());
            long id = FederationUtils.getNextFedDataID();
            FederatedRequest tmp = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new Object[]{mo1.getMetaData().getDataCharacteristics(), mo1.getDataType()});
            fr1 = FederationUtils.callInstruction(newInstString, this.output, id, new CPOperand[]{this.input1}, new long[]{mo1.getFedMapping().getID()}, InstructionUtils.getExecType(this.instString));
            mo1.getFedMapping().execute(this.getTID(), true, tmp);
            mo1.getFedMapping().execute(this.getTID(), true, fr1, new FederatedRequest[0]);
            reshapedFedMap = mo1.getFedMapping();
            for (int i = 0; i < reshapedFedMap.getFederatedRanges().length; ++i) {
                long cells = reshapedFedMap.getFederatedRanges()[i].getSize();
                long row = byRow.getBooleanValue() ? cells / (long)cols : (long)rows;
                long col = byRow.getBooleanValue() ? (long)cols : cells / (long)rows;
                reshapedFedMap.getFederatedRanges()[i].setBeginDim(0, reshapedFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0L || i == 0 ? 0L : reshapedFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
                reshapedFedMap.getFederatedRanges()[i].setEndDim(0, reshapedFedMap.getFederatedRanges()[i].getBeginDims()[0] + row);
                reshapedFedMap.getFederatedRanges()[i].setBeginDim(1, reshapedFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0L || i == 0 ? 0L : reshapedFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
                reshapedFedMap.getFederatedRanges()[i].setEndDim(1, reshapedFedMap.getFederatedRanges()[i].getBeginDims()[1] + col);
            }
        } else {
            throw new DMLRuntimeException("Federated Reshape Instruction only supports matrix as output.");
        }
        MatrixObject out = ec.getMatrixObject(this.output);
        out.getDataCharacteristics().set(rows, cols, (int)mo1.getBlocksize(), mo1.getNnz());
        out.setFedMapping(reshapedFedMap.copyWithNewID(fr1[0].getID()));
    }

    private static String[] getNewInstString(MatrixObject mo1, String instString, int rows, int cols, boolean byRow) {
        Object[] instStrings = new String[mo1.getFedMapping().getSize()];
        int sameFedSize = Arrays.stream(mo1.getFedMapping().getFederatedRanges()).map(FederatedRange::getSize).collect(Collectors.toSet()).size();
        sameFedSize = sameFedSize == 1 ? 1 : mo1.getFedMapping().getSize();
        for (int i = 0; i < sameFedSize; ++i) {
            String[] instParts = instString.split("\u00b0");
            long size = mo1.getFedMapping().getFederatedRanges()[i].getSize();
            String oldInstStringPart = byRow ? instParts[3] : instParts[4];
            String newInstStringPart = byRow ? oldInstStringPart.replace(String.valueOf(rows), String.valueOf(size / (long)cols)) : oldInstStringPart.replace(String.valueOf(cols), String.valueOf(size / (long)rows));
            instStrings[i] = instString.replace(oldInstStringPart, newInstStringPart);
        }
        if (sameFedSize == 1) {
            Arrays.fill(instStrings, instStrings[0]);
        }
        return instStrings;
    }

    @Override
    public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
        return Pair.of((Object)this.output.getName(), (Object)new LineageItem(this.getOpcode(), LineageItemUtils.getLineage(ec, this.input1, this._opRows, this._opCols, this._opDims, this._opByRow)));
    }
}

