/*
 * Decompiled with CFR 0.152.
 */
package hex.generic;

import hex.Model;
import hex.ModelBuilder;
import hex.ModelMetrics;
import hex.ModelMetricsAutoEncoder;
import hex.ModelMetricsBinomial;
import hex.ModelMetricsClustering;
import hex.ModelMetricsMultinomial;
import hex.ModelMetricsOrdinal;
import hex.ModelMetricsRegression;
import hex.ModelMetricsRegressionCoxPH;
import hex.PojoWriter;
import hex.generic.GenericModelMojoWriter;
import hex.generic.GenericModelOutput;
import hex.generic.GenericModelParameters;
import hex.generic.PojoLoader;
import hex.genmodel.CategoricalEncoding;
import hex.genmodel.GenModel;
import hex.genmodel.ModelMojoReader;
import hex.genmodel.MojoModel;
import hex.genmodel.MojoReaderBackend;
import hex.genmodel.MojoReaderBackendFactory;
import hex.genmodel.algos.kmeans.KMeansMojoModel;
import hex.genmodel.descriptor.ModelDescriptor;
import hex.genmodel.descriptor.ModelDescriptorBuilder;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.PredictException;
import hex.tree.isofor.ModelMetricsAnomaly;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import water.DKV;
import water.Futures;
import water.H2O;
import water.Iced;
import water.Job;
import water.JobUpdatePostMap;
import water.Key;
import water.MRTask;
import water.fvec.ByteVec;
import water.fvec.Chunk;
import water.fvec.Frame;
import water.fvec.NewChunk;
import water.udf.CFuncRef;
import water.util.ArrayUtils;
import water.util.Log;
import water.util.RowDataUtils;

