/*
 * Decompiled with CFR 0.152.
 */
package zingg;

import java.util.List;
import java.util.Scanner;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.functions;
import zingg.ZinggBase;
import zingg.client.Arguments;
import zingg.client.ZinggClientException;
import zingg.client.ZinggOptions;
import zingg.client.pipe.Pipe;
import zingg.util.DSUtil;
import zingg.util.LabelMatchType;
import zingg.util.PipeUtil;

public class Labeller
extends ZinggBase {
    protected static String name = "zingg.Labeller";
    public static final Log LOG = LogFactory.getLog(Labeller.class);
    long positivePairsCount;
    long negativePairsCount;
    long notSurePairsCount;
    long totalCount;

    public Labeller() {
        this.setZinggOptions(ZinggOptions.LABEL);
    }

    @Override
    public void execute() throws ZinggClientException {
        try {
            LOG.info("Reading inputs for labelling phase ...");
            this.getMarkedRecordsStat(this.getMarkedRecords());
            Dataset<Row> unmarkedRecords = this.getUnmarkedRecords();
            this.processRecordsCli(unmarkedRecords);
            LOG.info("Finished labelling phase");
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new ZinggClientException(e.getMessage());
        }
    }

    protected void getMarkedRecordsStat(Dataset<Row> markedRecords) {
        if (markedRecords != null) {
            this.positivePairsCount = this.getMatchedMarkedRecordsStat(markedRecords);
            this.negativePairsCount = this.getUnmatchedMarkedRecordsStat(markedRecords);
            this.notSurePairsCount = this.getUnsureMarkedRecordsStat(markedRecords);
            this.totalCount = markedRecords.count() / 2L;
        }
    }

    public List<Row> getClusterIds(Dataset<Row> lines) {
        return lines.select("z_cluster", new String[0]).distinct().collectAsList();
    }

    public List<Column> getDisplayColumns(Dataset<Row> lines, Arguments args) {
        return DSUtil.getFieldDefColumns(lines, args, false, args.getShowConcise());
    }

    public Dataset<Row> getCurrentPair(Dataset<Row> lines, int index, List<Row> clusterIds) {
        return lines.filter(lines.col("z_cluster").equalTo(clusterIds.get(index).getAs("z_cluster"))).cache();
    }

    public double getScore(Dataset<Row> currentPair) {
        return (Double)((Row)currentPair.head()).getAs("z_score");
    }

    public double getPrediction(Dataset<Row> currentPair) {
        return (Double)((Row)currentPair.head()).getAs("z_prediction");
    }

    public String getMsg1(int index, int totalPairs) {
        return String.format("\tCurrent labelling round  : %d/%d pairs labelled\n", index, totalPairs);
    }

    public String getMsg2(double prediction, double score) {
        String msg2 = "";
        String matchType = LabelMatchType.get((double)prediction).msg;
        msg2 = prediction == -1.0 ? String.format("\tZingg does not do any prediction for the above pairs as Zingg is still collecting training data to build the preliminary models.", new Object[0]) : String.format("\tZingg predicts the above records %s with a similarity score of %.2f", matchType, Math.floor(score * 100.0) * 0.01);
        return msg2;
    }

    public void processRecordsCli(Dataset<Row> lines) throws ZinggClientException {
        LOG.info("Processing Records for CLI Labelling");
        if (lines != null && lines.count() > 0L) {
            this.printMarkedRecordsStat();
            lines = lines.cache();
            List<Column> displayCols = this.getDisplayColumns((Dataset<Row>)lines, this.args);
            List<Row> clusterIDs = this.getClusterIds((Dataset<Row>)lines);
            try {
                Dataset<Row> updatedRecords = null;
                int selected_option = -1;
                int totalPairs = clusterIDs.size();
                for (int index = 0; index < totalPairs; ++index) {
                    Dataset<Row> currentPair = this.getCurrentPair((Dataset<Row>)lines, index, clusterIDs);
                    double score = this.getScore(currentPair);
                    double prediction = this.getPrediction(currentPair);
                    String msg1 = this.getMsg1(index, totalPairs);
                    String msg2 = this.getMsg2(prediction, score);
                    selected_option = this.displayRecordsAndGetUserInput(DSUtil.select(currentPair, displayCols), msg1, msg2);
                    this.updateLabellerStat(selected_option, 1);
                    this.printMarkedRecordsStat();
                    if (selected_option == 9) {
                        LOG.info("User has quit in the middle. Updating the records.");
                        break;
                    }
                    updatedRecords = this.updateRecords(selected_option, currentPair, updatedRecords);
                }
                this.writeLabelledOutput(updatedRecords);
                LOG.warn("Processing finished.");
            }
            catch (Exception e) {
                if (LOG.isDebugEnabled()) {
                    e.printStackTrace();
                }
                LOG.warn("Labelling error has occured " + e.getMessage());
                throw new ZinggClientException("An error has occured while Labelling.", e);
            }
        } else {
            LOG.info("It seems there are no unmarked records at this moment. Please run findTrainingData job to build some pairs to be labelled and then run this labeler.");
        }
    }

    protected int displayRecordsAndGetUserInput(Dataset<Row> records, String preMessage, String postMessage) {
        System.out.println(preMessage);
        records.show(false);
        System.out.println(postMessage);
        System.out.println("\tWhat do you think? Your choices are: ");
        int selection = this.readCliInput();
        return selection;
    }

    protected Dataset<Row> updateRecords(int matchValue, Dataset<Row> newRecords, Dataset<Row> updatedRecords) {
        newRecords = newRecords.withColumn("z_isMatch", functions.lit((Object)matchValue));
        updatedRecords = updatedRecords == null ? newRecords : updatedRecords.union(newRecords);
        return updatedRecords;
    }

    int readCliInput() {
        Scanner sc = new Scanner(System.in);
        System.out.println();
        System.out.println("\tNo, they do not match : 0");
        System.out.println("\tYes, they match       : 1");
        System.out.println("\tNot sure              : 2");
        System.out.println();
        System.out.println("\tTo exit               : 9");
        System.out.println();
        System.out.print("\tPlease enter your choice [0,1,2 or 9]: ");
        while (!sc.hasNext("[0129]")) {
            sc.next();
            System.out.println("Nope, please enter one of the allowed options!");
        }
        String word = sc.next();
        int selection = Integer.parseInt(word);
        return selection;
    }

    protected void updateLabellerStat(int selected_option, int increment) {
        this.totalCount += (long)increment;
        if (selected_option == 1) {
            this.positivePairsCount += (long)increment;
        } else if (selected_option == 0) {
            this.negativePairsCount += (long)increment;
        } else if (selected_option == 2) {
            this.notSurePairsCount += (long)increment;
        }
    }

    protected void printMarkedRecordsStat() {
        String msg = String.format("\tLabelled pairs so far    : %d/%d MATCH, %d/%d DO NOT MATCH, %d/%d NOT SURE", this.positivePairsCount, this.totalCount, this.negativePairsCount, this.totalCount, this.notSurePairsCount, this.totalCount);
        System.out.println();
        System.out.println();
        System.out.println();
        System.out.println(msg);
    }

    protected void writeLabelledOutput(Dataset<Row> records) throws ZinggClientException {
        if (records == null) {
            LOG.warn("No records to be labelled.");
            return;
        }
        PipeUtil.write(records, this.args, this.ctx, this.getOutputPipe());
    }

    protected Pipe getOutputPipe() {
        return PipeUtil.getTrainingDataMarkedPipe(this.args);
    }
}

