/*
 * Decompiled with CFR 0.152.
 */
package zingg.util;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.DataFrameReader;
import org.apache.spark.sql.DataFrameWriter;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.apache.spark.storage.StorageLevel;
import zingg.client.Arguments;
import zingg.client.ZinggClientException;
import zingg.client.pipe.ElasticPipe;
import zingg.client.pipe.InMemoryPipe;
import zingg.client.pipe.Pipe;
import zingg.scala.DFUtil;

public class PipeUtil {
    public static final Log LOG = LogFactory.getLog(PipeUtil.class);

    private static DataFrameReader getReader(SparkSession spark, Pipe p) {
        DataFrameReader reader = spark.read();
        LOG.warn("Reading input " + p.getFormat());
        reader = reader.format(p.getFormat());
        if (p.getSchema() != null) {
            reader = reader.schema(p.getSchema());
        }
        for (String key : p.getProps().keySet()) {
            reader = reader.option(key, p.get(key));
        }
        reader = reader.option("mode", "PERMISSIVE");
        return reader;
    }

    private static Dataset<Row> read(DataFrameReader reader, Pipe p, boolean addSource) throws ZinggClientException {
        Dataset input = null;
        LOG.warn("Reading " + p);
        try {
            input = p.getFormat() == "inMemory" ? ((InMemoryPipe)p).getRecords() : (p.getProps().containsKey("location") ? reader.load(p.get("location")) : reader.load());
            if (addSource) {
                input = input.withColumn("z_source", functions.lit((Object)p.getName()));
            }
        }
        catch (Exception ex) {
            LOG.warn(ex.getMessage());
            throw new ZinggClientException("Could not read data.", ex);
        }
        return input;
    }

    private static Dataset<Row> readInternal(SparkSession spark, Pipe p, boolean addSource) throws ZinggClientException {
        DataFrameReader reader = PipeUtil.getReader(spark, p);
        return PipeUtil.read(reader, p, addSource);
    }

    public static Dataset<Row> joinTrainingSetstoGetLabels(Dataset<Row> jdbc, Dataset<Row> file) {
        file = file.drop("z_isMatch");
        file.printSchema();
        file.show();
        jdbc = jdbc.select(new Column[]{jdbc.col("z_zid"), jdbc.col("z_source"), jdbc.col("z_isMatch"), jdbc.col("z_cluster")});
        String[] cols = jdbc.columns();
        for (int i = 0; i < cols.length; ++i) {
            cols[i] = "z_" + cols[i];
        }
        jdbc = jdbc.toDF(cols).cache();
        jdbc = jdbc.withColumnRenamed("z_z_isMatch", "z_isMatch");
        jdbc.printSchema();
        jdbc.show();
        LOG.warn("Building labels ");
        Dataset pairs = file.join(jdbc, file.col("z_zid").equalTo((Object)jdbc.col("z_z_zid")).and(file.col("z_source").equalTo((Object)jdbc.col("z_z_source"))).and(file.col("z_cluster").equalTo((Object)jdbc.col("z_z_cluster"))));
        LOG.warn("Pairs are " + pairs.count());
        pairs = pairs.drop("z_z_source");
        pairs = pairs.drop("z_z_zid");
        pairs = pairs.drop("z_z_cluster");
        return pairs;
    }

    private static Dataset<Row> readInternal(SparkSession spark, boolean addLineNo, boolean addSource, Pipe ... pipes) throws ZinggClientException {
        Dataset<Row> input = null;
        for (Pipe p : pipes) {
            if (input == null) {
                input = PipeUtil.readInternal(spark, p, addSource);
                LOG.debug("input size is " + input.count());
                continue;
            }
            if (p.get("type") != null && p.get("type").equals("join")) {
                LOG.warn("joining inputs");
                Dataset<Row> input1 = PipeUtil.readInternal(spark, p, addSource);
                LOG.warn("input now size is " + input1.count());
                input = PipeUtil.joinTrainingSetstoGetLabels(input, input1);
                continue;
            }
            input = input.union(PipeUtil.readInternal(spark, p, addSource));
        }
        if (addLineNo) {
            input = DFUtil.addRowNumber(input, spark);
        }
        return input;
    }

    public static Dataset<Row> read(SparkSession spark, boolean addLineNo, boolean addSource, Pipe ... pipes) throws ZinggClientException {
        Dataset rows = PipeUtil.readInternal(spark, addLineNo, addSource, pipes);
        rows = rows.persist(StorageLevel.MEMORY_ONLY());
        return rows;
    }

