/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.runtime.controlprogram.context;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.common.Types;
import org.tugraz.sysds.runtime.DMLRuntimeException;
import org.tugraz.sysds.runtime.controlprogram.LocalVariableMap;
import org.tugraz.sysds.runtime.controlprogram.Program;
import org.tugraz.sysds.runtime.controlprogram.caching.CacheableData;
import org.tugraz.sysds.runtime.controlprogram.caching.FrameObject;
import org.tugraz.sysds.runtime.controlprogram.caching.MatrixObject;
import org.tugraz.sysds.runtime.controlprogram.caching.TensorObject;
import org.tugraz.sysds.runtime.data.TensorBlock;
import org.tugraz.sysds.runtime.instructions.Instruction;
import org.tugraz.sysds.runtime.instructions.cp.CPOperand;
import org.tugraz.sysds.runtime.instructions.cp.Data;
import org.tugraz.sysds.runtime.instructions.cp.ListObject;
import org.tugraz.sysds.runtime.instructions.cp.ScalarObject;
import org.tugraz.sysds.runtime.instructions.cp.ScalarObjectFactory;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUContext;
import org.tugraz.sysds.runtime.instructions.gpu.context.GPUObject;
import org.tugraz.sysds.runtime.lineage.Lineage;
import org.tugraz.sysds.runtime.lineage.LineageItem;
import org.tugraz.sysds.runtime.lineage.LineagePath;
import org.tugraz.sysds.runtime.matrix.data.FrameBlock;
import org.tugraz.sysds.runtime.matrix.data.InputInfo;
import org.tugraz.sysds.runtime.matrix.data.MatrixBlock;
import org.tugraz.sysds.runtime.matrix.data.OutputInfo;
import org.tugraz.sysds.runtime.matrix.data.Pair;
import org.tugraz.sysds.runtime.meta.DataCharacteristics;
import org.tugraz.sysds.runtime.meta.MatrixCharacteristics;
import org.tugraz.sysds.runtime.meta.MetaData;
import org.tugraz.sysds.runtime.meta.MetaDataFormat;
import org.tugraz.sysds.runtime.util.HDFSTool;
import org.tugraz.sysds.utils.Statistics;

public class ExecutionContext {
    protected static final Log LOG = LogFactory.getLog((String)ExecutionContext.class.getName());
    protected Program _prog = null;
    protected LocalVariableMap _variables;
    protected Lineage _lineage;
    protected LineagePath _lineagePath = new LineagePath();
    protected List<GPUContext> _gpuContexts = new ArrayList<GPUContext>();

    protected ExecutionContext() {
        this(true, DMLScript.LINEAGE, null);
    }

    protected ExecutionContext(boolean allocateVariableMap, boolean allocateLineage, Program prog) {
        this._variables = allocateVariableMap ? new LocalVariableMap() : null;
        this._lineage = allocateLineage ? new Lineage() : null;
        this._prog = prog;
    }

    public ExecutionContext(LocalVariableMap vars) {
        this._variables = vars;
        this._lineage = null;
        this._prog = null;
    }

    public Program getProgram() {
        return this._prog;
    }

    public void setProgram(Program prog) {
        this._prog = prog;
    }

    public LocalVariableMap getVariables() {
        return this._variables;
    }

    public void setVariables(LocalVariableMap vars) {
        this._variables = vars;
    }

    public Lineage getLineage() {
        return this._lineage;
    }

    public void setLineage(Lineage lineage) {
        this._lineage = lineage;
    }

    public LineagePath getLineagePath() {
        return this._lineagePath;
    }

    public void setLineagePath(LineagePath lp) {
        this._lineagePath = lp;
    }

    public GPUContext getGPUContext(int index) {
        try {
            return this._gpuContexts.get(index);
        }
        catch (IndexOutOfBoundsException e) {
            return null;
        }
    }

    public void setGPUContexts(List<GPUContext> gpuContexts) {
        this._gpuContexts = gpuContexts;
    }

    public List<GPUContext> getGPUContexts() {
        return this._gpuContexts;
    }

    public int getNumGPUContexts() {
        return this._gpuContexts.size();
    }

    public Data getVariable(String name) {
        return this._variables.get(name);
    }

    public Data getVariable(CPOperand operand) {
        return operand.getDataType().isScalar() ? this.getScalarInput(operand) : this.getVariable(operand.getName());
    }

    public void setVariable(String name, Data val) {
        this._variables.put(name, val);
    }

