/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Iterator;
import org.tugraz.sysds.hops.Hop;
import org.tugraz.sysds.hops.IndexingOp;
import org.tugraz.sysds.hops.LeftIndexingOp;
import org.tugraz.sysds.hops.LiteralOp;
import org.tugraz.sysds.hops.rewrite.HopRewriteRule;
import org.tugraz.sysds.hops.rewrite.HopRewriteUtils;
import org.tugraz.sysds.hops.rewrite.ProgramRewriteStatus;

public class RewriteIndexingVectorization
extends HopRewriteRule {
    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        if (roots == null) {
            return roots;
        }
        for (Hop h : roots) {
            this.rule_IndexingVectorization(h);
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return root;
        }
        this.rule_IndexingVectorization(root);
        return root;
    }

    private void rule_IndexingVectorization(Hop hop) {
        if (hop.isVisited()) {
            return;
        }
        for (int i = 0; i < hop.getInput().size(); ++i) {
            Hop hi = hop.getInput().get(i);
            hi = RewriteIndexingVectorization.vectorizeRightLeftIndexingChains(hi);
            hi = RewriteIndexingVectorization.vectorizeLeftIndexing(hi);
            this.rule_IndexingVectorization(hi);
        }
        hop.setVisited();
    }

    private static Hop vectorizeRightLeftIndexingChains(Hop hi) {
        if (!(hi instanceof LeftIndexingOp) || !(hi.getInput().get(1) instanceof IndexingOp) || hi.getInput().get(1).getParent().size() != 1) {
            return hi;
        }
        LeftIndexingOp lix0 = (LeftIndexingOp)hi;
        IndexingOp rix0 = (IndexingOp)hi.getInput().get(1);
        if (!lix0.isRowLowerEqualsUpper() && !lix0.isColLowerEqualsUpper() || lix0.isRowLowerEqualsUpper() != rix0.isRowLowerEqualsUpper() || lix0.isColLowerEqualsUpper() != rix0.isColLowerEqualsUpper()) {
            return hi;
        }
        boolean row = lix0.isRowLowerEqualsUpper();
        if (!(row ? HopRewriteUtils.isFullRowIndexing(lix0) : HopRewriteUtils.isFullColumnIndexing(lix0)) || !(!row ? HopRewriteUtils.isFullColumnIndexing(rix0) : HopRewriteUtils.isFullRowIndexing(rix0))) {
            return hi;
        }
        ArrayList<LeftIndexingOp> lix = new ArrayList<LeftIndexingOp>();
        lix.add(lix0);
        ArrayList<IndexingOp> rix = new ArrayList<IndexingOp>();
        rix.add(rix0);
        LeftIndexingOp clix = lix0;
        IndexingOp crix = rix0;
        while (RewriteIndexingVectorization.isConsecutiveLeftRightIndexing(clix, crix, clix.getInput().get(0)) && clix.getInput().get(0).getParent().size() == 1 && clix.getInput().get(0).getInput().get(1).getParent().size() == 1) {
            clix = (LeftIndexingOp)clix.getInput().get(0);
            crix = (IndexingOp)clix.getInput().get(1);
            lix.add(clix);
            rix.add(crix);
        }
        if (lix.size() >= 2) {
            IndexingOp rixn = (IndexingOp)rix.get(rix.size() - 1);
            Hop rlrix = rixn.getInput().get(1);
            Hop rurix = row ? HopRewriteUtils.createBinary(rlrix, (Hop)new LiteralOp(rix.size() - 1), Hop.OpOp2.PLUS) : rixn.getInput().get(2);
            Hop clrix = rixn.getInput().get(3);
            Hop curix = row ? rixn.getInput().get(4) : HopRewriteUtils.createBinary(clrix, (Hop)new LiteralOp(rix.size() - 1), Hop.OpOp2.PLUS);
            IndexingOp rixNew = HopRewriteUtils.createIndexingOp(rixn.getInput().get(0), rlrix, rurix, clrix, curix);
            LeftIndexingOp lixn = (LeftIndexingOp)lix.get(rix.size() - 1);
            Hop rllix = lixn.getInput().get(2);
            Hop rulix = row ? HopRewriteUtils.createBinary(rllix, (Hop)new LiteralOp(lix.size() - 1), Hop.OpOp2.PLUS) : lixn.getInput().get(3);
            Hop cllix = lixn.getInput().get(4);
            Hop culix = row ? lixn.getInput().get(5) : HopRewriteUtils.createBinary(cllix, (Hop)new LiteralOp(lix.size() - 1), Hop.OpOp2.PLUS);
            LeftIndexingOp lixNew = HopRewriteUtils.createLeftIndexingOp(lixn.getInput().get(0), rixNew, rllix, rulix, cllix, culix);
            HopRewriteUtils.replaceChildReference(hi.getParent().get(0), hi, lixNew);
            for (int i = 0; i < lix.size(); ++i) {
                HopRewriteUtils.removeAllChildReferences((Hop)lix.get(i));
                HopRewriteUtils.removeAllChildReferences((Hop)rix.get(i));
            }
            hi = lixNew;
            LOG.debug((Object)("Applied vectorizeRightLeftIndexingChains (line " + hi.getBeginLine() + ")"));
        }
        return hi;
    }

    private static boolean isConsecutiveLeftRightIndexing(LeftIndexingOp lix, IndexingOp rix, Hop input) {
        boolean rixInputs;
        if (!(input instanceof LeftIndexingOp) || !(input.getInput().get(1) instanceof IndexingOp)) {
            return false;
        }
        boolean row = lix.isRowLowerEqualsUpper();
        LeftIndexingOp lix2 = (LeftIndexingOp)input;
        IndexingOp rix2 = (IndexingOp)input.getInput().get(1);
        boolean access = row ? HopRewriteUtils.isFullRowIndexing(lix2) && HopRewriteUtils.isFullRowIndexing(rix2) : HopRewriteUtils.isFullColumnIndexing(lix2) && HopRewriteUtils.isFullColumnIndexing(rix2);
        boolean bl = rixInputs = rix.getInput().get(0) == rix2.getInput().get(0);
        boolean consecutive = row ? HopRewriteUtils.isConsecutiveIndex(lix2.getInput().get(2), lix.getInput().get(2)) && HopRewriteUtils.isConsecutiveIndex(rix2.getInput().get(1), rix.getInput().get(1)) : HopRewriteUtils.isConsecutiveIndex(lix2.getInput().get(4), lix.getInput().get(4)) && HopRewriteUtils.isConsecutiveIndex(rix2.getInput().get(3), rix.getInput().get(3));
        return access && rixInputs && consecutive;
    }

    private static void vectorizeRightIndexing(Hop hop) {
        if (hop instanceof IndexingOp) {
            Object newRix;
            ArrayList<Hop> ihops;
            Hop input;
            IndexingOp ihop0 = (IndexingOp)hop;
            boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
            boolean isSingleCol = ihop0.isColLowerEqualsUpper();
            boolean appliedRow = false;
            if (isSingleRow && isSingleCol) {
                input = ihop0.getInput().get(0);
                ihops = new ArrayList<Hop>();
                ihops.add(ihop0);
                for (Hop hop2 : input.getParent()) {
                    if (hop2 == ihop0 || !(hop2 instanceof IndexingOp) || hop2.getInput().get(0) != input || !((IndexingOp)hop2).isRowLowerEqualsUpper() || hop2.getInput().get(1) != ihop0.getInput().get(1)) continue;
                    ihops.add(hop2);
                }
                if (ihops.size() > 1) {
                    newRix = new IndexingOp("tmp", input.getDataType(), input.getValueType(), input, ihop0.getInput().get(1), ihop0.getInput().get(1), new LiteralOp(1L), HopRewriteUtils.createValueHop(input, false), true, false);
                    HopRewriteUtils.setOutputParameters((Hop)newRix, -1L, -1L, input.getBlocksize(), -1L);
                    ((IndexingOp)newRix).refreshSizeInformation();
                    for (Hop c : ihops) {
                        HopRewriteUtils.removeChildReference(c, input);
                        HopRewriteUtils.addChildReference(c, (Hop)newRix, 0);
                        HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(1), 1);
                        HopRewriteUtils.addChildReference(c, new LiteralOp(1L), 1);
                        HopRewriteUtils.removeChildReferenceByPos(c, c.getInput().get(2), 2);
                        HopRewriteUtils.addChildReference(c, new LiteralOp(1L), 2);
                        c.refreshSizeInformation();
                    }
                    appliedRow = true;
                    LOG.debug((Object)"Applied vectorizeRightIndexingRow");
                }
            }
            if (isSingleRow && isSingleCol && !appliedRow) {
                input = ihop0.getInput().get(0);
                ihops = new ArrayList();
                ihops.add(ihop0);
                for (Hop hop3 : input.getParent()) {
                    if (hop3 == ihop0 || !(hop3 instanceof IndexingOp) || hop3.getInput().get(0) != input || !((IndexingOp)hop3).isColLowerEqualsUpper() || hop3.getInput().get(3) != ihop0.getInput().get(3)) continue;
                    ihops.add(hop3);
                }
                if (ihops.size() > 1) {
                    newRix = new IndexingOp("tmp", input.getDataType(), input.getValueType(), input, new LiteralOp(1L), HopRewriteUtils.createValueHop(input, true), ihop0.getInput().get(3), ihop0.getInput().get(3), false, true);
                    HopRewriteUtils.setOutputParameters((Hop)newRix, -1L, -1L, input.getBlocksize(), -1L);
                    ((IndexingOp)newRix).refreshSizeInformation();
                    for (Hop c : ihops) {
                        HopRewriteUtils.removeChildReference(c, input);
                        HopRewriteUtils.addChildReference(c, (Hop)newRix, 0);
                        HopRewriteUtils.replaceChildReference(c, c.getInput().get(3), new LiteralOp(1L), 3);
                        HopRewriteUtils.replaceChildReference(c, c.getInput().get(4), new LiteralOp(1L), 4);
                        c.refreshSizeInformation();
                    }
                    LOG.debug((Object)"Applied vectorizeRightIndexingCol");
                }
            }
        }
    }

    private static Hop vectorizeLeftIndexing(Hop hop) {
        Object ret = hop;
        if (hop instanceof LeftIndexingOp) {
            int posp;
            Object newLix;
            ArrayList<Integer> ihop0parentsPos;
            Hop c2;
            IndexingOp newRix;
            Hop input;
            LeftIndexingOp tmp;
            LeftIndexingOp current;
            ArrayList<LeftIndexingOp> ihops;
            LeftIndexingOp ihop0 = (LeftIndexingOp)hop;
            boolean isSingleRow = ihop0.isRowLowerEqualsUpper();
            boolean isSingleCol = ihop0.isColLowerEqualsUpper();
            boolean appliedRow = false;
            if (isSingleRow && isSingleCol) {
                ihops = new ArrayList<LeftIndexingOp>();
                ihops.add(ihop0);
                current = ihop0;
                while (current.getInput().get(0) instanceof LeftIndexingOp && (tmp = (LeftIndexingOp)current.getInput().get(0)).getParent().size() <= 1 && tmp.isRowLowerEqualsUpper() && tmp.getInput().get(2) == ihop0.getInput().get(2) && tmp.getInput().get(0).getDim2() > 1L) {
                    ihops.add(tmp);
                    current = tmp;
                }
                if (ihops.size() > 1) {
                    input = current.getInput().get(0);
                    Hop rowExpr = ihop0.getInput().get(2);
                    newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, rowExpr, rowExpr, new LiteralOp(1L), HopRewriteUtils.createValueHop(input, false), true, false);
                    HopRewriteUtils.setOutputParameters(newRix, -1L, -1L, input.getBlocksize(), -1L);
                    newRix.refreshSizeInformation();
                    for (Hop c2 : newRix.getInput()) {
                        c2.resetVisitStatus();
                    }
                    HopRewriteUtils.removeChildReference(current, input);
                    HopRewriteUtils.addChildReference(current, newRix, 0);
                    for (int i = ihops.size() - 1; i >= 0; --i) {
                        c2 = (Hop)ihops.get(i);
                        HopRewriteUtils.replaceChildReference(c2, c2.getInput().get(2), new LiteralOp(1L), 2);
                        HopRewriteUtils.replaceChildReference(c2, c2.getInput().get(3), new LiteralOp(1L), 3);
                        ((LeftIndexingOp)c2).setRowLowerEqualsUpper(true);
                        c2.refreshSizeInformation();
                    }
                    ArrayList ihop0parents = (ArrayList)ihop0.getParent().clone();
                    ihop0parentsPos = new ArrayList<Integer>();
                    Iterator iterator = ihop0parents.iterator();
                    while (iterator.hasNext()) {
                        Hop parent = (Hop)iterator.next();
                        int posp2 = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                        HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp2);
                        ihop0parentsPos.add(posp2);
                    }
                    newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, rowExpr, rowExpr, new LiteralOp(1L), HopRewriteUtils.createValueHop(input, false), true, false);
                    HopRewriteUtils.setOutputParameters((Hop)newLix, -1L, -1L, input.getBlocksize(), -1L);
                    ((LeftIndexingOp)newLix).refreshSizeInformation();
                    for (Hop c3 : ((Hop)newLix).getInput()) {
                        c3.resetVisitStatus();
                    }
                    for (int i = 0; i < ihop0parentsPos.size(); ++i) {
                        Hop parent = (Hop)ihop0parents.get(i);
                        posp = (Integer)ihop0parentsPos.get(i);
                        HopRewriteUtils.addChildReference(parent, (Hop)newLix, posp);
                    }
                    appliedRow = true;
                    ret = newLix;
                    LOG.debug((Object)("Applied vectorizeLeftIndexingRow for hop " + hop.getHopID()));
                }
            }
            if (isSingleRow && isSingleCol && !appliedRow) {
                ihops = new ArrayList();
                ihops.add(ihop0);
                current = ihop0;
                while (current.getInput().get(0) instanceof LeftIndexingOp && (tmp = (LeftIndexingOp)current.getInput().get(0)).getParent().size() <= 1 && tmp.isColLowerEqualsUpper() && tmp.getInput().get(4) == ihop0.getInput().get(4) && tmp.getInput().get(0).getDim1() > 1L) {
                    ihops.add(tmp);
                    current = tmp;
                }
                if (ihops.size() > 1) {
                    input = current.getInput().get(0);
                    Hop colExpr = ihop0.getInput().get(4);
                    newRix = new IndexingOp("tmp1", input.getDataType(), input.getValueType(), input, new LiteralOp(1L), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                    HopRewriteUtils.setOutputParameters(newRix, -1L, -1L, input.getBlocksize(), -1L);
                    newRix.refreshSizeInformation();
                    for (Hop c2 : newRix.getInput()) {
                        c2.resetVisitStatus();
                    }
                    HopRewriteUtils.removeChildReference(current, input);
                    HopRewriteUtils.addChildReference(current, newRix, 0);
                    for (int i = ihops.size() - 1; i >= 0; --i) {
                        c2 = (Hop)ihops.get(i);
                        HopRewriteUtils.replaceChildReference(c2, c2.getInput().get(4), new LiteralOp(1L), 4);
                        HopRewriteUtils.replaceChildReference(c2, c2.getInput().get(5), new LiteralOp(1L), 5);
                        ((LeftIndexingOp)c2).setColLowerEqualsUpper(true);
                        c2.refreshSizeInformation();
                    }
                    ArrayList ihop0parents = (ArrayList)ihop0.getParent().clone();
                    ihop0parentsPos = new ArrayList();
                    for (Hop parent : ihop0parents) {
                        int posp3 = HopRewriteUtils.getChildReferencePos(parent, ihop0);
                        HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp3);
                        ihop0parentsPos.add(posp3);
                    }
                    newLix = new LeftIndexingOp("tmp2", input.getDataType(), input.getValueType(), input, ihop0, new LiteralOp(1L), HopRewriteUtils.createValueHop(input, true), colExpr, colExpr, false, true);
                    HopRewriteUtils.setOutputParameters((Hop)newLix, -1L, -1L, input.getBlocksize(), -1L);
                    ((LeftIndexingOp)newLix).refreshSizeInformation();
                    for (Hop c4 : ((Hop)newLix).getInput()) {
                        c4.resetVisitStatus();
                    }
                    for (int i = 0; i < ihop0parentsPos.size(); ++i) {
                        Hop parent = (Hop)ihop0parents.get(i);
                        posp = (Integer)ihop0parentsPos.get(i);
                        HopRewriteUtils.addChildReference(parent, (Hop)newLix, posp);
                    }
                    ret = newLix;
                    LOG.debug((Object)("Applied vectorizeLeftIndexingCol for hop " + hop.getHopID()));
                }
            }
        }
        return ret;
    }
}

