/*
 * Decompiled with CFR 0.152.
 */
package hex.naivebayes;

import hex.DataInfo;
import hex.ModelBuilder;
import hex.ModelCategory;
import hex.ModelMetrics;
import hex.naivebayes.NaiveBayesModel;
import java.util.ArrayList;
import java.util.Arrays;
import water.H2O;
import water.Job;
import water.Key;
import water.Keyed;
import water.Lockable;
import water.MRTask;
import water.exceptions.H2OModelBuilderIllegalArgumentException;
import water.fvec.Chunk;
import water.util.ArrayUtils;
import water.util.PrettyPrint;
import water.util.TwoDimTable;

public class NaiveBayes
extends ModelBuilder<NaiveBayesModel, NaiveBayesModel.NaiveBayesParameters, NaiveBayesModel.NaiveBayesOutput> {
    @Override
    public boolean isSupervised() {
        return true;
    }

    @Override
    protected NaiveBayesDriver trainModelImpl() {
        return new NaiveBayesDriver();
    }

    @Override
    public ModelCategory[] can_build() {
        return new ModelCategory[]{ModelCategory.Unknown};
    }

    @Override
    public boolean havePojo() {
        return true;
    }

    @Override
    public boolean haveMojo() {
        return false;
    }

    @Override
    protected void checkMemoryFootPrint_impl() {
        long mem_usage = (this._train.numCols() - 1) * this._train.lastVec().cardinality();
        String[][] domains = this._train.domains();
        long count2 = 0L;
        for (int i2 = 0; i2 < this._train.numCols() - 1; ++i2) {
            count2 += domains[i2] == null ? 2L : (long)domains[i2].length;
        }
        mem_usage *= count2;
        long max_mem = H2O.SELF._heartbeat.get_free_mem();
        if ((mem_usage *= 8L) > max_mem) {
            String msg = "Conditional probabilities won't fit in the driver node's memory (" + PrettyPrint.bytes(mem_usage) + " > " + PrettyPrint.bytes(max_mem) + ") - try reducing the number of columns, the number of response classes or the number of categorical factors of the predictors.";
            this.error("_train", msg);
        }
    }

    public NaiveBayes(NaiveBayesModel.NaiveBayesParameters parms) {
        super(parms);
        this.init(false);
    }

    public NaiveBayes(boolean startup_once) {
        super(new NaiveBayesModel.NaiveBayesParameters(), startup_once);
    }

    @Override
    public void init(boolean expensive) {
        super.init(expensive);
        if (this._response != null) {
            if (!this._response.isCategorical()) {
                this.error("_response", "Response must be a categorical column");
            } else if (this._response.isConst()) {
                this.error("_response", "Response must have at least two unique categorical levels");
            }
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._laplace < 0.0) {
            this.error("_laplace", "Laplace smoothing must be a number >= 0");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._min_sdev < 1.0E-10) {
            this.error("_min_sdev", "Min. standard deviation must be at least 1e-10");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._eps_sdev < 0.0) {
            this.error("_eps_sdev", "Threshold for standard deviation must be positive");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._min_prob < 1.0E-10) {
            this.error("_min_prob", "Min. probability must be at least 1e-10");
        }
        if (((NaiveBayesModel.NaiveBayesParameters)this._parms)._eps_prob < 0.0) {
            this.error("_eps_prob", "Threshold for probability must be positive");
        }
        this.hide("_balance_classes", "Balance classes is not applicable to NaiveBayes.");
        this.hide("_class_sampling_factors", "Class sampling factors is not applicable to NaiveBayes.");
        this.hide("_max_after_balance_size", "Max after balance size is not applicable to NaiveBayes.");
        if (expensive && this.error_count() == 0) {
            this.checkMemoryFootPrint();
        }
    }

    private TwoDimTable createModelSummaryTable(NaiveBayesModel.NaiveBayesOutput output) {
        ArrayList<String> colHeaders = new ArrayList<String>();
        ArrayList<String> colTypes = new ArrayList<String>();
        ArrayList<String> colFormat = new ArrayList<String>();
        colHeaders.add("Number of Response Levels");
        colTypes.add("long");
        colFormat.add("%d");
        colHeaders.add("Min Apriori Probability");
        colTypes.add("double");
        colFormat.add("%.5f");
        colHeaders.add("Max Apriori Probability");
        colTypes.add("double");
        colFormat.add("%.5f");
        double apriori_min = output._apriori_raw[0];
        double apriori_max = output._apriori_raw[0];
        for (int i2 = 1; i2 < output._apriori_raw.length; ++i2) {
            if (output._apriori_raw[i2] < apriori_min) {
                apriori_min = output._apriori_raw[i2];
                continue;
            }
            if (!(output._apriori_raw[i2] > apriori_max)) continue;
            apriori_max = output._apriori_raw[i2];
        }
        boolean rows = true;
        TwoDimTable table = new TwoDimTable("Model Summary", null, new String[1], colHeaders.toArray(new String[0]), colTypes.toArray(new String[0]), colFormat.toArray(new String[0]), "");
        int row = 0;
        int col = 0;
        table.set(row, col++, output._apriori_raw.length);
        table.set(row, col++, apriori_min);
        table.set(row, col, apriori_max);
        return table;
    }

    private static class NBTask
    extends MRTask<NBTask> {
        protected final Key<Job> _jobKey;
        final DataInfo _dinfo;
        final String[][] _domains;
        final int _nrescat;
        final int _npreds;
        public int _nobs;
        public int[] _rescnt;
        public int[][][] _jntcnt;
        public double[][][] _jntsum;

        public NBTask(Key<Job> jobKey, DataInfo dinfo, int nres) {
            this._jobKey = jobKey;
            this._dinfo = dinfo;
            this._nrescat = nres;
            this._domains = dinfo._adaptedFrame.domains();
            this._npreds = dinfo._cats + dinfo._nums;
        }

        @Override
        public void map(Chunk[] chks) {
            int i2;
            if (this._jobKey.get().stop_requested()) {
                return;
            }
            this._nobs = 0;
            this._rescnt = new int[this._nrescat];
            if (this._dinfo._cats > 0) {
                this._jntcnt = new int[this._dinfo._cats][][];
                for (i2 = 0; i2 < this._dinfo._cats; ++i2) {
                    this._jntcnt[i2] = new int[this._nrescat][this._domains[i2].length];
                }
            }
            if (this._dinfo._nums > 0) {
                this._jntsum = new double[this._dinfo._nums][][];
                for (i2 = 0; i2 < this._dinfo._nums; ++i2) {
                    this._jntsum[i2] = new double[this._nrescat][2];
                }
            }
            Chunk res = chks[this._dinfo.responseChunkId(0)];
            block2: for (int row = 0; row < chks[0]._len; ++row) {
                int col;
                if (this._dinfo._weights && chks[this._dinfo.weightChunkId()].atd(row) == 0.0) continue;
                if (this._dinfo._weights && chks[this._dinfo.weightChunkId()].atd(row) != 1.0) {
                    throw new IllegalArgumentException("Weights must be either 0 or 1 for Naive Bayes.");
                }
                for (Chunk chk : chks) {
                    if (Double.isNaN(chk.atd(row))) continue block2;
                }
                int rlevel = (int)res.atd(row);
                for (col = 0; col < this._dinfo._cats; ++col) {
                    int plevel = (int)chks[col].atd(row);
                    int[] nArray = this._jntcnt[col][rlevel];
                    int n2 = plevel;
                    nArray[n2] = nArray[n2] + 1;
                }
                for (col = 0; col < this._dinfo._nums; ++col) {
                    int cidx = this._dinfo._cats + col;
                    double x2 = chks[cidx].atd(row);
                    double[] dArray = this._jntsum[col][rlevel];
                    dArray[0] = dArray[0] + x2;
                    double[] dArray2 = this._jntsum[col][rlevel];
                    dArray2[1] = dArray2[1] + x2 * x2;
                }
                int n3 = rlevel;
                this._rescnt[n3] = this._rescnt[n3] + 1;
                ++this._nobs;
            }
        }

        @Override
        public void reduce(NBTask nt) {
            int col;
            this._nobs += nt._nobs;
            ArrayUtils.add(this._rescnt, nt._rescnt);
            if (null != this._jntcnt) {
                for (col = 0; col < this._jntcnt.length; ++col) {
                    ArrayUtils.add(this._jntcnt[col], nt._jntcnt[col]);
                }
            }
            if (null != this._jntsum) {
                for (col = 0; col < this._jntsum.length; ++col) {
                    ArrayUtils.add(this._jntsum[col], nt._jntsum[col]);
                }
            }
        }
    }

    class NaiveBayesDriver
    extends ModelBuilder.Driver {
        NaiveBayesDriver() {
            super(NaiveBayes.this);
        }

        public boolean computeStatsFillModel(NaiveBayesModel model, DataInfo dinfo, NBTask tsk) {
            int col;
            int i2;
            int col2;
            int i3;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._levels = NaiveBayes.this._response.domain();
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._rescnt = tsk._rescnt;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._ncats = dinfo._cats;
            if (NaiveBayes.this.stop_requested() && !NaiveBayes.this.timeout()) {
                return false;
            }
            NaiveBayes.this._job.update(1L, "Initializing arrays for model statistics");
            String[][] domains = ((NaiveBayesModel.NaiveBayesOutput)model._output)._domains;
            double[] apriori = new double[tsk._nrescat];
            double[][][] pcond = new double[tsk._npreds][][];
            for (i3 = 0; i3 < pcond.length; ++i3) {
                int ncnt = domains[i3] == null ? 2 : domains[i3].length;
                pcond[i3] = new double[tsk._nrescat][ncnt];
            }
            if (NaiveBayes.this.stop_requested() && !NaiveBayes.this.timeout()) {
                return false;
            }
            NaiveBayes.this._job.update(1L, "Computing probabilities for categorical cols");
            for (i3 = 0; i3 < apriori.length; ++i3) {
                apriori[i3] = ((double)tsk._rescnt[i3] + ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace) / ((double)tsk._nobs + (double)tsk._nrescat * ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace);
            }
            for (col2 = 0; col2 < dinfo._cats; ++col2) {
                assert (pcond[col2].length == tsk._nrescat);
                for (i2 = 0; i2 < pcond[col2].length; ++i2) {
                    for (int j2 = 0; j2 < pcond[col2][i2].length; ++j2) {
                        pcond[col2][i2][j2] = ((double)tsk._jntcnt[col2][i2][j2] + ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace) / ((double)tsk._rescnt[i2] + (double)domains[col2].length * ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._laplace);
                    }
                }
            }
            if (NaiveBayes.this.stop_requested() && !NaiveBayes.this.timeout()) {
                return false;
            }
            NaiveBayes.this._job.update(1L, "Computing mean and standard deviation for numeric cols");
            for (col2 = 0; col2 < dinfo._nums; ++col2) {
                for (i2 = 0; i2 < pcond[0].length; ++i2) {
                    double pmean;
                    int cidx = dinfo._cats + col2;
                    double num = tsk._rescnt[i2];
                    pcond[cidx][i2][0] = pmean = tsk._jntsum[col2][i2][0] / num;
                    double pvar = tsk._jntsum[col2][i2][1] / (num - 1.0) - pmean * pmean * num / (num - 1.0);
                    pcond[cidx][i2][1] = Math.sqrt(pvar);
                }
            }
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._apriori_raw = apriori;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond_raw = pcond;
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond = new TwoDimTable[pcond.length];
            String[] rowNames = NaiveBayes.this._response.domain();
            for (col = 0; col < dinfo._cats; ++col) {
                String[] colNames = NaiveBayes.this._train.vec(col).domain();
                Object[] colTypes = new String[colNames.length];
                Object[] colFormats = new String[colNames.length];
                Arrays.fill(colTypes, "double");
                Arrays.fill(colFormats, "%5f");
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond[col] = new TwoDimTable(NaiveBayes.this._train.name(col), null, rowNames, colNames, (String[])colTypes, (String[])colFormats, "Y_by_" + NaiveBayes.this._train.name(col), new String[rowNames.length][], pcond[col]);
            }
            for (col = 0; col < dinfo._nums; ++col) {
                int cidx = dinfo._cats + col;
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._pcond[cidx] = new TwoDimTable(NaiveBayes.this._train.name(cidx), null, rowNames, new String[]{"Mean", "Std_Dev"}, new String[]{"double", "double"}, new String[]{"%5f", "%5f"}, "Y_by_" + NaiveBayes.this._train.name(cidx), new String[rowNames.length][], pcond[cidx]);
            }
            Object[] colTypes = new String[NaiveBayes.this._response.cardinality()];
            Object[] colFormats = new String[NaiveBayes.this._response.cardinality()];
            Arrays.fill(colTypes, "double");
            Arrays.fill(colFormats, "%5f");
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._apriori = new TwoDimTable("A Priori Response Probabilities", null, new String[1], NaiveBayes.this._response.domain(), (String[])colTypes, (String[])colFormats, "", new String[1][], new double[][]{apriori});
            ((NaiveBayesModel.NaiveBayesOutput)model._output)._model_summary = NaiveBayes.this.createModelSummaryTable((NaiveBayesModel.NaiveBayesOutput)model._output);
            if (NaiveBayes.this.stop_requested() && !NaiveBayes.this.timeout()) {
                return false;
            }
            NaiveBayes.this._job.update(1L, "Scoring and computing metrics on training data");
            if (((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms)._compute_metrics) {
                model.score(((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).train()).delete();
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._training_metrics = ModelMetrics.getFromDKV(model, ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).train());
            }
            if (NaiveBayes.this.stop_requested() && !NaiveBayes.this.timeout()) {
                return false;
            }
            NaiveBayes.this._job.update(1L, "Scoring and computing metrics on validation data");
            if (NaiveBayes.this._valid != null) {
                model.score(((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).valid()).delete();
                ((NaiveBayesModel.NaiveBayesOutput)model._output)._validation_metrics = ModelMetrics.getFromDKV(model, ((NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms).valid());
            }
            return true;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void computeImpl() {
            Lockable model = null;
            Keyed dinfo = null;
            try {
                NaiveBayes.this.init(true);
                if (NaiveBayes.this.error_count() > 0) {
                    throw H2OModelBuilderIllegalArgumentException.makeFromBuilder(NaiveBayes.this);
                }
                dinfo = new DataInfo(NaiveBayes.this._train, NaiveBayes.this._valid, 1, false, DataInfo.TransformType.NONE, DataInfo.TransformType.NONE, true, false, false, NaiveBayes.this._weights != null, false, NaiveBayes.this._fold != null);
                model = new NaiveBayesModel(NaiveBayes.this.dest(), (NaiveBayesModel.NaiveBayesParameters)NaiveBayes.this._parms, new NaiveBayesModel.NaiveBayesOutput(NaiveBayes.this));
                model.delete_and_lock(NaiveBayes.this._job);
                NaiveBayes.this._job.update(1L, "Begin distributed Naive Bayes calculation");
                NBTask tsk = (NBTask)new NBTask(NaiveBayes.this._job._key, (DataInfo)dinfo, NaiveBayes.this._response.cardinality()).doAll(((DataInfo)dinfo)._adaptedFrame);
                if (this.computeStatsFillModel((NaiveBayesModel)model, (DataInfo)dinfo, tsk)) {
                    model.update(NaiveBayes.this._job);
                }
            }
            finally {
                if (model != null) {
                    model.unlock(NaiveBayes.this._job);
                }
                if (dinfo != null) {
                    dinfo.remove();
                }
            }
        }
    }
}

