/*
 * Decompiled with CFR 0.152.
 */
package org.tugraz.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.Arrays;
import org.apache.commons.collections.CollectionUtils;
import org.tugraz.sysds.hops.AggBinaryOp;
import org.tugraz.sysds.hops.Hop;
import org.tugraz.sysds.hops.HopsException;
import org.tugraz.sysds.hops.rewrite.HopRewriteRule;
import org.tugraz.sysds.hops.rewrite.HopRewriteUtils;
import org.tugraz.sysds.hops.rewrite.ProgramRewriteStatus;
import org.tugraz.sysds.utils.Explain;

public class RewriteMatrixMultChainOptimization
extends HopRewriteRule {
    private static final boolean LDEBUG = false;

    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        if (roots == null) {
            return null;
        }
        for (Hop h : roots) {
            this.rule_OptimizeMMChains(h, state);
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return null;
        }
        this.rule_OptimizeMMChains(root, state);
        return root;
    }

    private void rule_OptimizeMMChains(Hop hop, ProgramRewriteStatus state) {
        if (hop.isVisited()) {
            return;
        }
        if (HopRewriteUtils.isMatrixMultiply(hop) && !((AggBinaryOp)hop).hasLeftPMInput() && !hop.isVisited()) {
            this.prepAndOptimizeMMChain(hop, state);
        }
        for (Hop hi : hop.getInput()) {
            this.rule_OptimizeMMChains(hi, state);
        }
        hop.setVisited();
    }

    private void prepAndOptimizeMMChain(Hop hop, ProgramRewriteStatus state) {
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)("MM Chain Optimization for HOP: (" + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ", " + hop.getName() + ")"));
        }
        ArrayList<Hop> mmChain = new ArrayList<Hop>();
        ArrayList<Hop> mmOperators = new ArrayList<Hop>();
        mmOperators.add(hop);
        for (Hop hi : hop.getInput()) {
            mmChain.add(hi);
        }
        int i = 0;
        while (i < mmChain.size()) {
            boolean expandable = false;
            Hop h = (Hop)mmChain.get(i);
            if (HopRewriteUtils.isMatrixMultiply(h) && !((AggBinaryOp)hop).hasLeftPMInput() && !h.isVisited()) {
                boolean bl = expandable = h.getParent().size() <= 1 && RewriteMatrixMultChainOptimization.inputCount(h.getParent().get(0), h) <= 1;
                if (!expandable) break;
            }
            h.setVisited();
            if (!expandable) {
                ++i;
                continue;
            }
            ArrayList<Hop> tempList = mmChain.get(i).getInput();
            if (tempList.size() != 2) {
                throw new HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain(): AggBinary must have exactly two inputs.");
            }
            mmOperators.add(mmChain.get(i));
            mmChain.set(i, tempList.get(0));
            mmChain.add(i + 1, tempList.get(1));
        }
        if (LOG.isTraceEnabled()) {
            LOG.trace((Object)"Identified MM Chain: ");
            for (Hop h : mmChain) {
                RewriteMatrixMultChainOptimization.logTraceHop(h, 1);
            }
        }
        if (mmChain.size() == 2) {
            return;
        }
        this.optimizeMMChain(hop, mmChain, mmOperators, state);
    }

    protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
        double[] dimsArray = new double[mmChain.size() + 1];
        boolean dimsKnown = RewriteMatrixMultChainOptimization.getDimsArray(hop, mmChain, dimsArray);
        if (dimsKnown) {
            RewriteMatrixMultChainOptimization.clearLinksWithinChain(hop, mmOperators);
            int size = mmChain.size();
            int[][] split = RewriteMatrixMultChainOptimization.mmChainDP(dimsArray, mmChain.size());
            LOG.trace((Object)"Optimal MM Chain: ");
            this.mmChainRelinkHops(mmOperators.get(0), 0, size - 1, mmChain, mmOperators, 1, split, 1);
        }
    }

    private static int[][] mmChainDP(double[] dimArray, int size) {
        double[][] dpMatrix = new double[size][size];
        int[][] split = new int[size][size];
        for (int i = 0; i < size; ++i) {
            Arrays.fill(dpMatrix[i], 0.0);
            Arrays.fill(split[i], -1);
        }
        for (int l = 2; l <= size; ++l) {
            for (int i = 0; i < size - l + 1; ++i) {
                int j = i + l - 1;
                dpMatrix[i][j] = Double.MAX_VALUE;
                for (int k = i; k <= j - 1; ++k) {
                    double cost = dpMatrix[i][k] + dpMatrix[k + 1][j] + dimArray[i] * dimArray[k + 1] * dimArray[j + 1];
                    if (!(cost < dpMatrix[i][j])) continue;
                    dpMatrix[i][j] = cost;
                    split[i][j] = k;
                }
                if (!LOG.isTraceEnabled()) continue;
                LOG.trace((Object)("mmchainopt [i=" + (i + 1) + ",j=" + (j + 1) + "]: costs = " + dpMatrix[i][j] + ", split = " + (split[i][j] + 1)));
            }
        }
        return split;
    }

    protected final void mmChainRelinkHops(Hop h, int i, int j, ArrayList<Hop> mmChain, ArrayList<Hop> mmOperators, int opIndex, int[][] split, int level) {
        String offset;
        if (i == j) {
            RewriteMatrixMultChainOptimization.logTraceHop(h, level);
            return;
        }
        if (LOG.isTraceEnabled()) {
            offset = Explain.getIdentation(level);
            LOG.trace((Object)(offset + "("));
        }
        if (i == split[i][j]) {
            h.getInput().add(mmChain.get(i));
            mmChain.get(i).getParent().add(h);
        } else {
            h.getInput().add(mmOperators.get(opIndex));
            mmOperators.get(opIndex).getParent().add(h);
            ++opIndex;
        }
        if (split[i][j] + 1 == j) {
            h.getInput().add(mmChain.get(j));
            mmChain.get(j).getParent().add(h);
        } else {
            h.getInput().add(mmOperators.get(opIndex));
            mmOperators.get(opIndex).getParent().add(h);
            ++opIndex;
        }
        this.mmChainRelinkHops(h.getInput().get(0), i, split[i][j], mmChain, mmOperators, opIndex, split, level + 1);
        this.mmChainRelinkHops(h.getInput().get(1), split[i][j] + 1, j, mmChain, mmOperators, opIndex, split, level + 1);
        h.refreshSizeInformation();
        if (LOG.isTraceEnabled()) {
            offset = Explain.getIdentation(level);
            LOG.trace((Object)(offset + ")"));
        }
    }

    protected static void clearLinksWithinChain(Hop hop, ArrayList<Hop> operators) {
        for (int i = 0; i < operators.size(); ++i) {
            Hop op = operators.get(i);
            if (op.getInput().size() != 2 || i != 0 && op.getParent().size() > 1) {
                throw new HopsException(hop.printErrorLocation() + "Unexpected error while applying optimization on matrix-mult chain. \n");
            }
            Hop input1 = op.getInput().get(0);
            Hop input2 = op.getInput().get(1);
            op.getInput().clear();
            input1.getParent().remove(op);
            input2.getParent().remove(op);
        }
    }

    protected static boolean getDimsArray(Hop hop, ArrayList<Hop> chain, double[] dimsArray) {
        int i;
        boolean dimsKnown = true;
        for (i = 0; i < chain.size(); ++i) {
            if (chain.get(i).getDim1() > 0L && chain.get(i).getDim2() > 0L) continue;
            dimsKnown = false;
        }
        if (dimsKnown) {
            for (i = 0; i < chain.size(); ++i) {
                if (i == 0) {
                    dimsArray[i] = chain.get(i).getDim1();
                    if (dimsArray[i] <= 0.0) {
                        throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i]);
                    }
                } else if (chain.get(i - 1).getDim2() != chain.get(i).getDim1()) {
                    throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Matrix Dimension Mismatch: " + chain.get(i - 1).getDim2() + " != " + chain.get(i).getDim1());
                }
                dimsArray[i + 1] = chain.get(i).getDim2();
                if (!(dimsArray[i + 1] <= 0.0)) continue;
                throw new HopsException(hop.printErrorLocation() + "Hops::optimizeMMChain() : Invalid Matrix Dimension: " + dimsArray[i + 1]);
            }
        }
        return dimsKnown;
    }

    private static int inputCount(Hop p, Hop h) {
        return CollectionUtils.cardinality((Object)h, p.getInput());
    }

    private static void logTraceHop(Hop hop, int level) {
        if (LOG.isTraceEnabled()) {
            String offset = Explain.getIdentation(level);
            LOG.trace((Object)(offset + "Hop " + hop.getName() + "(" + hop.getClass().getSimpleName() + ", " + hop.getHopID() + ") " + hop.getDim1() + "x" + hop.getDim2()));
        }
    }
}

