/*
 * Decompiled with CFR 0.152.
 */
package zingg;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import zingg.ZinggBase;
import zingg.block.Block;
import zingg.block.Canopy;
import zingg.block.Tree;
import zingg.client.ZinggClientException;
import zingg.client.ZinggOptions;
import zingg.client.pipe.Pipe;
import zingg.client.util.Util;
import zingg.model.LabelModel;
import zingg.model.Model;
import zingg.preprocess.StopWords;
import zingg.scala.DFUtil;
import zingg.util.BlockingTreeUtil;
import zingg.util.DSUtil;
import zingg.util.ModelUtil;
import zingg.util.PipeUtil;

public class TrainingDataFinder
extends ZinggBase {
    protected static String name = "zingg.TrainingDataFinder";
    public static final Log LOG = LogFactory.getLog(TrainingDataFinder.class);

    public TrainingDataFinder() {
        this.setZinggOptions(ZinggOptions.FIND_TRAINING_DATA);
    }

    public Dataset<Row> getTraining() throws ZinggClientException {
        return DSUtil.getTraining(this.spark, this.args);
    }

    @Override
    public void execute() throws ZinggClientException {
        try {
            Dataset<Row> data = PipeUtil.read(this.spark, true, true, this.args.getData());
            LOG.warn("Read input data " + data.count());
            Dataset posPairs = null;
            Dataset negPairs = null;
            Dataset<Row> trFile = this.getTraining();
            if (trFile != null) {
                trFile = StopWords.preprocessForStopWords(this.spark, this.args, trFile);
                Dataset<Row> trPairs = DSUtil.joinWithItself(trFile, "z_cluster", true);
                posPairs = trPairs.filter(trPairs.col("z_isMatch").equalTo((Object)1));
                negPairs = trPairs.filter(trPairs.col("z_isMatch").equalTo((Object)0));
                posPairs = posPairs.drop(new String[]{"z_isMatch", "z_z_isMatch", "z_cluster", "z_z_cluster"});
                negPairs = negPairs.drop(new String[]{"z_isMatch", "z_z_isMatch", "z_cluster", "z_z_cluster"});
                LOG.warn("Read training samples " + posPairs.count() + " neg " + negPairs.count());
            }
            if (posPairs == null || posPairs.count() <= 5L) {
                Dataset<Row> posSamplesOriginal = this.getPositiveSamples(data);
                Dataset posSamples = StopWords.preprocessForStopWords(this.spark, this.args, posSamplesOriginal);
                posPairs = posPairs != null ? posPairs.union(posSamples) : posSamples;
            }
            posPairs = posPairs.cache();
            if (negPairs != null) {
                negPairs = negPairs.cache();
            }
            Dataset<Row> sampleOrginal = data.sample(false, (double)this.args.getLabelDataSampleSize()).repartition(this.args.getNumPartitions()).persist(StorageLevel.MEMORY_ONLY());
            sampleOrginal = DSUtil.getFieldDefColumnsDS(sampleOrginal, this.args, true);
            LOG.info("Preprocessing DS for stopWords");
            Dataset<Row> sample = StopWords.preprocessForStopWords(this.spark, this.args, sampleOrginal);
            Tree<Canopy> tree = BlockingTreeUtil.createBlockingTree(sample, (Dataset<Row>)posPairs, 1.0, -1L, this.args, this.hashFunctions);
            Dataset blocked = sample.map((MapFunction)new Block.BlockFunction(tree), (Encoder)RowEncoder.apply((StructType)Block.appendHashCol(sample.schema())));
            blocked = blocked.repartition(this.args.getNumPartitions(), new Column[]{blocked.col("z_hash")}).cache();
            Dataset blocks = DSUtil.joinWithItself((Dataset<Row>)blocked, "z_hash", true);
            blocks = blocks.cache();
            if (negPairs != null) {
                negPairs = negPairs.persist(StorageLevel.MEMORY_ONLY());
            }
            if (posPairs != null && negPairs != null && posPairs.count() >= 5L && negPairs.count() >= 5L) {
                if (LOG.isDebugEnabled()) {
                    LOG.debug("num blocks " + blocks.count());
                }
                Model model = ModelUtil.createModel((Dataset<Row>)posPairs, (Dataset<Row>)negPairs, new LabelModel(this.featurers), this.spark);
                Dataset dupes = model.predict((Dataset<Row>)blocks);
                if (LOG.isDebugEnabled()) {
                    LOG.debug("num dupes " + dupes.count());
                }
                LOG.info("Writing uncertain pairs");
                dupes = dupes.persist(StorageLevel.MEMORY_ONLY());
                Dataset<Row> uncertain = this.getUncertain((Dataset<Row>)dupes);
                this.writeUncertain(uncertain, sampleOrginal);
            } else {
                LOG.info("Writing uncertain pairs when either positive or negative samples not provided ");
                Dataset posFiltered = blocks.sample(false, 20.0 / (double)blocks.count());
                posFiltered = posFiltered.withColumn("z_prediction", functions.lit((Object)-1.0));
                posFiltered = posFiltered.withColumn("z_score", functions.lit((Object)0.0));
                this.writeUncertain((Dataset<Row>)posFiltered, sampleOrginal);
            }
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new ZinggClientException(e.getMessage());
        }
    }

    public void writeUncertain(Dataset<Row> dupesActual, Dataset<Row> sampleOrginal) throws ZinggClientException {
        dupesActual = DFUtil.addClusterRowNumber(dupesActual, this.spark);
        dupesActual = Util.addUniqueCol(dupesActual, "z_cluster");
        Dataset<Row> dupes1 = DSUtil.alignDupes(dupesActual, this.args);
        dupes1 = DSUtil.postprocess(dupes1, sampleOrginal);
        Dataset dupes2 = dupes1.orderBy("z_cluster", new String[0]);
        LOG.debug("uncertain output schema is " + dupes2.schema());
        PipeUtil.write((Dataset<Row>)dupes2, this.args, this.ctx, this.getUnmarkedLocation());
    }

    public Pipe getUnmarkedLocation() {
        return PipeUtil.getTrainingDataUnmarkedPipe(this.args);
    }

    public Dataset<Row> getUncertain(Dataset<Row> dupes) {
        Dataset pos = dupes.filter(dupes.col("z_prediction").equalTo((Object)1.0));
        pos = pos.sort(new Column[]{functions.asc((String)"z_score")}).cache();
        if (LOG.isDebugEnabled()) {
            LOG.debug("num pos " + pos.count());
        }
        pos = pos.limit(10);
        Dataset neg = dupes.filter(dupes.col("z_prediction").equalTo((Object)0.0));
        neg = neg.sort(new Column[]{functions.desc((String)"z_score")}).cache();
        if (LOG.isDebugEnabled()) {
            LOG.debug("num neg " + neg.count());
        }
        neg = neg.limit(10);
        return pos.union(neg);
    }

    public Dataset<Row> getPositiveSamples(Dataset<Row> data) throws Exception {
        if (LOG.isDebugEnabled()) {
            long count = data.count();
            LOG.debug("Total count is " + count);
            LOG.debug("Label data sample size is " + this.args.getLabelDataSampleSize());
        }
        Dataset posSample = data.sample(false, (double)this.args.getLabelDataSampleSize());
        posSample = DSUtil.getFieldDefColumnsDS((Dataset<Row>)posSample, this.args, true);
        if (LOG.isDebugEnabled()) {
            LOG.debug("Sampled " + posSample.count());
        }
        posSample = posSample.cache();
        Dataset<Row> posPairs = DSUtil.joinWithItself((Dataset<Row>)posSample, "z_zid", false);
        LOG.info("Created positive sample pairs ");
        if (LOG.isDebugEnabled()) {
            LOG.debug("Pos Sample pairs count " + posPairs.count());
        }
        return posPairs;
    }
}

