/*
 * Decompiled with CFR 0.152.
 */
package zingg;

import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.function.MapFunction;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;
import scala.collection.Iterator;
import scala.collection.JavaConverters;
import zingg.ZinggBase;
import zingg.block.Block;
import zingg.block.Canopy;
import zingg.block.Tree;
import zingg.client.ZinggClientException;
import zingg.client.ZinggOptions;
import zingg.model.Model;
import zingg.preprocess.StopWords;
import zingg.util.Analytics;
import zingg.util.BlockingTreeUtil;
import zingg.util.DSUtil;
import zingg.util.GraphUtil;
import zingg.util.PipeUtil;

public class Matcher
extends ZinggBase {
    protected static String name = "zingg.Matcher";
    public static final Log LOG = LogFactory.getLog(Matcher.class);

    public Matcher() {
        this.setZinggOptions(ZinggOptions.MATCH);
    }

    protected Dataset<Row> getTestData() throws ZinggClientException {
        Dataset<Row> data = PipeUtil.read(this.spark, true, this.args.getNumPartitions(), true, this.args.getData());
        return data;
    }

    protected Dataset<Row> getBlocked(Dataset<Row> testData) throws Exception, ZinggClientException {
        LOG.debug("Blocking model file location is " + this.args.getBlockFile());
        Tree<Canopy> tree = BlockingTreeUtil.readBlockingTree(this.spark, this.args);
        Dataset blocked = testData.map((MapFunction)new Block.BlockFunction(tree), (Encoder)RowEncoder.apply((StructType)Block.appendHashCol(testData.schema())));
        Dataset blocked1 = blocked.repartition(this.args.getNumPartitions(), new Column[]{blocked.col("z_hash")});
        return blocked1;
    }

    protected Dataset<Row> getBlocks(Dataset<Row> blocked) throws Exception {
        return DSUtil.joinWithItself(blocked, "z_hash", true).cache();
    }

    protected Dataset<Row> getBlocks(Dataset<Row> blocked, Dataset<Row> bAll) throws Exception {
        Dataset joinH = blocked.as("first").join(blocked.as("second"), "z_hash").selectExpr(new String[]{"first.z_zid as z_zid", "second.z_zid as z_z_zid"});
        joinH.show();
        joinH = joinH.filter(joinH.col("z_zid").gt((Object)joinH.col("z_z_zid")));
        LOG.warn("Num comparisons " + joinH.count());
        joinH = joinH.repartition(this.args.getNumPartitions(), new Column[]{joinH.col("z_zid")});
        bAll = bAll.repartition(this.args.getNumPartitions(), new Column[]{bAll.col("z_zid")});
        joinH = joinH.join((Dataset)bAll, "z_zid");
        LOG.warn("Joining with actual values");
        bAll = DSUtil.getPrefixedColumnsDS(bAll);
        joinH = joinH.repartition(this.args.getNumPartitions(), new Column[]{joinH.col("z_z_zid")});
        joinH = joinH.join(bAll, "z_z_zid");
        LOG.warn("Joining again with actual values");
        return joinH;
    }

    protected Dataset<Row> massageAllEquals(Dataset<Row> allEqual) {
        allEqual = allEqual.withColumn("z_prediction", functions.lit((Object)1.0));
        allEqual = allEqual.withColumn("z_score", functions.lit((Object)1.0));
        return allEqual;
    }

    protected Model getModel() {
        Model model = new Model(this.featurers);
        model.register(this.spark);
        model.load(this.args.getModel());
        return model;
    }

    protected Dataset<Row> selectColsFromBlocked(Dataset<Row> blocked) {
        return blocked.select("z_zid", new String[]{"z_hash"});
    }

    @Override
    public void execute() throws ZinggClientException {
        try {
            Dataset<Row> testDataOriginal = this.getTestData();
            testDataOriginal = DSUtil.getFieldDefColumnsDS(testDataOriginal, this.args, true);
            Dataset testData = StopWords.preprocessForStopWords(this.spark, this.args, testDataOriginal);
            testData = testData.repartition(this.args.getNumPartitions(), new Column[]{testData.col("z_zid")});
            long count = testData.count();
            LOG.info("Read " + count);
            Analytics.track("dataCount", count, this.args.getCollectMetrics());
            Dataset<Row> blocked = this.getBlocked((Dataset<Row>)testData);
            LOG.info("Blocked ");
            if (LOG.isDebugEnabled()) {
                LOG.debug("Num distinct hashes " + blocked.select("z_hash", new String[0]).distinct().count());
            }
            Dataset<Row> blocks = this.getBlocks(this.selectColsFromBlocked(blocked), (Dataset<Row>)testData);
            if (LOG.isDebugEnabled()) {
                LOG.debug("block size" + blocks.count());
            }
            Model model = this.getModel();
            Dataset<Row> dupes = model.predict(blocks);
            if (LOG.isDebugEnabled()) {
                LOG.debug("Found dupes " + dupes.count());
            }
            Dataset<Row> dupesActual = this.getDupesActualForGraph(dupes);
            this.writeOutput(testDataOriginal, dupesActual);
        }
        catch (Exception e) {
            if (LOG.isDebugEnabled()) {
                e.printStackTrace();
            }
            throw new ZinggClientException(e.getMessage());
        }
    }

    public void writeOutput(Dataset<Row> blocked, Dataset<Row> dupesActual) throws ZinggClientException {
        try {
            if (this.args.getOutput() != null) {
                dupesActual = dupesActual.cache();
                Dataset graph = GraphUtil.buildGraph(blocked, (Dataset<Row>)dupesActual).cache();
                Dataset score = this.getMinMaxScores((Dataset<Row>)dupesActual, (Dataset<Row>)graph).cache();
                graph = graph.repartition(this.args.getNumPartitions(), new Column[]{graph.col("z_zid")}).cache();
                Dataset graphWithScores = DSUtil.joinZColFirst((Dataset<Row>)score, (Dataset<Row>)graph, "z_zid", false).cache();
                graphWithScores = graphWithScores.drop("z_hash");
                graphWithScores = graphWithScores.drop("z_z_zid");
                graphWithScores = graphWithScores.drop("z_zid");
                graphWithScores = graphWithScores.drop("z_source");
                PipeUtil.write((Dataset<Row>)graphWithScores, this.args, this.ctx, this.args.getOutput());
            }
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    protected Dataset<Row> getMinMaxScores(Dataset<Row> dupes, Dataset<Row> graph) {
        if (LOG.isDebugEnabled()) {
            dupes.show(500);
        }
        Dataset graph1 = graph.select("z_zid", new String[]{"z_cluster"});
        graph1 = graph1.repartition(this.args.getNumPartitions(), new Column[]{graph1.col("z_cluster")});
        Dataset dupesWithIds = dupes.select("z_zid", new String[]{"z_z_zid"});
        LOG.warn("Dupes with ids ");
        if (LOG.isDebugEnabled()) {
            dupesWithIds.show(500);
        }
        Dataset graphPairsFound = graph1.as("first").join(graph1.as("second"), "z_cluster").selectExpr(new String[]{"first.z_zid as z_zid", "second.z_zid as z_z_zid"});
        graphPairsFound = graphPairsFound.filter(graphPairsFound.col("z_zid").gt((Object)graphPairsFound.col("z_z_zid")));
        LOG.warn("graph pairs ");
        if (LOG.isDebugEnabled()) {
            graphPairsFound.show(500);
        }
        Dataset graphPairsExtra = graphPairsFound.except(dupesWithIds);
        Dataset graphPairsExtrawithDummyScore = graphPairsExtra.withColumn("z_score", functions.lit((Object)0.0));
        LOG.warn("graph pairs extra");
        if (LOG.isDebugEnabled()) {
            graphPairsExtra.show(500);
        }
        Dataset s1 = dupes.select("z_score", new String[]{"z_zid"});
        Dataset s2 = dupes.select("z_score", new String[]{"z_z_zid"});
        s1 = s1.union(graphPairsExtrawithDummyScore.select("z_score", new String[]{"z_zid"}));
        s2 = s2.union(graphPairsExtrawithDummyScore.select("z_score", new String[]{"z_z_zid"}));
        ArrayList cols = new ArrayList();
        Dataset s1RightCols = s1.toDF(new String[]{"z_score", "z_z_zid"}).cache();
        Dataset allScores = s1RightCols.union(s2);
        allScores = allScores.repartition(this.args.getNumPartitions(), new Column[]{allScores.col("z_z_zid")});
        return allScores.groupBy(new Column[]{allScores.col("z_z_zid")}).agg(functions.min((String)"z_score").as("z_minScore"), new Column[]{functions.max((String)"z_score").as("z_maxScore")});
    }

    protected Dataset<Row> getDupesActualForGraph(Dataset<Row> dupes) {
        Dataset<Row> dupesActual = this.selectColsFromDupes(dupes);
        LOG.debug("dupes al");
        if (LOG.isDebugEnabled()) {
            dupes.show(false);
        }
        return dupes.filter(dupes.col("z_prediction").equalTo((Object)1.0));
    }

    protected Dataset<Row> selectColsFromDupes(Dataset<Row> dupesActual) {
        ArrayList<Column> cols = new ArrayList<Column>();
        cols.add(dupesActual.col("z_zid"));
        cols.add(dupesActual.col("z_z_zid"));
        cols.add(dupesActual.col("z_prediction"));
        cols.add(dupesActual.col("z_score"));
        Dataset dupesActual1 = dupesActual.select(((Iterator)JavaConverters.asScalaIteratorConverter(cols.iterator()).asScala()).toSeq());
        return dupesActual1;
    }
}

