/*
 * Decompiled with CFR 0.152.
 */
package zingg.util;

import java.io.Serializable;
import java.util.ArrayList;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.StorageLevel;
import zingg.block.Block;
import zingg.block.Canopy;
import zingg.block.Tree;
import zingg.client.Arguments;
import zingg.client.FieldDefinition;
import zingg.client.MatchType;
import zingg.client.ZinggClientException;
import zingg.client.util.ListMap;
import zingg.client.util.Util;
import zingg.hash.HashFunction;
import zingg.util.Heuristics;
import zingg.util.PipeUtil;

public class BlockingTreeUtil {
    public static final Log LOG = LogFactory.getLog(BlockingTreeUtil.class);

    public static Tree<Canopy> createBlockingTree(Dataset<Row> testData, Dataset<Row> positives, double sampleFraction, long blockSize, Arguments args, ListMap<DataType, HashFunction> hashFunctions) throws Exception {
        Dataset sample = testData.sample(false, sampleFraction);
        sample = sample.persist(StorageLevel.MEMORY_ONLY());
        long totalCount = sample.count();
        if (LOG.isDebugEnabled()) {
            LOG.debug("Learning blocking rules for sample count " + totalCount + " and pos " + positives.count() + " and testData count " + testData.count());
        }
        if (blockSize == -1L) {
            blockSize = Heuristics.getMaxBlockSize(totalCount, args.getBlockSize());
        }
        LOG.info("Learning indexing rules for block size " + blockSize);
        positives = positives.coalesce(1);
        Block cblock = new Block((Dataset<Row>)sample, (Dataset<Row>)positives, hashFunctions, blockSize);
        Canopy root = new Canopy(sample.collectAsList(), positives.collectAsList());
        ArrayList<FieldDefinition> fd = new ArrayList<FieldDefinition>();
        for (FieldDefinition def : args.getFieldDefinition()) {
            if (def.getMatchType() == null || def.getMatchType().contains(MatchType.DONT_USE)) continue;
            fd.add(def);
        }
        Tree<Canopy> blockingTree = cblock.getBlockingTree(null, null, root, fd);
        if (LOG.isDebugEnabled()) {
            LOG.debug("The blocking tree is ");
            blockingTree.print(2);
        }
        return blockingTree;
    }

    public static Tree<Canopy> createBlockingTreeFromSample(Dataset<Row> testData, Dataset<Row> positives, double sampleFraction, long blockSize, Arguments args, ListMap<DataType, HashFunction> hashFunctions) throws Exception {
        Dataset sample = testData.sample(false, sampleFraction);
        return BlockingTreeUtil.createBlockingTree((Dataset<Row>)sample, positives, sampleFraction, blockSize, args, hashFunctions);
    }

    public static void writeBlockingTree(SparkSession spark, JavaSparkContext ctx, Tree<Canopy> blockingTree, Arguments args) throws Exception, ZinggClientException {
        byte[] byteArray = Util.convertObjectIntoByteArray(blockingTree);
        StructType schema = DataTypes.createStructType((StructField[])new StructField[]{DataTypes.createStructField((String)"BlockingTree", (DataType)DataTypes.BinaryType, (boolean)false)});
        ArrayList<byte[]> objList = new ArrayList<byte[]>();
        objList.add(byteArray);
        JavaRDD rowRDD = ctx.parallelize(objList).map((Function & Serializable)row -> RowFactory.create((Object[])new Object[]{row}));
        Dataset df = spark.sqlContext().createDataFrame(rowRDD, schema).toDF().coalesce(1);
        PipeUtil.write((Dataset<Row>)df, args, ctx, PipeUtil.getBlockingTreePipe(args));
    }

    public static Tree<Canopy> readBlockingTree(SparkSession spark, Arguments args) throws Exception, ZinggClientException {
        Dataset<Row> tree = PipeUtil.read(spark, false, args.getNumPartitions(), false, PipeUtil.getBlockingTreePipe(args));
        byte[] byteArrayBack = (byte[])((Row)tree.head()).get(0);
        Tree blockingTree = null;
        blockingTree = (Tree)Util.revertObjectFromByteArray(byteArrayBack);
        return blockingTree;
    }
}

