/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.nn.quantized;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.quantized.Quantizer$;
import com.intel.analytics.bigdl.dllib.nn.quantized.Utils$;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import scala.Function1;
import scala.MatchError;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.immutable.Nil$;
import scala.math.Numeric;
import scala.math.Ordering;
import scala.reflect.ClassTag;
import scala.runtime.BoxesRunTime;
import scala.runtime.DoubleRef;
import scala.runtime.RichInt$;

public final class Quantization$ {
    public static final Quantization$ MODULE$;

    static {
        new Quantization$();
    }

    public float findMax(float[] src, int start2, int end) {
        return BoxesRunTime.unboxToFloat((Object)Predef$.MODULE$.floatArrayOps((float[])Predef$.MODULE$.floatArrayOps(src).slice(start2, end)).max((Ordering)Ordering.Float$.MODULE$));
    }

    public float findMin(float[] src, int start2, int end) {
        return BoxesRunTime.unboxToFloat((Object)Predef$.MODULE$.floatArrayOps((float[])Predef$.MODULE$.floatArrayOps(src).slice(start2, end)).min((Ordering)Ordering.Float$.MODULE$));
    }

    public byte quantize(float value2, float max2, float min2) {
        return (byte)Math.round(1.0 * (double)value2 / (double)Math.max(Math.abs(max2), Math.abs(min2)) * (double)127);
    }

    public float dequantize(byte by, float max2, float min2) {
        return (float)by / (float)127 * Math.max(Math.abs(max2), Math.abs(min2));
    }

