/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.tugraz.sysds.common.Types;
import org.tugraz.sysds.hops.Hop;
import org.tugraz.sysds.hops.rewrite.HopRewriteUtils;
import org.tugraz.sysds.hops.rewrite.ProgramRewriteStatus;
import org.tugraz.sysds.hops.rewrite.StatementBlockRewriteRule;
import org.tugraz.sysds.parser.ForStatement;
import org.tugraz.sysds.parser.ForStatementBlock;
import org.tugraz.sysds.parser.FunctionStatement;
import org.tugraz.sysds.parser.FunctionStatementBlock;
import org.tugraz.sysds.parser.IfStatement;
import org.tugraz.sysds.parser.IfStatementBlock;
import org.tugraz.sysds.parser.StatementBlock;
import org.tugraz.sysds.parser.WhileStatement;
import org.tugraz.sysds.parser.WhileStatementBlock;
import org.tugraz.sysds.runtime.lineage.LineageCacheConfig;

public class MarkForLineageReuse
extends StatementBlockRewriteRule {
    @Override
    public boolean createsSplitDag() {
        return false;
    }

    @Override
    public List<StatementBlock> rewriteStatementBlock(StatementBlock sb, ProgramRewriteStatus status) {
        HashSet<String> deproots;
        Set<String> loopVar;
        if (!HopRewriteUtils.isLoopStatementBlock(sb) || LineageCacheConfig.ReuseCacheType.isNone()) {
            return Arrays.asList(sb);
        }
        if (sb instanceof ForStatementBlock) {
            ForStatement fstmt = (ForStatement)sb.getStatement(0);
            loopVar = new HashSet<String>(Arrays.asList(fstmt.getIterablePredicate().getIterVar().getName()));
            deproots = new HashSet<String>();
            this.rUnmarkLoopDepVarsSB(fstmt.getBody(), deproots, loopVar);
        }
        if (sb instanceof WhileStatementBlock) {
            WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
            loopVar = sb.variablesUpdated().getVariableNames().stream().filter(v -> wstmt.getConditionalPredicate().variablesRead().containsVariable((String)v)).collect(Collectors.toSet());
            deproots = new HashSet();
            this.rUnmarkLoopDepVarsSB(wstmt.getBody(), deproots, loopVar);
        }
        return Arrays.asList(sb);
    }

    private void rUnmarkLoopDepVarsSB(ArrayList<StatementBlock> sbs, HashSet<String> deproots, Set<String> loopVar) {
        HashSet<String> newdepsbs = new HashSet<String>();
        int lim = 0;
        do {
            newdepsbs.clear();
            newdepsbs.addAll(deproots);
            block1: for (StatementBlock sb : sbs) {
                if (sb instanceof ForStatementBlock) {
                    ForStatement fstmt = (ForStatement)sb.getStatement(0);
                    this.rUnmarkLoopDepVarsSB(fstmt.getBody(), newdepsbs, loopVar);
                    continue;
                }
                if (sb instanceof WhileStatementBlock) {
                    WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
                    this.rUnmarkLoopDepVarsSB(wstmt.getBody(), newdepsbs, loopVar);
                    continue;
                }
                if (sb instanceof IfStatementBlock) {
                    IfStatement ifstmt = (IfStatement)sb.getStatement(0);
                    this.rUnmarkLoopDepVarsSB(ifstmt.getIfBody(), newdepsbs, loopVar);
                    if (ifstmt.getElseBody() == null) continue;
                    this.rUnmarkLoopDepVarsSB(ifstmt.getElseBody(), newdepsbs, loopVar);
                    continue;
                }
                if (sb instanceof FunctionStatementBlock) {
                    FunctionStatement fnstmt = (FunctionStatement)sb.getStatement(0);
                    this.rUnmarkLoopDepVarsSB(fnstmt.getBody(), newdepsbs, loopVar);
                    continue;
                }
                if (sb.getHops() == null) continue;
                for (int j = 0; j < sb.variablesUpdated().getSize(); ++j) {
                    HashSet<String> newdeproots = new HashSet<String>(deproots);
                    for (Hop hop : sb.getHops()) {
                        Hop.resetVisitStatus(sb.getHops());
                        HashSet<Long> dephops = new HashSet<Long>();
                        this.rUnmarkLoopDepVars(hop, loopVar, newdeproots, dephops);
                    }
                    if (!deproots.isEmpty() && deproots.equals(newdeproots)) continue block1;
                    deproots.addAll(newdeproots);
                }
            }
            deproots.addAll(newdepsbs);
        } while (++lim < sbs.size() && (deproots.isEmpty() || !deproots.equals(newdepsbs)));
    }

    private void rUnmarkLoopDepVars(Hop hop, Set<String> loopVar, HashSet<String> deproots, HashSet<Long> dephops) {
        if (hop.isVisited()) {
            return;
        }
        for (Hop hi : hop.getInput()) {
            this.rUnmarkLoopDepVars(hi, loopVar, deproots, dephops);
        }
        boolean loopdephop = loopVar.contains(hop.getName()) || HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTREAD) && deproots.contains(hop.getName());
        for (Hop hi : hop.getInput()) {
            loopdephop |= dephops.contains(hi.getHopID());
        }
        if (loopdephop) {
            dephops.add(hop.getHopID());
            hop.setRequiresLineageCaching(false);
        }
        if (HopRewriteUtils.isData(hop, Types.OpOpData.TRANSIENTWRITE) && !dephops.isEmpty()) {
            deproots.add(hop.getName());
        }
        hop.setVisited();
    }

    @Override
    public List<StatementBlock> rewriteStatementBlocks(List<StatementBlock> sbs, ProgramRewriteStatus status) {
        return sbs;
    }
}