    public boolean containsVariable(CPOperand operand) {
        return this.containsVariable(operand.getName());
    }

    public boolean containsVariable(String name) {
        return this._variables.keySet().contains(name);
    }

    public Data removeVariable(String name) {
        return this._variables.remove(name);
    }

    public void setMetaData(String fname, MetaData md) {
        this._variables.get(fname).setMetaData(md);
    }

    public MetaData getMetaData(String varname) {
        Data tmp = this._variables.get(varname);
        if (tmp == null) {
            throw new DMLRuntimeException(ExecutionContext.getNonExistingVarError(varname));
        }
        return tmp.getMetaData();
    }

    public boolean isMatrixObject(String varname) {
        Data dat = this.getVariable(varname);
        return dat != null && dat instanceof MatrixObject;
    }

    public MatrixObject getMatrixObject(CPOperand input) {
        return this.getMatrixObject(input.getName());
    }

    public MatrixObject getMatrixObject(String varname) {
        Data dat = this.getVariable(varname);
        if (dat == null) {
            throw new DMLRuntimeException(ExecutionContext.getNonExistingVarError(varname));
        }
        if (!(dat instanceof MatrixObject)) {
            throw new DMLRuntimeException("Variable '" + varname + "' is not a matrix.");
        }
        return (MatrixObject)dat;
    }

    public TensorObject getTensorObject(String varname) {
        Data dat = this.getVariable(varname);
        if (dat == null) {
            throw new DMLRuntimeException(ExecutionContext.getNonExistingVarError(varname));
        }
        if (!(dat instanceof TensorObject)) {
            throw new DMLRuntimeException("Variable '" + varname + "' is not a tensor.");
        }
        return (TensorObject)dat;
    }

    public boolean isFrameObject(String varname) {
        Data dat = this.getVariable(varname);
        return dat != null && dat instanceof FrameObject;
    }

    public FrameObject getFrameObject(CPOperand input) {
        return this.getFrameObject(input.getName());
    }

    public FrameObject getFrameObject(String varname) {
        Data dat = this.getVariable(varname);
        if (dat == null) {
            throw new DMLRuntimeException(ExecutionContext.getNonExistingVarError(varname));
        }
        if (!(dat instanceof FrameObject)) {
            throw new DMLRuntimeException("Variable '" + varname + "' is not a frame.");
        }
        return (FrameObject)dat;
    }

    public CacheableData<?> getCacheableData(CPOperand input) {
        return this.getCacheableData(input.getName());
    }

    public CacheableData<?> getCacheableData(String varname) {
        Data dat = this.getVariable(varname);
        if (dat == null) {
            throw new DMLRuntimeException(ExecutionContext.getNonExistingVarError(varname));
        }
        if (!(dat instanceof CacheableData)) {
            throw new DMLRuntimeException("Variable '" + varname + "' is not a matrix, tensor or frame.");
        }
        return (CacheableData)dat;
    }

    public void releaseCacheableData(String varname) {
        this.getCacheableData(varname).release();
    }

    public DataCharacteristics getDataCharacteristics(String varname) {
        return this.getMetaData(varname).getDataCharacteristics();
    }

    public MatrixBlock getMatrixInput(String varName) {
        return this.getMatrixObject(varName).acquireRead();
    }

    public TensorBlock getTensorInput(String varName) {
        return (TensorBlock)this.getTensorObject(varName).acquireRead();
    }

    public void setMetaData(String varName, long nrows, long ncols) {
        MatrixObject mo = this.getMatrixObject(varName);
        if (mo.getNumRows() == nrows && mo.getNumColumns() == ncols) {
            return;
        }
        MetaData oldMetaData = mo.getMetaData();
        if (oldMetaData == null || !(oldMetaData instanceof MetaDataFormat)) {
            throw new DMLRuntimeException("Metadata not available");
        }
        MatrixCharacteristics mc = new MatrixCharacteristics(nrows, ncols, (int)mo.getBlocksize());
        mo.setMetaData(new MetaDataFormat(mc, ((MetaDataFormat)oldMetaData).getOutputInfo(), ((MetaDataFormat)oldMetaData).getInputInfo()));
    }

    private static long validateDimensions(long d1, long d2) {
        if (d1 >= 0L && d2 >= 0L && d1 != d2) {
            throw new DMLRuntimeException("Incorrect dimensions:" + d1 + " != " + d2);
        }
        return Math.max(d1, d2);
    }