    public Tuple2<Object, Object> quantize(float[] src, int start2, int end, byte[] dst, int dstOffset) {
        float max2 = this.findMax(src, start2, end);
        float min2 = this.findMin(src, start2, end);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), end - start2).foreach$mVc$sp((Function1)new Serializable(src, start2, dst, dstOffset, max2, min2){
            public static final long serialVersionUID = 0L;
            private final float[] src$1;
            private final int start$1;
            private final byte[] dst$1;
            private final int dstOffset$1;
            private final float max$1;
            private final float min$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                this.dst$1[this.dstOffset$1 + i] = Quantization$.MODULE$.quantize(this.src$1[this.start$1 + i], this.max$1, this.min$1);
            }
            {
                this.src$1 = src$1;
                this.start$1 = start$1;
                this.dst$1 = dst$1;
                this.dstOffset$1 = dstOffset$1;
                this.max$1 = max$1;
                this.min$1 = min$1;
            }
        });
        return new Tuple2((Object)BoxesRunTime.boxToFloat((float)max2), (Object)BoxesRunTime.boxToFloat((float)min2));
    }

    public void dequantize(float[] src, int start2, int end, byte[] dst, int dstOffset, float max2, float min2) {
        Log4Error$.MODULE$.invalidInputError(src.length >= end, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"you write too much elements"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), end - start2).foreach$mVc$sp((Function1)new Serializable(src, start2, dst, dstOffset, max2, min2){
            public static final long serialVersionUID = 0L;
            private final float[] src$3;
            private final int start$3;
            private final byte[] dst$3;
            private final int dstOffset$3;
            private final float max$2;
            private final float min$2;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                this.src$3[this.start$3 + i] = Quantization$.MODULE$.dequantize(this.dst$3[this.dstOffset$3 + i], this.max$2, this.min$2);
            }
            {
                this.src$3 = src$3;
                this.start$3 = start$3;
                this.dst$3 = dst$3;
                this.dstOffset$3 = dstOffset$3;
                this.max$2 = max$2;
                this.min$2 = min$2;
            }
        });
    }

    public Tuple2<float[], float[]> quantize(float[] src, int start2, int end, byte[] dst, int dstOffset, int[] size) {
        Log4Error$.MODULE$.invalidInputError(size.length == 2, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"only support 2-dim matrix"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(BoxesRunTime.unboxToInt((Object)Predef$.MODULE$.intArrayOps(size).product((Numeric)Numeric.IntIsIntegral$.MODULE$)) == end - start2, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"number of elements does not match"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
        int height = size[0];
        int width = size[1];
        float[] max2 = new float[height];
        float[] min2 = new float[height];
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), height).foreach$mVc$sp((Function1)new Serializable(src, start2, dst, dstOffset, width, max2, min2){
            public static final long serialVersionUID = 0L;
            private final float[] src$2;
            private final int start$2;
            private final byte[] dst$2;
            private final int dstOffset$2;
            private final int width$1;
            private final float[] max$4;
            private final float[] min$4;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                Tuple2<Object, Object> maxAndMin = Quantization$.MODULE$.quantize(this.src$2, this.start$2 + i * this.width$1, this.start$2 + (i + 1) * this.width$1, this.dst$2, this.dstOffset$2 + i * this.width$1);
                this.max$4[i] = BoxesRunTime.unboxToFloat((Object)maxAndMin._1());
                this.min$4[i] = BoxesRunTime.unboxToFloat((Object)maxAndMin._2());
            }
            {
                this.src$2 = src$2;
                this.start$2 = start$2;
                this.dst$2 = dst$2;
                this.dstOffset$2 = dstOffset$2;
                this.width$1 = width$1;
                this.max$4 = max$4;
                this.min$4 = min$4;
            }
        });
        return new Tuple2((Object)max2, (Object)min2);
    }

    public void dequantize(float[] data2, int start2, int end, byte[] quantizedData, int offset, float[] max2, float[] min2, int[] size) {
        Log4Error$.MODULE$.invalidInputError(max2.length == min2.length, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"the number of max doesn't match with the number of min"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(size.length == 2, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"only support 2-dim matrix"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(max2.length == size[0], new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"the number of max(", ") doesn't match the size(", ")"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)max2.length), BoxesRunTime.boxToInteger((int)size[1])})), Log4Error$.MODULE$.invalidInputError$default$3());
        Log4Error$.MODULE$.invalidInputError(BoxesRunTime.unboxToInt((Object)Predef$.MODULE$.intArrayOps(size).product((Numeric)Numeric.IntIsIntegral$.MODULE$)) == end - start2, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"number of elements does not match"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
        int height = size[0];
        int width = size[1];
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), height).foreach$mVc$sp((Function1)new Serializable(data2, start2, quantizedData, offset, max2, min2, width){
            public static final long serialVersionUID = 0L;
            private final float[] data$1;
            private final int start$4;
            private final byte[] quantizedData$1;
            private final int offset$1;
            private final float[] max$3;
            private final float[] min$3;
            private final int width$2;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                Quantization$.MODULE$.dequantize(this.data$1, this.start$4 + i * this.width$2, this.start$4 + (i + 1) * this.width$2, this.quantizedData$1, this.offset$1 + i * this.width$2, this.max$3[i], this.min$3[i]);
            }
            {
                this.data$1 = data$1;
                this.start$4 = start$4;
                this.quantizedData$1 = quantizedData$1;
                this.offset$1 = offset$1;
                this.max$3 = max$3;
                this.min$3 = min$3;
                this.width$2 = width$2;
            }
        });
    }

    public int[] get2Dim(int[] shape) {
        Log4Error$.MODULE$.invalidInputError(shape.length > 1, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"error size dimension, which must be great than 1"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
        int first = shape[0];
        int last = BoxesRunTime.unboxToInt((Object)Predef$.MODULE$.intArrayOps((int[])Predef$.MODULE$.intArrayOps(shape).slice(1, shape.length)).product((Numeric)Numeric.IntIsIntegral$.MODULE$));
        return new int[]{first, last};
    }

    public Tuple2<float[], float[]> quantize(Tensor<Object> input, byte[] buffer, int offset) {
        Tuple2<Object, Object> tuple2;
        block5: {
            Tuple2 tuple22;
            int length = input.nElement();
            int n = input.dim();
            switch (n) {
                default: {
                    if (n > 1) {
                        int[] size = this.get2Dim(input.size());
                        int start2 = input.storageOffset() - 1;
                        int end = start2 + length;
                        Tuple2<float[], float[]> tuple23 = this.quantize((float[])input.storage().array(), start2, end, buffer, offset, size);
                        if (tuple23 != null) {
                            Tuple2 tuple24;
                            float[] max2 = (float[])tuple23._1();
                            float[] min2 = (float[])tuple23._2();
                            Tuple2 tuple25 = tuple24 = new Tuple2((Object)max2, (Object)min2);
                            float[] max3 = (float[])tuple25._1();
                            float[] min3 = (float[])tuple25._2();
                            tuple22 = new Tuple2((Object)max3, (Object)min3);
                            break;
                        }
                        throw new MatchError(tuple23);
                    }
                    Log4Error$.MODULE$.invalidOperationError(false, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"unsupported input dim ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)input.dim())})), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
                    tuple22 = null;
                    break;
                }
                case 1: {
                    Tuple2 tuple26;
                    tuple2 = this.quantize((float[])input.storage().array(), input.storageOffset() - 1, length, buffer, offset);
                    if (tuple2 == null) break block5;
                    float max4 = BoxesRunTime.unboxToFloat((Object)tuple2._1());
                    float min4 = BoxesRunTime.unboxToFloat((Object)tuple2._2());
                    Tuple2 tuple27 = tuple26 = new Tuple2((Object)BoxesRunTime.boxToFloat((float)max4), (Object)BoxesRunTime.boxToFloat((float)min4));
                    float max5 = BoxesRunTime.unboxToFloat((Object)tuple27._1());
                    float min5 = BoxesRunTime.unboxToFloat((Object)tuple27._2());
                    tuple22 = new Tuple2((Object)new float[]{max5}, (Object)new float[]{min5});
                }
            }
            return tuple22;
        }
        throw new MatchError(tuple2);
    }

    public void dequantize(Tensor<Object> input, byte[] buffer, int offset, float[] max2, float[] min2) {
        int start2 = input.storageOffset() - 1;
        int end = start2 + input.nElement();
        int n = input.dim();
        switch (n) {
            default: {
                if (n > 1) {
                    this.dequantize((float[])input.storage().array(), start2, end, buffer, offset, max2, min2, this.get2Dim(input.size()));
                    break;
                }
                Log4Error$.MODULE$.invalidOperationError(false, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"unsupported input dim ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)input.dim())})), Log4Error$.MODULE$.invalidOperationError$default$3(), Log4Error$.MODULE$.invalidOperationError$default$4());
                break;
            }
            case 1: {
                this.dequantize((float[])input.storage().array(), start2, end, buffer, offset, max2[0], min2[0]);
            }
        }
    }

    public double loss(float[] before, float[] after2, int start2, int end) {
        DoubleRef lossValue = DoubleRef.create((double)0.0);
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(start2), end).foreach$mVc$sp((Function1)new Serializable(before, after2, lossValue){
            public static final long serialVersionUID = 0L;
            private final float[] before$1;
            private final float[] after$1;
            private final DoubleRef lossValue$1;

            public final void apply(int i) {
                this.apply$mcVI$sp(i);
            }

            public void apply$mcVI$sp(int i) {
                this.lossValue$1.elem += (double)Math.abs(this.before$1[i] - this.after$1[i]);
            }
            {
                this.before$1 = before$1;
                this.after$1 = after$1;
                this.lossValue$1 = lossValue$1;
            }
        });
        return lossValue.elem;
    }

    public double loss(Tensor<Object> before, Tensor<Object> after2) {
        float[] beforeArray = (float[])before.storage().array();
        float[] afterArray = (float[])after2.storage().array();
        int start2 = 0;
        int end = before.nElement();
        return this.loss(beforeArray, afterArray, start2, end) / (double)BoxesRunTime.unboxToFloat((Object)Predef$.MODULE$.floatArrayOps(beforeArray).sum((Numeric)Numeric.FloatIsFractional$.MODULE$));
    }

    public <T> AbstractModule<Activity, Activity, T> quantize(AbstractModule<Activity, Activity, T> model, ClassTag<T> evidence$1, TensorNumericMath.TensorNumeric<T> ev) {
        AbstractModule<Activity, Activity, T> clonedModel = model.cloneModule();
        Predef$.MODULE$.println((Object)"Converting model now");
        AbstractModule<Activity, Activity, T> quantizedModel = Quantizer$.MODULE$.quantize(clonedModel, evidence$1, ev);
        Predef$.MODULE$.println((Object)"Converting model successfully");
        Tensor[] paras = (Tensor[])quantizedModel.parameters()._1();
        Utils$.MODULE$.reorganizeParameters(paras, evidence$1, ev);
        return quantizedModel;
    }

    private Quantization$() {
        MODULE$ = this;
    }
}

