/*
 * Decompiled with CFR 0.152.
 */
package hex.aggregator;

import hex.DataInfo;
import hex.Model;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ToEigenVec;
import hex.aggregator.AggregatorModel;
import hex.util.LinearAlgebraUtils;
import java.util.Arrays;
import water.DKV;
import water.Iced;
import water.IcedUtils;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.MRTask;
import water.Scope;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.ArrayUtils;
import water.util.IcedInt;
import water.util.Log;

public class Aggregator
extends ModelBuilder<AggregatorModel, AggregatorModel.AggregatorParameters, AggregatorModel.AggregatorOutput> {
    @Override
    public ToEigenVec getToEigenVec() {
        return LinearAlgebraUtils.toEigen;
    }

    @Override
    public ModelBuilder.BuilderVisibility builderVisibility() {
        return ModelBuilder.BuilderVisibility.Stable;
    }

    @Override
    public boolean isSupervised() {
        return false;
    }

    @Override
    protected AggregatorDriver trainModelImpl() {
        return new AggregatorDriver();
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Clustering};
    }

    public Aggregator(AggregatorModel.AggregatorParameters parms) {
        super(parms);
        this.init(false);
    }

    public Aggregator(boolean startup_once) {
        super(new AggregatorModel.AggregatorParameters(), startup_once);
    }

    @Override
    public void init(boolean expensive) {
        if (expensive && ((AggregatorModel.AggregatorParameters)this._parms)._categorical_encoding == Model.Parameters.CategoricalEncodingScheme.AUTO) {
            ((AggregatorModel.AggregatorParameters)this._parms)._categorical_encoding = Model.Parameters.CategoricalEncodingScheme.Eigen;
        }
        if (((AggregatorModel.AggregatorParameters)this._parms)._target_num_exemplars <= 0) {
            this.error("_target_num_exemplars", "target_num_exemplars must be > 0.");
        }
        if (((AggregatorModel.AggregatorParameters)this._parms)._rel_tol_num_exemplars <= 0.0 || ((AggregatorModel.AggregatorParameters)this._parms)._rel_tol_num_exemplars >= 1.0) {
            this.error("_rel_tol_num_exemplars", "rel_tol_num_exemplars must be inside 0...1.");
        }
        super.init(expensive);
        if (this.error_count() > 0) {
            throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(this);
        }
    }

    private static class RenumberTask
    extends MRTask<RenumberTask> {
        final long[][] _map;

        public RenumberTask(AggregateTask.GIDMapping mapping) {
            this._map = mapping.unsortedList();
        }

        @Override
        public void map(Chunk c2) {
            for (int i2 = 0; i2 < c2._len; ++i2) {
                long old = c2.at8(i2);
                int pos = ArrayUtils.find(this._map[0], old);
                if (pos < 0) continue;
                long newVal = this._map[1][pos];
                c2.set(i2, newVal);
            }
        }
    }

    private static class AggregateTask
    extends MRTask<AggregateTask> {
        final double _delta;
        final Key _dataInfoKey;
        final Key _jobKey;
        final int _maxExemplars;
        Exemplar[] _exemplars;
        Key _terminateKey;
        GIDMapping _mapping;

        public AggregateTask(Key<DataInfo> dataInfoKey, double radius, Key<Job> jobKey, int maxExemplars, Key terminateKey) {
            this._delta = radius * radius;
            this._dataInfoKey = dataInfoKey;
            this._jobKey = jobKey;
            this._maxExemplars = maxExemplars;
            this._terminateKey = terminateKey;
            if (this._terminateKey != null) {
                DKV.put(this._terminateKey, new IcedInt(0));
            }
        }

        private boolean isTerminated() {
            return this._terminateKey != null && ((IcedInt)DKV.getGet((Key)this._terminateKey))._val == 1;
        }

        private void terminate() {
            if (this._terminateKey != null) {
                DKV.put(this._terminateKey, new IcedInt(1));
            }
        }

        @Override
        public void map(Chunk[] chks) {
            this._mapping = new GIDMapping();
            Exemplar[] es = new Exemplar[4];
            Chunk[] dataChks = Arrays.copyOf(chks, chks.length - 1);
            Chunk assignmentChk = chks[chks.length - 1];
            DataInfo di = (DataInfo)this._dataInfoKey.get();
            assert (di != null);
            DataInfo.Row row = di.newDenseRow();
            int nCols = row.nNums;
            for (int r2 = 0; r2 < chks[0]._len; ++r2) {
                if (r2 % 100 == 0 && this.isTerminated()) {
                    return;
                }
                long rowIndex = chks[0].start() + (long)r2;
                row = di.extractDenseRow(dataChks, r2, row);
                double[] data = Arrays.copyOf(row.numVals, nCols);
                int[] cats = Arrays.copyOf(row.binIds, row.binIds.length);
                if (r2 == 0) {
                    Exemplar ex = new Exemplar(data, cats, rowIndex);
                    es = Exemplar.addExemplar(es, ex);
                    assignmentChk.set(r2, ex.gid);
                    continue;
                }
                double distanceToNearestExemplar = Double.MAX_VALUE;
                int closestExemplarIndex = 0;
                int index = 0;
                long gid = -1L;
                for (Exemplar e2 : es) {
                    if (null == e2) break;
                    if (!Arrays.equals(cats, e2.cats)) {
                        ++index;
                        continue;
                    }
                    double distToExemplar = e2.squaredEuclideanDistance(data, distanceToNearestExemplar);
                    if (distToExemplar < distanceToNearestExemplar) {
                        distanceToNearestExemplar = distToExemplar;
                        closestExemplarIndex = index;
                        gid = e2.gid;
                    }
                    if (distanceToNearestExemplar < this._delta) break;
                    ++index;
                }
                if (distanceToNearestExemplar < this._delta) {
                    ++es[closestExemplarIndex]._cnt;
                    assignmentChk.set(r2, gid);
                    continue;
                }
                Exemplar ex = new Exemplar(data, cats, rowIndex);
                assert (Arrays.equals(cats, ex.cats));
                if ((es = Exemplar.addExemplar(es, ex)).length > 2 * this._maxExemplars) {
                    this.terminate();
                }
                assignmentChk.set(r2, rowIndex);
            }
            this._exemplars = Exemplar.trim(es);
            if (this._exemplars.length > this._maxExemplars) {
                this.terminate();
            }
            if (this.isTerminated()) {
                return;
            }
            assert (this._exemplars.length <= chks[0].len());
            long sum = 0L;
            for (Exemplar e3 : this._exemplars) {
                sum += e3._cnt;
            }
            assert (sum <= (long)chks[0].len());
            ((Job)this._jobKey.get()).update(1L, "Aggregating.");
        }

        @Override
        public void reduce(AggregateTask mrt) {
            if (this.isTerminated() || this._exemplars == null || mrt._exemplars == null || this._exemplars.length > this._maxExemplars || mrt._exemplars.length > this._maxExemplars) {
                this.terminate();
                this._mapping = null;
                this._exemplars = null;
                mrt._exemplars = null;
            }
            if (this.isTerminated()) {
                return;
            }
            for (int i2 = 0; i2 < mrt._mapping.len; ++i2) {
                this._mapping.set(mrt._mapping.pairSet[i2].first, mrt._mapping.pairSet[i2].second);
            }
            Exemplar[] exemplars = mrt._exemplars;
            long localCounts = 0L;
            for (Exemplar e2 : this._exemplars) {
                localCounts += e2._cnt;
            }
            long remoteCounts = 0L;
            for (Exemplar e3 : mrt._exemplars) {
                remoteCounts += e3._cnt;
            }
            for (int r2 = 0; r2 < mrt._exemplars.length; ++r2) {
                double distanceToNearestExemplar = Double.MAX_VALUE;
                int closestExemplarIndex = 0;
                int index = 0;
                for (Exemplar le : this._exemplars) {
                    if (null == le) break;
                    double distToExemplar = le.squaredEuclideanDistance(mrt._exemplars[r2].data, distanceToNearestExemplar);
                    if (distToExemplar < distanceToNearestExemplar) {
                        distanceToNearestExemplar = distToExemplar;
                        closestExemplarIndex = index;
                    }
                    if (distanceToNearestExemplar < this._delta) break;
                    ++index;
                }
                if (distanceToNearestExemplar < this._delta) {
                    this._exemplars[closestExemplarIndex]._cnt += mrt._exemplars[r2]._cnt;
                    this._mapping.set(exemplars[r2].gid, this._exemplars[closestExemplarIndex].gid);
                    continue;
                }
                this._exemplars = Exemplar.addExemplar(this._exemplars, IcedUtils.deepCopy(mrt._exemplars[r2]));
            }
            mrt._exemplars = null;
            this._exemplars = Exemplar.trim(this._exemplars);
            assert ((long)this._exemplars.length <= localCounts + remoteCounts);
            long sum = 0L;
            for (Exemplar e4 : this._exemplars) {
                sum += e4._cnt;
            }
            assert (sum == localCounts + remoteCounts);
            ((Job)this._jobKey.get()).update(1L, "Aggregating.");
        }

        private static class GIDMapping
        extends Iced<GIDMapping> {
            MyPair[] pairSet = new MyPair[this.capacity];
            int len = 0;
            int capacity = 32;

            void set(long from, long to) {
                for (int i2 = 0; i2 < this.len; ++i2) {
                    MyPair p2 = this.pairSet[i2];
                    if (p2.second != from) continue;
                    p2.second = to;
                }
                MyPair p3 = new MyPair(from, to);
                if (this.len == this.capacity) {
                    this.capacity *= 2;
                    this.pairSet = Arrays.copyOf(this.pairSet, this.capacity);
                }
                this.pairSet[this.len++] = p3;
            }

            long[][] unsortedList() {
                long[][] li = new long[2][this.len];
                MyPair[] pl = this.pairSet;
                for (int i2 = 0; i2 < this.len; ++i2) {
                    li[0][i2] = pl[i2].first;
                    li[1][i2] = pl[i2].second;
                }
                return li;
            }
        }

        static class MyPair
        extends Iced<MyPair>
        implements Comparable<MyPair> {
            long first;
            long second;

            public MyPair(long f2, long s2) {
                this.first = f2;
                this.second = s2;
            }

            public MyPair() {
            }

            @Override
            public int compareTo(MyPair o2) {
                if (this.first < o2.first) {
                    return -1;
                }
                if (this.first == o2.first) {
                    return 0;
                }
                return 1;
            }
        }
    }

    class AggregatorDriver
    extends ModelBuilder.Driver {
        AggregatorDriver() {
            super(Aggregator.this);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void computeImpl() {
            Keyed di;
            block23: {
                Frame mappingFrame;
                Frame outFrame;
                Lockable model = null;
                di = null;
                try {
                    int numExemplars;
                    AggregateTask aggTask;
                    Vec assignment;
                    Aggregator.this.init(true);
                    if (Aggregator.this.error_count() > 0) {
                        throw new IllegalArgumentException("Found validation errors: " + Aggregator.this.validationErrors());
                    }
                    model = new AggregatorModel(Aggregator.this.dest(), (AggregatorModel.AggregatorParameters)Aggregator.this._parms, new AggregatorModel.AggregatorOutput(Aggregator.this));
                    model.delete_and_lock(Aggregator.this._job);
                    Frame orig = Aggregator.this.train();
                    Aggregator.this._job.update(1L, "Preprocessing data.");
                    di = new DataInfo(orig, null, true, ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._transform, false, false, false);
                    DKV.put(di);
                    double radiusBase = 0.1 / Math.pow(Math.log(orig.numRows()), 1.0 / (double)orig.numCols());
                    int targetNumExemplars = (int)Math.min((long)((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._target_num_exemplars, orig.numRows());
                    Aggregator.this._job.update(0L, "Aggregating.");
                    double lo = 0.0;
                    double hi = 256.0;
                    double mid = 8.0;
                    int noNewExamplarsIterCount = 0;
                    int previousNumExemplars = 0;
                    double tol = ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._rel_tol_num_exemplars;
                    int upperLimit = (int)((1.0 + tol) * (double)targetNumExemplars);
                    int lowerLimit = (int)((1.0 - tol) * (double)targetNumExemplars);
                    Key terminateKey = Key.make();
                    while (true) {
                        Log.info("radius_scale lo/mid/hi: " + lo + "/" + mid + "/" + hi);
                        double radius = mid * radiusBase;
                        if ((long)targetNumExemplars == orig.numRows()) {
                            radius = 0.0;
                        }
                        Vec[] vecs = Arrays.copyOf(orig.vecs(), orig.vecs().length + 1);
                        Vec vec = orig.anyVec().makeZero();
                        vecs[vecs.length - 1] = vec;
                        assignment = vec;
                        Log.info("Aggregating with radius " + String.format("%5f", radius) + ":");
                        aggTask = (AggregateTask)new AggregateTask(((DataInfo)di)._key, radius, Aggregator.this._job._key, upperLimit, radius == 0.0 ? null : terminateKey).doAll(vecs);
                        if (radius == 0.0) {
                            Log.info(" Returning original dataset.");
                            numExemplars = aggTask._exemplars.length;
                            assert ((long)numExemplars == orig.numRows());
                            break;
                        }
                        if (aggTask.isTerminated() && Math.abs(hi - lo) < 0.001 * Math.abs(lo + hi)) {
                            aggTask = (AggregateTask)new AggregateTask(((DataInfo)di)._key, radius, Aggregator.this._job._key, (int)orig.numRows(), terminateKey).doAll(vecs);
                            Log.info(" Running again without early cutout.");
                            numExemplars = aggTask._exemplars.length;
                            break;
                        }
                        if (aggTask.isTerminated() || aggTask._exemplars.length > upperLimit) {
                            Log.info(" Too many exemplars.");
                            lo = mid;
                        } else {
                            numExemplars = aggTask._exemplars.length;
                            Log.info(" " + numExemplars + " exemplars.");
                            if (numExemplars >= lowerLimit && numExemplars <= upperLimit) {
                                Log.info(" Within " + 100.0 * tol + "% of target number of exemplars. Done.");
                                break;
                            }
                            Log.info(" Too few exemplars.");
                            hi = mid;
                            if (previousNumExemplars == numExemplars) {
                                ++noNewExamplarsIterCount;
                            }
                            if (noNewExamplarsIterCount > ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._num_iteration_without_new_exemplar) {
                                Log.info("Exiting with " + numExemplars + " exemplars as last " + ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._num_iteration_without_new_exemplar + " iterations did not accure any more exemplars");
                                break;
                            }
                            previousNumExemplars = numExemplars;
                        }
                        mid = lo + (hi - lo) / 2.0;
                    }
                    Aggregator.this._job.update(1L, "Aggregation finished. Got " + numExemplars + " examplars");
                    assert (!aggTask.isTerminated());
                    DKV.remove(terminateKey);
                    String msg = "Creating exemplar assignments.";
                    Log.info(msg);
                    Aggregator.this._job.update(1L, msg);
                    new RenumberTask(aggTask._mapping).doAll(assignment);
                    ((AggregatorModel)model)._exemplars = aggTask._exemplars;
                    ((AggregatorModel)model)._counts = new long[aggTask._exemplars.length];
                    for (int i2 = 0; i2 < aggTask._exemplars.length; ++i2) {
                        ((AggregatorModel)model)._counts[i2] = aggTask._exemplars[i2]._cnt;
                    }
                    ((AggregatorModel)model)._exemplar_assignment_vec_key = assignment._key;
                    ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._output_frame = Key.make("aggregated_" + ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._train.toString() + "_by_" + ((AggregatorModel)model)._key);
                    msg = "Creating output frame.";
                    Log.info(msg);
                    Aggregator.this._job.update(1L, msg);
                    ((AggregatorModel)model).createFrameOfExemplars((Frame)((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._train.get(), ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._output_frame);
                    if (((AggregatorModel.AggregatorParameters)((AggregatorModel)model)._parms)._save_mapping_frame) {
                        ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._mapping_frame = Key.make("aggregated_mapping_" + ((AggregatorModel.AggregatorParameters)Aggregator.this._parms)._train.toString() + "_by_" + ((AggregatorModel)model)._key);
                        ((AggregatorModel)model).createMappingOfExemplars(((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._mapping_frame);
                    }
                    Aggregator.this._job.update(1L, "Done.");
                    model.update(Aggregator.this._job);
                    if (model == null) break block23;
                }
                catch (Throwable throwable) {
                    if (model != null) {
                        Frame mappingFrame2;
                        Frame outFrame2;
                        model.unlock(Aggregator.this._job);
                        Scope.untrack(((AggregatorModel)model)._exemplar_assignment_vec_key);
                        Frame frame = outFrame2 = ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._output_frame != null ? ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._output_frame.get() : null;
                        if (outFrame2 != null) {
                            Scope.untrack(outFrame2.keys());
                        }
                        Frame frame2 = mappingFrame2 = ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._mapping_frame != null ? ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._mapping_frame.get() : null;
                        if (mappingFrame2 != null) {
                            Scope.untrack(mappingFrame2.keys());
                        }
                    }
                    if (di != null) {
                        di.remove();
                    }
                    throw throwable;
                }
                model.unlock(Aggregator.this._job);
                Scope.untrack(((AggregatorModel)model)._exemplar_assignment_vec_key);
                Frame frame = outFrame = ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._output_frame != null ? ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._output_frame.get() : null;
                if (outFrame != null) {
                    Scope.untrack(outFrame.keys());
                }
                Frame frame3 = mappingFrame = ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._mapping_frame != null ? ((AggregatorModel.AggregatorOutput)((AggregatorModel)model)._output)._mapping_frame.get() : null;
                if (mappingFrame != null) {
                    Scope.untrack(mappingFrame.keys());
                }
            }
            if (di != null) {
                di.remove();
            }
        }
    }

    public static class Exemplar
    extends Iced<Exemplar> {
        final double[] data;
        final int[] cats;
        final long gid;
        long _cnt;

        Exemplar(double[] d2, int[] c2, long id) {
            this.data = d2;
            this.cats = c2;
            this.gid = id;
            this._cnt = 1L;
        }

        public static Exemplar[] addExemplar(Exemplar[] es, Exemplar e2) {
            int idx;
            if (es.length == 0) {
                return new Exemplar[]{e2};
            }
            Exemplar[] res = es;
            for (idx = es.length - 1; idx >= 0 && es[idx] == null; --idx) {
            }
            if (idx == es.length - 1) {
                res = Arrays.copyOf(es, es.length << 1);
                res[es.length] = e2;
                return res;
            }
            res[idx + 1] = e2;
            return res;
        }

        public static Exemplar[] trim(Exemplar[] es) {
            int idx;
            for (idx = es.length - 1; idx >= 0 && null == es[idx]; --idx) {
            }
            return Arrays.copyOf(es, idx + 1);
        }

        private double squaredEuclideanDistance(double[] e2, double thresh) {
            double sum = 0.0;
            int n2 = 0;
            boolean missing = false;
            double[] e1 = this.data;
            double ncols = e1.length;
            int j2 = 0;
            while ((double)j2 < ncols) {
                double d1 = e1[j2];
                double d2 = e2[j2];
                if (!Exemplar.isMissing(d1) && !Exemplar.isMissing(d2)) {
                    double dist = d1 - d2;
                    sum += dist * dist;
                    ++n2;
                } else {
                    missing = true;
                }
                if (!missing && sum > thresh) break;
                ++j2;
            }
            return sum *= ncols / (double)n2;
        }

        private static boolean isMissing(double x2) {
            return Double.isNaN(x2);
        }
    }
}

