/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.hops.ipa;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.stream.Collectors;
import org.tugraz.sysds.common.Types;
import org.tugraz.sysds.hops.FunctionOp;
import org.tugraz.sysds.hops.Hop;
import org.tugraz.sysds.hops.HopsException;
import org.tugraz.sysds.hops.rewrite.HopRewriteUtils;
import org.tugraz.sysds.parser.DMLProgram;
import org.tugraz.sysds.parser.ForStatement;
import org.tugraz.sysds.parser.ForStatementBlock;
import org.tugraz.sysds.parser.FunctionStatement;
import org.tugraz.sysds.parser.FunctionStatementBlock;
import org.tugraz.sysds.parser.IfStatement;
import org.tugraz.sysds.parser.IfStatementBlock;
import org.tugraz.sysds.parser.StatementBlock;
import org.tugraz.sysds.parser.WhileStatement;
import org.tugraz.sysds.parser.WhileStatementBlock;

public class FunctionCallGraph {
    private static final String MAIN_FUNCTION_KEY = "_main";
    private final HashMap<String, HashSet<String>> _fGraph = new HashMap();
    private final HashMap<String, ArrayList<FunctionOp>> _fCalls = new HashMap();
    private final HashMap<String, ArrayList<StatementBlock>> _fCallsSB = new HashMap();
    private final HashSet<String> _fRecursive = new HashSet();
    private final HashSet<String> _fSideEffectFree = new HashSet();
    private final boolean _containsSecondOrder;

    public FunctionCallGraph(DMLProgram prog) {
        this._containsSecondOrder = this.constructFunctionCallGraph(prog);
    }

    public FunctionCallGraph(StatementBlock sb) {
        this._containsSecondOrder = this.constructFunctionCallGraph(sb);
    }

    public Set<String> getCalledFunctions(String fnamespace, String fname) {
        return this.getCalledFunctions(DMLProgram.constructFunctionKey(fnamespace, fname));
    }

    public Set<String> getCalledFunctions(String fkey) {
        String lfkey = fkey == null ? MAIN_FUNCTION_KEY : fkey;
        return this._fGraph.get(lfkey);
    }

    public List<FunctionOp> getFunctionCalls(String fkey) {
        if (fkey == null) {
            return Collections.emptyList();
        }
        return this._fCalls.get(fkey);
    }

    public List<StatementBlock> getFunctionCallsSB(String fkey) {
        if (fkey == null) {
            return Collections.emptyList();
        }
        return this._fCallsSB.get(fkey);
    }

    public void removeFunctionCalls(String fkey) {
        this._fCalls.remove(fkey);
        this._fCallsSB.remove(fkey);
        this._fRecursive.remove(fkey);
        this._fGraph.remove(fkey);
        for (Map.Entry<String, HashSet<String>> e : this._fGraph.entrySet()) {
            e.getValue().removeIf(s -> s.equals(fkey));
        }
    }

    public void removeFunctionCall(String fkey, FunctionOp fop, StatementBlock sb) {
        if (this._fCalls.containsKey(fkey)) {
            this._fCalls.get(fkey).remove(fop);
        }
        if (this._fCallsSB.containsKey(fkey)) {
            this._fCallsSB.get(fkey).remove(sb);
        }
    }

    public void replaceFunctionCalls(String fkeyOld, String fkey) {
        ArrayList<FunctionOp> fopTmp = this._fCalls.get(fkeyOld);
        ArrayList<StatementBlock> sbTmp = this._fCallsSB.get(fkeyOld);
        this._fCalls.remove(fkeyOld);
        this._fCallsSB.remove(fkeyOld);
        this._fCalls.put(fkey, fopTmp);
        this._fCallsSB.put(fkey, sbTmp);
        this._fRecursive.remove(fkeyOld);
        this._fSideEffectFree.remove(fkeyOld);
        this._fGraph.remove(fkeyOld);
        for (HashSet<String> hs : this._fGraph.values()) {
            hs.remove(fkeyOld);
        }
    }

    public boolean isRecursiveFunction(String fnamespace, String fname) {
        return this.isRecursiveFunction(DMLProgram.constructFunctionKey(fnamespace, fname));
    }

