/*
 * Decompiled with CFR 0.152.
 */
package zingg;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import zingg.ZinggBase;
import zingg.block.Canopy;
import zingg.block.Tree;
import zingg.client.ZinggClientException;
import zingg.client.ZinggOptions;
import zingg.model.Model;
import zingg.preprocess.StopWords;
import zingg.util.Analytics;
import zingg.util.BlockingTreeUtil;
import zingg.util.DSUtil;
import zingg.util.Metric;
import zingg.util.ModelUtil;
import zingg.util.PipeUtil;

public class Trainer
extends ZinggBase {
    protected static String name = "zingg.Trainer";
    public static final Log LOG = LogFactory.getLog(Trainer.class);

    public Trainer() {
        this.setZinggOptions(ZinggOptions.TRAIN);
    }

    @Override
    public void execute() throws ZinggClientException {
        try {
            LOG.info("Reading inputs for training phase ...");
            LOG.info("Initializing learning similarity rules");
            Dataset positives = null;
            Dataset negatives = null;
            Dataset<Row> traOriginal = DSUtil.getTraining(this.spark, this.args);
            Dataset tra = StopWords.preprocessForStopWords(this.spark, this.args, traOriginal);
            tra = DSUtil.joinWithItself(tra, "z_cluster", true);
            tra = tra.cache();
            positives = tra.filter(tra.col("z_isMatch").equalTo((Object)1));
            negatives = tra.filter(tra.col("z_isMatch").equalTo((Object)0));
            this.verifyTraining((Dataset<Row>)positives, (Dataset<Row>)negatives);
            Dataset<Row> testDataOriginal = PipeUtil.read(this.spark, true, this.args.getNumPartitions(), false, this.args.getData());
            Dataset<Row> testData = StopWords.preprocessForStopWords(this.spark, this.args, testDataOriginal);
            Tree<Canopy> blockingTree = BlockingTreeUtil.createBlockingTreeFromSample(testData, (Dataset<Row>)positives, 0.5, -1L, this.args, this.hashFunctions);
            if (blockingTree == null || blockingTree.getSubTrees() == null) {
                LOG.warn("Seems like no indexing rules have been learnt");
            }
            BlockingTreeUtil.writeBlockingTree(this.spark, this.ctx, blockingTree, this.args);
            LOG.info("Learnt indexing rules and saved output at " + this.args.getZinggDir());
            Model model = ModelUtil.createModel((Dataset<Row>)positives, (Dataset<Row>)negatives, new Model(this.featurers), this.spark);
            model.save(this.args.getModel());
            LOG.info("Learnt similarity rules and saved output at " + this.args.getZinggDir());
            Analytics.track("trainingDataMatches", Metric.approxCount((Dataset<Row>)positives), this.args.getCollectMetrics());
            Analytics.track("trainingDataNonmatches", Metric.approxCount((Dataset<Row>)negatives), this.args.getCollectMetrics());
            LOG.info("Finished Learning phase");
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new ZinggClientException(e.getMessage());
        }
    }

    public void verifyTraining(Dataset<Row> positives, Dataset<Row> negatives) throws ZinggClientException {
        if (positives == null) {
            throw new ZinggClientException("Unable to train as insufficient positive training data found. ");
        }
        if (negatives == null) {
            throw new ZinggClientException("Unable to train as insufficient negative training data found. ");
        }
        long posCount = positives.count();
        LOG.warn("Training on positive pairs - " + posCount);
        long negCount = negatives.count();
        LOG.warn("Training on negative pairs - " + negCount);
        if (posCount < 5L || negCount < 5L) {
            throw new ZinggClientException("Unable to train as insufficient training data found. Training data has " + posCount + " matches and " + negCount + " non matches. Please run findTrainingData and label till you have sufficient labelled data to build the models");
        }
    }
}

