/*
 * Decompiled with CFR 0.152.
 */
package hex.coxph;

import hex.DataInfo;
import hex.Model;
import hex.ModelCategory;
import hex.ModelMetricsRegressionCoxPH;
import hex.StringPair;
import hex.coxph.CoxPH;
import hex.coxph.CoxPHMojoWriter;
import hex.coxph.Storage;
import hex.schemas.CoxPHModelV3;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import water.AutoBuffer;
import water.DKV;
import water.Futures;
import water.Job;
import water.Key;
import water.MRTask;
import water.api.schemas3.ModelSchemaV3;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.fvec.Vec;
import water.rapids.ast.prims.mungers.AstGroup;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.IcedHashMap;
import water.util.IcedInt;

public class CoxPHModel
extends Model<CoxPHModel, CoxPHParameters, CoxPHOutput> {
    @Override
    public ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH makeMetricBuilder(String[] domain) {
        return new ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH(((CoxPHParameters)this._parms)._start_column, ((CoxPHParameters)this._parms)._stop_column, ((CoxPHParameters)this._parms).isStratified(), ((CoxPHParameters)this._parms)._stratify_by);
    }

    public ModelSchemaV3 schema() {
        return new CoxPHModelV3();
    }

    public CoxPHModel(Key destKey, CoxPHParameters parms, CoxPHOutput output) {
        super(destKey, parms, output);
    }

    @Override
    protected Model.PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job job, boolean computeMetrics, CFuncRef customMetricFunc) {
        int nResponses = 0;
        for (String col : ((CoxPHParameters)this._parms).responseCols()) {
            if (adaptFrm.find(col) == -1) continue;
            ++nResponses;
        }
        DataInfo scoringInfo = ((CoxPHOutput)this._output).data_info.scoringInfo(((CoxPHOutput)this._output)._names, adaptFrm, nResponses, false);
        CoxPHScore score = new CoxPHScore(scoringInfo, (CoxPHOutput)this._output, ((CoxPHParameters)this._parms).isStratified(), null != ((CoxPHParameters)this._parms)._offset_column);
        Frame scored = ((CoxPHScore)score.doAll((byte)3, scoringInfo._adaptedFrame)).outputFrame(Key.make(destination_key), new String[]{"lp"}, null);
        ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH mb = null;
        if (computeMetrics) {
            mb = this.makeMetricBuilder(null);
        }
        return new Model.PredictScoreResult(this, mb, scored, scored);
    }

    @Override
    public String[] adaptTestForTrain(Frame test, boolean expensive, boolean computeMetrics) {
        boolean createStrataVec;
        boolean bl = createStrataVec = ((CoxPHParameters)this._parms).isStratified() && test.vec(((CoxPHParameters)this._parms)._strata_column) == null;
        if (createStrataVec) {
            Vec strataVec = test.anyVec().makeCon(Double.NaN);
            this._toDelete.put(strataVec._key, "adapted missing strata vector");
            test.add(((CoxPHParameters)this._parms)._strata_column, strataVec);
        }
        String[] msgs = super.adaptTestForTrain(test, expensive, computeMetrics);
        if (createStrataVec) {
            Vec strataVec = CoxPH.StrataTask.makeStrataVec(test, ((CoxPHParameters)this._parms)._stratify_by, ((CoxPHOutput)this._output)._strataMap, ((CoxPHParameters)this._parms)._single_node_mode);
            this._toDelete.put(strataVec._key, "adapted missing strata vector");
            test.replace(test.find(((CoxPHParameters)this._parms)._strata_column), strataVec);
            if (((CoxPHOutput)this._output)._strataOnlyCols != null) {
                test.remove(((CoxPHOutput)this._output)._strataOnlyCols);
            }
        }
        return msgs;
    }

    @Override
    public double[] score0(double[] data, double[] preds) {
        throw new UnsupportedOperationException("CoxPHModel.score0 should never be called");
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        this.remove(fs, ((CoxPHOutput)this._output)._var_cumhaz_2);
        this.remove(fs, ((CoxPHOutput)this._output)._baseline_hazard);
        this.remove(fs, ((CoxPHOutput)this._output)._baseline_survival);
        super.remove_impl(fs, cascade);
        return fs;
    }

    private void remove(Futures fs, Key<Frame> key) {
        Frame fr;
        Frame frame = fr = key != null ? key.get() : null;
        if (fr != null) {
            fr.remove(fs);
        }
    }

    @Override
    public CoxPHMojoWriter getMojo() {
        return new CoxPHMojoWriter(this);
    }

    @Override
    public boolean haveMojo() {
        boolean hasInteraction;
        boolean bl = hasInteraction = ((CoxPHParameters)this._parms).interactionSpec() != null;
        if (hasInteraction) {
            return false;
        }
        return super.haveMojo();
    }

    private static class CoxPHScore
    extends MRTask<CoxPHScore> {
        private DataInfo _dinfo;
        private double[] _coef;
        private double[] _lpBase;
        private int _numStart;
        private boolean _hasStrata;

        private CoxPHScore(DataInfo dinfo, CoxPHOutput o2, boolean hasStrata, boolean hasOffsets) {
            int strataCount = o2._x_mean_cat.length;
            this._dinfo = dinfo;
            this._hasStrata = hasStrata;
            this._coef = hasOffsets ? ArrayUtils.append(o2._coef, 1.0) : o2._coef;
            this._numStart = o2._x_mean_cat[0].length;
            this._lpBase = new double[strataCount];
            for (int s2 = 0; s2 < strataCount; ++s2) {
                int i2;
                for (i2 = 0; i2 < o2._x_mean_cat[s2].length; ++i2) {
                    int n2 = s2;
                    this._lpBase[n2] = this._lpBase[n2] + o2._x_mean_cat[s2][i2] * this._coef[i2];
                }
                for (i2 = 0; i2 < o2._x_mean_num[s2].length; ++i2) {
                    int n3 = s2;
                    this._lpBase[n3] = this._lpBase[n3] + o2._x_mean_num[s2][i2] * this._coef[i2 + this._numStart];
                }
            }
        }

        @Override
        public void map(Chunk[] chks, NewChunk nc) {
            DataInfo.Row r2 = this._dinfo.newDenseRow();
            for (int rid = 0; rid < chks[0]._len; ++rid) {
                this._dinfo.extractDenseRow(chks, rid, r2);
                if (r2.predictors_bad) {
                    nc.addNA();
                    continue;
                }
                double s2 = this._hasStrata ? chks[this._dinfo.responseChunkId(0)].atd(rid) : 0.0;
                boolean unknownStrata = Double.isNaN(s2);
                if (unknownStrata) {
                    nc.addNA();
                    continue;
                }
                double lp = r2.innerProduct(this._coef) - this._lpBase[(int)s2];
                nc.addNum(lp);
            }
        }
    }

    public static class FrameMatrix
    extends Storage.DenseRowMatrix {
        Key<Frame> _frame_key;

        FrameMatrix(Key<Frame> frame_key, int rows, int cols) {
            super(rows, cols);
            this._frame_key = frame_key;
        }

        public final AutoBuffer write_impl(AutoBuffer ab) {
            Key.write_impl(this._frame_key, ab);
            return ab;
        }

        public final FrameMatrix read_impl(AutoBuffer ab) {
            this._frame_key = Key.read_impl(null, ab);
            if (DKV.getGet(this._frame_key) == null) {
                this.toFrame(this._frame_key);
            }
            return this;
        }
    }

    public static class CoxPHOutput
    extends Model.Output {
        Model.InteractionSpec _interactionSpec;
        DataInfo data_info;
        IcedHashMap<AstGroup.G, IcedInt> _strataMap;
        String[] _strataOnlyCols;
        public String[] _coef_names;
        public double[] _coef;
        public double[] _exp_coef;
        public double[] _exp_neg_coef;
        public double[] _se_coef;
        public double[] _z_coef;
        double[][] _var_coef;
        double _null_loglik;
        double _loglik;
        double _loglik_test;
        double _wald_test;
        double _score_test;
        double _rsq;
        double _maxrsq;
        double _lre;
        int _iter;
        double[][] _x_mean_cat;
        double[][] _x_mean_num;
        double[] _mean_offset;
        String[] _offset_names;
        long _n;
        long _n_missing;
        long _total_event;
        double[] _time;
        double[] _n_risk;
        double[] _n_event;
        double[] _n_censor;
        double[] _cumhaz_0;
        double[] _var_cumhaz_1;
        FrameMatrix _var_cumhaz_2_matrix;
        Key<Frame> _var_cumhaz_2;
        Key<Frame> _baseline_hazard;
        FrameMatrix _baseline_hazard_matrix;
        Key<Frame> _baseline_survival;
        FrameMatrix _baseline_survival_matrix;
        CoxPHParameters.CoxPHTies _ties;
        String _formula;
        double _concordance;

        public CoxPHOutput(CoxPH coxPH, Frame adaptFr, Frame train, IcedHashMap<AstGroup.G, IcedInt> strataMap) {
            super(coxPH, CoxPHOutput.fullFrame(coxPH, adaptFr, train));
            this._strataOnlyCols = new String[this._names.length - adaptFr._names.length];
            for (int i2 = 0; i2 < this._strataOnlyCols.length; ++i2) {
                this._strataOnlyCols[i2] = this._names[i2];
            }
            this._ties = ((CoxPHParameters)coxPH._parms)._ties;
            this._formula = ((CoxPHParameters)coxPH._parms).toFormula(train);
            this._interactionSpec = ((CoxPHParameters)coxPH._parms).interactionSpec();
            this._strataMap = strataMap;
        }

        @Override
        public int nclasses() {
            return 1;
        }

        private static Frame fullFrame(CoxPH coxPH, Frame adaptFr, Frame train) {
            if (!((CoxPHParameters)coxPH._parms).isStratified()) {
                return adaptFr;
            }
            Frame ff = new Frame(new Vec[0]);
            for (String col : ((CoxPHParameters)coxPH._parms)._stratify_by) {
                if (adaptFr.vec(col) != null) continue;
                ff.add(col, train.vec(col));
            }
            ff.add(adaptFr);
            return ff;
        }

        @Override
        public ModelCategory getModelCategory() {
            return ModelCategory.CoxPH;
        }

        @Override
        public Model.InteractionBuilder interactionBuilder() {
            return this._interactionSpec != null ? new CoxPHInteractionBuilder() : null;
        }

        private class CoxPHInteractionBuilder
        implements Model.InteractionBuilder {
            private CoxPHInteractionBuilder() {
            }

            @Override
            public Frame makeInteractions(Frame f2) {
                Model.InteractionPair[] interactions = CoxPHOutput.this._interactionSpec.makeInteractionPairs(f2);
                f2.add(Model.makeInteractions(f2, false, interactions, CoxPHOutput.this.data_info._useAllFactorLevels, CoxPHOutput.this.data_info._skipMissing, CoxPHOutput.this.data_info._predictor_transform == DataInfo.TransformType.STANDARDIZE));
                return f2;
            }
        }
    }

    public static class CoxPHParameters
    extends Model.Parameters {
        public String _start_column;
        public String _stop_column;
        final String _strata_column = "__strata";
        public String[] _stratify_by;
        public CoxPHTies _ties = CoxPHTies.efron;
        public double _init = 0.0;
        public double _lre_min = 9.0;
        public int _max_iterations = 20;
        public boolean _use_all_factor_levels;
        public String[] _interactions_only;
        public String[] _interactions = null;
        public StringPair[] _interaction_pairs = null;
        public boolean _calc_cumhaz = true;
        public boolean _single_node_mode = false;

        @Override
        public String algoName() {
            return "CoxPH";
        }

        @Override
        public String fullName() {
            return "Cox Proportional Hazards";
        }

        @Override
        public String javaName() {
            return CoxPHModel.class.getName();
        }

        @Override
        public long progressUnits() {
            return (this._max_iterations + 1) * 2 + 1;
        }

        String[] responseCols() {
            String[] cols;
            String[] stringArray;
            if (this._start_column != null) {
                String[] stringArray2 = new String[1];
                stringArray = stringArray2;
                stringArray2[0] = this._start_column;
            } else {
                stringArray = cols = new String[]{};
            }
            if (this.isStratified()) {
                cols = ArrayUtils.append(cols, this._start_column);
            }
            return ArrayUtils.append(cols, this._stop_column, this._response_column);
        }

        Vec startVec() {
            return this.train().vec(this._start_column);
        }

        Vec stopVec() {
            return this.train().vec(this._stop_column);
        }

        Model.InteractionSpec interactionSpec() {
            String[] interOnly;
            if (this._interactions_only != null && this._stratify_by != null) {
                Object[] io = (String[])this._interactions_only.clone();
                Arrays.sort(io);
                Object[] sb = (String[])this._stratify_by.clone();
                Arrays.sort(sb);
                interOnly = ArrayUtils.union((String[])io, (String[])sb, true);
            } else {
                interOnly = this._interactions_only != null ? this._interactions_only : this._stratify_by;
            }
            return Model.InteractionSpec.create(this._interactions, this._interaction_pairs, interOnly, this._stratify_by);
        }

        boolean isStratified() {
            return this._stratify_by != null && this._stratify_by.length > 0;
        }

        String toFormula(Frame f2) {
            Model.InteractionSpec interactionSpec;
            StringBuilder sb = new StringBuilder();
            sb.append("Surv(");
            if (this._start_column != null) {
                sb.append(this._start_column).append(", ");
            }
            sb.append(this._stop_column).append(", ").append(this._response_column);
            sb.append(") ~ ");
            HashSet<String> stratifyBy = this._stratify_by != null ? new HashSet<String>(Arrays.asList(this._stratify_by)) : Collections.emptySet();
            HashSet<String> interactionsOnly = this._interactions_only != null ? new HashSet<String>(Arrays.asList(this._interactions_only)) : Collections.emptySet();
            HashSet<String> specialCols = new HashSet<String>(){
                {
                    this.add(_start_column);
                    if (_stop_column != null) {
                        this.add(_stop_column);
                    }
                    this.add(_response_column);
                    this.add("__strata");
                    if (_weights_column != null) {
                        this.add(_weights_column);
                    }
                    if (_ignored_columns != null) {
                        this.addAll(Arrays.asList(_ignored_columns));
                    }
                }
            };
            String sep = "";
            for (String col : f2._names) {
                if (this._offset_column != null && this._offset_column.equals(col) || stratifyBy.contains(col) || interactionsOnly.contains(col) || specialCols.contains(col)) continue;
                sb.append(sep).append(col);
                sep = " + ";
            }
            if (this._offset_column != null) {
                sb.append(sep).append("offset(").append(this._offset_column).append(")");
            }
            if ((interactionSpec = this.interactionSpec()) != null) {
                Model.InteractionPair[] interactionPairs;
                Model.InteractionPair[] interactionPairArray = interactionPairs = this.interactionSpec().makeInteractionPairs(f2);
                int n2 = interactionPairArray.length;
                for (int i2 = 0; i2 < n2; ++i2) {
                    Model.InteractionPair ip = interactionPairArray[i2];
                    sb.append(sep);
                    String v1 = f2._names[ip.getV1()];
                    String v2 = f2._names[ip.getV2()];
                    if (stratifyBy.contains(v1)) {
                        sb.append("strata(").append(v1).append(")");
                    } else {
                        sb.append(v1);
                    }
                    sb.append(":");
                    if (stratifyBy.contains(v2)) {
                        sb.append("strata(").append(v2).append(")");
                    } else {
                        sb.append(v2);
                    }
                    sep = " + ";
                }
            }
            if (this._stratify_by != null) {
                String tmp = sb.toString();
                for (String col : this._stratify_by) {
                    String strataCol = "strata(" + col + ")";
                    if (tmp.contains(strataCol)) continue;
                    sb.append(sep).append(strataCol);
                    sep = " + ";
                }
            }
            return sb.toString();
        }

        public static enum CoxPHTies {
            efron,
            breslow;

        }
    }
}

