/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.example.nnframes.xgboost;

import com.intel.analytics.bigdl.dllib.NNContext$;
import com.intel.analytics.bigdl.dllib.nnframes.XGBClassifier;
import com.intel.analytics.bigdl.dllib.nnframes.XGBClassifierModel;
import ml.dmlc.xgboost4j.scala.spark.TrackerConf;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.MatchError;
import scala.Option;
import scala.Predef;
import scala.Predef$;
import scala.Tuple2;
import scala.Tuple4;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.collection.immutable.Map;
import scala.collection.immutable.StringOps;
import scala.sys.package$;

public final class xgbClassifierTrainingExample$ {
    public static final xgbClassifierTrainingExample$ MODULE$;

    static {
        new xgbClassifierTrainingExample$();
    }

    public void main(String[] args) {
        if (args.length < 4) {
            Predef$.MODULE$.println((Object)"Usage: program inputPath numThreads numRound modelsavePath");
            throw package$.MODULE$.exit(1);
        }
        SparkContext sc = NNContext$.MODULE$.initNNContext();
        SQLContext spark = SQLContext$.MODULE$.getOrCreate(sc);
        String inputPath = args[0];
        int numThreads = new StringOps(Predef$.MODULE$.augmentString(args[1])).toInt();
        int numRound = new StringOps(Predef$.MODULE$.augmentString(args[2])).toInt();
        String modelsavePath = args[3];
        StructType schema = new StructType((StructField[])((Object[])new StructField[]{new StructField("sepal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("sepal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("class", (DataType)StringType$.MODULE$, true, StructField$.MODULE$.apply$default$4())}));
        Dataset df = spark.read().schema(schema).csv(inputPath);
        StringIndexerModel stringIndexer = new StringIndexer().setInputCol("class").setOutputCol("classIndex").fit(df);
        Dataset labelTransformed = stringIndexer.transform(df).drop("class");
        VectorAssembler vectorAssembler = new VectorAssembler().setInputCols((String[])((Object[])new String[]{"sepal length", "sepal width", "petal length", "petal width"})).setOutputCol("features");
        Dataset xgbInput = vectorAssembler.transform(labelTransformed).select("features", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"classIndex"}));
        Dataset[] datasetArray = xgbInput.randomSplit(new double[]{0.6, 0.2, 0.1, 0.1});
        Option option = Array$.MODULE$.unapplySeq((Object)datasetArray);
        if (!option.isEmpty() && option.get() != null && ((SeqLike)option.get()).lengthCompare(4) == 0) {
            Tuple4 tuple4;
            Dataset train2 = (Dataset)((SeqLike)option.get()).apply(0);
            Dataset eval1 = (Dataset)((SeqLike)option.get()).apply(1);
            Dataset eval2 = (Dataset)((SeqLike)option.get()).apply(2);
            Dataset test2 = (Dataset)((SeqLike)option.get()).apply(3);
            Tuple4 tuple42 = tuple4 = new Tuple4((Object)train2, (Object)eval1, (Object)eval2, (Object)test2);
            Dataset train3 = (Dataset)tuple42._1();
            Dataset eval12 = (Dataset)tuple42._2();
            Dataset eval22 = (Dataset)tuple42._3();
            Dataset test3 = (Dataset)tuple42._4();
            Map xgbParam = (Map)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"tracker_conf"), (Object)new TrackerConf(3600L, "scala")), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eval_sets"), (Object)Predef$.MODULE$.Map().apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tuple2[]{Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eval1"), (Object)eval12), Predef.ArrowAssoc$.MODULE$.$minus$greater$extension(Predef$.MODULE$.ArrowAssoc((Object)"eval2"), (Object)eval22)})))}));
            XGBClassifier xgbClassifier = new XGBClassifier((Map<String, Object>)xgbParam);
            xgbClassifier.setFeaturesCol("features");
            xgbClassifier.setLabelCol("classIndex");
            xgbClassifier.setNumClass(3);
            xgbClassifier.setMaxDepth(2);
            xgbClassifier.setNumWorkers(1);
            xgbClassifier.setNthread(numThreads);
            xgbClassifier.setNumRound(numRound);
            xgbClassifier.setTreeMethod("auto");
            xgbClassifier.setObjective("multi:softprob");
            xgbClassifier.setTimeoutRequestWorkers(180000L);
            XGBClassifierModel xgbClassificationModel = xgbClassifier.fit((Dataset<Row>)train3);
            xgbClassificationModel.save(modelsavePath);
            sc.stop();
            return;
        }
        throw new MatchError((Object)datasetArray);
    }

    private xgbClassifierTrainingExample$() {
        MODULE$ = this;
    }
}

