/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.optim.aggregator;

import java.io.Serializable;
import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.InstanceBlock;
import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.DenseMatrix;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.DenseVector$;
import org.apache.spark.ml.linalg.Matrix;
import org.apache.spark.ml.linalg.SparseMatrix;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.Numeric;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.java8.JFunction0;
import scala.runtime.java8.JFunction1;

@ScalaSignature(bytes="\u0006\u0001u3Qa\u0003\u0007\u0001!aA\u0001B\u000b\u0001\u0003\u0002\u0003\u0006I\u0001\f\u0005\t_\u0001\u0011\t\u0011)A\u0005a!)A\b\u0001C\u0001{!9\u0011\t\u0001b\u0001\n#\u0012\u0005B\u0002$\u0001A\u0003%1\tC\u0004H\u0001\t\u0007I\u0011\u0002\"\t\r!\u0003\u0001\u0015!\u0003D\u0011!I\u0005\u0001#b\u0001\n\u0013Q\u0005\u0002C+\u0001\u0011\u000b\u0007I\u0011\u0002,\t\u000ba\u0003A\u0011A-\u0003%\tcwnY6B\rR\u000bum\u001a:fO\u0006$xN\u001d\u0006\u0003\u001b9\t!\"Y4he\u0016<\u0017\r^8s\u0015\ty\u0001#A\u0003paRLWN\u0003\u0002\u0012%\u0005\u0011Q\u000e\u001c\u0006\u0003'Q\tQa\u001d9be.T!!\u0006\f\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u00059\u0012aA8sON\u0019\u0001!G\u0010\u0011\u0005iiR\"A\u000e\u000b\u0003q\tQa]2bY\u0006L!AH\u000e\u0003\r\u0005s\u0017PU3g!\u0011\u0001\u0013eI\u0015\u000e\u00031I!A\t\u0007\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;peB\u0011AeJ\u0007\u0002K)\u0011a\u0005E\u0001\bM\u0016\fG/\u001e:f\u0013\tASEA\u0007J]N$\u0018M\\2f\u00052|7m\u001b\t\u0003A\u0001\tABZ5u\u0013:$XM]2faR\u001c\u0001\u0001\u0005\u0002\u001b[%\u0011af\u0007\u0002\b\u0005>|G.Z1o\u00039\u00117mQ8fM\u001aL7-[3oiN\u00042!\r\u001b7\u001b\u0005\u0011$BA\u001a\u0013\u0003%\u0011'o\\1eG\u0006\u001cH/\u0003\u00026e\tI!I]8bI\u000e\f7\u000f\u001e\t\u0003oij\u0011\u0001\u000f\u0006\u0003sA\ta\u0001\\5oC2<\u0017BA\u001e9\u0005\u00191Vm\u0019;pe\u00061A(\u001b8jiz\"\"A\u0010!\u0015\u0005%z\u0004\"B\u0018\u0004\u0001\u0004\u0001\u0004\"\u0002\u0016\u0004\u0001\u0004a\u0013a\u00013j[V\t1\t\u0005\u0002\u001b\t&\u0011Qi\u0007\u0002\u0004\u0013:$\u0018\u0001\u00023j[\u0002\n1B\\;n\r\u0016\fG/\u001e:fg\u0006aa.^7GK\u0006$XO]3tA\u0005\t2m\\3gM&\u001c\u0017.\u001a8ug\u0006\u0013(/Y=\u0016\u0003-\u00032A\u0007'O\u0013\ti5DA\u0003BeJ\f\u0017\u0010\u0005\u0002\u001b\u001f&\u0011\u0001k\u0007\u0002\u0007\t>,(\r\\3)\u0005!\u0011\u0006C\u0001\u000eT\u0013\t!6DA\u0005ue\u0006t7/[3oi\u00061A.\u001b8fCJ,\u0012A\u000e\u0015\u0003\u0013I\u000b1!\u00193e)\tQ6,D\u0001\u0001\u0011\u0015a&\u00021\u0001$\u0003\u0015\u0011Gn\\2l\u0001")
public class BlockAFTAggregator
implements DifferentiableLossAggregator<InstanceBlock, BlockAFTAggregator> {
    private transient double[] coefficientsArray;
    private transient Vector linear;
    private final boolean fitIntercept;
    private final Broadcast<Vector> bcCoefficients;
    private final int dim;
    private final int numFeatures;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile transient byte bitmap$trans$0;
    private volatile boolean bitmap$0;

    @Override
    public DifferentiableLossAggregator merge(DifferentiableLossAggregator other) {
        return DifferentiableLossAggregator.merge$(this, other);
    }

    @Override
    public Vector gradient() {
        return DifferentiableLossAggregator.gradient$(this);
    }

    @Override
    public double weight() {
        return DifferentiableLossAggregator.weight$(this);
    }

    @Override
    public double loss() {
        return DifferentiableLossAggregator.loss$(this);
    }

    @Override
    public double weightSum() {
        return this.weightSum;
    }

    @Override
    public void weightSum_$eq(double x$1) {
        this.weightSum = x$1;
    }

    @Override
    public double lossSum() {
        return this.lossSum;
    }

    @Override
    public void lossSum_$eq(double x$1) {
        this.lossSum = x$1;
    }

    private double[] gradientSumArray$lzycompute() {
        BlockAFTAggregator blockAFTAggregator = this;
        synchronized (blockAFTAggregator) {
            if (!this.bitmap$0) {
                this.gradientSumArray = DifferentiableLossAggregator.gradientSumArray$(this);
                this.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? this.gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private double[] coefficientsArray$lzycompute() {
        BlockAFTAggregator blockAFTAggregator = this;
        synchronized (blockAFTAggregator) {
            if ((byte)(this.bitmap$trans$0 & 1) == 0) {
                double[] values;
                DenseVector denseVector;
                Option option;
                Vector vector = (Vector)this.bcCoefficients.value();
                if (!(vector instanceof DenseVector) || (option = DenseVector$.MODULE$.unapply(denseVector = (DenseVector)vector)).isEmpty()) {
                    throw new IllegalArgumentException(new StringBuilder(54).append("coefficients only supports dense vector").append(" but got type ").append(this.bcCoefficients.value().getClass()).append(".").toString());
                }
                double[] dArray = values = (double[])option.get();
                this.coefficientsArray = dArray;
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 1);
            }
        }
        return this.coefficientsArray;
    }

    private double[] coefficientsArray() {
        return (byte)(this.bitmap$trans$0 & 1) == 0 ? this.coefficientsArray$lzycompute() : this.coefficientsArray;
    }

    private Vector linear$lzycompute() {
        BlockAFTAggregator blockAFTAggregator = this;
        synchronized (blockAFTAggregator) {
            if ((byte)(this.bitmap$trans$0 & 2) == 0) {
                this.linear = Vectors$.MODULE$.dense((double[])new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(this.coefficientsArray())).take(this.numFeatures()));
                this.bitmap$trans$0 = (byte)(this.bitmap$trans$0 | 2);
            }
        }
        return this.linear;
    }

    private Vector linear() {
        return (byte)(this.bitmap$trans$0 & 2) == 0 ? this.linear$lzycompute() : this.linear;
    }

    @Override
    public BlockAFTAggregator add(InstanceBlock block) {
        Predef$.MODULE$.require(block.matrix().isTransposed());
        Predef$.MODULE$.require(this.numFeatures() == block.numFeatures(), (Function0 & Serializable & scala.Serializable)() -> new StringBuilder(66).append("Dimensions mismatch when adding new ").append("instance. Expecting ").append(this.numFeatures()).append(" but got ").append(block.numFeatures()).append(".").toString());
        Predef$.MODULE$.require(new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(block.labels())).forall((Function1)(JFunction1.mcZD.sp & Serializable & scala.Serializable)x$1 -> x$1 > 0.0), (Function0 & Serializable & scala.Serializable)() -> "The lifetime or label should be  greater than 0.");
        int size = block.size();
        double intercept = this.coefficientsArray()[this.dim() - 2];
        double sigma = package$.MODULE$.exp(this.coefficientsArray()[this.dim() - 1]);
        DenseVector vec = this.fitIntercept ? Vectors$.MODULE$.dense((double[])Array$.MODULE$.fill(size, (Function0)(JFunction0.mcD.sp & Serializable & scala.Serializable)() -> intercept, ClassTag$.MODULE$.Double())).toDense() : Vectors$.MODULE$.zeros(size).toDense();
        BLAS$.MODULE$.gemv(1.0, block.matrix(), this.linear(), 1.0, vec);
        double localLossSum = 0.0;
        double sigmaGradSum = 0.0;
        for (int i = 0; i < size; ++i) {
            double multiplier;
            double ti = block.getLabel(i);
            double delta = block.getWeight().apply$mcDI$sp(i);
            double margin = vec.apply(i);
            double epsilon = (package$.MODULE$.log(ti) - margin) / sigma;
            double expEpsilon = package$.MODULE$.exp(epsilon);
            localLossSum += delta * package$.MODULE$.log(sigma) - delta * epsilon + expEpsilon;
            vec.values()[i] = multiplier = (delta - expEpsilon) / sigma;
            sigmaGradSum += delta + multiplier * sigma * epsilon;
        }
        this.lossSum_$eq(this.lossSum() + localLossSum);
        this.weightSum_$eq(this.weightSum() + (double)size);
        Matrix matrix = block.matrix();
        if (matrix instanceof DenseMatrix) {
            DenseMatrix denseMatrix = (DenseMatrix)matrix;
            BLAS$.MODULE$.nativeBLAS().dgemv("N", denseMatrix.numCols(), denseMatrix.numRows(), 1.0, denseMatrix.values(), denseMatrix.numCols(), vec.values(), 1, 1.0, this.gradientSumArray(), 1);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else if (matrix instanceof SparseMatrix) {
            SparseMatrix sparseMatrix = (SparseMatrix)matrix;
            DenseVector linearGradSumVec = Vectors$.MODULE$.zeros(this.numFeatures()).toDense();
            BLAS$.MODULE$.gemv(1.0, (Matrix)sparseMatrix.transpose(), (Vector)vec, 0.0, linearGradSumVec);
            BLAS$.MODULE$.getBLAS(this.numFeatures()).daxpy(this.numFeatures(), 1.0, linearGradSumVec.values(), 1, this.gradientSumArray(), 1);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        } else {
            throw new MatchError((Object)matrix);
        }
        if (this.fitIntercept) {
            int n = this.dim() - 2;
            this.gradientSumArray()[n] = this.gradientSumArray()[n] + BoxesRunTime.unboxToDouble((Object)new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(vec.values())).sum((Numeric)Numeric.DoubleIsFractional$.MODULE$));
        }
        int n = this.dim() - 1;
        this.gradientSumArray()[n] = this.gradientSumArray()[n] + sigmaGradSum;
        return this;
    }

    public BlockAFTAggregator(boolean fitIntercept, Broadcast<Vector> bcCoefficients) {
        this.fitIntercept = fitIntercept;
        this.bcCoefficients = bcCoefficients;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector)bcCoefficients.value()).size();
        this.numFeatures = this.dim() - 2;
    }
}