    public static Dataset<Row> sample(SparkSession spark, Pipe p) throws ZinggClientException {
        DataFrameReader reader = PipeUtil.getReader(spark, p);
        reader.option("inferSchema", true);
        reader.option("mode", "DROPMALFORMED");
        LOG.info("reader is ready to sample with inferring " + p.get("location"));
        LOG.warn("Reading input of type " + p.getFormat());
        Dataset<Row> input = PipeUtil.read(reader, p, false);
        List values = input.takeAsList(10);
        values.forEach(r -> LOG.warn(r));
        Dataset ret = spark.createDataFrame(values, input.schema());
        return ret;
    }

    public static Dataset<Row> read(SparkSession spark, boolean addLineNo, int numPartitions, boolean addSource, Pipe ... pipes) throws ZinggClientException {
        Dataset rows = PipeUtil.readInternal(spark, addLineNo, addSource, pipes);
        rows = rows.repartition(numPartitions);
        rows = rows.persist(StorageLevel.MEMORY_ONLY());
        return rows;
    }

    public static void write(Dataset<Row> toWriteOrig, Arguments args, JavaSparkContext ctx, Pipe ... pipes) throws ZinggClientException {
        try {
            for (Pipe p : pipes) {
                Dataset<Row> toWrite = toWriteOrig;
                DataFrameWriter writer = toWrite.write();
                LOG.warn("Writing output " + p);
                if (p.getFormat() == "inMemory") {
                    p.setDataset(toWriteOrig);
                    return;
                }
                if (p.getMode() != null) {
                    writer.mode(p.getMode());
                } else {
                    writer.mode("Append");
                }
                if (p.getFormat().equals("org.elasticsearch.spark.sql")) {
                    ctx.getConf().set(ElasticPipe.NODE, p.getProps().get(ElasticPipe.NODE));
                    ctx.getConf().set(ElasticPipe.PORT, p.getProps().get(ElasticPipe.PORT));
                    ctx.getConf().set(ElasticPipe.ID, "z_zid");
                    ctx.getConf().set(ElasticPipe.RESOURCE, p.getName());
                }
                writer = writer.format(p.getFormat());
                for (String key : p.getProps().keySet()) {
                    writer = writer.option(key, p.get(key));
                }
                if (p.getFormat() == "org.apache.spark.sql.cassandra") continue;
                if (p.getProps().containsKey("location")) {
                    LOG.warn("Writing file");
                    writer.save(p.get("location"));
                    continue;
                }
                if (p.getFormat().equals("jdbc")) {
                    writer = toWrite.write();
                    writer = writer.format(p.getFormat());
                    if (p.getMode() != null) {
                        writer.mode(p.getMode());
                    } else {
                        writer.mode("Append");
                    }
                    for (String key : p.getProps().keySet()) {
                        writer = writer.option(key, p.get(key));
                    }
                    writer.save();
                    continue;
                }
                writer.save();
            }
        }
        catch (Exception ex) {
            throw new ZinggClientException(ex.getMessage());
        }
    }

    public static void writePerSource(Dataset<Row> toWrite, Arguments args, JavaSparkContext ctx, Pipe[] pipes) throws ZinggClientException {
        List sources = toWrite.select("z_source", new String[0]).distinct().collectAsList();
        for (Row r : sources) {
            Dataset toWriteNow = toWrite.filter(toWrite.col("z_source").equalTo(r.get(0)));
            toWriteNow = toWriteNow.drop("z_source");
            PipeUtil.write((Dataset<Row>)toWriteNow, args, ctx, pipes);
        }
    }

    public static Pipe getTrainingDataUnmarkedPipe(Arguments args) {
        Pipe p = new Pipe();
        p.setFormat("parquet");
        p.setProp("location", args.getZinggTrainingDataUnmarkedDir());
        return p;
    }

    public static Pipe getTrainingDataMarkedPipe(Arguments args) {
        Pipe p = new Pipe();
        p.setFormat("parquet");
        p.setProp("location", args.getZinggTrainingDataMarkedDir());
        return p;
    }

    public static Pipe getModelDocumentationPipe(Arguments args) {
        Pipe p = new Pipe();
        p.setFormat("text");
        p.setProp("location", args.getZinggModelDocFile());
        return p;
    }

    public static Pipe getStopWordsPipe(Arguments args, String fileName) {
        Pipe p = new Pipe();
        p.setFormat("csv");
        p.setProp("header", "true");
        p.setProp("location", fileName);
        p.setMode(SaveMode.Overwrite);
        return p;
    }

    public static Pipe getBlockingTreePipe(Arguments args) {
        Pipe p = new Pipe();
        p.setFormat("parquet");
        p.setProp("location", args.getBlockFile());
        p.setMode(SaveMode.Overwrite);
        return p;
    }

    public static String getPipesAsString(Pipe[] pipes) {
        return Arrays.stream(pipes).map(p -> p.getFormat()).collect(Collectors.toList()).stream().reduce((p1, p2) -> p1 + "," + p2).map(Object::toString).orElse("");
    }
}