    public boolean isRecursiveFunction(String fkey) {
        String lfkey = fkey == null ? MAIN_FUNCTION_KEY : fkey;
        return this._fRecursive.contains(lfkey);
    }

    public boolean isSideEffectFreeFunction(String fnamespace, String fname) {
        return this.isSideEffectFreeFunction(DMLProgram.constructFunctionKey(fnamespace, fname));
    }

    public boolean isSideEffectFreeFunction(String fkey) {
        String lfkey = fkey == null ? MAIN_FUNCTION_KEY : fkey;
        return this._fSideEffectFree.contains(lfkey);
    }

    public Set<String> getReachableFunctions() {
        return this.getReachableFunctions(Collections.emptySet());
    }

    public Set<String> getReachableFunctions(Set<String> blacklist) {
        return this._fGraph.keySet().stream().filter(p -> !blacklist.contains(p) && !MAIN_FUNCTION_KEY.equals(p)).collect(Collectors.toSet());
    }

    public boolean isReachableFunction(String fnamespace, String fname) {
        return this.isReachableFunction(DMLProgram.constructFunctionKey(fnamespace, fname));
    }

    public boolean isReachableFunction(String fkey) {
        return this.isReachableFunction(fkey, false);
    }

    protected boolean isReachableFunction(String fkey, boolean deep) {
        String lfkey = fkey == null ? MAIN_FUNCTION_KEY : fkey;
        return !deep ? this._fGraph.containsKey(lfkey) : this._fGraph.values().stream().anyMatch(list -> list.contains(lfkey));
    }

    public boolean containsSecondOrderCall() {
        return this._containsSecondOrder;
    }

