/*
 * Decompiled with CFR 0.152.
 */
package com.intel.analytics.bigdl.dllib.nn.mkldnn;

import com.intel.analytics.bigdl.dllib.nn.MklInt8Convertible;
import com.intel.analytics.bigdl.dllib.nn.Scale;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.AbstractModule;
import com.intel.analytics.bigdl.dllib.nn.abstractnn.Activity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.AvgPooling;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.BlasWrapper;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.CAddTable;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Identity;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.Identity$;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.JoinTable;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.MaxPooling;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.ReLU;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.SpatialBatchNormalization;
import com.intel.analytics.bigdl.dllib.nn.mkldnn.SpatialConvolution;
import com.intel.analytics.bigdl.dllib.tensor.Tensor;
import com.intel.analytics.bigdl.dllib.tensor.Tensor$;
import com.intel.analytics.bigdl.dllib.tensor.TensorNumericMath$TensorNumeric$NumericFloat$;
import com.intel.analytics.bigdl.dllib.utils.Log4Error$;
import com.intel.analytics.bigdl.dllib.utils.Node;
import scala.Function1;
import scala.Predef$;
import scala.Serializable;
import scala.StringContext;
import scala.collection.IterableLike;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.generic.GenericTraversableTemplate;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.Set;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.ArrayOps;
import scala.collection.mutable.WrappedArray;
import scala.math.Ordering;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

public final class Fusion$ {
    public static final Fusion$ MODULE$;

    static {
        new Fusion$();
    }

    private boolean fuse() {
        return new StringOps(Predef$.MODULE$.augmentString(System.getProperty("bigdl.mkldnn.fusion", "true"))).toBoolean();
    }

