/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.utils.serializer;

import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath;
import com.intel.analytics.bigdl.dllib.utils.File$;
import com.intel.analytics.bigdl.dllib.utils.FileWriter;
import com.intel.analytics.bigdl.dllib.utils.FileWriter$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.serializer.BigDLStorage$;
import com.intel.analytics.bigdl.dllib.utils.serializer.ModuleData;
import com.intel.analytics.bigdl.dllib.utils.serializer.ModuleSerializer$;
import com.intel.analytics.bigdl.dllib.utils.serializer.ProtoStorageType$;
import com.intel.analytics.bigdl.dllib.utils.serializer.SerConst$;
import com.intel.analytics.bigdl.dllib.utils.serializer.SerializeContext;
import com.intel.analytics.bigdl.dllib.utils.serializer.SerializeContext$;
import com.intel.analytics.bigdl.dllib.utils.serializer.SerializeResult;
import com.intel.analytics.bigdl.dllib.utils.serializer.StorageType;
import com.intel.analytics.bigdl.dllib.utils.serializer.converters.DataReaderWriter;
import com.intel.analytics.bigdl.dllib.utils.serializer.converters.DataReaderWriter$;
import com.intel.analytics.bigdl.serialization.Bigdl;
import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.FilterOutputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.security.DigestOutputStream;
import java.security.MessageDigest;
import java.util.List;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.Tuple2;
import scala.collection.IterableLike;
import scala.collection.JavaConverters$;
import scala.collection.Seq;
import scala.collection.mutable.ArrayBuffer;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashSet;
import scala.reflect.ClassTag;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.ObjectRef;
import scala.runtime.ScalaRunTime$;

public final class ModulePersister$ {
    public static final ModulePersister$ MODULE$;

    static {
        new ModulePersister$();
    }

    public <T> void saveToFile(String modelPath, String weightPath, AbstractModule<Activity, Activity, T> module, boolean overwrite, ClassTag<T> evidence$9, TensorNumericMath.TensorNumeric<T> ev) {
        if (weightPath == null) {
            SerializeResult serializeResult = this.serializeModule(module, ProtoStorageType$.MODULE$, evidence$9, ev);
            this.setTensorStorage(serializeResult.bigDLModule(), serializeResult.storages());
            File$.MODULE$.saveBytes(serializeResult.bigDLModule().build().toByteArray(), modelPath, overwrite);
        } else {
            SerializeResult serializeResult = this.serializeModule(module, BigDLStorage$.MODULE$, evidence$9, ev);
            HashMap tensorStorages = (HashMap)serializeResult.storages().filter((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final boolean apply(Tuple2<Object, Object> x$1) {
                    return ScalaRunTime$.MODULE$.isArray(x$1._2(), 1);
                }
            });
            File$.MODULE$.saveBytes(serializeResult.bigDLModule().build().toByteArray(), modelPath, overwrite);
            this.saveWeightsToFile(weightPath, (HashMap<Object, Object>)tensorStorages, overwrite);
        }
    }

    public <T> String saveToFile$default$2() {
        return null;
    }

    public <T> boolean saveToFile$default$4() {
        return false;
    }

    private <T> SerializeResult serializeModule(AbstractModule<Activity, Activity, T> module, StorageType storageType, ClassTag<T> evidence$10, TensorNumericMath.TensorNumeric<T> ev) {
        ModuleData<T> bigDLModule = new ModuleData<T>(module, (Seq<String>)new ArrayBuffer(), (Seq<String>)new ArrayBuffer(), evidence$10);
        HashMap storages = new HashMap();
        SerializeContext<T> context = new SerializeContext<T>(bigDLModule, (HashMap<Object, Object>)storages, storageType, SerializeContext$.MODULE$.apply$default$4(), SerializeContext$.MODULE$.apply$default$5(), evidence$10);
        return ModuleSerializer$.MODULE$.serialize(context, evidence$10, ev);
    }

