/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.controlprogram.federated;

import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.Map;
import org.apache.commons.lang.exception.ExceptionUtils;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.wink.json4j.JSONObject;
import org.tugraz.sysds.common.Types;
import org.tugraz.sysds.conf.ConfigurationManager;
import org.tugraz.sysds.hops.OptimizerUtils;
import org.tugraz.sysds.parser.DataExpression;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.caching.CacheableData;
import org.tugraz.sysds.runtime.controlprogram.caching.MatrixObject;
import org.tugraz.sysds.runtime.controlprogram.caching.TensorObject;
import org.tugraz.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.tugraz.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.tugraz.sysds.runtime.controlprogram.parfor.util.IDSequence;
import org.tugraz.sysds.runtime.functionobjects.Multiply;
import org.tugraz.sysds.runtime.functionobjects.Plus;
import org.tugraz.sysds.runtime.instructions.cp.Data;
import org.tugraz.sysds.runtime.instructions.cp.ListObject;
import org.tugraz.sysds.runtime.io.IOUtilFunctions;
import org.tugraz.sysds.runtime.matrix.data.InputInfo;
import org.tugraz.sysds.runtime.matrix.data.LibMatrixAgg;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.matrix.data.OutputInfo;
import org.tugraz.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.tugraz.sysds.runtime.matrix.operators.AggregateOperator;
import org.tugraz.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.tugraz.sysds.runtime.matrix.operators.ScalarOperator;
import org.tugraz.sysds.runtime.meta.MatrixCharacteristics;
import org.tugraz.sysds.runtime.meta.MetaDataFormat;
import org.tugraz.sysds.utils.JSONHelper;