    private boolean constructFunctionCallGraph(DMLProgram prog) {
        if (!prog.hasFunctionStatementBlocks()) {
            return false;
        }
        boolean ret = false;
        try {
            Stack<String> fstack = new Stack<String>();
            HashSet<String> lfset = new HashSet<String>();
            this._fGraph.put(MAIN_FUNCTION_KEY, new HashSet());
            for (StatementBlock sblk : prog.getStatementBlocks()) {
                ret |= this.rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset);
            }
            this._fSideEffectFree.addAll(this._fCalls.keySet().stream().filter(s -> !s.startsWith("_internal")).filter(s -> FunctionCallGraph.isSideEffectFree(prog.getFunctionStatementBlock((String)s))).collect(Collectors.toList()));
        }
        catch (HopsException ex) {
            throw new RuntimeException(ex);
        }
        return ret;
    }

    private boolean constructFunctionCallGraph(StatementBlock sb) {
        if (!sb.getDMLProg().hasFunctionStatementBlocks()) {
            return false;
        }
        try {
            Stack<String> fstack = new Stack<String>();
            HashSet<String> lfset = new HashSet<String>();
            this._fGraph.put(MAIN_FUNCTION_KEY, new HashSet());
            return this.rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sb, fstack, lfset);
        }
        catch (HopsException ex) {
            throw new RuntimeException(ex);
        }
    }

    private boolean rConstructFunctionCallGraph(String fkey, StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
        boolean ret = false;
        if (sb instanceof WhileStatementBlock) {
            WhileStatement ws = (WhileStatement)sb.getStatement(0);
            for (StatementBlock current : ws.getBody()) {
                ret |= this.rConstructFunctionCallGraph(fkey, current, fstack, lfset);
            }
        } else if (sb instanceof IfStatementBlock) {
            IfStatement ifs = (IfStatement)sb.getStatement(0);
            for (StatementBlock current : ifs.getIfBody()) {
                ret |= this.rConstructFunctionCallGraph(fkey, current, fstack, lfset);
            }
            for (StatementBlock current : ifs.getElseBody()) {
                ret |= this.rConstructFunctionCallGraph(fkey, current, fstack, lfset);
            }
        } else if (sb instanceof ForStatementBlock) {
            ForStatement fs = (ForStatement)sb.getStatement(0);
            for (StatementBlock current : fs.getBody()) {
                ret |= this.rConstructFunctionCallGraph(fkey, current, fstack, lfset);
            }
        } else if (sb instanceof FunctionStatementBlock) {
            FunctionStatement fsb = (FunctionStatement)sb.getStatement(0);
            for (StatementBlock current : fsb.getBody()) {
                ret |= this.rConstructFunctionCallGraph(fkey, current, fstack, lfset);
            }
        } else {
            ArrayList<Hop> hopsDAG = sb.getHops();
            if (hopsDAG == null || hopsDAG.isEmpty()) {
                return false;
            }
            ret = HopRewriteUtils.containsSecondOrderBuiltin(hopsDAG);
            for (Hop h : hopsDAG) {
                if (!(h instanceof FunctionOp)) continue;
                FunctionOp fop = (FunctionOp)h;
                String lfkey = fop.getFunctionKey();
                if (!this._fCalls.containsKey(lfkey)) {
                    this._fCalls.put(lfkey, new ArrayList());
                    this._fCallsSB.put(lfkey, new ArrayList());
                }
                this._fCalls.get(lfkey).add(fop);
                this._fCallsSB.get(lfkey).add(sb);
                if (lfset.contains(lfkey) || fop.getFunctionNamespace().equals("_internal")) continue;
                if (!this._fGraph.containsKey(lfkey)) {
                    this._fGraph.put(lfkey, new HashSet());
                }
                if (!fstack.contains(lfkey)) {
                    fstack.push(lfkey);
                    this._fGraph.get(fkey).add(lfkey);
                    FunctionStatementBlock fsb = sb.getDMLProg().getFunctionStatementBlock(fop.getFunctionNamespace(), fop.getFunctionName());
                    FunctionStatement fs = (FunctionStatement)fsb.getStatement(0);
                    for (StatementBlock csb : fs.getBody()) {
                        ret |= this.rConstructFunctionCallGraph(lfkey, csb, fstack, new HashSet<String>());
                    }
                    fstack.pop();
                } else {
                    this._fGraph.get(fkey).add(lfkey);
                    this._fRecursive.add(lfkey);
                    int ix = fstack.indexOf(lfkey);
                    for (int i = ix + 1; i < fstack.size(); ++i) {
                        this._fRecursive.add((String)fstack.get(i));
                    }
                }
                lfset.add(lfkey);
            }
        }
        return ret;
    }

    private static boolean isSideEffectFree(FunctionStatementBlock fsb) {
        FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
        for (StatementBlock csb : fstmt.getBody()) {
            if (!FunctionCallGraph.rHasSideEffects(csb)) continue;
            return false;
        }
        return true;
    }

    private static boolean rHasSideEffects(StatementBlock sb) {
        boolean ret;
        block7: {
            block9: {
                block8: {
                    block6: {
                        ret = false;
                        if (!(sb instanceof ForStatementBlock)) break block6;
                        ForStatement fstmt = (ForStatement)sb.getStatement(0);
                        for (StatementBlock csb : fstmt.getBody()) {
                            ret |= FunctionCallGraph.rHasSideEffects(csb);
                        }
                        break block7;
                    }
                    if (!(sb instanceof WhileStatementBlock)) break block8;
                    WhileStatement wstmt = (WhileStatement)sb.getStatement(0);
                    for (StatementBlock csb : wstmt.getBody()) {
                        ret |= FunctionCallGraph.rHasSideEffects(csb);
                    }
                    break block7;
                }
                if (!(sb instanceof IfStatementBlock)) break block9;
                IfStatement istmt = (IfStatement)sb.getStatement(0);
                for (StatementBlock csb : istmt.getIfBody()) {
                    ret |= FunctionCallGraph.rHasSideEffects(csb);
                }
                if (istmt.getElseBody() == null) break block7;
                for (StatementBlock csb : istmt.getElseBody()) {
                    ret |= FunctionCallGraph.rHasSideEffects(csb);
                }
                break block7;
            }
            if (sb.getHops() != null) {
                for (Hop root : sb.getHops()) {
                    ret |= HopRewriteUtils.isUnary(root, Hop.OpOp1.PRINT) || HopRewriteUtils.isNary(root, Types.OpOpN.PRINTF) || HopRewriteUtils.isData(root, Types.OpOpData.PERSISTENTWRITE) || root instanceof FunctionOp;
                }
            }
        }
        return ret;
    }
}