    private void saveWeightsToFile(String weightPath, HashMap<Object, Object> storages, boolean overwrite) {
        int magicNo = SerConst$.MODULE$.MAGIC_NO();
        int total = storages.size();
        FileWriter fw = null;
        OutputStream out = null;
        ObjectOutputStream objFile = null;
        FilterOutputStream digestOutputStream = null;
        ObjectRef dataOutputStream = ObjectRef.create(null);
        try {
            fw = FileWriter$.MODULE$.apply(weightPath);
            out = fw.create(overwrite);
            MessageDigest digest = MessageDigest.getInstance(SerConst$.MODULE$.DIGEST_TYPE());
            digestOutputStream = new DigestOutputStream(out, digest);
            dataOutputStream.elem = new DataOutputStream(digestOutputStream);
            ((DigestOutputStream)digestOutputStream).on(true);
            ((DataOutputStream)dataOutputStream.elem).writeInt(magicNo);
            ((DataOutputStream)dataOutputStream.elem).writeInt(total);
            storages.foreach((Function1)new Serializable(dataOutputStream){
                public static final long serialVersionUID = 0L;
                private final ObjectRef dataOutputStream$1;

                public final void apply(Tuple2<Object, Object> storage) {
                    int storageId = storage._1$mcI$sp();
                    Object dataArray = storage._2();
                    DataReaderWriter writer = DataReaderWriter$.MODULE$.apply(dataArray);
                    ((DataOutputStream)this.dataOutputStream$1.elem).writeInt(storageId);
                    ((DataOutputStream)this.dataOutputStream$1.elem).writeInt(writer.dataType().id());
                    ((DataOutputStream)this.dataOutputStream$1.elem).writeInt(Predef$.MODULE$.genericArrayOps(dataArray).size());
                    writer.write((DataOutputStream)this.dataOutputStream$1.elem, dataArray);
                }
                {
                    this.dataOutputStream$1 = dataOutputStream$1;
                }
            });
            ((DigestOutputStream)digestOutputStream).on(false);
            byte[] digestContent = ((DigestOutputStream)digestOutputStream).getMessageDigest().digest();
            ((DataOutputStream)dataOutputStream.elem).writeInt(digestContent.length);
            ((DataOutputStream)dataOutputStream.elem).write(digestContent);
            return;
        }
        finally {
            if (objFile != null) {
                objFile.close();
            }
            if (out != null) {
                out.close();
            }
            if (fw != null) {
                fw.close();
            }
            if (digestOutputStream != null) {
                digestOutputStream.flush();
                digestOutputStream.close();
            }
            if ((DataOutputStream)dataOutputStream.elem != null) {
                ((DataOutputStream)dataOutputStream.elem).close();
            }
        }
    }

    private boolean saveWeightsToFile$default$3() {
        return false;
    }

