/*
 * Decompiled with CFR 0.152.
 */
package zingg;

import java.util.List;
import java.util.Scanner;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.sql.Column;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import zingg.Labeller;
import zingg.client.ZinggClientException;
import zingg.client.ZinggOptions;
import zingg.client.pipe.Pipe;
import zingg.util.DSUtil;
import zingg.util.LabelMatchType;
import zingg.util.PipeUtil;

public class LabelUpdater
extends Labeller {
    protected static String name = "zingg.LabelUpdater";
    public static final Log LOG = LogFactory.getLog(LabelUpdater.class);

    public LabelUpdater() {
        this.setZinggOptions(ZinggOptions.UPDATE_LABEL);
    }

    @Override
    public void execute() throws ZinggClientException {
        try {
            LOG.info("Reading inputs for updateLabelling phase ...");
            Dataset<Row> markedRecords = PipeUtil.read(this.spark, false, false, PipeUtil.getTrainingDataMarkedPipe(this.args));
            this.processRecordsCli(markedRecords);
            LOG.info("Finished updataLabelling phase");
        }
        catch (Exception e) {
            e.printStackTrace();
            throw new ZinggClientException(e.getMessage());
        }
    }

    @Override
    public void processRecordsCli(Dataset<Row> lines) throws ZinggClientException {
        LOG.info("Processing Records for CLI updateLabelling");
        if (lines != null && lines.count() > 0L) {
            this.getMarkedRecordsStat(lines);
            this.printMarkedRecordsStat();
            List<Column> displayCols = DSUtil.getFieldDefColumns(lines, this.args, false, this.args.getShowConcise());
            try {
                Dataset updatedRecords = null;
                Dataset recordsToUpdate = lines;
                int selectedOption = -1;
                Scanner sc = new Scanner(System.in);
                do {
                    System.out.print("\n\tPlease enter the cluster id (or 9 to exit): ");
                    String cluster_id = sc.next();
                    if (cluster_id.equals("9")) {
                        LOG.info("User has exit in the middle. Updating the records.");
                        break;
                    }
                    Dataset currentPair = lines.filter(lines.col("z_cluster").equalTo((Object)cluster_id));
                    if (currentPair.isEmpty()) {
                        System.out.println("\tInvalid cluster id. Enter '9' to exit");
                        continue;
                    }
                    int matchFlag = (Integer)((Row)currentPair.head()).getAs("z_isMatch");
                    String preMsg = String.format("\n\tThe record pairs belonging to the input cluster id %s are:", cluster_id);
                    String matchType = LabelMatchType.get((double)((double)matchFlag)).msg;
                    String postMsg = String.format("\tThe above pair is labeled as %s\n", matchType);
                    selectedOption = this.displayRecordsAndGetUserInput(DSUtil.select((Dataset<Row>)currentPair, displayCols), preMsg, postMsg);
                    this.updateLabellerStat(selectedOption, 1);
                    this.updateLabellerStat(matchFlag, -1);
                    this.printMarkedRecordsStat();
                    if (selectedOption == 9) {
                        LOG.info("User has quit in the middle. Updating the records.");
                        break;
                    }
                    recordsToUpdate = recordsToUpdate.filter(recordsToUpdate.col("z_cluster").notEqual((Object)cluster_id));
                    if (updatedRecords != null) {
                        updatedRecords = updatedRecords.filter(updatedRecords.col("z_cluster").notEqual((Object)cluster_id));
                    }
                    updatedRecords = this.updateRecords(selectedOption, (Dataset<Row>)currentPair, (Dataset<Row>)updatedRecords);
                } while (selectedOption != 9);
                if (updatedRecords != null) {
                    updatedRecords = updatedRecords.union(recordsToUpdate);
                }
                this.writeLabelledOutput(updatedRecords);
                sc.close();
                LOG.info("Processing finished.");
            }
            catch (Exception e) {
                if (LOG.isDebugEnabled()) {
                    e.printStackTrace();
                }
                LOG.warn("An error has occured while Updating Label. " + e.getMessage());
                throw new ZinggClientException("An error while updating label", e);
            }
        } else {
            LOG.info("There is no marked record for updating. Please run findTrainingData/label jobs to generate training data.");
        }
    }

    @Override
    protected Pipe getOutputPipe() {
        Pipe p = PipeUtil.getTrainingDataMarkedPipe(this.args);
        p.setMode(SaveMode.Overwrite);
        return p;
    }
}