    public Pair<MatrixObject, Boolean> getDenseMatrixOutputForGPUInstruction(String varName, long numRows, long numCols) {
        MatrixObject mo = this.allocateGPUMatrixObject(varName, numRows, numCols);
        boolean allocated = mo.getGPUObject(this.getGPUContext(0)).acquireDeviceModifyDense();
        mo.getDataCharacteristics().setNonZeros(-1L);
        return new Pair<MatrixObject, Boolean>(mo, allocated);
    }

    public Pair<MatrixObject, Boolean> getSparseMatrixOutputForGPUInstruction(String varName, long numRows, long numCols, long nnz) {
        MatrixObject mo = this.allocateGPUMatrixObject(varName, numRows, numCols);
        mo.getDataCharacteristics().setNonZeros(nnz);
        boolean allocated = mo.getGPUObject(this.getGPUContext(0)).acquireDeviceModifySparse();
        return new Pair<MatrixObject, Boolean>(mo, allocated);
    }

    public MatrixObject allocateGPUMatrixObject(String varName, long numRows, long numCols) {
        MatrixObject mo = this.getMatrixObject(varName);
        long dim1 = -1L;
        long dim2 = -1L;
        try {
            dim1 = ExecutionContext.validateDimensions(mo.getNumRows(), numRows);
            dim2 = ExecutionContext.validateDimensions(mo.getNumColumns(), numCols);
        }
        catch (DMLRuntimeException e) {
            throw new DMLRuntimeException("Incorrect dimensions given to allocateGPUMatrixObject: [" + numRows + "," + numCols + "], [" + mo.getNumRows() + "," + mo.getNumColumns() + "]", e);
        }
        if (dim1 != mo.getNumRows() || dim2 != mo.getNumColumns()) {
            mo.getDataCharacteristics().setDimension(dim1, dim2);
        }
        if (mo.getGPUObject(this.getGPUContext(0)) == null) {
            GPUObject newGObj = this.getGPUContext(0).createGPUObject(mo);
            mo.setGPUObject(this.getGPUContext(0), newGObj);
        }
        mo.getGPUObject(this.getGPUContext(0)).addWriteLock();
        return mo;
    }

    public MatrixObject getMatrixInputForGPUInstruction(String varName, String opcode) {
        GPUContext gCtx = this.getGPUContext(0);
        MatrixObject mo = this.getMatrixObject(varName);
        if (mo == null) {
            throw new DMLRuntimeException("No matrix object available for variable:" + varName);
        }
        if (mo.getGPUObject(gCtx) == null) {
            GPUObject newGObj = gCtx.createGPUObject(mo);
            mo.setGPUObject(gCtx, newGObj);
        }
        mo.getGPUObject(gCtx).acquireDeviceRead(opcode);
        return mo;
    }

    public void releaseMatrixInput(String varName) {
        this.getMatrixObject(varName).release();
    }

    public void releaseMatrixInput(String ... varNames) {
        for (String varName : varNames) {
            this.releaseMatrixInput(varName);
        }
    }

    public void releaseMatrixInputForGPUInstruction(String varName) {
        this.getMatrixObject(varName).getGPUObject(this.getGPUContext(0)).releaseInput();
    }

    public FrameBlock getFrameInput(String varName) {
        return (FrameBlock)this.getFrameObject(varName).acquireRead();
    }

    public void releaseFrameInput(String varName) {
        this.getFrameObject(varName).release();
    }

    public void releaseTensorInput(String varName) {
        this.getTensorObject(varName).release();
    }

    public void releaseTensorInput(String ... varNames) {
        for (String varName : varNames) {
            this.releaseTensorInput(varName);
        }
    }

    public ScalarObject getScalarInput(CPOperand input) {
        return input.isLiteral() ? input.getLiteral() : this.getScalarInput(input.getName(), input.getValueType(), false);
    }

    public ScalarObject getScalarInput(String name, Types.ValueType vt, boolean isLiteral) {
        if (isLiteral) {
            return ScalarObjectFactory.createScalarObject(vt, name);
        }
        Data obj = this.getVariable(name);
        if (obj == null) {
            throw new DMLRuntimeException("Unknown variable: " + name);
        }
        return (ScalarObject)obj;
    }

    public void setScalarOutput(String varName, ScalarObject so) {
        this.setVariable(varName, so);
    }

    public ListObject getListObject(CPOperand input) {
        return this.getListObject(input.getName());
    }