    public void setTensorStorage(Bigdl.BigDLModule.Builder bigDLModule, HashMap<Object, Object> storages) {
        HashSet storageIds = new HashSet();
        HashMap tensorStorages = (HashMap)storages.filter((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final boolean apply(Tuple2<Object, Object> x$2) {
                return x$2._2() instanceof Bigdl.TensorStorage;
            }
        });
        ObjectRef nameAttributes = ObjectRef.create((Object)Bigdl.NameAttrList.newBuilder().setName(SerConst$.MODULE$.GLOBAL_STORAGE()));
        ((IterableLike)storages.values().filter((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final boolean apply(Object x$3) {
                return x$3 instanceof Bigdl.BigDLTensor;
            }
        })).foreach((Function1)new Serializable(storageIds, tensorStorages, nameAttributes){
            public static final long serialVersionUID = 0L;
            private final HashSet storageIds$1;
            private final HashMap tensorStorages$1;
            private final ObjectRef nameAttributes$1;

            public final Object apply(Object storage) {
                Object object;
                Bigdl.BigDLTensor bigdlTensor = (Bigdl.BigDLTensor)storage;
                int storageId = bigdlTensor.getStorage().getId();
                if (!this.storageIds$1.contains((Object)BoxesRunTime.boxToInteger((int)storageId)) && storageId != -1) {
                    Bigdl.BigDLTensor.Builder tensorBuilder = Bigdl.BigDLTensor.newBuilder(bigdlTensor);
                    tensorBuilder.clearStorage();
                    Log4Error$.MODULE$.invalidInputError(this.tensorStorages$1.contains((Object)BoxesRunTime.boxToInteger((int)storageId)), new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"", " does not exist"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{BoxesRunTime.boxToInteger((int)storageId)})), Log4Error$.MODULE$.invalidInputError$default$3());
                    tensorBuilder.setStorage((Bigdl.TensorStorage)this.tensorStorages$1.get((Object)BoxesRunTime.boxToInteger((int)storageId)).get());
                    Bigdl.AttrValue.Builder attrValueBuilder = Bigdl.AttrValue.newBuilder();
                    attrValueBuilder.setTensorValue(tensorBuilder.build());
                    ((Bigdl.NameAttrList.Builder)this.nameAttributes$1.elem).putAttr(((Object)BoxesRunTime.boxToInteger((int)tensorBuilder.getId())).toString(), attrValueBuilder.build());
                    object = BoxesRunTime.boxToBoolean((boolean)this.storageIds$1.add((Object)BoxesRunTime.boxToInteger((int)storageId)));
                } else {
                    object = BoxedUnit.UNIT;
                }
                return object;
            }
            {
                this.storageIds$1 = storageIds$1;
                this.tensorStorages$1 = tensorStorages$1;
                this.nameAttributes$1 = nameAttributes$1;
            }
        });
        Bigdl.AttrValue.Builder attrValueBuilder = Bigdl.AttrValue.newBuilder();
        attrValueBuilder.setNameAttrListValue((Bigdl.NameAttrList.Builder)nameAttributes.elem);
        bigDLModule.putAttr(SerConst$.MODULE$.GLOBAL_STORAGE(), attrValueBuilder.build());
    }

    public <T> void saveModelDefinitionToFile(String definitionPath, AbstractModule<Activity, Activity, T> module, boolean overwrite, ClassTag<T> evidence$11, TensorNumericMath.TensorNumeric<T> ev) {
        ModuleData<T> bigDLModule = new ModuleData<T>(module, (Seq<String>)new ArrayBuffer(), (Seq<String>)new ArrayBuffer(), evidence$11);
        HashMap storages = new HashMap();
        SerializeContext<T> context = new SerializeContext<T>(bigDLModule, (HashMap<Object, Object>)storages, ProtoStorageType$.MODULE$, SerializeContext$.MODULE$.apply$default$4(), SerializeContext$.MODULE$.apply$default$5(), evidence$11);
        Bigdl.BigDLModule.Builder bigDLModel = ModuleSerializer$.MODULE$.serialize(context, evidence$11, ev).bigDLModule();
        this.com$intel$analytics$bigdl$dllib$utils$serializer$ModulePersister$$cleantWeightAndBias(bigDLModel);
        Bigdl.BigDLModule model = bigDLModel.build();
        ByteArrayOutputStream byteArrayOut = new ByteArrayOutputStream();
        byteArrayOut.write(model.toString().getBytes());
        File$.MODULE$.saveBytes(byteArrayOut.toByteArray(), definitionPath, overwrite);
    }

    public <T> boolean saveModelDefinitionToFile$default$3() {
        return false;
    }

    public void com$intel$analytics$bigdl$dllib$utils$serializer$ModulePersister$$cleantWeightAndBias(Bigdl.BigDLModule.Builder modelBuilder) {
        modelBuilder.clearWeight();
        modelBuilder.clearBias();
        if (modelBuilder.getSubModulesCount() > 0) {
            List<Bigdl.BigDLModule> subModules = modelBuilder.getSubModulesList();
            modelBuilder.clearSubModules();
            ((IterableLike)JavaConverters$.MODULE$.asScalaBufferConverter(subModules).asScala()).foreach((Function1)new Serializable(modelBuilder){
                public static final long serialVersionUID = 0L;
                private final Bigdl.BigDLModule.Builder modelBuilder$1;

                public final Bigdl.BigDLModule.Builder apply(Bigdl.BigDLModule sub2) {
                    Bigdl.BigDLModule.Builder subModelBuilder = Bigdl.BigDLModule.newBuilder(sub2);
                    ModulePersister$.MODULE$.com$intel$analytics$bigdl$dllib$utils$serializer$ModulePersister$$cleantWeightAndBias(subModelBuilder);
                    return this.modelBuilder$1.addSubModules(subModelBuilder.build());
                }
                {
                    this.modelBuilder$1 = modelBuilder$1;
                }
            });
        }
    }

    private ModulePersister$() {
        MODULE$ = this;
    }
}

