/*
 * Decompiled with CFR 0.152.
 */
package water.rapids.ast.prims.mungers;

import java.util.HashMap;
import water.H2O;
import water.MRTask;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.Env;
import water.rapids.ast.AstPrimitive;
import water.rapids.ast.AstRoot;
import water.rapids.ast.params.AstNumList;
import water.rapids.ast.prims.mungers.AstGroup;
import water.rapids.vals.ValFrame;
import water.util.IcedHashMap;
import water.util.Log;

public class AstGroupedPermute
extends AstPrimitive {
    @Override
    public String[] args() {
        return new String[]{"ary", "permCol", "groupBy", "permuteBy", "keepCol"};
    }

    @Override
    public int nargs() {
        return 6;
    }

    @Override
    public String str() {
        return "grouped_permute";
    }

    @Override
    public ValFrame apply(Env env, Env.StackHelp stk, AstRoot[] asts) {
        Frame fr = stk.track(asts[1].exec(env)).getFrame();
        int permCol = (int)asts[2].exec(env).getNum();
        AstNumList groupby2 = AstGroup.check(fr.numCols(), asts[3]);
        int[] gbCols = groupby2.expand4();
        int permuteBy = (int)asts[4].exec(env).getNum();
        int keepCol = (int)asts[5].exec(env).getNum();
        String[] names = new String[gbCols.length + 4];
        for (int i2 = 0; i2 < gbCols.length; ++i2) {
            names[i2] = fr.name(gbCols[i2]);
        }
        names[i2++] = "In";
        names[i2++] = "Out";
        names[i2++] = "InAmnt";
        names[i2] = "OutAmnt";
        String[][] domains = new String[names.length][];
        for (int d2 = 0; d2 < gbCols.length; ++d2) {
            domains[d2] = fr.domains()[gbCols[d2]];
        }
        domains[d2++] = fr.domains()[permCol];
        domains[d2++] = fr.domains()[permCol];
        domains[d2++] = fr.domains()[keepCol];
        domains[d2] = fr.domains()[keepCol];
        long s2 = System.currentTimeMillis();
        BuildGroups t2 = (BuildGroups)new BuildGroups(gbCols, permuteBy, permCol, keepCol).doAll(fr);
        Log.info("Elapsed time: " + (double)(System.currentTimeMillis() - s2) / 1000.0 + "s");
        s2 = System.currentTimeMillis();
        SmashGroups sg = new SmashGroups(t2._grps);
        H2O.submitTask(sg).join();
        Log.info("Elapsed time: " + (double)(System.currentTimeMillis() - s2) / 1000.0 + "s");
        return new ValFrame(AstGroupedPermute.buildOutput((double[][][])sg._res.values().toArray((T[])new double[0][][]), names, domains));
    }

    private static Frame buildOutput(final double[][][] a2, String[] names, String[][] domains) {
        Frame dVec = new Frame(Vec.makeSeq(0L, a2.length));
        long s2 = System.currentTimeMillis();
        Frame res = ((MRTask)new MRTask(){

            @Override
            public void map(Chunk[] cs, NewChunk[] ncs) {
                for (int i2 = 0; i2 < cs[0]._len; ++i2) {
                    for (double[] anAa : a2[(int)cs[0].at8(i2)]) {
                        for (int k2 = 0; k2 < anAa.length; ++k2) {
                            ncs[k2].addNum(anAa[k2]);
                        }
                    }
                }
            }
        }.doAll(5, (byte)3, dVec)).outputFrame(null, names, domains);
        Log.info("Elapsed time: " + (double)(System.currentTimeMillis() - s2) / 1000.0 + "s");
        dVec.delete();
        return res;
    }

    private static class SmashGroups
    extends H2O.H2OCountedCompleter<SmashGroups> {
        private final IcedHashMap<Long, IcedHashMap<Long, double[]>[]> _grps;
        private final HashMap<Integer, Long> _map;
        private int _hi;
        private int _lo;
        SmashGroups _left;
        SmashGroups _rite;
        private IcedHashMap<Long, double[][]> _res;

        SmashGroups(IcedHashMap<Long, IcedHashMap<Long, double[]>[]> grps) {
            this._grps = grps;
            this._lo = 0;
            this._hi = this._grps.size();
            this._res = new IcedHashMap();
            this._map = new HashMap();
            int i2 = 0;
            for (Long l2 : this._grps.keySet()) {
                this._map.put(i2++, l2);
            }
        }

        @Override
        public void compute2() {
            assert (this._left == null && this._rite == null);
            if (this._hi - this._lo >= 2) {
                int mid = this._lo + this._hi >>> 1;
                this._left = this.copyAndInit();
                this._rite = this.copyAndInit();
                this._left._hi = mid;
                this._rite._lo = mid;
                this.addToPendingCount(1);
                this._left.fork();
                this._rite.compute2();
                return;
            }
            if (this._hi > this._lo) {
                this.smash();
            }
            this.tryComplete();
        }

        private void smash() {
            long key = this._map.get(this._lo);
            IcedHashMap[] pair = (IcedHashMap[])this._grps.get(key);
            double[][] res = new double[pair[0].size() * pair[1].size()][];
            int d0 = 0;
            for (double[] ds0 : pair[0].values()) {
                for (double[] ds1 : pair[1].values()) {
                    res[d0++] = new double[]{key, ds0[0], ds1[0], ds0[1], ds1[1]};
                }
            }
            this._res.put(key, res);
        }

        private SmashGroups copyAndInit() {
            SmashGroups x2 = (SmashGroups)this.clone();
            x2.setCompleter(this);
            x2._rite = null;
            x2._left = null;
            x2.setPendingCount(0);
            return x2;
        }
    }

    private static class BuildGroups
    extends MRTask<BuildGroups> {
        IcedHashMap<Long, IcedHashMap<Long, double[]>[]> _grps;
        private final int[] _gbCols;
        private final int _permuteBy;
        private final int _permuteCol;
        private final int _amntCol;

        BuildGroups(int[] gbCols, int permuteBy, int permuteCol, int amntCol) {
            this._gbCols = gbCols;
            this._permuteBy = permuteBy;
            this._permuteCol = permuteCol;
            this._amntCol = amntCol;
        }

        @Override
        public void setupLocal() {
            this._grps = new IcedHashMap();
        }

        @Override
        public void map(Chunk[] chks) {
            String[] dom = chks[this._permuteBy].vec().domain();
            IcedHashMap<Long, IcedHashMap<Long, double[]>[]> grps = new IcedHashMap<Long, IcedHashMap<Long, double[]>[]>();
            for (int row = 0; row < chks[0]._len; ++row) {
                int type;
                long jid = chks[this._gbCols[0]].at8(row);
                long rid = chks[this._permuteCol].at8(row);
                double[] aci = new double[]{rid, chks[this._amntCol].atd(row)};
                int n2 = type = dom[(int)chks[this._permuteBy].at8(row)].equals("D") ? 0 : 1;
                if (grps.containsKey(jid)) {
                    IcedHashMap[] dcWork = (IcedHashMap[])grps.get(jid);
                    if (dcWork[type].putIfAbsent(rid, aci) == null) continue;
                    double[] dArray = (double[])dcWork[type].get(rid);
                    dArray[1] = dArray[1] + aci[1];
                    continue;
                }
                IcedHashMap[] dcAcnts = new IcedHashMap[]{new IcedHashMap(), new IcedHashMap()};
                dcAcnts[type].put(rid, aci);
                grps.put(jid, dcAcnts);
            }
            this.reduce(grps);
        }

        @Override
        public void reduce(BuildGroups t2) {
            if (this._grps != t2._grps) {
                this.reduce(t2._grps);
            }
        }

        @Override
        private void reduce(IcedHashMap<Long, IcedHashMap<Long, double[]>[]> r2) {
            for (Long l2 : r2.keySet()) {
                if (this._grps.putIfAbsent(l2, (IcedHashMap<Long, double[]>[])r2.get(l2)) == null) continue;
                IcedHashMap[] rdbls = (IcedHashMap[])r2.get(l2);
                IcedHashMap[] ldbls = (IcedHashMap[])this._grps.get(l2);
                for (Long rr : rdbls[0].keySet()) {
                    if (ldbls[0].putIfAbsent(rr, rdbls[0].get(rr)) == null) continue;
                    double[] dArray = (double[])ldbls[0].get(rr);
                    dArray[1] = dArray[1] + ((double[])rdbls[0].get(rr))[1];
                }
                for (Long rr : rdbls[1].keySet()) {
                    if (ldbls[1].putIfAbsent(rr, rdbls[1].get(rr)) == null) continue;
                    double[] dArray = (double[])ldbls[1].get(rr);
                    dArray[1] = dArray[1] + ((double[])rdbls[1].get(rr))[1];
                }
            }
        }
    }
}

