/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.utils.serializer;

import com.intel.analytics.bigdl.dllib.nn.Container;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.QuantizedTensor;
import com.intel.analytics.bigdl.dllib.tensor.Storage;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.File$;
import com.intel.analytics.bigdl.dllib.utils.FileReader;
import com.intel.analytics.bigdl.dllib.utils.FileReader$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Table;
import com.intel.analytics.bigdl.dllib.utils.serializer.BigDLDataType$;
import com.intel.analytics.bigdl.dllib.utils.serializer.BigDLStorage$;
import com.intel.analytics.bigdl.dllib.utils.serializer.DeserializeContext;
import com.intel.analytics.bigdl.dllib.utils.serializer.DeserializeContext$;
import com.intel.analytics.bigdl.dllib.utils.serializer.ModuleSerializer$;
import com.intel.analytics.bigdl.dllib.utils.serializer.ProtoStorageType$;
import com.intel.analytics.bigdl.dllib.utils.serializer.SerConst$;
import com.intel.analytics.bigdl.dllib.utils.serializer.converters.DataReaderWriter;
import com.intel.analytics.bigdl.dllib.utils.serializer.converters.DataReaderWriter$;
import com.intel.analytics.bigdl.dllib.utils.serializer.converters.TensorConverter$;
import com.intel.analytics.bigdl.serialization.Bigdl;
import com.intel.analytics.shaded.protobuf_v_3_5_1.CodedInputStream;
import java.io.ByteArrayInputStream;
import java.io.DataInputStream;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.security.DigestInputStream;
import java.security.MessageDigest;
import java.util.Map;
import scala.Enumeration;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashSet;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

public final class ModuleLoader$ {
    public static final ModuleLoader$ MODULE$;

    static {
        new ModuleLoader$();
    }

    public <T> AbstractModule<Activity, Activity, T> loadFromFile(String modelPath, String weightPath, ClassTag<T> evidence$1, TensorNumericMath.TensorNumeric<T> ev) {
        Bigdl.BigDLModule.Builder modelBuilder = Bigdl.BigDLModule.newBuilder();
        byte[] inputBytes = File$.MODULE$.readBytes(modelPath);
        CodedInputStream cis = CodedInputStream.newInstance(new ByteArrayInputStream(inputBytes));
        cis.setSizeLimit(Integer.MAX_VALUE);
        modelBuilder.mergeFrom(cis);
        Bigdl.BigDLModule bigDLModel = modelBuilder.build();
        HashMap storages = new HashMap();
        DeserializeContext deserializationContext = null;
        if (weightPath == null) {
            deserializationContext = new DeserializeContext(bigDLModel, (HashMap<Object, Object>)storages, ProtoStorageType$.MODULE$, DeserializeContext$.MODULE$.apply$default$4());
            this.initTensorStorage(deserializationContext, evidence$1, ev);
        } else {
            deserializationContext = new DeserializeContext(bigDLModel, (HashMap<Object, Object>)storages, BigDLStorage$.MODULE$, DeserializeContext$.MODULE$.apply$default$4());
            this.initTensorStorage(deserializationContext, weightPath, evidence$1, ev);
        }
        return ModuleSerializer$.MODULE$.load(deserializationContext, evidence$1, ev).module();
    }

    public <T> String loadFromFile$default$2() {
        return null;
    }

