/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.tugraz.sysds.api.DMLScript;
import org.tugraz.sysds.common.Types;
import org.tugraz.sysds.hops.DataOp;
import org.tugraz.sysds.hops.FunctionOp;
import org.tugraz.sysds.hops.Hop;
import org.tugraz.sysds.hops.LeftIndexingOp;
import org.tugraz.sysds.hops.UnaryOp;
import org.tugraz.sysds.hops.rewrite.HopRewriteUtils;
import org.tugraz.sysds.hops.rewrite.ProgramRewriteStatus;
import org.tugraz.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.tugraz.sysds.parser.ForStatement;
import org.tugraz.sysds.parser.ForStatementBlock;
import org.tugraz.sysds.parser.IfStatement;
import org.tugraz.sysds.parser.IfStatementBlock;
import org.tugraz.sysds.parser.Statement;
import org.tugraz.sysds.parser.StatementBlock;
import org.tugraz.sysds.parser.VariableSet;
import org.tugraz.sysds.parser.WhileStatement;
import org.tugraz.sysds.parser.WhileStatementBlock;

public class RewriteMarkLoopVariablesUpdateInPlace
extends StatementBlockRewriteRule {
    @Override
    public boolean createsSplitDag() {
        return false;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) {
        if (DMLScript.getGlobalExecMode() == Types.ExecMode.SPARK) {
            return Arrays.asList(sb);
        }
        if (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
            ArrayList<String> candidates = new ArrayList<String>();
            VariableSet updated = sb.variablesUpdated();
            VariableSet liveout = sb.liveOut();
            for (String varname : updated.getVariableNames()) {
                Statement wstmt;
                if (updated.getVariable(varname).getDataType() != Types.DataType.MATRIX || !liveout.containsVariable(varname)) continue;
                if (sb instanceof WhileStatementBlock) {
                    wstmt = (WhileStatement)sb.getStatement(0);
                    if (!this.rIsApplicableForUpdateInPlace(((WhileStatement)wstmt).getBody(), varname)) continue;
                    candidates.add(varname);
                    continue;
                }
                if (!(sb instanceof ForStatementBlock) || !this.rIsApplicableForUpdateInPlace(((ForStatement)(wstmt = (ForStatement)sb.getStatement(0))).getBody(), varname)) continue;
                candidates.add(varname);
            }
            sb.setUpdateInPlaceVars(candidates);
        }
        return Arrays.asList(sb);
    }

    private boolean rIsApplicableForUpdateInPlace(ArrayList<StatementBlock> sbs, String varname) {
        boolean ret = true;
        for (StatementBlock sb : sbs) {
            if (!sb.variablesRead().containsVariable(varname) && !sb.variablesUpdated().containsVariable(varname)) continue;
            if (sb instanceof WhileStatementBlock || sb instanceof ForStatementBlock) {
                ret &= sb.getUpdateInPlaceVars().contains(varname);
            } else if (sb instanceof IfStatementBlock) {
                IfStatementBlock isb = (IfStatementBlock)sb;
                IfStatement istmt = (IfStatement)isb.getStatement(0);
                if ((ret &= this.rIsApplicableForUpdateInPlace(istmt.getIfBody(), varname)) && istmt.getElseBody() != null) {
                    ret &= this.rIsApplicableForUpdateInPlace(istmt.getElseBody(), varname);
                }
            } else if (sb.getHops() != null && !RewriteMarkLoopVariablesUpdateInPlace.isApplicableForUpdateInPlace(sb.getHops(), varname)) {
                for (Hop hop : sb.getHops()) {
                    ret &= RewriteMarkLoopVariablesUpdateInPlace.isApplicableForUpdateInPlace(hop, varname);
                }
            }
            if (ret) continue;
            break;
        }
        return ret;
    }

    private static boolean isApplicableForUpdateInPlace(Hop hop, String varname) {
        if (hop instanceof FunctionOp && ((FunctionOp)hop).containsOutput(varname)) {
            return false;
        }
        if (!hop.getName().equals(varname)) {
            return true;
        }
        boolean validLix = RewriteMarkLoopVariablesUpdateInPlace.probeLixRoot(hop, varname);
        if (validLix) {
            for (Hop p : hop.getInput().get(0).getInput().get(0).getParent()) {
                validLix &= p == hop.getInput().get(0) || p instanceof UnaryOp && ((UnaryOp)p).getOp() == Hop.OpOp1.NROW || p instanceof UnaryOp && ((UnaryOp)p).getOp() == Hop.OpOp1.NCOL;
            }
        }
        return validLix;
    }

    private static boolean isApplicableForUpdateInPlace(ArrayList<Hop> hops, String varname) {
        Hop bLix = null;
        for (Hop hop : hops) {
            if (!RewriteMarkLoopVariablesUpdateInPlace.probeLixRoot(hop, varname)) continue;
            if (bLix != null) {
                return false;
            }
            bLix = hop.getInput().get(0);
        }
        boolean valid = true;
        Hop.resetVisitStatus(hops);
        for (Hop hop : hops) {
            if (hop.getInput().get(0) == bLix) continue;
            valid &= RewriteMarkLoopVariablesUpdateInPlace.rProbeOtherRoot(hop, varname);
        }
        Hop.resetVisitStatus(hops);
        return valid;
    }

    private static boolean probeLixRoot(Hop root, String varname) {
        return root instanceof DataOp && root.isMatrix() && root.getInput().get(0).isMatrix() && root.getInput().get(0) instanceof LeftIndexingOp && root.getInput().get(0).getInput().get(0) instanceof DataOp && root.getInput().get(0).getInput().get(0).getName().equals(varname);
    }

    private static boolean rProbeOtherRoot(Hop hop, String varname) {
        if (hop.isVisited()) {
            return false;
        }
        boolean valid = !(hop instanceof LeftIndexingOp) && (!HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) || !hop.getName().equals(varname));
        for (Hop c : hop.getInput()) {
            valid &= RewriteMarkLoopVariablesUpdateInPlace.rProbeOtherRoot(c, varname);
        }
        hop.setVisited();
        return valid;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus sate) {
        return sbs;
    }
}