    public ListObject getListObject(String name) {
        Data dat = this.getVariable(name);
        if (dat == null) {
            throw new DMLRuntimeException(ExecutionContext.getNonExistingVarError(name));
        }
        if (!(dat instanceof ListObject)) {
            throw new DMLRuntimeException("Variable '" + name + "' is not a list.");
        }
        return (ListObject)dat;
    }

    private List<MatrixObject> getMatricesFromList(ListObject lo) {
        ArrayList<MatrixObject> ret = new ArrayList<MatrixObject>();
        for (Data e : lo.getData()) {
            if (e instanceof MatrixObject) {
                ret.add((MatrixObject)e);
                continue;
            }
            if (e instanceof ListObject) {
                ret.addAll(this.getMatricesFromList((ListObject)e));
                continue;
            }
            throw new DMLRuntimeException("List must contain only matrices or lists for rbind/cbind.");
        }
        return ret;
    }

    public void releaseMatrixOutputForGPUInstruction(String varName) {
        MatrixObject mo = this.getMatrixObject(varName);
        if (mo.getGPUObject(this.getGPUContext(0)) == null || !mo.getGPUObject(this.getGPUContext(0)).isAllocated()) {
            throw new DMLRuntimeException("No output is allocated on GPU");
        }
        this.setMetaData(varName, new MetaDataFormat(mo.getDataCharacteristics(), OutputInfo.BinaryBlockOutputInfo, InputInfo.BinaryBlockInputInfo));
        mo.getGPUObject(this.getGPUContext(0)).releaseOutput();
    }

    public void setMatrixOutput(String varName, MatrixBlock outputData) {
        MatrixObject mo = this.getMatrixObject(varName);
        mo.acquireModify(outputData);
        mo.release();
        this.setVariable(varName, mo);
    }

    public void setMatrixOutput(String varName, MatrixBlock outputData, MatrixObject.UpdateType flag) {
        if (flag.isInPlace()) {
            MatrixObject mo = this.getMatrixObject(varName);
            mo.setUpdateType(flag);
        }
        this.setMatrixOutput(varName, outputData);
    }

    public void setMatrixOutput(String varName, MatrixBlock outputData, MatrixObject.UpdateType flag, String opcode) {
        this.setMatrixOutput(varName, outputData, flag);
    }

    public void setTensorOutput(String varName, TensorBlock outputData) {
        TensorObject to = this.getTensorObject(varName);
        to.acquireModify(outputData);
        to.release();
        this.setVariable(varName, to);
    }

    public void setFrameOutput(String varName, FrameBlock outputData) {
        FrameObject fo = this.getFrameObject(varName);
        fo.acquireModify(outputData);
        fo.release();
        this.setVariable(varName, fo);
    }

    public List<MatrixBlock> getMatrixInputs(CPOperand[] inputs) {
        return this.getMatrixInputs(inputs, false);
    }

    public List<MatrixBlock> getMatrixInputs(CPOperand[] inputs, boolean includeList) {
        List<MatrixBlock> ret = Arrays.stream(inputs).filter(in -> in.isMatrix()).map(in -> this.getMatrixInput(in.getName())).collect(Collectors.toList());
        if (includeList) {
            List lolist = Arrays.stream(inputs).filter(in -> in.isList()).map(in -> this.getListObject(in.getName())).collect(Collectors.toList());
            for (ListObject lo : lolist) {
                ret.addAll(this.getMatricesFromList(lo).stream().map(mo -> mo.acquireRead()).collect(Collectors.toList()));
            }
        }
        return ret;
    }

    public List<ScalarObject> getScalarInputs(CPOperand[] inputs) {
        return Arrays.stream(inputs).filter(in -> in.isScalar()).map(in -> this.getScalarInput((CPOperand)in)).collect(Collectors.toList());
    }

    public void releaseMatrixInputs(CPOperand[] inputs) {
        this.releaseMatrixInputs(inputs, false);
    }

    public void releaseMatrixInputs(CPOperand[] inputs, boolean includeList) {
        Arrays.stream(inputs).filter(in -> in.isMatrix()).forEach(in -> this.releaseMatrixInput(in.getName()));
        if (includeList) {
            List lolist = Arrays.stream(inputs).filter(in -> in.isList()).map(in -> this.getListObject(in.getName())).collect(Collectors.toList());
            for (ListObject lo : lolist) {
                this.getMatricesFromList(lo).stream().forEach(mo -> mo.release());
            }
        }
    }