public class FederatedWorkerHandler
extends ChannelInboundHandlerAdapter {
    protected static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
    private final IDSequence _seq;
    private Map<Long, CacheableData<?>> _vars;

    public FederatedWorkerHandler(IDSequence seq, Map<Long, CacheableData<?>> _vars2) {
        this._seq = seq;
        this._vars = _vars2;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        log.debug((Object)("Received: " + msg.getClass().getSimpleName()));
        if (!(msg instanceof FederatedRequest)) {
            throw new DMLRuntimeException("FederatedWorkerHandler: Received object no instance of `FederatedRequest`.");
        }
        FederatedRequest request = (FederatedRequest)msg;
        FederatedRequest.FedMethod method = request.getMethod();
        log.debug((Object)("Received command: " + method.name()));
        IDSequence iDSequence = this._seq;
        synchronized (iDSequence) {
            FederatedResponse response = this.constructResponse(request);
            if (!response.isSuccessful()) {
                log.error((Object)("Method " + (Object)((Object)method) + " failed: " + response.getErrorMessage()));
            }
            ctx.writeAndFlush((Object)response).addListener((GenericFutureListener)new CloseListener());
        }
    }

    private FederatedResponse constructResponse(FederatedRequest request) {
        FederatedRequest.FedMethod method = request.getMethod();
        try {
            switch (method) {
                case READ: {
                    return this.readMatrix(request);
                }
                case MATVECMULT: {
                    return this.executeMatVecMult(request);
                }
                case TRANSFER: {
                    return this.getVariableData(request);
                }
                case AGGREGATE: {
                    return this.executeAggregation(request);
                }
                case SCALAR: {
                    return this.executeScalarOperation(request);
                }
            }
            String message = String.format("Method %s is not supported.", new Object[]{method});
            return new FederatedResponse(FederatedResponse.Type.ERROR, message);
        }
        catch (Exception exception) {
            return new FederatedResponse(FederatedResponse.Type.ERROR, ExceptionUtils.getFullStackTrace((Throwable)exception));
        }
    }

    private FederatedResponse readMatrix(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 1);
        String filename = (String)request.getParam(0);
        return this.readMatrix(filename);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private FederatedResponse readMatrix(String filename) {
        MatrixCharacteristics mc = new MatrixCharacteristics();
        mc.setBlocksize(ConfigurationManager.getBlocksize());
        MatrixObject mo = new MatrixObject(Types.ValueType.FP64, filename);
        OutputInfo oi = null;
        InputInfo ii = null;
        try {
            String mtdname = DataExpression.getMTDFileName(filename);
            Path path = new Path(mtdname);
            try (FileSystem fs = IOUtilFunctions.getFileSystem(mtdname);
                 BufferedReader br = new BufferedReader(new InputStreamReader((InputStream)fs.open(path)));){
                JSONObject mtd = JSONHelper.parse(br);
                if (mtd == null) {
                    FederatedResponse federatedResponse = new FederatedResponse(FederatedResponse.Type.ERROR, "Could not parse metadata file");
                    return federatedResponse;
                }
                mc.setRows(mtd.getLong("rows"));
                mc.setCols(mtd.getLong("cols"));
                String format = mtd.getString("format");
                oi = OutputInfo.outputInfoFromStringExternal(format);
                ii = OutputInfo.getMatchingInputInfo(oi);
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        MetaDataFormat mdf = new MetaDataFormat(mc, oi, ii);
        mo.setMetaData(mdf);
        mo.acquireRead();
        mo.refreshMetaData();
        mo.release();
        long id = this._seq.getNextID();
        this._vars.put(id, mo);
        return new FederatedResponse(FederatedResponse.Type.SUCCESS, id);
    }

    private FederatedResponse executeMatVecMult(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 3);
        MatrixBlock vector = (MatrixBlock)request.getParam(0);
        boolean isMatVecMult = (Boolean)request.getParam(1);
        long varID = (Long)request.getParam(2);
        return this.executeMatVecMult(varID, vector, isMatVecMult);
    }

    private FederatedResponse executeMatVecMult(long varID, MatrixBlock vector, boolean isMatVecMult) {
        MatrixObject matTo = (MatrixObject)this._vars.get(varID);
        MatrixBlock matBlock1 = (MatrixBlock)matTo.acquireReadAndRelease();
        AggregateBinaryOperator ab_op = new AggregateBinaryOperator(Multiply.getMultiplyFnObject(), new AggregateOperator(0.0, Plus.getPlusFnObject()));
        MatrixBlock result = isMatVecMult ? matBlock1.aggregateBinaryOperations(matBlock1, vector, new MatrixBlock(), ab_op) : vector.aggregateBinaryOperations(vector, matBlock1, new MatrixBlock(), ab_op);
        return new FederatedResponse(FederatedResponse.Type.SUCCESS, result);
    }

    private FederatedResponse getVariableData(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 1);
        long varID = (Long)request.getParam(0);
        return this.getVariableData(varID);
    }

    private FederatedResponse getVariableData(long varID) {
        Data dataObject = this._vars.get(varID);
        switch (dataObject.getDataType()) {
            case TENSOR: {
                return new FederatedResponse(FederatedResponse.Type.SUCCESS, ((TensorObject)dataObject).acquireReadAndRelease());
            }
            case MATRIX: {
                return new FederatedResponse(FederatedResponse.Type.SUCCESS, ((MatrixObject)dataObject).acquireReadAndRelease());
            }
            case LIST: {
                return new FederatedResponse(FederatedResponse.Type.SUCCESS, ((ListObject)dataObject).getData());
            }
        }
        return new FederatedResponse(FederatedResponse.Type.ERROR, "FederatedWorkerHandler: Not possible to send datatype " + dataObject.getDataType().name());
    }

    private FederatedResponse executeAggregation(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 2);
        AggregateUnaryOperator operator = (AggregateUnaryOperator)request.getParam(0);
        long varID = (Long)request.getParam(1);
        return this.executeAggregation(varID, operator);
    }

    private FederatedResponse executeAggregation(long varID, AggregateUnaryOperator operator) {
        Data dataObject = this._vars.get(varID);
        if (dataObject.getDataType() != Types.DataType.MATRIX) {
            return new FederatedResponse(FederatedResponse.Type.ERROR, "FederatedWorkerHandler: Aggregation only supported for matrices, not for " + dataObject.getDataType().name());
        }
        MatrixObject matrixObject = (MatrixObject)dataObject;
        MatrixBlock matrixBlock = matrixObject.acquireRead();
        MatrixCharacteristics mc = new MatrixCharacteristics();
        operator.indexFn.computeDimension(matrixObject.getDataCharacteristics(), mc);
        int outNumRows = (int)mc.getRows();
        int outNumCols = (int)mc.getCols();
        if (operator.aggOp.existsCorrection()) {
            int numMissing = operator.aggOp.correction.getNumRemovedRowsColumns();
            if (operator.aggOp.correction.isRows()) {
                outNumRows += numMissing;
            } else {
                outNumCols += numMissing;
            }
        }
        MatrixBlock ret = new MatrixBlock(outNumRows, outNumCols, operator.aggOp.initialValue);
        try {
            LibMatrixAgg.aggregateUnaryMatrix(matrixBlock, ret, operator);
        }
        catch (Exception e) {
            return new FederatedResponse(FederatedResponse.Type.ERROR, "FederatedWorkerHandler: " + e);
        }
        ret.dropLastRowsOrColumns(operator.aggOp.correction);
        return new FederatedResponse(FederatedResponse.Type.SUCCESS, ret);
    }

    private FederatedResponse executeScalarOperation(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 2);
        ScalarOperator operator = (ScalarOperator)request.getParam(0);
        long varID = (Long)request.getParam(1);
        return this.executeScalarOperation(varID, operator);
    }

    private FederatedResponse executeScalarOperation(long varID, ScalarOperator operator) {
        Data dataObject = this._vars.get(varID);
        if (dataObject.getDataType() != Types.DataType.MATRIX) {
            return new FederatedResponse(FederatedResponse.Type.ERROR, "FederatedWorkerHandler: ScalarOperator dont support " + dataObject.getDataType().name());
        }
        MatrixObject matrixObject = (MatrixObject)dataObject;
        MatrixBlock inBlock = matrixObject.acquireRead();
        MatrixBlock retBlock = inBlock.scalarOperations(operator, new MatrixBlock());
        return new FederatedResponse(FederatedResponse.Type.SUCCESS, retBlock);
    }

    private FederatedResponse createMatrixObject(MatrixBlock result) {
        MatrixObject resTo = new MatrixObject(Types.ValueType.FP64, OptimizerUtils.getUniqueTempFileName());
        MetaDataFormat metadata = new MetaDataFormat(new MatrixCharacteristics(result.getNumRows(), result.getNumColumns()), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo);
        resTo.setMetaData(metadata);
        resTo.acquireModify(result);
        resTo.release();
        long result_var = this._seq.getNextID();
        this._vars.put(result_var, resTo);
        return new FederatedResponse(FederatedResponse.Type.SUCCESS, result_var);
    }

    private static void checkNumParams(int actual, int ... expected) {
        if (Arrays.stream(expected).anyMatch(x -> x == actual)) {
            return;
        }
        throw new DMLRuntimeException("FederatedWorkerHandler: Received wrong amount of params: expected=" + Arrays.toString(expected) + ", actual=" + actual);
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        cause.printStackTrace();
        ctx.close();
    }

    private static class CloseListener
    implements ChannelFutureListener {
        private CloseListener() {
        }

        public void operationComplete(ChannelFuture channelFuture) throws InterruptedException, DMLRuntimeException {
            if (!channelFuture.isSuccess()) {
                throw new DMLRuntimeException("Federated Worker Write failed");
            }
            channelFuture.channel().close().sync();
        }
    }
}

