/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.example.nnframes.xgboost;

import com.intel.analytics.bigdl.dllib.NNContext$;
import com.intel.analytics.bigdl.dllib.example.nnframes.xgboost.Params;
import com.intel.analytics.bigdl.dllib.example.nnframes.xgboost.Params$;
import com.intel.analytics.bigdl.dllib.example.nnframes.xgboost.Task;
import com.intel.analytics.bigdl.dllib.nnframes.XGBClassifier;
import com.intel.analytics.bigdl.dllib.nnframes.XGBClassifierModel;
import ml.dmlc.xgboost4j.scala.spark.TrackerConf;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.Row$;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.LongType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.collection.immutable.IndexedSeq$;
import scala.collection.immutable.Map;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.RichInt$;
import scala.runtime.ScalaRunTime$;
import scopt.OptionParser;
import scopt.Read$;

public final class xgbClassifierTrainingExampleOnCriteoClickLogsDataset$ {
    public static final xgbClassifierTrainingExampleOnCriteoClickLogsDataset$ MODULE$;
    private final int featureNum;
    private final OptionParser<Params> parser;

    static {
        new xgbClassifierTrainingExampleOnCriteoClickLogsDataset$();
    }

    public int featureNum() {
        return this.featureNum;
    }

    public void main(String[] args) {
        Logger log2 = LoggerFactory.getLogger(this.getClass());
        Params params = (Params)this.parser().parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])args), new Params(Params$.MODULE$.$lessinit$greater$default$1(), Params$.MODULE$.$lessinit$greater$default$2(), Params$.MODULE$.$lessinit$greater$default$3(), Params$.MODULE$.$lessinit$greater$default$4(), Params$.MODULE$.$lessinit$greater$default$5(), Params$.MODULE$.$lessinit$greater$default$6())).get();
        String trainingDataPath = params.trainingDataPath();
        String modelSavePath = params.modelSavePath();
        int numThread = params.numThread();
        int numRound = params.numRound();
        int maxDepth = params.maxDepth();
        int numWorkers = params.numWorkers();
        SparkContext sc = NNContext$.MODULE$.initNNContext();
        SQLContext spark = SQLContext$.MODULE$.getOrCreate(sc);
        Task task = new Task();
        long tStart = System.nanoTime();
        Dataset df = spark.read().option("header", "false").option("inferSchema", "true").option("delimiter", "\t").csv(trainingDataPath);
        long tBeforePreprocess = System.nanoTime();
        float elapsed = (float)((double)(tBeforePreprocess - tStart) / 1.0E9);
        log2.info(new StringBuilder().append((Object)"--reading data time is ").append((Object)BoxesRunTime.boxToFloat((float)elapsed)).append((Object)" s").toString());
        RDD processedRdd = df.rdd().map((Function1)new Serializable(task){
            public static final long serialVersionUID = 0L;
            private final Task task$1;

            public final String apply(Row row) {
                return this.task$1.rowToLibsvm(row);
            }
            {
                this.task$1 = task$1;
            }
        }, ClassTag$.MODULE$.apply(String.class));
        ObjectRef structFieldArray = ObjectRef.create((Object)new StructField[this.featureNum() + 1]);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), this.featureNum()).foreach$mVc$sp((Function1)new Serializable(structFieldArray){
            public static final long serialVersionUID = 0L;
            private final ObjectRef structFieldArray$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                ((StructField[])this.structFieldArray$1.elem)[i] = new StructField(new StringBuilder().append((Object)"_c").append((Object)((Object)BoxesRunTime.boxToInteger((int)i)).toString()).toString(), (DataType)LongType$.MODULE$, true, StructField$.MODULE$.apply$default$4());
            }
            {
                this.structFieldArray$1 = structFieldArray$1;
            }
        });
        StructType schema = new StructType((StructField[])structFieldArray.elem);
        RDD rowRDD = processedRdd.map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final String[] apply(String x$1) {
                return x$1.split(" ");
            }
        }, ClassTag$.MODULE$.apply(ScalaRunTime$.MODULE$.arrayClass(String.class))).map((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Row apply(String[] row) {
                return Row$.MODULE$.fromSeq((Seq)RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), xgbClassifierTrainingExampleOnCriteoClickLogsDataset$.MODULE$.featureNum()).map((Function1)new Serializable(this, row){
                    public static final long serialVersionUID = 0L;
                    private final String[] row$2;

                    public final long apply(int i) {
                        return this.apply$mcJI$sp(i);
                    }

                    public long apply$mcJI$sp(int i) {
                        return new StringOps(Predef$.MODULE$.augmentString(this.row$2[i])).toLong();
                    }
                    {
                        this.row$2 = row$2;
                    }
                }, IndexedSeq$.MODULE$.canBuildFrom()));
            }
        }, ClassTag$.MODULE$.apply(Row.class));
        df = spark.createDataFrame(rowRDD, schema);
        StringIndexerModel stringIndexer = new StringIndexer().setInputCol("_c0").setOutputCol("classIndex").fit(df);
        Dataset labelTransformed = stringIndexer.transform(df).drop("_c0");
        ObjectRef inputCols = ObjectRef.create((Object)new String[this.featureNum()]);
        RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(0), this.featureNum() - 1).foreach$mVc$sp((Function1)new Serializable(inputCols){
            public static final long serialVersionUID = 0L;
            private final ObjectRef inputCols$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                ((String[])this.inputCols$1.elem)[i] = new StringBuilder().append((Object)"_c").append((Object)((Object)BoxesRunTime.boxToInteger((int)(i + 1))).toString()).toString();
            }
            {
                this.inputCols$1 = inputCols$1;
            }
        });
        VectorAssembler vectorAssembler = new VectorAssembler().setInputCols((String[])inputCols.elem).setOutputCol("features");
        Dataset xgbInput = vectorAssembler.transform(labelTransformed).select("features", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"classIndex"}));
        Dataset[] datasetArray = xgbInput.randomSplit(new double[]{0.6, 0.2, 0.1, 0.1});
        Option option = Array$.MODULE$.unapplySeq((Object)datasetArray);
        if (!option.isEmpty() && option.get() != null && ((SeqLike)option.get()).lengthCompare(4) == 0) {
            Tuple4 tuple4;
            Dataset train2 = (Dataset)((SeqLike)option.get()).apply(0);
            Dataset eval1 = (Dataset)((SeqLike)option.get()).apply(1);
            Dataset eval2 = (Dataset)((SeqLike)option.get()).apply(2);
            Dataset test2 = (Dataset)((SeqLike)option.get()).apply(3);
            Tuple4 tuple42 = tuple4 = new Tuple4((Object)train2, (Object)eval1, (Object)eval2, (Object)test2);
            Dataset train3 = (Dataset)tuple42._1();
            Dataset eval12 = (Dataset)tuple42._2();
            Dataset eval22 = (Dataset)tuple42._3();
            Dataset test3 = (Dataset)tuple42._4();
            train3.cache().count();
            eval12.cache().count();
            eval22.cache().count();
            long tBeforeTraining = System.nanoTime();
            elapsed = (float)((double)(tBeforeTraining - tBeforePreprocess) / 1.0E9);
            log2.info(new StringBuilder().append((Object)"--preprocess time is ").append((Object)BoxesRunTime.boxToFloat((float)elapsed)).append((Object)" s").toString());
            Map xgbParam = (Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"tracker_conf"), (Object)new TrackerConf(0L, "scala")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eval_sets"), (Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eval1"), (Object)eval12), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eval2"), (Object)eval22)})))}));
            XGBClassifier xgbClassifier = new XGBClassifier((Map<String, Object>)xgbParam);
            xgbClassifier.setFeaturesCol("features");
            xgbClassifier.setLabelCol("classIndex");
            xgbClassifier.setNumClass(2);
            xgbClassifier.setNumWorkers(numWorkers);
            xgbClassifier.setMaxDepth(maxDepth);
            xgbClassifier.setNthread(numThread);
            xgbClassifier.setNumRound(numRound);
            xgbClassifier.setTreeMethod("auto");
            xgbClassifier.setObjective("multi:softprob");
            xgbClassifier.setTimeoutRequestWorkers(180000L);
            XGBClassifierModel xgbClassificationModel = xgbClassifier.fit((Dataset<Row>)train3);
            long tAfterTraining = System.nanoTime();
            elapsed = (float)((double)(tAfterTraining - tBeforeTraining) / 1.0E9);
            log2.info(new StringBuilder().append((Object)"--training time is ").append((Object)BoxesRunTime.boxToFloat((float)elapsed)).append((Object)" s").toString());
            xgbClassificationModel.save(modelSavePath);
            long tAfterSave = System.nanoTime();
            elapsed = (float)((double)(tAfterSave - tAfterTraining) / 1.0E9);
            log2.info(new StringBuilder().append((Object)"--model save time is ").append((Object)BoxesRunTime.boxToFloat((float)elapsed)).append((Object)" s").toString());
            elapsed = (float)((double)(tAfterSave - tStart) / 1.0E9);
            log2.info(new StringBuilder().append((Object)"--end-to-end time is ").append((Object)BoxesRunTime.boxToFloat((float)elapsed)).append((Object)" s").toString());
            sc.stop();
            return;
        }
        throw new MatchError((Object)datasetArray);
    }

    public OptionParser<Params> parser() {
        return this.parser;
    }

    private xgbClassifierTrainingExampleOnCriteoClickLogsDataset$() {
        MODULE$ = this;
        this.featureNum = 39;
        this.parser = new OptionParser<Params>(){
            {
                this.opt('i', "trainingDataPath", Read$.MODULE$.stringRead()).text("trainingData Path").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Params apply(String v, Params p) {
                        return p.copy(v, p.copy$default$2(), p.copy$default$3(), p.copy$default$4(), p.copy$default$5(), p.copy$default$6());
                    }
                }).required();
                this.opt('s', "modelSavePath", Read$.MODULE$.stringRead()).text("savePath of model").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Params apply(String v, Params p) {
                        String x$3 = v;
                        String x$4 = p.copy$default$1();
                        int x$5 = p.copy$default$3();
                        int x$6 = p.copy$default$4();
                        int x$7 = p.copy$default$5();
                        int x$8 = p.copy$default$6();
                        return p.copy(x$4, x$3, x$5, x$6, x$7, x$8);
                    }
                }).required();
                this.opt('t', "numThread", Read$.MODULE$.intRead()).text("threads num").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Params apply(int v, Params p) {
                        int x$9 = v;
                        String x$10 = p.copy$default$1();
                        String x$11 = p.copy$default$2();
                        int x$12 = p.copy$default$4();
                        int x$13 = p.copy$default$5();
                        int x$14 = p.copy$default$6();
                        return p.copy(x$10, x$11, x$9, x$12, x$13, x$14);
                    }
                });
                this.opt('r', "numRound", Read$.MODULE$.intRead()).text("Round num").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Params apply(int v, Params p) {
                        int x$15 = v;
                        String x$16 = p.copy$default$1();
                        String x$17 = p.copy$default$2();
                        int x$18 = p.copy$default$3();
                        int x$19 = p.copy$default$5();
                        int x$20 = p.copy$default$6();
                        return p.copy(x$16, x$17, x$18, x$15, x$19, x$20);
                    }
                });
                this.opt('d', "maxDepth", Read$.MODULE$.intRead()).text("maxDepth").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Params apply(int v, Params p) {
                        int x$21 = v;
                        String x$22 = p.copy$default$1();
                        String x$23 = p.copy$default$2();
                        int x$24 = p.copy$default$3();
                        int x$25 = p.copy$default$4();
                        int x$26 = p.copy$default$6();
                        return p.copy(x$22, x$23, x$24, x$25, x$21, x$26);
                    }
                });
                this.opt('w', "numWorkers", Read$.MODULE$.intRead()).text("Workers num").action(new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Params apply(int v, Params p) {
                        int x$27 = v;
                        String x$28 = p.copy$default$1();
                        String x$29 = p.copy$default$2();
                        int x$30 = p.copy$default$3();
                        int x$31 = p.copy$default$4();
                        int x$32 = p.copy$default$5();
                        return p.copy(x$28, x$29, x$30, x$31, x$32, x$27);
                    }
                });
            }
        };
    }
}