    public void fuseModule(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (this.fuse()) {
            AbstractModule<Activity, Activity, Object> abstractModule = node.element();
            if (abstractModule instanceof ReLU) {
                this.fusionRelu(node);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else if (abstractModule instanceof SpatialBatchNormalization) {
                this.fusionBN(node);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            return;
        }
    }

    public void fuseCAdd(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (this.fuse()) {
            AbstractModule<Activity, Activity, Object> abstractModule = node.element();
            if (abstractModule instanceof CAddTable) {
                this.fusionCAddTable(node);
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            } else {
                BoxedUnit boxedUnit = BoxedUnit.UNIT;
            }
            return;
        }
    }

    private void fusionBN(Node<AbstractModule<Activity, Activity, Object>> node) {
        SpatialBatchNormalization bn = (SpatialBatchNormalization)node.element();
        node.prevNodes().foreach((Function1)new Serializable(node, bn){
            public static final long serialVersionUID = 0L;
            private final Node node$3;
            private final SpatialBatchNormalization bn$1;

            public final Object apply(Node<AbstractModule<Activity, Activity, Object>> n) {
                BoxedUnit boxedUnit;
                AbstractModule<Activity, Activity, Object> abstractModule = n.element();
                if (abstractModule instanceof SpatialConvolution) {
                    BoxedUnit boxedUnit2;
                    SpatialConvolution spatialConvolution = (SpatialConvolution)abstractModule;
                    if (spatialConvolution.relu() || spatialConvolution.batchNorm()) {
                        boxedUnit2 = BoxedUnit.UNIT;
                    } else {
                        Object object = this.bn$1.relu() ? spatialConvolution.setReLU(true) : BoxedUnit.UNIT;
                        Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$fusionConvBn(spatialConvolution, this.bn$1);
                        this.node$3.element_$eq(Identity$.MODULE$.apply$mFc$sp((ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                        boxedUnit2 = BoxedUnit.UNIT;
                    }
                    boxedUnit = boxedUnit2;
                } else {
                    boxedUnit = null;
                }
                return boxedUnit;
            }
            {
                this.node$3 = node$3;
                this.bn$1 = bn$1;
            }
        });
    }

    private void fusionRelu(Node<AbstractModule<Activity, Activity, Object>> node) {
        node.prevNodes().foreach((Function1)new Serializable(node){
            public static final long serialVersionUID = 0L;
            private final Node node$2;

            public final Object apply(Node<AbstractModule<Activity, Activity, Object>> n) {
                BoxedUnit boxedUnit;
                Node<AbstractModule<Activity, Activity, Object>> notIdentity = Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(n);
                AbstractModule<Activity, Activity, Object> abstractModule = notIdentity.element();
                if (abstractModule instanceof SpatialConvolution) {
                    BoxedUnit boxedUnit2;
                    SpatialConvolution spatialConvolution = (SpatialConvolution)abstractModule;
                    if (spatialConvolution.relu()) {
                        boxedUnit2 = BoxedUnit.UNIT;
                    } else {
                        spatialConvolution.setReLU(true);
                        spatialConvolution.setOutputScales(((ReLU)this.node$2.element()).getOutputScales());
                        this.node$2.element_$eq(Identity$.MODULE$.apply$mFc$sp((ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                        boxedUnit2 = BoxedUnit.UNIT;
                    }
                    boxedUnit = boxedUnit2;
                } else if (abstractModule instanceof SpatialBatchNormalization) {
                    BoxedUnit boxedUnit3;
                    SpatialBatchNormalization spatialBatchNormalization = (SpatialBatchNormalization)abstractModule;
                    if (spatialBatchNormalization.relu()) {
                        boxedUnit3 = BoxedUnit.UNIT;
                    } else {
                        spatialBatchNormalization.setReLU(true);
                        spatialBatchNormalization.setOutputScales(((ReLU)this.node$2.element()).getOutputScales());
                        this.node$2.element_$eq(Identity$.MODULE$.apply$mFc$sp((ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                        boxedUnit3 = BoxedUnit.UNIT;
                    }
                    boxedUnit = boxedUnit3;
                } else {
                    boxedUnit = null;
                }
                return boxedUnit;
            }
            {
                this.node$2 = node$2;
            }
        });
    }

    public Node<AbstractModule<Activity, Activity, Object>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(Node<AbstractModule<Activity, Activity, Object>> node) {
        while (node.element() instanceof Identity && node.prevNodes().length() == 1) {
            node = (Node)node.prevNodes().apply(0);
        }
        return node;
    }

    public Seq<Node<AbstractModule<Activity, Activity, Object>>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findNext(Node<AbstractModule<Activity, Activity, Object>> node) {
        return node.element() instanceof Identity ? (Seq)node.nextNodes().flatMap((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Seq<Node<AbstractModule<Activity, Activity, Object>>> apply(Node<AbstractModule<Activity, Activity, Object>> n) {
                return Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findNext(n);
            }
        }, Seq$.MODULE$.canBuildFrom()) : (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Node[]{node}));
    }

    private void fusionCAddTable(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (node.element() instanceof CAddTable && node.prevNodes().length() == 2) {
            Node[] previousNodes = (Node[])node.prevNodes().toArray(ClassTag$.MODULE$.apply(Node.class));
            Node<AbstractModule<Activity, Activity, Object>> node1 = this.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(previousNodes[0]);
            Node<AbstractModule<Activity, Activity, Object>> node2 = this.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(previousNodes[1]);
            Node<AbstractModule<Activity, Activity, Object>> conv = null;
            int otherNumber = 0;
            if (node1.element() instanceof SpatialConvolution) {
                if (this.requirements(node1)) {
                    conv = node1;
                }
                otherNumber = 1;
            } else if (node2.element() instanceof SpatialConvolution) {
                if (this.requirements(node2)) {
                    conv = node2;
                }
                otherNumber = 0;
            }
            if (conv != null) {
                Node<AbstractModule<Activity, Activity, Object>> prevIsNotIdentity;
                AbstractModule<Activity, Activity, Object> abstractModule;
                node.element_$eq(conv.element());
                SpatialConvolution element = (SpatialConvolution)node.element();
                element.setSumOp((AbstractModule)previousNodes[otherNumber].element(), otherNumber + 1);
                conv.element_$eq(Identity$.MODULE$.apply$mFc$sp((ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
                Node nexts = (Node)node.nextNodes().apply(0);
                if (nexts.element() instanceof ReLU && !element.relu()) {
                    ((SpatialConvolution)node.element()).setReLU(true);
                    ((SpatialConvolution)node.element()).setOutputScales(((ReLU)nexts.element()).getOutputScales());
                    nexts.element_$eq(new Identity());
                }
                if ((abstractModule = (prevIsNotIdentity = this.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findPrevious(previousNodes[otherNumber])).element()) instanceof SpatialConvolution) {
                    SpatialConvolution spatialConvolution = (SpatialConvolution)abstractModule;
                    spatialConvolution.setOutputScales(((SpatialConvolution)node.element()).getOutputScales());
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else if (abstractModule instanceof ReLU) {
                    ReLU reLU = (ReLU)abstractModule;
                    reLU.setOutputScales(((SpatialConvolution)node.element()).getOutputScales());
                    ((IterableLike)((TraversableLike)prevIsNotIdentity.nextNodes().flatMap((Function1)new Serializable(){
                        public static final long serialVersionUID = 0L;

                        public final Seq<Node<AbstractModule<Activity, Activity, Object>>> apply(Node<AbstractModule<Activity, Activity, Object>> x) {
                            return Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findNext(x);
                        }
                    }, Seq$.MODULE$.canBuildFrom())).filter((Function1)new Serializable(node){
                        public static final long serialVersionUID = 0L;
                        private final Node node$4;

                        public final boolean apply(Node<AbstractModule<Activity, Activity, Object>> x) {
                            Node<AbstractModule<Activity, Activity, Object>> node = x;
                            Node node2 = this.node$4;
                            return (node == null ? node2 != null : !node.equals(node2)) && x.element() instanceof MklInt8Convertible;
                        }
                        {
                            this.node$4 = node$4;
                        }
                    })).foreach((Function1)new Serializable(node){
                        public static final long serialVersionUID = 0L;
                        private final Node node$4;

                        public final void apply(Node<AbstractModule<Activity, Activity, Object>> x$1) {
                            ((MklInt8Convertible)((Object)x$1.element())).setInputScales(((SpatialConvolution)this.node$4.element()).getOutputScales());
                        }
                        {
                            this.node$4 = node$4;
                        }
                    });
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                } else {
                    BoxedUnit boxedUnit = BoxedUnit.UNIT;
                }
            }
        }
    }

    private boolean requirements(Node<AbstractModule<Activity, Activity, Object>> node) {
        SpatialConvolution conv = (SpatialConvolution)node.element();
        return !conv.sum();
    }

    public void com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$fusionConvBn(SpatialConvolution conv, SpatialBatchNormalization bn) {
        conv.setBatchNorm(true);
        Tensor<Object> qual$1 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        int[] x$12 = bn.runningVariance().size();
        int[] x$13 = qual$1.resize$default$2();
        Tensor<Object> originVar = qual$1.resize(x$12, x$13).copy(bn.runningVariance().dense());
        Tensor<Object> qual$2 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        int[] x$14 = bn.runningMean().size();
        int[] x$15 = qual$2.resize$default$2();
        Tensor<Object> originMean = qual$2.resize(x$14, x$15).copy(bn.runningMean().dense());
        Tensor<Object> qual$3 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        int[] x$16 = conv.weight().size();
        int[] x$17 = qual$3.resize$default$2();
        Tensor<Object> convWeight = qual$3.resize(x$16, x$17).copy(conv.weight().dense());
        Tensor<Object> qual$4 = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$);
        int[] x$18 = conv.bias().size();
        int[] x$19 = qual$4.resize$default$2();
        Tensor<Object> convBias = qual$4.resize(x$18, x$19).copy(conv.bias().dense());
        Tensor<Object> bnWeight = Tensor$.MODULE$.apply(ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$).resizeAs(bn.weightAndBias().dense()).copy(bn.weightAndBias().dense());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), bn.nOutput()).foreach$mVc$sp((Function1)new Serializable(conv, bn, originVar, originMean, convWeight, convBias, bnWeight){
            public static final long serialVersionUID = 0L;
            private final SpatialConvolution conv$1;
            private final SpatialBatchNormalization bn$2;
            private final Tensor originVar$1;
            private final Tensor originMean$1;
            private final Tensor convWeight$1;
            private final Tensor convBias$1;
            private final Tensor bnWeight$1;

            public final void apply(int j) {
                this.apply$mcVI$sp(j);
            }

            public void apply$mcVI$sp(int j) {
                Tensor<Float> tensor;
                float variance = ((float[])this.originVar$1.storage().array())[j + this.originVar$1.storageOffset() - 1];
                float base = (float)Math.sqrt((double)variance + this.bn$2.eps());
                Log4Error$.MODULE$.invalidInputError((double)base != 0.0, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"the eps of ", " should be more than 0"})).s((Seq)Predef$.MODULE$.genericWrapArray((Object)new Object[]{this.bn$2.getName()})), Log4Error$.MODULE$.invalidInputError$default$3());
                float alpha = ((float[])this.bnWeight$1.storage().array())[this.bnWeight$1.storageOffset() - 1 + j];
                float beta = ((float[])this.bnWeight$1.storage().array())[this.bnWeight$1.storageOffset() - 1 + this.bn$2.nOutput() + j];
                if (this.conv$1.nGroup() == 1) {
                    tensor = this.convWeight$1.select(1, j + 1);
                } else {
                    int channelPerGroup = this.conv$1.nOutputPlane() / this.conv$1.nGroup();
                    int group = j / channelPerGroup + 1;
                    int channel = j % channelPerGroup + 1;
                    tensor = this.convWeight$1.select(1, group).select(2, channel);
                }
                Tensor<Float> weight = tensor;
                weight.div(BoxesRunTime.boxToFloat((float)base));
                weight.mul(BoxesRunTime.boxToFloat((float)alpha));
                float bias = ((float[])this.convBias$1.storage().array())[j];
                float mean2 = ((float[])this.originMean$1.storage().array())[j];
                ((float[])this.convBias$1.storage().array())[j] = alpha / base * bias + beta - alpha * mean2 / base;
            }
            {
                this.conv$1 = conv$1;
                this.bn$2 = bn$2;
                this.originVar$1 = originVar$1;
                this.originMean$1 = originMean$1;
                this.convWeight$1 = convWeight$1;
                this.convBias$1 = convBias$1;
                this.bnWeight$1 = bnWeight$1;
            }
        });
        conv.weight().dense().set(convWeight);
        conv.bias().dense().set(convBias);
        conv.flushWeightScales(conv.weight().dense());
        conv.setOutputScales(bn.getOutputScales());
    }

    public void setNegativeInputOfConv(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (this.fuse() && node.element() instanceof SpatialConvolution) {
            boolean successFromReLU = ((IterableLike)((TraversableLike)node.prevNodes().flatMap((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Seq<Node<AbstractModule<Activity, Activity, Object>>> apply(Node<AbstractModule<Activity, Activity, Object>> x) {
                    return Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findAllNonIdentityPrevs(x);
                }
            }, Seq$.MODULE$.canBuildFrom())).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final boolean apply(Node<AbstractModule<Activity, Activity, Object>> x) {
                    AbstractModule<Activity, Activity, Object> abstractModule = x.element();
                    boolean bl = abstractModule instanceof SpatialConvolution ? ((SpatialConvolution)x.element()).relu() : abstractModule instanceof ReLU;
                    return bl;
                }
            }, Seq$.MODULE$.canBuildFrom())).forall((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final boolean apply(boolean x$2) {
                    return x$2;
                }
            });
            if (successFromReLU) {
                ((SpatialConvolution)node.element()).negativeInput_$eq(false);
            }
            return;
        }
    }

    public void setScalesPrevousJoinTable(Node<AbstractModule<Activity, Activity, Object>> node) {
        if (this.fuse() && node.element() instanceof JoinTable) {
            Seq preConvs = (Seq)((TraversableLike)((TraversableLike)node.prevNodes().flatMap((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final Seq<Node<AbstractModule<Activity, Activity, Object>>> apply(Node<AbstractModule<Activity, Activity, Object>> x) {
                    return Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findAllNonIdentityPrevs(x);
                }
            }, Seq$.MODULE$.canBuildFrom())).filter((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final boolean apply(Node<AbstractModule<Activity, Activity, Object>> x$3) {
                    return x$3.element() instanceof SpatialConvolution;
                }
            })).map((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final SpatialConvolution apply(Node<AbstractModule<Activity, Activity, Object>> x$4) {
                    return (SpatialConvolution)x$4.element();
                }
            }, Seq$.MODULE$.canBuildFrom());
            if (preConvs.exists((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final boolean apply(SpatialConvolution x$5) {
                    return x$5.needQuantize();
                }
            })) {
                Set masks = ((TraversableOnce)preConvs.map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final int apply(SpatialConvolution x$6) {
                        return x$6.getOutputDimMask();
                    }
                }, Seq$.MODULE$.canBuildFrom())).toSet();
                Log4Error$.MODULE$.invalidInputError(masks.size() == 1, new StringContext((Seq)Predef$.MODULE$.wrapRefArray((Object[])new String[]{"all preceding convolutions must have the same mask"})).s((Seq)Nil$.MODULE$), Log4Error$.MODULE$.invalidInputError$default$3());
                Seq nextConvs = (Seq)((TraversableLike)node.nextNodes().flatMap((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final Seq<Node<AbstractModule<Activity, Activity, Object>>> apply(Node<AbstractModule<Activity, Activity, Object>> node) {
                        return Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findNext(node);
                    }
                }, Seq$.MODULE$.canBuildFrom())).filter((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final boolean apply(Node<AbstractModule<Activity, Activity, Object>> x$7) {
                        return x$7.element() instanceof SpatialConvolution;
                    }
                });
                float[][] scales = nextConvs.isEmpty() ? (float[][])((Object[])new float[][]{(float[])((TraversableOnce)((TraversableLike)((GenericTraversableTemplate)preConvs.map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final float[] apply(SpatialConvolution x$8) {
                        return (float[])Predef$.MODULE$.refArrayOps((Object[])x$8.getOutputScales()).flatten((Function1)new Serializable(this){
                            public static final long serialVersionUID = 0L;

                            public final WrappedArray<Object> apply(float[] xs) {
                                return Predef$.MODULE$.wrapFloatArray(xs);
                            }
                        }, ClassTag$.MODULE$.Float());
                    }
                }, Seq$.MODULE$.canBuildFrom())).transpose((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final ArrayOps<Object> apply(float[] xs) {
                        return Predef$.MODULE$.floatArrayOps(xs);
                    }
                })).map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final float apply(Seq<Object> x$9) {
                        return BoxesRunTime.unboxToFloat((Object)x$9.max((Ordering)Ordering.Float$.MODULE$));
                    }
                }, Seq$.MODULE$.canBuildFrom())).toArray(ClassTag$.MODULE$.Float())}) : ((MklInt8Convertible)((IterableLike)nextConvs.map((Function1)new Serializable(){
                    public static final long serialVersionUID = 0L;

                    public final SpatialConvolution apply(Node<AbstractModule<Activity, Activity, Object>> x$10) {
                        return (SpatialConvolution)x$10.element();
                    }
                }, Seq$.MODULE$.canBuildFrom())).head()).getInputScales();
                preConvs.foreach((Function1)new Serializable(scales){
                    public static final long serialVersionUID = 0L;
                    private final float[][] scales$1;

                    public final void apply(SpatialConvolution conv) {
                        conv.setOutputScales(this.scales$1);
                    }
                    {
                        this.scales$1 = scales$1;
                    }
                });
                return;
            }
            return;
        }
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    public void fuseScale(Node<AbstractModule<Activity, Activity, Object>> node) {
        BlasWrapper blasWrapper;
        AbstractModule<Activity, Activity, Object> abstractModule = node.element();
        if (abstractModule instanceof BlasWrapper && (blasWrapper = (BlasWrapper)abstractModule).module() instanceof Scale) {
            boolean isValid = node.prevNodes().forall((Function1)new Serializable(){
                public static final long serialVersionUID = 0L;

                public final boolean apply(Node<AbstractModule<Activity, Activity, Object>> x$11) {
                    return x$11.element() instanceof SpatialBatchNormalization;
                }
            });
            if (!isValid) return;
            node.prevNodes().foreach((Function1)new Serializable(node){
                public static final long serialVersionUID = 0L;
                private final Node node$1;

                public final Tensor<Object> apply(Node<AbstractModule<Activity, Activity, Object>> prevNode) {
                    SpatialBatchNormalization bn = (SpatialBatchNormalization)prevNode.element();
                    Tensor<Object> weightAndBias = bn.weightAndBias().dense();
                    Tensor<Object> weight = weightAndBias.narrow(1, 1, bn.nOutput());
                    Tensor<Object> bias = weightAndBias.narrow(1, bn.nOutput() + 1, bn.nOutput());
                    Scale scale = (Scale)((BlasWrapper)this.node$1.element()).module();
                    Tensor scaleWeight = ((Tensor[])scale.parameters()._1())[0];
                    Tensor scaleBias = ((Tensor[])scale.parameters()._1())[1];
                    weight.cmul(scaleWeight);
                    bias.cmul(scaleWeight);
                    bias.add((Object)scaleBias);
                    return bn.weightAndBias().dense().set(weightAndBias);
                }
                {
                    this.node$1 = node$1;
                }
            });
            node.element_$eq(Identity$.MODULE$.apply$mFc$sp((ClassTag<Object>)ClassTag$.MODULE$.Float(), TensorNumericMath$TensorNumeric$NumericFloat$.MODULE$));
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            return;
        } else {
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
        }
    }

    public Seq<Node<AbstractModule<Activity, Activity, Object>>> com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findAllNonIdentityPrevs(Node<AbstractModule<Activity, Activity, Object>> node) {
        return node.element() instanceof Identity || node.element() instanceof MaxPooling || node.element() instanceof AvgPooling || node.element() instanceof JoinTable ? (Seq)node.prevNodes().flatMap((Function1)new Serializable(){
            public static final long serialVersionUID = 0L;

            public final Seq<Node<AbstractModule<Activity, Activity, Object>>> apply(Node<AbstractModule<Activity, Activity, Object>> node) {
                return Fusion$.MODULE$.com$intel$analytics$bigdl$dllib$nn$mkldnn$Fusion$$findAllNonIdentityPrevs(node);
            }
        }, Seq$.MODULE$.canBuildFrom()) : (Seq)Seq$.MODULE$.apply((Seq)Predef$.MODULE$.wrapRefArray((Object[])new Node[]{node}));
    }

    private Fusion$() {
        MODULE$ = this;
    }
}

