# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs-dev/03_cnn_virus_architecture.ipynb.

# %% auto 0
__all__ = ['create_model_original']

# %% ../../nbs-dev/03_cnn_virus_architecture.ipynb 3
# Imports all dependencies
import tensorflow as tf
import tensorflow.keras

from pathlib import Path
from typing import Callable, Tuple
from ..core import ProjectFileSystem

from tensorflow.keras.layers import Convolution1D, Dense, Flatten, Dropout, Activation, BatchNormalization, Input
from tensorflow.keras.layers import MaxPooling1D, Concatenate
from tensorflow.keras.models import Sequential, Model, load_model

# %% ../../nbs-dev/03_cnn_virus_architecture.ipynb 8
def create_model_original(
    load_parameters: bool = True, # Load pretrained weights when True
    path2parameters: Path = None,  # Path to pretrained weights, defaults to project CNN Virus weights
    ) -> tf.keras.Model:          # New instance of an original paper architecture
    """Build a CNN model as per CNN Virus paper"""

    print("Creating CNN Model (Original)")
    # Build cnn model
    input_seq=Input(shape=(50,5), name='input-seq')
    layer1=Convolution1D(512, 5, padding="same",activation="relu",kernel_initializer="he_uniform", name="conv-1")(input_seq)
    layer2=BatchNormalization(momentum=0.6, name='bn-1')(layer1)
    layer3=MaxPooling1D(pool_size=2,padding='same', name='maxpool-1')(layer2)
    layer4=Convolution1D(512, 5, padding="same",activation="relu",kernel_initializer="he_uniform", name="conv-2")(layer3)
    layer5=BatchNormalization(momentum=0.6, name='bn-2')(layer4)
    layer6=MaxPooling1D(pool_size=2,padding='same',name='maxpool-2')(layer5)
    layer7=Convolution1D(1024, 7, padding="same", activation="relu",kernel_initializer="he_uniform", name="conv-3")(layer6)
    layer8=Convolution1D(1024, 7, padding="same", activation="relu",kernel_initializer="he_uniform", name="conv-4")(layer7)
    layer9=BatchNormalization(momentum=0.6, name='bn-3')(layer8)
    layer10=MaxPooling1D(pool_size=2,padding='same', name='maxpool-3')(layer9)
    layer11=Flatten(name='flatten')(layer10)
    layer12=Dense(1024,kernel_initializer="he_uniform", name='dense-1')(layer11)
    layer13=BatchNormalization(momentum=0.6, name='bn-4')(layer12)
    layer14=Dropout(0.2, name='do-1')(layer13)
    labels=Dense(187, activation='softmax',kernel_initializer="he_uniform",name="labels")(layer14)
    output_con=Concatenate(name='concat')([layer14,labels])
    layer15=Dense(1024, kernel_initializer="he_uniform", name='dense-2')(output_con)
    layer16=BatchNormalization(momentum=0.6, name='bn-5')(layer15)
    pos=Dense(10, activation='softmax',kernel_initializer="he_uniform",name="pos")(layer16)
    model = Model(inputs=input_seq, outputs=[labels,pos], name="CNN_Virus")

    # Load pretrained weights
    if load_parameters:
        if path2parameters is None: 
            path2parameters = ProjectFileSystem().data /'saved/cnn_virus_original/pretrained_model.h5'
        if not path2parameters.is_file(): 
            raise FileNotFoundError(f"Could not find pretrained model at {path2parameters}")
        print(f"Loading parameters from {path2parameters.name}")
        model.load_weights(path2parameters)
        print("Created pretrained model")
    else:
        print("Created randomly initialized model")
    return model    
