/*
 * Decompiled with CFR 0.152.
 */
package zingg.model;

import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.ml.Estimator;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.Transformer;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.evaluation.Evaluator;
import org.apache.spark.ml.feature.PolynomialExpansion;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import zingg.client.FieldDefinition;
import zingg.feature.Feature;
import zingg.model.VectorValueExtractor;
import zingg.similarity.function.BaseSimilarityFunction;

public class Model
implements Serializable {
    public static final Log LOG = LogFactory.getLog(Model.class);
    public static final Log DbLOG = LogFactory.getLog("WEB");
    List<PipelineStage> pipelineStage;
    List<BaseSimilarityFunction> featureCreators = new ArrayList<BaseSimilarityFunction>();
    LogisticRegression lr;
    Transformer transformer;
    BinaryClassificationEvaluator binaryClassificationEvaluator;
    List<String> columnsAdded;
    VectorValueExtractor vve;

    public Model(Map<FieldDefinition, Feature> f) {
        this.pipelineStage = new ArrayList<PipelineStage>();
        this.columnsAdded = new ArrayList<String>();
        int count = 0;
        for (FieldDefinition fd : f.keySet()) {
            Feature fea = f.get(fd);
            List sfList = fea.getSimFunctions();
            for (BaseSimilarityFunction sf : sfList) {
                sf.setInputCol(fd.fieldName);
                String outputCol = "z_sim" + count;
                this.columnsAdded.add(outputCol);
                sf.setOutputCol(outputCol);
                ++count;
                this.featureCreators.add(sf);
            }
        }
        VectorAssembler assembler = new VectorAssembler();
        assembler.setInputCols(this.columnsAdded.toArray(new String[this.columnsAdded.size()]));
        assembler.setOutputCol("z_featurevector");
        this.columnsAdded.add("z_featurevector");
        this.pipelineStage.add((PipelineStage)assembler);
        PolynomialExpansion polyExpansion = ((PolynomialExpansion)((PolynomialExpansion)new PolynomialExpansion().setInputCol("z_featurevector")).setOutputCol("z_feature")).setDegree(3);
        this.columnsAdded.add("z_feature");
        this.pipelineStage.add((PipelineStage)polyExpansion);
        this.lr = new LogisticRegression();
        this.lr.setMaxIter(100);
        this.lr.setFeaturesCol("z_feature");
        this.lr.setLabelCol("z_isMatch");
        this.lr.setProbabilityCol("z_probability");
        this.lr.setPredictionCol("z_prediction");
        this.lr.setFitIntercept(true);
        this.pipelineStage.add((PipelineStage)this.lr);
        this.vve = new VectorValueExtractor();
        this.vve.setInputCol("z_probability");
        this.vve.setOutputCol("z_score");
        this.columnsAdded.add("z_probability");
        this.columnsAdded.add("rawPrediction");
    }

    public void register(SparkSession spark) {
        if (this.featureCreators != null) {
            for (BaseSimilarityFunction bsf : this.featureCreators) {
                bsf.register(spark);
            }
        }
        this.vve.register(spark);
    }

    public static double[] getGrid(double begin, double end, double jump, boolean isMultiple) {
        double alpha;
        ArrayList<Double> alphaList = new ArrayList<Double>();
        if (isMultiple) {
            for (alpha = begin; alpha <= end; alpha *= jump) {
                alphaList.add(alpha);
            }
        } else {
            for (alpha = begin; alpha <= end; alpha += jump) {
                alphaList.add(alpha);
            }
        }
        double[] retArr = new double[alphaList.size()];
        for (int i = 0; i < alphaList.size(); ++i) {
            retArr[i] = (Double)alphaList.get(i);
        }
        return retArr;
    }

    public void fit(Dataset<Row> pos, Dataset<Row> neg) {
        Dataset input = this.transform((Dataset<Row>)pos.union(neg)).coalesce(1).cache();
        if (LOG.isDebugEnabled()) {
            input.write().csv("/tmp/input/" + System.currentTimeMillis());
        }
        Pipeline pipeline = new Pipeline();
        pipeline.setStages(this.pipelineStage.toArray(new PipelineStage[this.pipelineStage.size()]));
        LOG.debug("Pipeline is " + pipeline);
        ParamMap[] paramGrid = new ParamGridBuilder().addGrid(this.lr.regParam(), Model.getGrid(1.0E-4, 1.0, 10.0, true)).addGrid(this.lr.threshold(), Model.getGrid(0.4, 0.55, 0.05, false)).build();
        this.binaryClassificationEvaluator = new BinaryClassificationEvaluator();
        this.binaryClassificationEvaluator.setLabelCol("z_isMatch");
        CrossValidator cv = new CrossValidator().setEstimator((Estimator)pipeline).setEvaluator((Evaluator)this.binaryClassificationEvaluator).setEstimatorParamMaps(paramGrid).setNumFolds(2);
        CrossValidatorModel cvModel = cv.fit(input);
        this.transformer = cvModel;
        LOG.debug("threshold after fitting is " + this.lr.getThreshold());
    }

    public void load(String path) {
        this.transformer = CrossValidatorModel.load((String)path);
    }

    public Dataset<Row> predict(Dataset<Row> data) {
        return this.predict(data, true);
    }

    public Dataset<Row> predict(Dataset<Row> data, boolean isDrop) {
        LOG.info("threshold while predicting is " + this.lr.getThreshold());
        Dataset<Row> predictWithFeatures = this.transformer.transform(this.transform(data));
        LOG.debug(predictWithFeatures.schema());
        predictWithFeatures = this.vve.transform(predictWithFeatures);
        LOG.debug("Original schema is " + predictWithFeatures.schema());
        if (isDrop) {
            Dataset returnDS = predictWithFeatures.drop(this.columnsAdded.toArray(new String[this.columnsAdded.size()]));
            LOG.debug("Return schema after dropping additional columns is " + returnDS.schema());
            return returnDS;
        }
        LOG.debug("Return schema is " + predictWithFeatures.schema());
        return predictWithFeatures;
    }

    public void save(String path) throws IOException {
        ((CrossValidatorModel)this.transformer).write().overwrite().save(path);
    }

    public Dataset<Row> transform(Dataset<Row> input) {
        for (BaseSimilarityFunction bsf : this.featureCreators) {
            input = bsf.transform(input);
        }
        return input;
    }
}