public class GenericModel
extends Model<GenericModel, GenericModelParameters, GenericModelOutput>
implements Model.Contributions {
    private static final Map<String, ModelBehavior[]> MODEL_BEHAVIORS;
    private final String _algoName;
    private final GenModelSource _genModelSource;

    public GenericModel(Key<GenericModel> selfKey, GenericModelParameters parms, GenericModelOutput output, MojoModel mojoModel, Key<Frame> mojoSource) {
        super(selfKey, parms, output);
        this._algoName = mojoModel._algoName;
        this._genModelSource = new MojoModelSource(mojoSource, mojoModel);
        this._output = new GenericModelOutput(mojoModel._modelDescriptor, mojoModel._modelAttributes, mojoModel._reproducibilityInformation);
        if (mojoModel._modelAttributes != null && mojoModel._modelAttributes.getModelParameters() != null) {
            ((GenericModelParameters)this._parms)._modelParameters = GenericModelParameters.convertParameters(mojoModel._modelAttributes.getModelParameters());
        }
    }

    public GenericModel(Key<GenericModel> selfKey, GenericModelParameters parms, GenericModelOutput output, GenModel pojoModel, Key<Frame> pojoSource) {
        super(selfKey, parms, output);
        this._algoName = "pojo";
        this._genModelSource = new PojoModelSource(pojoSource, pojoModel);
        this._output = new GenericModelOutput(ModelDescriptorBuilder.makeDescriptor(pojoModel));
    }

    private static MojoModel reconstructMojo(ByteVec mojoBytes) {
        try {
            MojoReaderBackend readerBackend = MojoReaderBackendFactory.createReaderBackend(mojoBytes.openStream(null), MojoReaderBackendFactory.CachingStrategy.MEMORY);
            return ModelMojoReader.readFrom(readerBackend, true);
        }
        catch (IOException e2) {
            throw new IllegalStateException("Unreachable MOJO file: " + mojoBytes._key, e2);
        }
    }

    @Override
    public ModelMetrics.MetricBuilder makeMetricBuilder(String[] domain) {
        switch (((GenericModelOutput)this._output).getModelCategory()) {
            case Unknown: {
                throw new IllegalStateException("Model category is unknown");
            }
            case Binomial: {
                return new ModelMetricsBinomial.MetricBuilderBinomial(domain);
            }
            case Multinomial: {
                return new ModelMetricsMultinomial.MetricBuilderMultinomial(((GenericModelOutput)this._output).nclasses(), domain, ((GenericModelParameters)this._parms)._auc_type);
            }
            case Ordinal: {
                return new ModelMetricsOrdinal.MetricBuilderOrdinal(((GenericModelOutput)this._output).nclasses(), domain);
            }
            case Regression: {
                return new ModelMetricsRegression.MetricBuilderRegression();
            }
            case Clustering: {
                if (this.genModel() instanceof KMeansMojoModel) {
                    KMeansMojoModel kMeansMojoModel = (KMeansMojoModel)this.genModel();
                    return new ModelMetricsClustering.MetricBuilderClustering(((GenericModelOutput)this._output).nfeatures(), kMeansMojoModel.getNumClusters());
                }
                return this.unsupportedMetricsBuilder();
            }
            case AutoEncoder: {
                return new ModelMetricsAutoEncoder.MetricBuilderAutoEncoder(((GenericModelOutput)this._output).nfeatures());
            }
            case DimReduction: {
                return this.unsupportedMetricsBuilder();
            }
            case WordEmbedding: {
                return this.unsupportedMetricsBuilder();
            }
            case CoxPH: {
                return new ModelMetricsRegressionCoxPH.MetricBuilderRegressionCoxPH("start", "stop", false, new String[0]);
            }
            case AnomalyDetection: {
                return new ModelMetricsAnomaly.MetricBuilderAnomaly();
            }
        }
        throw H2O.unimpl();
    }

    @Override
    protected Frame adaptFrameForScore(Frame fr, boolean computeMetrics, List<Frame> tmpFrames) {
        if (this.hasBehavior(ModelBehavior.USE_MOJO_PREDICT)) {
            return fr;
        }
        return super.adaptFrameForScore(fr, computeMetrics, tmpFrames);
    }

    @Override
    protected Model.PredictScoreResult predictScoreImpl(Frame fr, Frame adaptFrm, String destination_key, Job j2, boolean computeMetrics, CFuncRef customMetricFunc) {
        if (this.hasBehavior(ModelBehavior.USE_MOJO_PREDICT)) {
            return this.predictScoreMojoImpl(fr, destination_key, j2, computeMetrics);
        }
        return super.predictScoreImpl(fr, adaptFrm, destination_key, j2, computeMetrics, customMetricFunc);
    }

    Model.PredictScoreResult predictScoreMojoImpl(Frame fr, String destination_key, Job<?> j2, boolean computeMetrics) {
        GenModel model = this.genModel();
        assert (model.isSupervised()) : "MOJO Predict only works for supervised models";
        String[] names = model.getOutputNames();
        String[][] domains = model.getOutputDomains();
        byte[] type = new byte[domains.length];
        for (int i2 = 0; i2 < type.length; ++i2) {
            type[i2] = domains[i2] != null ? 4 : 3;
        }
        PredictScoreMojoTask bs = new PredictScoreMojoTask(computeMetrics, j2);
        Frame predictFr = ((PredictScoreMojoTask)bs.doAll(type, fr)).outputFrame(Key.make(destination_key), names, domains);
        return new Model.PredictScoreResult(this, bs._mb, predictFr, predictFr);
    }

    private ModelMetrics.MetricBuilder unsupportedMetricsBuilder() {
        if (((GenericModelParameters)this._parms)._disable_algo_check) {
            Log.warn("Model category `" + (Object)((Object)((GenericModelOutput)this._output)._modelCategory) + "` currently doesn't support calculating model metrics. Model metrics will not be available.");
            return new MetricBuilderGeneric(this.genModel().getPredsSize(((GenericModelOutput)this._output)._modelCategory));
        }
        throw new UnsupportedOperationException((Object)((Object)((GenericModelOutput)this._output)._modelCategory) + " is not supported.");
    }

    @Override
    protected double[] score0(double[] data, double[] preds) {
        return this.genModel().score0(data, preds);
    }

    @Override
    protected double[] score0(double[] data, double[] preds, double offset) {
        if (offset == 0.0) {
            return this.score0(data, preds);
        }
        return this.genModel().score0(data, offset, preds);
    }

    @Override
    protected Model.AdaptFrameParameters makeAdaptFrameParameters() {
        final GenModel genModel = this.genModel();
        CategoricalEncoding encoding = genModel.getCategoricalEncoding();
        if (encoding.isParametrized()) {
            throw new UnsupportedOperationException("Models with categorical encoding '" + (Object)((Object)encoding) + "' are not currently supported for predicting and/or calculating metrics.");
        }
        final Model.Parameters.CategoricalEncodingScheme encodingScheme = Model.Parameters.CategoricalEncodingScheme.fromGenModel(encoding);
        final ModelDescriptor descriptor = genModel instanceof MojoModel ? ((MojoModel)genModel)._modelDescriptor : null;
        return new Model.AdaptFrameParameters(){

            @Override
            public Model.Parameters.CategoricalEncodingScheme getCategoricalEncoding() {
                return encodingScheme;
            }

            @Override
            public String getWeightsColumn() {
                return descriptor != null ? descriptor.weightsColumn() : null;
            }

            @Override
            public String getOffsetColumn() {
                return descriptor != null ? descriptor.offsetColumn() : null;
            }

            @Override
            public String getFoldColumn() {
                return descriptor != null ? descriptor.foldColumn() : null;
            }

            @Override
            public String getResponseColumn() {
                return genModel.isSupervised() ? genModel.getResponseName() : null;
            }

            @Override
            public String getTreatmentColumn() {
                return null;
            }

            @Override
            public double missingColumnsType() {
                return Double.NaN;
            }

            @Override
            public int getMaxCategoricalLevels() {
                return -1;
            }
        };
    }

    @Override
    protected String[] makeScoringNames() {
        return this.genModel().getOutputNames();
    }

    @Override
    protected boolean needsPostProcess() {
        return false;
    }

    @Override
    public GenericModelMojoWriter getMojo() {
        if (this._genModelSource instanceof MojoModelSource) {
            return new GenericModelMojoWriter(this._genModelSource.backingByteVec());
        }
        throw new IllegalStateException("Cannot create a MOJO from a POJO");
    }

    private GenModel genModel() {
        GenericModel self = (GenericModel)DKV.getGet(this._key);
        return self._genModelSource.get();
    }

    @Override
    protected Model.BigScorePredict setupBigScorePredict(Model.BigScore bs) {
        GenModel genmodel = this.genModel();
        assert (genmodel != null);
        return super.setupBigScorePredict(bs);
    }

    @Override
    protected Futures remove_impl(Futures fs, boolean cascade) {
        if (((GenericModelParameters)this._parms)._path != null) {
            this._genModelSource.remove(fs, cascade);
        }
        return super.remove_impl(fs, cascade);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key) {
        return this.scoreContributions(frame, destination_key, null);
    }

    @Override
    public Frame scoreContributions(Frame frame, Key<Frame> destination_key, Job<Frame> job) {
        EasyPredictModelWrapper wrapper = this.makeWrapperWithContributions();
        Frame adaptFrm = new Frame(frame);
        GenModel model = wrapper.getModel();
        String[] columnNames = model.getOrigNames() != null ? model.getOrigNames() : model.getNames();
        adaptFrm.remove(ArrayUtils.difference(frame._names, columnNames));
        String[] outputNames = wrapper.getContributionNames();
        return ((GenericScoreContributionsTask)new GenericScoreContributionsTask(wrapper).withPostMapAction(JobUpdatePostMap.forJob(job)).doAll(outputNames.length, (byte)3, adaptFrm)).outputFrame(destination_key, outputNames, null);
    }

    EasyPredictModelWrapper makeWrapperWithContributions() {
        EasyPredictModelWrapper.Config config;
        try {
            config = new EasyPredictModelWrapper.Config().setModel(this.genModel()).setConvertUnknownCategoricalLevelsToNa(true).setEnableContributions(true);
        }
        catch (IOException e2) {
            throw new RuntimeException(e2);
        }
        return new EasyPredictModelWrapper(config);
    }

    @Override
    protected String toJavaModelClassName() {
        return ModelBuilder.make(((GenericModelOutput)this._output)._original_model_identifier, null, null).getClass().getSimpleName() + "Model";
    }

    @Override
    protected String toJavaAlgo() {
        return ((GenericModelOutput)this._output)._original_model_identifier;
    }

    @Override
    protected String toJavaUUID() {
        return this.genModel().getUUID();
    }

    @Override
    protected PojoWriter makePojoWriter() {
        GenModel genModel = this.genModel();
        if (!this.havePojo()) {
            throw new UnsupportedOperationException("Only MOJO models can be converted to POJO.");
        }
        MojoModel mojoModel = (MojoModel)genModel;
        Object builder = ModelBuilder.make(mojoModel._algoName, null, null);
        return ((ModelBuilder)builder).makePojoWriter(this, mojoModel);
    }

    @Override
    public boolean havePojo() {
        GenModel genModel = this.genModel();
        return genModel instanceof MojoModel;
    }

    boolean hasBehavior(ModelBehavior b2) {
        if (!MODEL_BEHAVIORS.containsKey(this._algoName)) {
            return false;
        }
        return ArrayUtils.find((Object[])MODEL_BEHAVIORS.get(this._algoName), b2) >= 0;
    }

    static {
        HashMap<String, ModelBehavior[]> behaviors = new HashMap<String, ModelBehavior[]>();
        behaviors.put("gam", new ModelBehavior[]{ModelBehavior.USE_MOJO_PREDICT});
        MODEL_BEHAVIORS = Collections.unmodifiableMap(behaviors);
    }

    static enum ModelBehavior {
        USE_MOJO_PREDICT;

    }

    private class GenericScoreContributionsTask
    extends MRTask<GenericScoreContributionsTask> {
        private transient EasyPredictModelWrapper _wrapper;

        GenericScoreContributionsTask(EasyPredictModelWrapper wrapper) {
            this._wrapper = wrapper;
        }

        @Override
        protected void setupLocal() {
            if (this._wrapper == null) {
                this._wrapper = GenericModel.this.makeWrapperWithContributions();
            }
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            try {
                this.predict(cs, ncs);
            }
            catch (PredictException e2) {
                throw new RuntimeException(e2);
            }
        }

        private void predict(Chunk[] cs, NewChunk[] ncs) throws PredictException {
            RowData rowData = new RowData();
            byte[] types = this._fr.types();
            for (int i2 = 0; i2 < cs[0]._len; ++i2) {
                RowDataUtils.extractChunkRow(cs, this._fr._names, types, i2, rowData);
                float[] contributions = this._wrapper.predictContributions(rowData);
                NewChunk.addNums(ncs, contributions);
            }
        }
    }

    private static class PojoModelSource
    extends GenModelSource<PojoModelSource> {
        PojoModelSource(Key<Frame> pojoSource, GenModel pojoModel) {
            super(pojoSource, pojoModel);
        }

        @Override
        GenModel reconstructGenModel(ByteVec bv) {
            Key<Frame> pojoKey = this.getSourceKey();
            try {
                return PojoLoader.loadPojoFromSourceCode(bv, pojoKey);
            }
            catch (IOException e2) {
                throw new RuntimeException("Unable to load POJO source code from Vec " + pojoKey);
            }
        }
    }

    private static class MojoModelSource
    extends GenModelSource<MojoModelSource> {
        MojoModelSource(Key<Frame> mojoSource, MojoModel mojoModel) {
            super(mojoSource, mojoModel);
        }

        @Override
        GenModel reconstructGenModel(ByteVec bv) {
            return GenericModel.reconstructMojo(bv);
        }
    }

    private static abstract class GenModelSource<T extends Iced<T>>
    extends Iced<T> {
        private final Key<Frame> _source;
        private volatile transient GenModel _genModel;

        GenModelSource(Key<Frame> source, GenModel genModel) {
            this._source = source;
            this._genModel = genModel;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        GenModel get() {
            if (this._genModel == null) {
                GenModelSource genModelSource = this;
                synchronized (genModelSource) {
                    if (this._genModel == null) {
                        this._genModel = this.reconstructGenModel(this.backingByteVec());
                    }
                }
            }
            assert (this._genModel != null);
            return this._genModel;
        }

        void remove(Futures fs, boolean cascade) {
            Frame mojoFrame = this._source.get();
            if (mojoFrame != null) {
                mojoFrame.remove(fs, cascade);
            }
        }

        abstract GenModel reconstructGenModel(ByteVec var1);

        ByteVec backingByteVec() {
            return (ByteVec)this._source.get().anyVec();
        }

        Key<Frame> getSourceKey() {
            return this._source;
        }
    }

    private static class MetricBuilderGeneric
    extends ModelMetrics.MetricBuilder<MetricBuilderGeneric> {
        private MetricBuilderGeneric(int predsSize) {
            this._work = new double[predsSize];
        }

        @Override
        public double[] perRow(double[] ds, float[] yact, Model m4) {
            return ds;
        }

        @Override
        public ModelMetrics makeModelMetrics(Model m4, Frame f2, Frame adaptedFrame, Frame preds) {
            return null;
        }
    }

    private class PredictScoreMojoTask
    extends MRTask<PredictScoreMojoTask> {
        private final boolean _computeMetrics;
        private final Job<?> _j;
        private ModelMetrics.MetricBuilder _mb;

        public PredictScoreMojoTask(boolean computeMetrics, Job<?> j2) {
            this._computeMetrics = computeMetrics;
            this._j = j2;
        }

        @Override
        public void map(Chunk[] cs, NewChunk[] ncs) {
            if (this.isCancelled() || this._j != null && this._j.stop_requested()) {
                return;
            }
            EasyPredictModelWrapper wrapper = this.makeWrapper();
            GenModel model = wrapper.getModel();
            String[] responseDomain = model.getDomainValues(model.getResponseName());
            Model.AdaptFrameParameters adaptFrameParameters = GenericModel.this.makeAdaptFrameParameters();
            this._mb = this._computeMetrics ? GenericModel.this.makeMetricBuilder(responseDomain) : null;
            try {
                this.predict(wrapper, adaptFrameParameters, responseDomain, cs, ncs);
            }
            catch (PredictException e2) {
                throw new RuntimeException(e2);
            }
        }

        private void predict(EasyPredictModelWrapper wrapper, Model.AdaptFrameParameters adaptFrameParameters, String[] responseDomain, Chunk[] cs, NewChunk[] ncs) throws PredictException {
            byte[] types = this._fr.types();
            String offsetColumn = adaptFrameParameters.getOffsetColumn();
            String weightsColumn = adaptFrameParameters.getWeightsColumn();
            String responseColumn = adaptFrameParameters.getResponseColumn();
            boolean isClassifier = wrapper.getModel().isClassifier();
            float[] yact = new float[1];
            for (int row = 0; row < cs[0]._len; ++row) {
                double weight;
                Object response;
                RowData rowData = new RowData();
                RowDataUtils.extractChunkRow(cs, this._fr._names, types, row, rowData);
                double offset = offsetColumn != null && rowData.containsKey(offsetColumn) ? (Double)rowData.get(offsetColumn) : 0.0;
                double[] result = wrapper.predictRaw(rowData, offset);
                for (int i2 = 0; i2 < ncs.length; ++i2) {
                    ncs[i2].addNum(result[i2]);
                }
                if (this._mb == null) continue;
                Object v0 = response = responseColumn != null && rowData.containsKey(responseColumn) ? rowData.get(responseColumn) : null;
                if (response == null) continue;
                double d2 = weight = weightsColumn != null && rowData.containsKey(weightsColumn) ? (Double)rowData.get(weightsColumn) : 1.0;
                if (isClassifier) {
                    int idx = ArrayUtils.find(responseDomain, String.valueOf(response));
                    if (idx < 0) continue;
                    yact[0] = idx;
                } else {
                    yact[0] = ((Number)response).floatValue();
                }
                this._mb.perRow(result, yact, weight, offset, GenericModel.this);
            }
        }

        @Override
        public void reduce(PredictScoreMojoTask bs) {
            super.reduce(bs);
            if (this._mb != null) {
                this._mb.reduce(bs._mb);
            }
        }

        EasyPredictModelWrapper makeWrapper() {
            EasyPredictModelWrapper.Config config = new EasyPredictModelWrapper.Config().setModel(GenericModel.this.genModel().internal_threadSafeInstance()).setConvertUnknownCategoricalLevelsToNa(true);
            return new EasyPredictModelWrapper(config);
        }
    }
}