    private <T> void initTensorStorage(DeserializeContext context, String weightPath, ClassTag<T> evidence$2, TensorNumericMath.TensorNumeric<T> ev) {
        int magicNo = SerConst$.MODULE$.MAGIC_NO();
        FileReader fr = null;
        InputStream in = null;
        ObjectInputStream objFile = null;
        HashMap<Object, Object> storages = context.storages();
        try {
            fr = FileReader$.MODULE$.apply(weightPath);
            in = fr.open();
            MessageDigest digest = MessageDigest.getInstance(SerConst$.MODULE$.DIGEST_TYPE());
            DigestInputStream digestInputStream = new DigestInputStream(in, digest);
            DataInputStream dataInputStream = new DataInputStream(digestInputStream);
            digestInputStream.on(true);
            int magicNumber = dataInputStream.readInt();
            Log4Error$.MODULE$.invalidInputError(magicNumber == magicNo, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Magic number mismatch, expected ", ", actual ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)magicNo), BoxesRunTime.boxToInteger((int)magicNumber)})), Log4Error$.MODULE$.invalidInputError$default$3());
            int totalCount = dataInputStream.readInt();
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), totalCount).foreach$mVc$sp((Function1)new Serializable(storages, dataInputStream){
                public static final long serialVersionUID = 0L;
                private final HashMap storages$1;
                private final DataInputStream dataInputStream$1;

                public final void apply(int i) {
                    this.apply$mcVI$sp(i);
                }

                public void apply$mcVI$sp(int i) {
                    int storageId = this.dataInputStream$1.readInt();
                    Enumeration.Value dataType = BigDLDataType$.MODULE$.apply(this.dataInputStream$1.readInt());
                    DataReaderWriter reader = DataReaderWriter$.MODULE$.apply(dataType);
                    int size = this.dataInputStream$1.readInt();
                    Object data2 = reader.read(this.dataInputStream$1, size);
                    this.storages$1.update((Object)BoxesRunTime.boxToInteger((int)storageId), data2);
                }
                {
                    this.storages$1 = storages$1;
                    this.dataInputStream$1 = dataInputStream$1;
                }
            });
            digestInputStream.on(false);
            int digestLen = dataInputStream.readInt();
            byte[] storedDigest = new byte[digestLen];
            dataInputStream.read(storedDigest);
            byte[] calculatedDigest = digestInputStream.getMessageDigest().digest();
            Log4Error$.MODULE$.invalidInputError(calculatedDigest.length == digestLen, "checksum error, size mismatch", Log4Error$.MODULE$.invalidInputError$default$3());
            RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), digestLen).foreach$mVc$sp((Function1)new Serializable(storedDigest, calculatedDigest){
                public static final long serialVersionUID = 0L;
                private final byte[] storedDigest$1;
                private final byte[] calculatedDigest$1;

                public final void apply(int i) {
                    this.apply$mcVI$sp(i);
                }

                public void apply$mcVI$sp(int i) {
                    Log4Error$.MODULE$.invalidInputError(this.calculatedDigest$1[i] == this.storedDigest$1[i], "check sum error, please check weight file", Log4Error$.MODULE$.invalidInputError$default$3());
                }
                {
                    this.storedDigest$1 = storedDigest$1;
                    this.calculatedDigest$1 = calculatedDigest$1;
                }
            });
            return;
        }
        finally {
            if (in != null) {
                in.close();
            }
            if (fr != null) {
                fr.close();
            }
            if (objFile != null) {
                objFile.close();
            }
        }
    }

    public <T> void initTensorStorage(DeserializeContext context, ClassTag<T> evidence$3, TensorNumericMath.TensorNumeric<T> ev) {
        Map<String, Bigdl.AttrValue> attrMap = context.bigdlModule().getAttrMap();
        Map<String, Bigdl.AttrValue> storagesMap = attrMap.get(SerConst$.MODULE$.GLOBAL_STORAGE()).getNameAttrListValue().getAttrMap();
        ((IterableLike)JavaConverters$.MODULE$.mapAsScalaMapConverter(storagesMap).asScala()).foreach((Function1)new Serializable(context, evidence$3, ev){
            public static final long serialVersionUID = 0L;
            private final DeserializeContext context$1;
            private final ClassTag evidence$3$1;
            private final TensorNumericMath.TensorNumeric ev$1;

            public final void apply(Tuple2<String, Bigdl.AttrValue> map) {
                Object object;
                HashMap<Object, Object> storages = this.context$1.storages();
                int tensorId = new StringOps(Predef$.MODULE$.augmentString((String)map._1())).toInt();
                Bigdl.BigDLTensor tensorValue = ((Bigdl.AttrValue)map._2()).getTensorValue();
                int storageId = tensorValue.getStorage().getId();
                Tensor tensor = (Tensor)TensorConverter$.MODULE$.getAttributeValue(this.context$1, (Bigdl.AttrValue)map._2(), this.evidence$3$1, this.ev$1);
                Bigdl.TensorType tensorType = tensorValue.getTensorType();
                if (Bigdl.TensorType.DENSE.equals(tensorType)) {
                    object = tensor.storage();
                } else if (Bigdl.TensorType.QUANT.equals(tensorType)) {
                    object = ((QuantizedTensor)tensor).getStorage();
                } else {
                    Log4Error$.MODULE$.invalidOperationError(false, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"Unsupported Tensor Type ", ""})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{tensorValue.getTensorType()})), "only suupport DENSE and QUANT", Log4Error$.MODULE$.invalidOperationError$default$4());
                    object = BoxedUnit.UNIT;
                }
                Storage<T> tensorStorage = object;
                storages.update((Object)BoxesRunTime.boxToInteger((int)tensorId), (Object)tensor);
                storages.update((Object)BoxesRunTime.boxToInteger((int)storageId), tensorStorage);
            }
            {
                this.context$1 = context$1;
                this.evidence$3$1 = evidence$3$1;
                this.ev$1 = ev$1;
            }
        });
    }

    public <T> void loadFromDefinition(AbstractModule<Activity, Activity, T> definition, String modelPath, HashSet<String> layers, ClassTag<T> evidence$4, TensorNumericMath.TensorNumeric<T> ev) {
        HashSet hashSet;
        AbstractModule<Activity, Activity, T> loadedModule = this.loadFromFile(modelPath, this.loadFromFile$default$2(), evidence$4, ev);
        if (layers == null) {
            HashSet allLayers = new HashSet();
            this.com$intel$analytics$bigdl$dllib$utils$serializer$ModuleLoader$$getAllLayers(definition, (HashSet<String>)allLayers, evidence$4);
            hashSet = allLayers;
        } else {
            hashSet = layers;
        }
        HashSet layersToCopy = hashSet;
        this.copyParams(definition, loadedModule, layersToCopy, evidence$4);
    }

    public <T> HashSet<String> loadFromDefinition$default$3() {
        return null;
    }

    public <T> void com$intel$analytics$bigdl$dllib$utils$serializer$ModuleLoader$$getAllLayers(AbstractModule<Activity, Activity, T> module, HashSet<String> layers, ClassTag<T> evidence$5) {
        layers.add((Object)module.getName());
        if (module instanceof Container) {
            ((Container)module).modules().foreach((Function1)new Serializable(layers){
                public static final long serialVersionUID = 0L;
                private final HashSet layers$1;

                public final void apply(AbstractModule<Activity, Activity, Object> subModule) {
                    ModuleLoader$.MODULE$.com$intel$analytics$bigdl$dllib$utils$serializer$ModuleLoader$$getAllLayers(subModule, (HashSet<String>)this.layers$1, ClassTag$.MODULE$.apply(Object.class));
                }
                {
                    this.layers$1 = layers$1;
                }
            });
        }
    }

    private <T> void copyParams(AbstractModule<Activity, Activity, T> definition, AbstractModule<Activity, Activity, T> mirror, HashSet<String> layers, ClassTag<T> evidence$6) {
        Table parameterTable = definition.getParametersTable();
        Table copiedParameterTable = mirror.getParametersTable();
        layers.foreach((Function1)new Serializable(evidence$6, parameterTable, copiedParameterTable){
            public static final long serialVersionUID = 0L;
            private final ClassTag evidence$6$1;
            private final Table parameterTable$1;
            private final Table copiedParameterTable$1;

            public final void apply(String name) {
                if (this.parameterTable$1.contains(name)) {
                    Log4Error$.MODULE$.invalidInputError(this.copiedParameterTable$1.contains(name), new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " does not exist in loaded module"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{name})), Log4Error$.MODULE$.invalidInputError$default$3());
                    ModuleLoader$.MODULE$.com$intel$analytics$bigdl$dllib$utils$serializer$ModuleLoader$$copyParams((Table)this.parameterTable$1.get(name).get(), (Table)this.copiedParameterTable$1.get(name).get(), this.evidence$6$1);
                }
            }
            {
                this.evidence$6$1 = evidence$6$1;
                this.parameterTable$1 = parameterTable$1;
                this.copiedParameterTable$1 = copiedParameterTable$1;
            }
        });
    }

    public <T> void com$intel$analytics$bigdl$dllib$utils$serializer$ModuleLoader$$copyParams(Table params, Table copyParams2, ClassTag<T> evidence$7) {
        this.copyParam(params, copyParams2, "weight", evidence$7);
        this.copyParam(params, copyParams2, "bias", evidence$7);
    }

    private <T> void copyParam(Table params, Table copyParams2, String paraName, ClassTag<T> evidence$8) {
        if (params.contains(paraName)) {
            if (copyParams2.get(paraName).get() instanceof Tensor[]) {
                Log4Error$.MODULE$.invalidInputError(params.get(paraName).get() instanceof Tensor[], "param type mismatch!", Log4Error$.MODULE$.invalidInputError$default$3());
                Tensor[] copies = (Tensor[])params.get(paraName).get();
                Tensor[] origins = (Tensor[])params.get(paraName).get();
                for (int i = 0; i < copies.length; ++i) {
                    origins[i].copy(copies[i]);
                }
            } else {
                ((Tensor)params.get(paraName).get()).copy((Tensor)copyParams2.get(paraName).get());
            }
        }
    }

    private ModuleLoader$() {
        MODULE$ = this;
    }
}

