/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.example.nnframes.lightGBM;

import com.intel.analytics.bigdl.dllib.NNContext$;
import com.intel.analytics.bigdl.dllib.example.nnframes.lightGBM.Utils;
import com.intel.analytics.bigdl.dllib.example.nnframes.lightGBM.Utils$;
import com.intel.analytics.bigdl.dllib.example.nnframes.lightGBM.Utils$LGBMParams$;
import com.intel.analytics.bigdl.dllib.nnframes.LightGBMClassifier;
import com.intel.analytics.bigdl.dllib.nnframes.LightGBMClassifier$;
import com.intel.analytics.bigdl.dllib.nnframes.LightGBMClassifierModel;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.feature.StringIndexerModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.SQLContext$;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DoubleType$;
import org.apache.spark.sql.types.StringType$;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructField$;
import org.apache.spark.sql.types.StructType;
import scala.Array$;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.SeqLike;
import scala.runtime.BoxesRunTime;

public final class LgbmClassifierTrain$ {
    public static final LgbmClassifierTrain$ MODULE$;

    static {
        new LgbmClassifierTrain$();
    }

    public void main(String[] args) {
        Utils.LGBMParams defaultParams = new Utils.LGBMParams(Utils$LGBMParams$.MODULE$.apply$default$1(), Utils$LGBMParams$.MODULE$.apply$default$2(), Utils$LGBMParams$.MODULE$.apply$default$3(), Utils$LGBMParams$.MODULE$.apply$default$4(), Utils$LGBMParams$.MODULE$.apply$default$5(), Utils$LGBMParams$.MODULE$.apply$default$6(), Utils$LGBMParams$.MODULE$.apply$default$7(), Utils$LGBMParams$.MODULE$.apply$default$8(), Utils$LGBMParams$.MODULE$.apply$default$9(), Utils$LGBMParams$.MODULE$.apply$default$10(), Utils$LGBMParams$.MODULE$.apply$default$11());
        Utils$.MODULE$.parser().parse((Seq<String>)Predef$.MODULE$.wrapRefArray((Object[])args), defaultParams).foreach((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final void apply(Utils.LGBMParams params) {
                SparkContext sc = NNContext$.MODULE$.initNNContext("LGBM example");
                SQLContext spark = SQLContext$.MODULE$.getOrCreate(sc);
                StructType schema = new StructType((StructField[])((Object[])new StructField[]{new StructField("sepal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("sepal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal length", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("petal width", (DataType)DoubleType$.MODULE$, true, StructField$.MODULE$.apply$default$4()), new StructField("class", (DataType)StringType$.MODULE$, true, StructField$.MODULE$.apply$default$4())}));
                Dataset df = spark.read().schema(schema).csv(params.inputPath()).repartition(params.nPartition());
                StringIndexerModel stringIndexer = new StringIndexer().setInputCol("class").setOutputCol("classIndex").fit(df);
                Dataset labelTransformed = stringIndexer.transform(df).drop("class");
                VectorAssembler vectorAssembler = new VectorAssembler().setInputCols((String[])((Object[])new String[]{"sepal length", "sepal width", "petal length", "petal width"})).setOutputCol("features");
                Dataset dfinput = vectorAssembler.transform(labelTransformed).select("features", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"classIndex"}));
                Dataset[] datasetArray = dfinput.randomSplit(new double[]{0.8, 0.2});
                Option option = Array$.MODULE$.unapplySeq((Object)datasetArray);
                if (!option.isEmpty() && option.get() != null && ((SeqLike)option.get()).lengthCompare(2) == 0) {
                    Tuple2 tuple2;
                    Dataset train2 = (Dataset)((SeqLike)option.get()).apply(0);
                    Dataset test2 = (Dataset)((SeqLike)option.get()).apply(1);
                    Tuple2 tuple22 = tuple2 = new Tuple2((Object)train2, (Object)test2);
                    Dataset train3 = (Dataset)tuple22._1();
                    Dataset test3 = (Dataset)tuple22._2();
                    LightGBMClassifier classifier = new LightGBMClassifier(LightGBMClassifier$.MODULE$.$lessinit$greater$default$1());
                    classifier.setFeaturesCol("features");
                    classifier.setLabelCol("classIndex");
                    classifier.setNumIterations(params.numIterations());
                    classifier.setNumLeaves(params.numLeaves());
                    classifier.setMaxDepth(params.maxDepth());
                    classifier.setLambdaL1(params.lamda1());
                    classifier.setLambdaL2(params.lamda2());
                    classifier.setBaggingFreq(params.bagFreq());
                    classifier.setMaxBin(params.maxBin());
                    classifier.setNumIterations(params.numIterations());
                    LightGBMClassifierModel model = classifier.fit((Dataset<Row>)train3);
                    Dataset<Row> predictions = model.transform((Dataset<Row>)test3);
                    predictions.show(10);
                    MulticlassClassificationEvaluator evaluatorMulti = new MulticlassClassificationEvaluator().setLabelCol("classIndex").setMetricName("accuracy");
                    double acc = evaluatorMulti.evaluate(predictions);
                    Predef$.MODULE$.println((Object)new Tuple2((Object)"acc:", (Object)BoxesRunTime.boxToDouble((double)acc)));
                    model.saveNativeModel(params.modelSavePath());
                    sc.stop();
                    return;
                }
                throw new MatchError((Object)datasetArray);
            }
        });
    }

    private LgbmClassifierTrain$() {
        MODULE$ = this;
    }
}