    public boolean[] pinVariables(List<String> varList) {
        int i;
        int nlist = 0;
        int nlistItems = 0;
        for (int i2 = 0; i2 < varList.size(); ++i2) {
            Data dat = this._variables.get(varList.get(i2));
            if (!(dat instanceof ListObject)) continue;
            nlistItems += ((ListObject)dat).getNumCacheableData();
            ++nlist;
        }
        boolean[] varsState = new boolean[varList.size() - nlist + nlistItems];
        int pos = 0;
        for (i = 0; i < varList.size(); ++i) {
            Data dat = this._variables.get(varList.get(i));
            if (dat instanceof CacheableData) {
                varsState[pos++] = ((CacheableData)dat).isCleanupEnabled();
                continue;
            }
            if (!(dat instanceof ListObject)) continue;
            for (Data dat2 : ((ListObject)dat).getData()) {
                if (!(dat2 instanceof CacheableData)) continue;
                varsState[pos++] = ((CacheableData)dat2).isCleanupEnabled();
            }
        }
        for (i = 0; i < varList.size(); ++i) {
            Data dat = this._variables.get(varList.get(i));
            if (dat instanceof CacheableData) {
                ((CacheableData)dat).enableCleanup(false);
                continue;
            }
            if (!(dat instanceof ListObject)) continue;
            for (Data dat2 : ((ListObject)dat).getData()) {
                if (!(dat2 instanceof CacheableData)) continue;
                ((CacheableData)dat2).enableCleanup(false);
            }
        }
        return varsState;
    }

    public void unpinVariables(List<String> varList, boolean[] varsState) {
        int pos = 0;
        for (int i = 0; i < varList.size(); ++i) {
            Data dat = this._variables.get(varList.get(i));
            if (dat instanceof CacheableData) {
                ((CacheableData)dat).enableCleanup(varsState[pos++]);
                continue;
            }
            if (!(dat instanceof ListObject)) continue;
            for (Data dat2 : ((ListObject)dat).getData()) {
                if (!(dat2 instanceof CacheableData)) continue;
                ((CacheableData)dat2).enableCleanup(varsState[pos++]);
            }
        }
    }

    public ArrayList<String> getVarList() {
        return new ArrayList<String>(this._variables.keySet());
    }

    public ArrayList<String> getVarListPartitioned() {
        ArrayList<String> ret = new ArrayList<String>();
        for (String var : this._variables.keySet()) {
            Data dat = this._variables.get(var);
            if (!(dat instanceof MatrixObject) || !((MatrixObject)dat).isPartitioned()) continue;
            ret.add(var);
        }
        return ret;
    }

    public final void cleanupDataObject(Data dat) {
        if (dat == null) {
            return;
        }
        if (dat instanceof CacheableData) {
            this.cleanupCacheableData((CacheableData)dat);
        } else if (dat instanceof ListObject) {
            for (Data dat2 : ((ListObject)dat).getData()) {
                if (!(dat2 instanceof CacheableData)) continue;
                this.cleanupCacheableData((CacheableData)dat2);
            }
        }
    }

    public void cleanupCacheableData(CacheableData<?> mo) {
        boolean fileExists;
        if (DMLScript.JMLC_MEM_STATISTICS) {
            Statistics.removeCPMemObject(System.identityHashCode(mo));
        }
        boolean bl = fileExists = mo.isHDFSFileExists() && mo.getFileName() != null;
        if (!CacheableData.isCachingActive() && !fileExists) {
            return;
        }
        try {
            if (mo.isCleanupEnabled() && !this.getVariables().hasReferences(mo)) {
                mo.clearData();
                if (fileExists) {
                    HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName());
                    HDFSTool.deleteFileIfExistOnHDFS(mo.getFileName() + ".mtd");
                }
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
    }

    public void traceLineage(Instruction inst) {
        if (this._lineage == null) {
            throw new DMLRuntimeException("Lineage Trace unavailable.");
        }
        this._lineage.trace(inst, this);
    }

    public LineageItem getLineageItem(CPOperand input) {
        if (this._lineage == null) {
            throw new DMLRuntimeException("Lineage Trace unavailable.");
        }
        return this._lineage.get(input);
    }

    public LineageItem getOrCreateLineageItem(CPOperand input) {
        if (this._lineage == null) {
            throw new DMLRuntimeException("Lineage Trace unavailable.");
        }
        return this._lineage.getOrCreate(input);
    }

    private static String getNonExistingVarError(String varname) {
        return "Variable '" + varname + "' does not exist in the symbol table.";
    }
}

