/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.instructions.gpu;

import java.util.ArrayList;
import jcuda.Pointer;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.caching.MatrixObject;
import org.tugraz.sysds.runtime.controlprogram.context.ExecutionContext;
import org.tugraz.sysds.runtime.functionobjects.SwapIndex;
import org.tugraz.sysds.runtime.instructions.InstructionUtils;
import org.tugraz.sysds.runtime.instructions.cp.CPOperand;
import org.tugraz.sysds.runtime.instructions.gpu.GPUInstruction;
import org.tugraz.sysds.runtime.instructions.gpu.context.ExecutionConfig;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUContext;
import org.tugraz.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.tugraz.sysds.runtime.matrix.data.LibMatrixCuDNN;
import org.tugraz.sysds.runtime.matrix.data.LibMatrixDNN;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.matrix.operators.ReorgOperator;
import org.tugraz.sysds.runtime.util.DnnUtils;
import org.tugraz.sysds.utils.GPUStatistics;

public class DnnGPUInstruction
extends GPUInstruction {
    private CPOperand _input1;
    private CPOperand _input2;
    private CPOperand _input3;
    private CPOperand _input4;
    private CPOperand _input5;
    private CPOperand _input6;
    private CPOperand _input7;
    private CPOperand _input8;
    private CPOperand _output;
    private CPOperand _output2;
    private CPOperand _output3;
    private CPOperand _output4;
    private CPOperand _output5;
    private ArrayList<CPOperand> _input_shape;
    private ArrayList<CPOperand> _filter_shape;
    private ArrayList<CPOperand> _stride = new ArrayList();
    private ArrayList<CPOperand> _padding = new ArrayList();
    private double _intermediateMemoryBudget = 0.0;

    public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        if (!(opcode.equals("bias_add") || opcode.equals("bias_multiply") || opcode.equals("relu_backward"))) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be bias_add or bias_multiply or relu_backward, but found " + opcode);
        }
        this._input1 = in1;
        this._input2 = in2;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output = out;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, CPOperand out, CPOperand out2, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this._input4 = in4;
        this._input5 = in5;
        this._input6 = in6;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output = out;
        this._output2 = out2;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, CPOperand in7, CPOperand in8, CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this._input4 = in4;
        this._input5 = in5;
        this._input6 = in6;
        this._input7 = in7;
        this._input8 = in8;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output = out;
        this._output2 = out2;
        this._output3 = out3;
        this._output4 = out4;
        this._output5 = out5;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        if (!opcode.equals("channel_sums")) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be channel_sums, but found " + opcode);
        }
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output = out;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) throws DMLRuntimeException {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        if (!opcode.equals("update_nesterov_x")) {
            throw new DMLRuntimeException("Incorrect opcode: " + opcode);
        }
        this._input1 = in1;
        this._input2 = in2;
        this._input3 = in3;
        this._input4 = in4;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output = out;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget) {
        this(in1, in2, out, opcode, istr, stride, padding, input_shape, filter_shape, intermediateMemoryBudget);
        this._input3 = in3;
    }

    public DnnGPUInstruction(CPOperand in1, CPOperand in2, CPOperand out, String opcode, String istr, ArrayList<CPOperand> stride, ArrayList<CPOperand> padding, ArrayList<CPOperand> input_shape, ArrayList<CPOperand> filter_shape, double intermediateMemoryBudget) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._input1 = in1;
        this._input2 = in2;
        this._output = out;
        this._stride = stride;
        this._padding = padding;
        this._input_shape = input_shape;
        this._filter_shape = filter_shape;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public DnnGPUInstruction(CPOperand in, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, CPOperand out, String opcode, String istr, double intermediateMemoryBudget) {
        super(new ReorgOperator(SwapIndex.getSwapIndexFnObject()), opcode, istr);
        if (!opcode.equals("batch_norm2d_test")) {
            throw new DMLRuntimeException("Incorrect usage. Expected the opcode to be batch_norm2d_test, but found " + opcode);
        }
        this._input1 = in;
        this._input2 = in2;
        this._input3 = in3;
        this._input4 = in4;
        this._input5 = in5;
        this._input6 = in6;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.Dnn;
        this._output = out;
        this._intermediateMemoryBudget = intermediateMemoryBudget;
    }

    public static DnnGPUInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        if (opcode.equalsIgnoreCase("conv2d") || opcode.equalsIgnoreCase("conv2d_backward_filter") || opcode.equalsIgnoreCase("conv2d_backward_data")) {
            InstructionUtils.checkNumFields(parts, 16);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[15]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[3]));
            stride.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            return new DnnGPUInstruction(in1, in2, out, opcode, str, stride, padding, input_shape, filter_shape, Double.parseDouble(parts[16]));
        }
        if (opcode.equalsIgnoreCase("maxpooling_backward") || opcode.equalsIgnoreCase("avgpooling_backward")) {
            boolean withMaxPoolOut = false;
            if (parts.length == 18) {
                withMaxPoolOut = true;
            } else {
                InstructionUtils.checkNumFields(parts, 16);
            }
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = withMaxPoolOut ? new CPOperand(parts[15]) : null;
            CPOperand out = withMaxPoolOut ? new CPOperand(parts[16]) : new CPOperand(parts[15]);
            double memBudget = withMaxPoolOut ? Double.parseDouble(parts[17]) : Double.parseDouble(parts[16]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[3]));
            stride.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape, memBudget);
        }
        if (opcode.equalsIgnoreCase("conv2d_bias_add")) {
            InstructionUtils.checkNumFields(parts, 17);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[16]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[4]));
            stride.add(new CPOperand(parts[5]));
            padding.add(new CPOperand(parts[6]));
            padding.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            input_shape.add(new CPOperand(parts[10]));
            input_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            filter_shape.add(new CPOperand(parts[14]));
            filter_shape.add(new CPOperand(parts[15]));
            return new DnnGPUInstruction(in1, in2, in3, out, opcode, str, stride, padding, input_shape, filter_shape, Double.parseDouble(parts[17]));
        }
        if (opcode.equalsIgnoreCase("maxpooling") || opcode.equalsIgnoreCase("avgpooling")) {
            InstructionUtils.checkNumFields(parts, 15);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand out = new CPOperand(parts[14]);
            ArrayList<CPOperand> stride = new ArrayList<CPOperand>();
            ArrayList<CPOperand> padding = new ArrayList<CPOperand>();
            ArrayList<CPOperand> input_shape = new ArrayList<CPOperand>();
            ArrayList<CPOperand> filter_shape = new ArrayList<CPOperand>();
            stride.add(new CPOperand(parts[2]));
            stride.add(new CPOperand(parts[3]));
            padding.add(new CPOperand(parts[4]));
            padding.add(new CPOperand(parts[5]));
            input_shape.add(new CPOperand(parts[6]));
            input_shape.add(new CPOperand(parts[7]));
            input_shape.add(new CPOperand(parts[8]));
            input_shape.add(new CPOperand(parts[9]));
            filter_shape.add(new CPOperand(parts[10]));
            filter_shape.add(new CPOperand(parts[11]));
            filter_shape.add(new CPOperand(parts[12]));
            filter_shape.add(new CPOperand(parts[13]));
            return new DnnGPUInstruction(in1, null, out, opcode, str, stride, padding, input_shape, filter_shape, Double.parseDouble(parts[15]));
        }
        if (opcode.equalsIgnoreCase("bias_add") || opcode.equalsIgnoreCase("relu_backward") || opcode.equalsIgnoreCase("bias_multiply")) {
            InstructionUtils.checkNumFields(parts, 4);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand out = new CPOperand(parts[3]);
            return new DnnGPUInstruction(in1, in2, out, opcode, str, Double.parseDouble(parts[4]));
        }
        if (opcode.equalsIgnoreCase("channel_sums")) {
            InstructionUtils.checkNumFields(parts, 4);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand out = new CPOperand(parts[4]);
            return new DnnGPUInstruction(in, in2, in3, out, opcode, str, 0.0);
        }
        if (opcode.equalsIgnoreCase("update_nesterov_x")) {
            InstructionUtils.checkNumFields(parts, 5);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand out = new CPOperand(parts[5]);
            return new DnnGPUInstruction(in, in2, in3, in4, out, opcode, str, 0.0);
        }
        if (opcode.equalsIgnoreCase("lstm")) {
            InstructionUtils.checkNumFields(parts, 8);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand in5 = new CPOperand(parts[5]);
            CPOperand in6 = new CPOperand(parts[6]);
            CPOperand out = new CPOperand(parts[7]);
            CPOperand out2 = new CPOperand(parts[8]);
            return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, out, out2, opcode, str, 0.0);
        }
        if (opcode.equalsIgnoreCase("batch_norm2d") || opcode.equalsIgnoreCase("lstm_backward")) {
            InstructionUtils.checkNumFields(parts, 13);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand in5 = new CPOperand(parts[5]);
            CPOperand in6 = new CPOperand(parts[6]);
            CPOperand in7 = new CPOperand(parts[7]);
            CPOperand in8 = new CPOperand(parts[8]);
            CPOperand out = new CPOperand(parts[9]);
            CPOperand out2 = new CPOperand(parts[10]);
            CPOperand out3 = new CPOperand(parts[11]);
            CPOperand out4 = new CPOperand(parts[12]);
            CPOperand out5 = new CPOperand(parts[13]);
            return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0.0);
        }
        if (opcode.equalsIgnoreCase("batch_norm2d_backward")) {
            InstructionUtils.checkNumFields(parts, 9);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand in5 = new CPOperand(parts[5]);
            CPOperand in6 = new CPOperand(parts[6]);
            CPOperand out = new CPOperand(parts[7]);
            CPOperand out2 = new CPOperand(parts[8]);
            CPOperand out3 = new CPOperand(parts[9]);
            return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0.0);
        }
        if (opcode.equalsIgnoreCase("batch_norm2d_test")) {
            InstructionUtils.checkNumFields(parts, 7);
            CPOperand in = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand in5 = new CPOperand(parts[5]);
            CPOperand in6 = new CPOperand(parts[6]);
            CPOperand out = new CPOperand(parts[7]);
            return new DnnGPUInstruction(in, in2, in3, in4, in5, in6, out, opcode, str, 0.0);
        }
        if (opcode.equalsIgnoreCase("batch_norm2d_train")) {
            InstructionUtils.checkNumFields(parts, 12);
            CPOperand in1 = new CPOperand(parts[1]);
            CPOperand in2 = new CPOperand(parts[2]);
            CPOperand in3 = new CPOperand(parts[3]);
            CPOperand in4 = new CPOperand(parts[4]);
            CPOperand in5 = new CPOperand(parts[5]);
            CPOperand in6 = new CPOperand(parts[6]);
            CPOperand in7 = new CPOperand(parts[7]);
            CPOperand out = new CPOperand(parts[8]);
            CPOperand out2 = new CPOperand(parts[9]);
            CPOperand out3 = new CPOperand(parts[10]);
            CPOperand out4 = new CPOperand(parts[11]);
            CPOperand out5 = new CPOperand(parts[12]);
            return new DnnGPUInstruction(in1, in2, in3, in4, in5, in6, in7, null, out, out2, out3, out4, out5, opcode, str, 0.0);
        }
        throw new DMLRuntimeException("Unknown opcode while parsing a DnnGPUInstruction: " + str);
    }

    private void processBiasInstruction(String instOpcode, ExecutionContext ec) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject input = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), input.getNumRows(), input.getNumColumns());
        if (instOpcode.equalsIgnoreCase("bias_add")) {
            LibMatrixCUDA.biasAdd(ec.getGPUContext(0), this.getExtendedOpcode(), input, bias, out);
        } else if (instOpcode.equalsIgnoreCase("bias_multiply")) {
            LibMatrixCUDA.biasMultiply(ec.getGPUContext(0), this.getExtendedOpcode(), input, bias, out);
        }
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processBatchNorm2dInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject scale = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
        MatrixObject runningMean = this.getMatrixInputForGPUInstruction(ec, this._input4.getName());
        MatrixObject runningVar = this.getMatrixInputForGPUInstruction(ec, this._input5.getName());
        String phase = ec.getScalarInput(this._input6).getStringValue();
        double epsilon = ec.getScalarInput(this._input7).getDoubleValue();
        MatrixObject ret = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), image.getNumRows(), image.getNumColumns());
        if (phase.equalsIgnoreCase("train")) {
            double exponentialAverageFactor = 1.0 - ec.getScalarInput(this._input8.getName(), this._input8.getValueType(), this._input8.isLiteral()).getDoubleValue();
            MatrixObject retRunningMean = this.getDenseMatrixOutputForGPUInstruction(ec, this._output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
            MatrixObject retRunningVar = this.getDenseMatrixOutputForGPUInstruction(ec, this._output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
            MatrixObject resultSaveMean = this.getDenseMatrixOutputForGPUInstruction(ec, this._output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
            MatrixObject resultSaveInvVariance = this.getDenseMatrixOutputForGPUInstruction(ec, this._output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
            LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), this.getExtendedOpcode(), image, scale, bias, runningMean, runningVar, ret, retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
            ec.releaseMatrixOutputForGPUInstruction(this._output2.getName());
            ec.releaseMatrixOutputForGPUInstruction(this._output3.getName());
            ec.releaseMatrixOutputForGPUInstruction(this._output4.getName());
            ec.releaseMatrixOutputForGPUInstruction(this._output5.getName());
        } else if (phase.equalsIgnoreCase("test")) {
            LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), this.getExtendedOpcode(), image, scale, bias, runningMean, runningVar, ret, epsilon);
            ec.setMatrixOutput(this._output2.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true));
            ec.setMatrixOutput(this._output3.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true));
            ec.setMatrixOutput(this._output4.getName(), new MatrixBlock((int)runningMean.getNumRows(), (int)runningMean.getNumColumns(), true));
            ec.setMatrixOutput(this._output5.getName(), new MatrixBlock((int)runningVar.getNumRows(), (int)runningVar.getNumColumns(), true));
        } else {
            throw new DMLRuntimeException("Incorrect mode: Expected either train or test, but found " + phase);
        }
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input4.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input5.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processBatchNorm2dTrainInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject scale = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
        MatrixObject runningMean = this.getMatrixInputForGPUInstruction(ec, this._input4.getName());
        MatrixObject runningVar = this.getMatrixInputForGPUInstruction(ec, this._input5.getName());
        double epsilon = ec.getScalarInput(this._input6.getName(), this._input6.getValueType(), this._input6.isLiteral()).getDoubleValue();
        double exponentialAverageFactor = 1.0 - ec.getScalarInput(this._input7.getName(), this._input7.getValueType(), this._input7.isLiteral()).getDoubleValue();
        MatrixObject ret = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), image.getNumRows(), image.getNumColumns());
        MatrixObject retRunningMean = this.getDenseMatrixOutputForGPUInstruction(ec, this._output2.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
        MatrixObject retRunningVar = this.getDenseMatrixOutputForGPUInstruction(ec, this._output3.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
        MatrixObject resultSaveMean = this.getDenseMatrixOutputForGPUInstruction(ec, this._output4.getName(), runningMean.getNumRows(), runningMean.getNumColumns());
        MatrixObject resultSaveInvVariance = this.getDenseMatrixOutputForGPUInstruction(ec, this._output5.getName(), runningVar.getNumRows(), runningVar.getNumColumns());
        LibMatrixCuDNN.batchNormalizationForwardTraining(ec.getGPUContext(0), this.getExtendedOpcode(), image, scale, bias, runningMean, runningVar, ret, retRunningMean, retRunningVar, epsilon, exponentialAverageFactor, resultSaveMean, resultSaveInvVariance);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input4.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input5.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output3.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output4.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output5.getName());
    }

    private void processBatchNorm2dTestInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject scale = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
        MatrixObject runningMean = this.getMatrixInputForGPUInstruction(ec, this._input4.getName());
        MatrixObject runningVar = this.getMatrixInputForGPUInstruction(ec, this._input5.getName());
        double epsilon = ec.getScalarInput(this._input6.getName(), this._input6.getValueType(), this._input6.isLiteral()).getDoubleValue();
        MatrixObject ret = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), image.getNumRows(), image.getNumColumns());
        LibMatrixCuDNN.batchNormalizationForwardInference(ec.getGPUContext(0), this.getExtendedOpcode(), image, scale, bias, runningMean, runningVar, ret, epsilon);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input4.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input5.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    public void processBatchNorm2dBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject scale = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
        double epsilon = ec.getScalarInput(this._input4).getDoubleValue();
        MatrixObject resultSaveMean = this.getMatrixInputForGPUInstruction(ec, this._input5.getName());
        MatrixObject resultSaveInvVariance = this.getMatrixInputForGPUInstruction(ec, this._input6.getName());
        MatrixObject dX = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), image.getNumRows(), image.getNumColumns());
        MatrixObject dScale = this.getDenseMatrixOutputForGPUInstruction(ec, this._output2.getName(), scale.getNumRows(), scale.getNumColumns());
        MatrixObject dBias = this.getDenseMatrixOutputForGPUInstruction(ec, this._output3.getName(), scale.getNumRows(), scale.getNumColumns());
        LibMatrixCuDNN.batchNormalizationBackward(ec.getGPUContext(0), this.getExtendedOpcode(), image, dout, scale, dX, dScale, dBias, epsilon, resultSaveMean, resultSaveInvVariance);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input5.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input6.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output3.getName());
    }

    public void processReLUBackwardInstruction(ExecutionContext ec) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject input = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), input.getNumRows(), input.getNumColumns());
        LibMatrixCUDA.reluBackward(ec.getGPUContext(0), this.getExtendedOpcode(), input, dout, out);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processChannelSumsInstruction(ExecutionContext ec) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject input = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        int C = (int)ec.getScalarInput(this._input2.getName(), this._input2.getValueType(), this._input2.isLiteral()).getLongValue();
        int HW = (int)ec.getScalarInput(this._input3.getName(), this._input3.getValueType(), this._input3.isLiteral()).getLongValue();
        if ((long)(C * HW) != input.getNumColumns()) {
            throw new DMLRuntimeException("Expected rows*cols" + C + "*" + HW + " to be equal to number of columns of input " + input.getNumColumns());
        }
        MatrixObject outputBlock = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), C, 1L);
        LibMatrixCUDA.channelSums(ec.getGPUContext(0), this.getExtendedOpcode(), input, outputBlock, C, HW);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private void processNesterovUpdateInstruction(ExecutionContext ec) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject input = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        MatrixObject v = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject v_prev = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
        double mu = (int)ec.getScalarInput(this._input4).getDoubleValue();
        int rows = LibMatrixCUDA.toInt(input.getNumRows());
        int cols = LibMatrixCUDA.toInt(input.getNumColumns());
        MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), rows, cols);
        GPUContext gCtx = ec.getGPUContext(0);
        String instName = this.getExtendedOpcode();
        LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("update_nesterov_x", ExecutionConfig.getConfigForSimpleVectorOperations(LibMatrixCUDA.toInt(rows * cols)), LibMatrixCUDA.getDensePointer(gCtx, input, instName), LibMatrixCUDA.getDensePointer(gCtx, v, instName), LibMatrixCUDA.getDensePointer(gCtx, v_prev, instName), mu, LibMatrixCUDA.getDensePointer(gCtx, out, instName), rows * cols);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private static int toInt(long num) throws DMLRuntimeException {
        if (num >= Integer.MAX_VALUE || num <= Integer.MIN_VALUE) {
            throw new DMLRuntimeException("GPU : Exceeded supported size " + num);
        }
        return (int)num;
    }

    private void processLstmBackwardInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        GPUContext gCtx = ec.getGPUContext(0);
        String instructionName = this.getExtendedOpcode();
        MatrixObject out0 = this.getMatrixInputForGPUInstruction(ec, this._input4.getName());
        int M = DnnGPUInstruction.toInt(out0.getNumColumns());
        Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
        MatrixObject W = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
        long numRowsW = W.getNumRows();
        int D = DnnGPUInstruction.toInt(numRowsW) - M;
        Pointer sysdsWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D + M, 4 * M);
        Pointer sysdsBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4 * M);
        Pointer cudnnWPointer = gCtx.allocate(instructionName, (D + M + 2) * (4 * M) * LibMatrixCUDA.sizeOfDataType);
        LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", ExecutionConfig.getConfigForSimpleVectorOperations((D + M + 2) * (4 * M)), sysdsWPointer, sysdsBiasPointer, cudnnWPointer, D, M);
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        MatrixObject X = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName);
        int N = DnnGPUInstruction.toInt(X.getNumRows());
        long numColsX = X.getNumColumns();
        int T = DnnGPUInstruction.toInt(numColsX / (long)D);
        Pointer cudnnInput = gCtx.allocate(instructionName, N * T * D * LibMatrixCUDA.sizeOfDataType);
        LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", ExecutionConfig.getConfigForSimpleVectorOperations(N * T * D), xPointer, cudnnInput, N, D, T * D, N * T * D);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, this.getMatrixInputForGPUInstruction(ec, this._input5.getName()), instructionName);
        boolean return_sequences = ec.getScalarInput(this._input6.getName(), this._input6.getValueType(), this._input6.isLiteral()).getBooleanValue();
        String dxName = this._output.getName();
        String dwName = this._output2.getName();
        String dbName = this._output3.getName();
        String dhxName = this._output4.getName();
        String dcxName = this._output5.getName();
        String doutName = this._input7.getName();
        String dcyName = this._input8.getName();
        LibMatrixCuDNN.lstmBackward(ec, gCtx, instructionName, cudnnInput, out0Pointer, c0Pointer, cudnnWPointer, doutName, dcyName, dxName, dwName, dbName, dhxName, dcxName, return_sequences, N, M, D, T);
        gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE);
        gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE);
        ec.releaseMatrixInputForGPUInstruction(this._input4.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input5.getName());
    }

    private void processLstmInstruction(ExecutionContext ec) throws DMLRuntimeException {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        GPUContext gCtx = ec.getGPUContext(0);
        String instructionName = this.getExtendedOpcode();
        MatrixObject out0 = this.getMatrixInputForGPUInstruction(ec, this._input4.getName());
        int M = DnnGPUInstruction.toInt(out0.getNumColumns());
        Pointer out0Pointer = LibMatrixCUDA.getDensePointer(gCtx, out0, instructionName);
        MatrixObject W = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
        MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
        long numRowsW = W.getNumRows();
        int D = DnnGPUInstruction.toInt(numRowsW) - M;
        Pointer sysdsWPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, W, instructionName, D + M, 4 * M);
        Pointer sysdsBiasPointer = LibMatrixCuDNN.getDensePointerForCuDNN(gCtx, bias, instructionName, 1, 4 * M);
        Pointer cudnnWPointer = gCtx.allocate(instructionName, (D + M + 2) * (4 * M) * LibMatrixCUDA.sizeOfDataType);
        LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_weight", ExecutionConfig.getConfigForSimpleVectorOperations((D + M + 2) * (4 * M)), sysdsWPointer, sysdsBiasPointer, cudnnWPointer, D, M);
        ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        boolean return_sequences = ec.getScalarInput(this._input6.getName(), this._input6.getValueType(), this._input6.isLiteral()).getBooleanValue();
        MatrixObject X = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
        Pointer xPointer = LibMatrixCUDA.getDensePointer(gCtx, X, instructionName);
        int N = DnnGPUInstruction.toInt(X.getNumRows());
        long numColsX = X.getNumColumns();
        int T = DnnGPUInstruction.toInt(numColsX / (long)D);
        Pointer cudnnInput = gCtx.allocate(instructionName, N * T * D * LibMatrixCUDA.sizeOfDataType);
        LibMatrixCUDA.getCudaKernels(gCtx).launchKernel("prepare_lstm_input", ExecutionConfig.getConfigForSimpleVectorOperations(N * T * D), xPointer, cudnnInput, N, D, T * D, N * T * D);
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        Pointer c0Pointer = LibMatrixCUDA.getDensePointer(gCtx, this.getMatrixInputForGPUInstruction(ec, this._input5.getName()), instructionName);
        LibMatrixCuDNN.lstm(ec, gCtx, instructionName, cudnnInput, cudnnWPointer, out0Pointer, c0Pointer, return_sequences, this._output.getName(), this._output2.getName(), N, M, D, T);
        gCtx.cudaFreeHelper(instructionName, cudnnWPointer, DMLScript.EAGER_CUDA_FREE);
        gCtx.cudaFreeHelper(instructionName, cudnnInput, DMLScript.EAGER_CUDA_FREE);
        ec.releaseMatrixInputForGPUInstruction(this._input4.getName());
        ec.releaseMatrixInputForGPUInstruction(this._input5.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output2.getName());
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        boolean isPoolBackward;
        MatrixObject dout;
        MatrixObject image;
        if (this.instOpcode.equalsIgnoreCase("bias_add") || this.instOpcode.equalsIgnoreCase("bias_multiply")) {
            this.processBiasInstruction(this.instOpcode, ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("relu_backward")) {
            this.processReLUBackwardInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("channel_sums")) {
            this.processChannelSumsInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("update_nesterov_x")) {
            this.processNesterovUpdateInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("lstm")) {
            this.processLstmInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("lstm_backward")) {
            this.processLstmBackwardInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d")) {
            this.processBatchNorm2dInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_backward")) {
            this.processBatchNorm2dBackwardInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_test")) {
            this.processBatchNorm2dTestInstruction(ec);
            return;
        }
        if (this.instOpcode.equalsIgnoreCase("batch_norm2d_train")) {
            this.processBatchNorm2dTrainInstruction(ec);
            return;
        }
        GPUStatistics.incrementNoOfExecutedGPUInst();
        int pad_h = DnnGPUInstruction.getScalarInput(ec, this._padding, 0);
        int pad_w = DnnGPUInstruction.getScalarInput(ec, this._padding, 1);
        int stride_h = DnnGPUInstruction.getScalarInput(ec, this._stride, 0);
        int stride_w = DnnGPUInstruction.getScalarInput(ec, this._stride, 1);
        int N = DnnGPUInstruction.getScalarInput(ec, this._input_shape, 0);
        int C = DnnGPUInstruction.getScalarInput(ec, this._input_shape, 1);
        int H = DnnGPUInstruction.getScalarInput(ec, this._input_shape, 2);
        int W = DnnGPUInstruction.getScalarInput(ec, this._input_shape, 3);
        int K = DnnGPUInstruction.getScalarInput(ec, this._filter_shape, 0);
        int R = DnnGPUInstruction.getScalarInput(ec, this._filter_shape, 2);
        int S = DnnGPUInstruction.getScalarInput(ec, this._filter_shape, 3);
        int P = (int)DnnUtils.getP(H, R, stride_h, pad_h);
        int Q = (int)DnnUtils.getQ(W, S, stride_w, pad_w);
        if (this.instOpcode.equalsIgnoreCase("conv2d")) {
            image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            MatrixObject filter = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
            }
            if (filter.getNumRows() != (long)K || filter.getNumColumns() != (long)(C * R * S)) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, K * P * Q);
            LibMatrixCuDNN.conv2d(ec.getGPUContext(0), this.getExtendedOpcode(), image, filter, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add")) {
            image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            MatrixObject bias = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            MatrixObject filter = this.getMatrixInputForGPUInstruction(ec, this._input3.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d");
            }
            if (filter.getNumRows() != (long)K || filter.getNumColumns() != (long)(C * R * S)) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in conv2d");
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, K * P * Q);
            LibMatrixCuDNN.conv2dBiasAdd(ec.getGPUContext(0), this.getExtendedOpcode(), image, bias, filter, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_filter")) {
            image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in conv2d_backward_filter");
            }
            if (dout.getNumRows() != (long)N || dout.getNumColumns() != (long)(K * P * Q)) {
                throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_filter: " + dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K * P * Q);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), K, C * R * S);
            LibMatrixCuDNN.conv2dBackwardFilter(ec.getGPUContext(0), this.getExtendedOpcode(), image, dout, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("conv2d_backward_data")) {
            MatrixObject filter = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            if (filter.getNumRows() != (long)K || filter.getNumColumns() != (long)(C * R * S)) {
                throw new DMLRuntimeException("Incorrect dimensions for filter in convolution_backward_data");
            }
            if (dout.getNumRows() != (long)N || dout.getNumColumns() != (long)(K * P * Q)) {
                throw new DMLRuntimeException("Incorrect dimensions for dout in conv2d_backward_data: " + dout.getNumRows() + " != " + N + " || " + dout.getNumColumns() + " != " + K * P * Q);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, C * H * W);
            LibMatrixCuDNN.conv2dBackwardData(ec.getGPUContext(0), this.getExtendedOpcode(), filter, dout, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("avgpooling")) {
            image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling: " + image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + C * H * W);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, C * P * Q);
            LibMatrixDNN.PoolingType poolType = this.instOpcode.equalsIgnoreCase("maxpooling") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
            LibMatrixCuDNN.pooling(ec.getGPUContext(0), this.getExtendedOpcode(), image, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, this._intermediateMemoryBudget);
        } else if (this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("avgpooling_backward")) {
            MatrixObject maxPoolOutput;
            image = this.getMatrixInputForGPUInstruction(ec, this._input1.getName());
            dout = this.getMatrixInputForGPUInstruction(ec, this._input2.getName());
            MatrixObject matrixObject = maxPoolOutput = this._input3 != null ? this.getMatrixInputForGPUInstruction(ec, this._input3.getName()) : null;
            if (dout.getNumRows() != (long)N || dout.getNumColumns() != (long)(C * P * Q)) {
                throw new DMLRuntimeException("Incorrect dimensions for dout in maxpooling_backward");
            }
            if (image.getNumRows() != (long)N || image.getNumColumns() != (long)(C * H * W)) {
                throw new DMLRuntimeException("Incorrect dimensions for image in maxpooling_backward: " + image.getNumRows() + " != " + N + " || " + image.getNumColumns() + " != " + K * P * Q);
            }
            MatrixObject out = this.getDenseMatrixOutputForGPUInstruction(ec, this._output.getName(), N, C * H * W);
            LibMatrixDNN.PoolingType poolType = this.instOpcode.equalsIgnoreCase("maxpooling_backward") ? LibMatrixDNN.PoolingType.MAX : LibMatrixDNN.PoolingType.AVG;
            LibMatrixCuDNN.poolingBackward(ec.getGPUContext(0), this.getExtendedOpcode(), image, dout, maxPoolOutput, out, N, C, H, W, K, R, S, pad_h, pad_w, stride_h, stride_w, P, Q, poolType, this._intermediateMemoryBudget);
        } else {
            throw new DMLRuntimeException("Unsupported GPU context for " + this.instOpcode);
        }
        ec.releaseMatrixInputForGPUInstruction(this._input1.getName());
        boolean isPool = this.instOpcode.equalsIgnoreCase("maxpooling") || this.instOpcode.equalsIgnoreCase("avgpooling");
        boolean bl = isPoolBackward = this.instOpcode.equalsIgnoreCase("maxpooling_backward") || this.instOpcode.equalsIgnoreCase("avgpooling_backward");
        if (!isPool) {
            ec.releaseMatrixInputForGPUInstruction(this._input2.getName());
        }
        if (this.instOpcode.equalsIgnoreCase("conv2d_bias_add") || isPoolBackward && this._input3 != null) {
            ec.releaseMatrixInputForGPUInstruction(this._input3.getName());
        }
        ec.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private static int getScalarInput(ExecutionContext ec, ArrayList<CPOperand> aL, int index) {
        return (int)ec.getScalarInput(aL.get(index)).getLongValue();
    }
}

