# Copyright © 2023, University of California, Irvine.
# 
# GP+ Intellectual Property Notice:
# 
# The software known as GP+ is the proprietary material of the University of California, Irvine. 
# Non-profit academic institutions and U.S. government agencies may utilize this software exclusively for educational and research endeavors. 
# All other entities are granted permission for evaluation purposes solely; any additional utilization demands prior written consent from the appropriate authority. 
# The direct sale or redistribution of this software, in any form, without explicit written authorization is strictly prohibited. 
# Users are permitted to make duplicate copies of the software, contingent upon the assurance that no copies are sold or redistributed and they adhere to the stipulated terms herein.
# 
# Being academic research software, GP+ is provided on an "as is" base, devoid of warranties, whether explicit or implicit. 
# The act of downloading or executing any segment of this software inherently signifies compliance with these terms. 
# The University of California, Irvine reserves the right to modify these terms and conditions without prior intimation at any juncture.

import torch
import math
import gpytorch
from typing import Dict,List,Optional
from gpytorch.constraints import Positive
from gpytorch.priors import NormalPrior
from gpytorch.distributions import MultivariateNormal
from ..priors import MollifiedUniformPrior
from GP_Plus_functions.visual.plot_latenth import plot_sep
from GP_Plus_functions.models.gpregression import GPR
from GP_Plus_functions import kernels
from GP_Plus_functions.priors import MollifiedUniformPrior
import numpy as np
from GP_Plus_functions.preprocessing import setlevels
from GP_Plus_functions.visual import plot_ls
import matplotlib.pyplot as plt
import torch
from gpytorch.means import Mean
import torch
from tabulate import tabulate
from GP_Plus_functions.utils import set_seed
from GP_Plus_functions.optim import fit_model_scipy
import torch
import numpy as np
import sobol_seq
import warnings
from torch import Tensor
from gpytorch.means import Mean
from torch.nn.parameter import Parameter, UninitializedParameter
from torch.nn import init
from torch import nn
import torch.nn.functional as F 
from scipy.stats import norm

tkwargs = {
    "dtype": torch.float,
    "device": torch.device("cpu" if torch.cuda.is_available() else "cpu"),
}

class GP_Plus(GPR):
    """The GP_Plus which extends GPs to learn nonlinear and probabilistic nmanifold, handle categorical inputs, and  ... ...

    :note: Binary categorical variables should not be treated as qualitative inputs. There is no 
        benefit from applying a latent variable treatment for such variables. Instead, treat them
        as numerical inputs.

    :param train_x: The training inputs (size N x d). Qualitative inputs needed to be encoded as 
        integers 0,...,L-1 where L is the number of levels. For best performance, scale the 
        numerical variables to the unit hypercube.
    """
    def __init__(
        self,
        train_x:torch.Tensor,
        train_y:torch.Tensor,
        qual_ind_lev = {},
        multiple_noise = False,
        lv_dim:int=2,
        quant_correlation_class:str='Rough_RBF',
        noise:float=1e-4,
        fix_noise:bool=False,
        lb_noise:float=1e-8,
        NN_layers:list = [],
        encoding_type = 'one-hot',
        manifold_type='deterministic',
        uniform_encoding_columns = 2,
        lv_columns = [] ,
        base='single',
        NN_layers_base=[],
        base_function_size=None,
        Calibration_ID=[],
        seed_number=1,
        mean_prior_cal=0,
        std_prior_cal=1,
        Calibration_type='deterministic'
    ) -> None:
        
        self.mean_prior_cal=mean_prior_cal
        self.std_prior_cal=std_prior_cal
        
        ## The checks lists:
        if not isinstance(train_x, torch.Tensor):
            original_type = type(train_x).__name__
            warnings.warn(f"'train_x' was not a torch.Tensor (type: {original_type}). It is converted to torch.Tensor to proceed with the emulation.")
            train_x = torch.tensor(train_x)

        if not isinstance(train_y, torch.Tensor):
            original_type = type(train_y).__name__
            warnings.warn(f"'train_y' was not a torch.Tensor (type: {original_type}). It is converted to torch.Tensor to proceed with the emulation.")
            train_y = torch.tensor(train_y)


        if not isinstance(qual_ind_lev, dict):
            raise ValueError("qual_ind_lev should be a dictionary")

        if multiple_noise not in [True, False]:
            raise ValueError("multiple_noise should be either True or False")

        if not isinstance(lv_dim, int):
            raise ValueError("lv_dim should be an integer")

        if quant_correlation_class not in ['Rough_RBF', 'RBFKernel', 'mattern52', 'mattern32']:
            raise ValueError("quant_correlation_class should be 'Rough_RBF', 'RBFKernel', 'mattern52', or 'mattern32'")

        if fix_noise not in [True, False]:
            raise ValueError("fix_noise should be either True or False")

        if not isinstance(NN_layers, list) or not all(isinstance(i, int) for i in NN_layers):
            raise ValueError("NN_layers should be a list with integers representing the number of neurons in each layer for mapping the manifold")

        if encoding_type != 'one-hot':
            raise ValueError("encoding_type should be 'one-hot'")

        if manifold_type not in ['deterministic', 'probabilistic']:
            raise ValueError("manifold_type should be either 'deterministic' or 'probabilistic'")

        if not isinstance(lv_columns, list) or not all(isinstance(i, int) for i in lv_columns):
            raise ValueError("lv_columns should be a list with integers showing the number of categorical inputs to be considered in a separate manifold in each layer")

        if base not in ['single', 'multiple_polynomial', 'multiple_constant', 'neural_network']:
            raise ValueError("base should be 'single', 'multiple_polynomial', 'multiple_constant', or 'neural_network'")

        if not isinstance(NN_layers_base, list) or not all(isinstance(i, int) for i in NN_layers_base):
            raise ValueError("NN_layers_base should be a list with integers representing the number of neurons in each layer for the mean function")

        if not isinstance(Calibration_ID, list) or not all(isinstance(i, int) for i in Calibration_ID):
            raise ValueError("Calibration_ID should be a list where each entry shows the column number in the dataset that the calibration parameters are assigned to")
        
        train_x=self.fill_nan_with_mean(train_x,Calibration_ID)
        ###############################################################################################
        ###############################################################################################
        self.seed=seed_number
        self.unique_indd=None
        self.Calibration_ID=Calibration_ID
        self.calibration_source_index=0    ## It is supposed the calibration parameter is for high fidelity needs
        qual_index = list(qual_ind_lev.keys())
        all_index = set(range(train_x.shape[-1]))
        quant_index = list(all_index.difference(qual_index))
        num_levels_per_var = list(qual_ind_lev.values())
        #------------------- lm columns --------------------------
        lm_columns = list(set(qual_index).difference(lv_columns))
        if len(lm_columns) > 0:
            qual_kernel_columns = [*lv_columns, lm_columns]
        else:
            qual_kernel_columns = lv_columns

        #########################
        if len(qual_index) > 0:
            train_x = setlevels(train_x, qual_index=qual_index)
        #
        if multiple_noise:
            noise_indices = list(range(0,num_levels_per_var[-1]))
        else:
            noise_indices = []

        if len(qual_index) == 1 and num_levels_per_var[0] < 2:
            temp = quant_index.copy()
            temp.append(qual_index[0])
            quant_index = temp.copy()
            qual_index = []
            lv_dim = 0
        elif len(qual_index) == 0:
            lv_dim = 0

        quant_correlation_class_name = quant_correlation_class

        if len(qual_index) == 0:
            lv_dim = 0

        if quant_correlation_class_name == 'Rough_RBF':
            quant_correlation_class = 'RBFKernel'
        if len(qual_index) > 0:
            ####################### Defined multiple kernels for seperate variables ###################
            qual_kernels = []
            for i in range(len(qual_kernel_columns)):
                qual_kernels.append(kernels.RBFKernel(
                    active_dims=torch.arange(lv_dim) + lv_dim * i) )
                qual_kernels[i].initialize(**{'lengthscale':1.0})
                qual_kernels[i].raw_lengthscale.requires_grad_(False)

        if len(quant_index) == 0:
            correlation_kernel = qual_kernels[0]
            for i in range(1, len(qual_kernels)):
                correlation_kernel *= qual_kernels[i]
        else:
            try:
                quant_correlation_class = getattr(kernels,quant_correlation_class)
            except:
                raise RuntimeError(
                    "%s not an allowed kernel" % quant_correlation_class
                )
            if quant_correlation_class_name == 'RBFKernel':
                quant_kernel = quant_correlation_class(
                    ard_num_dims=len(quant_index),
                    active_dims=len(qual_kernel_columns) * lv_dim+torch.arange(len(quant_index)),
                    lengthscale_constraint= Positive(transform= torch.exp,inv_transform= torch.log)
                )
            elif quant_correlation_class_name == 'Rough_RBF':
                quant_kernel = quant_correlation_class(
                    ard_num_dims=len(quant_index),
                    active_dims=len(qual_kernel_columns)*lv_dim+torch.arange(len(quant_index)),
                    lengthscale_constraint= Positive(transform= lambda x: 2.0**(-0.5) * torch.pow(10,-x/2),inv_transform= lambda x: -2.0*torch.log10(x/2.0))
                )
                #####################
            if quant_correlation_class_name == 'RBFKernel':
                
                quant_kernel.register_prior(
                    'lengthscale_prior', MollifiedUniformPrior(math.log(0.1),math.log(10)),'raw_lengthscale'
                )
                
            elif quant_correlation_class_name == 'Rough_RBF':
                quant_kernel.register_prior(
                    'lengthscale_prior',NormalPrior(-3.0,3.0),'raw_lengthscale'
                )
                
            if len(qual_index) > 0:
                temp = qual_kernels[0]
                for i in range(1, len(qual_kernels)):
                    temp *= qual_kernels[i]
                correlation_kernel = temp*quant_kernel #+ qual_kernel + quant_kernel
            else:
                correlation_kernel = quant_kernel

        super(GP_Plus,self).__init__(
            train_x=train_x,train_y=train_y,noise_indices=noise_indices,
            correlation_kernel=correlation_kernel,
            noise=noise,fix_noise=fix_noise,lb_noise=lb_noise
        )

        self.Calibration_type=Calibration_type
        self.forward_np=5
        for n in self.Calibration_ID:
            if self.Calibration_type=='probabilistic':
                setattr(self,'Theta_'+str(n), LinearVariational(batch_shape=torch.Size([]),mean_prior=self.mean_prior_cal,std_prior=0*self.std_prior_cal).to(**tkwargs)) 
                setattr(self,'calibration_element'+str(n), torch.where(train_x[:, -1]==self.calibration_source_index)[0]) 
            else:
                setattr(self,'Theta_'+str(n), gpytorch.means.ConstantMean(prior=NormalPrior(self.mean_prior_cal,self.std_prior_cal))) 
                setattr(self,'calibration_element'+str(n), torch.where(train_x[:, -1]==self.calibration_source_index)[0]) 
            train_x[getattr(self,'calibration_element'+str(n)),n]=torch.zeros_like(train_x[getattr(self,'calibration_element'+str(n)),n])
        
        # register index and transforms
        self.register_buffer('quant_index',torch.tensor(quant_index))
        self.register_buffer('qual_index',torch.tensor(qual_index))

        self.qual_kernel_columns = qual_kernel_columns
        # latent variable mapping
        self.num_levels_per_var = num_levels_per_var
        self.lv_dim = lv_dim
        self.uniform_encoding_columns = uniform_encoding_columns
        self.encoding_type = encoding_type
        self.manifold_type=manifold_type
        self.perm =[]
        self.zeta = []
        self.random_zeta=[]
        self.perm_dict = []
        self.A_matrix = []
        self.epsilon=None
        self.epsilon_f=None
        self.embeddings_Dtrain=[]
        self.count=train_x.size()[0]
        if len(qual_kernel_columns) > 0:
            for i in range(len(qual_kernel_columns)):
                if type(qual_kernel_columns[i]) == int:
                    num = self.num_levels_per_var[qual_index.index(qual_kernel_columns[i])]
                    cat = [num]
                else:
                    cat = [self.num_levels_per_var[qual_index.index(k)] for k in qual_kernel_columns[i]]
                    num = sum(cat)

                zeta, perm, perm_dict = self.zeta_matrix(num_levels=cat, lv_dim = self.lv_dim)
                self.zeta.append(zeta)
                self.perm.append(perm)
                self.perm_dict.append(perm_dict)       
                ###################################  latent map (manifold) #################################   
                if self.manifold_type=='probabilistic':
                    setattr(self,'A_matrix', Variational_Encoder(self, input_size= num, num_classes=5, 
                        layers =NN_layers, name = str(qual_kernel_columns[i])).to(**tkwargs))
                else:
                    model_temp = FFNN(self, input_size= num, num_classes=lv_dim, 
                        layers = NN_layers, name ='latent'+ str(qual_kernel_columns[i])).to(**tkwargs)
                    self.A_matrix.append(model_temp.to(**tkwargs))
        ###################################  Mean Function #################################   
        self.base=base
        i=0
        if self.base=='single':
            self.mean_module = gpytorch.means.ConstantMean(prior=NormalPrior(0.,1))
        elif self.base=='multiple_constant':
            if base_function_size is None:
                base_function_size=train_x.shape[1]-1
            self.num_sources=int(torch.max(train_x[:,-1]))
            for i in range(self.num_sources +1):
                if i==0:
                    setattr(self,'mean_module_'+str(i), gpytorch.means.ZeroMean())
                else:
                    #Constant 
                    setattr(self,'mean_module_'+str(i), gpytorch.means.ConstantMean(prior=NormalPrior(0.,.1))) 
        elif self.base=='multiple_polynomial':
            if base_function_size is None:
                base_function_size=train_x.shape[1]-1
            self.num_sources=int(torch.max(train_x[:,-1]))
            for i in range(self.num_sources +1):
                if i==0:
                    setattr(self,'mean_module_'+str(i), gpytorch.means.ZeroMean())
                else:
                    setattr(self,'mean_module_'+str(i), LinearMean_with_prior(input_size=base_function_size, batch_shape=torch.Size([]), bias=True)) 
        elif self.base=='neural_network':
            ############################################### One NN for ALL
            setattr(self,'mean_module_NN_All', FFNN_as_Mean(self, input_size= train_x.shape[1]+1, num_classes=1,layers =NN_layers_base, name = str('mean_module_'+str(i)+'_')).to(**tkwargs)) 


    def forward(self,x:torch.Tensor) -> MultivariateNormal:
        if self.manifold_type=='probabilistic' or self.Calibration_type=='probabilistic':
            set_seed(self.seed)
            if self.training:
                Numper_of_pass=5 #20
            else:
                Numper_of_pass=5 #30
        else:
            Numper_of_pass=1
        
        Sigma_sum=torch.zeros(x.size(0),x.size(0), dtype=torch.float64)
        mean_x_sum=torch.zeros(x.size(0), dtype=torch.float64)

        for NP in range(Numper_of_pass):
            x_forward_raw=x.clone()
            nd_flag = 0
            if x.dim() > 2:
                xsize = x.shape
                x = x.reshape(-1, x.shape[-1])
                nd_flag = 1
            x_new= x
            if len(self.qual_kernel_columns) > 0:
                embeddings = []
                for i in range(len(self.qual_kernel_columns)):
                    temp= self.transform_categorical(x=x[:,self.qual_kernel_columns[i]].clone().type(torch.int64).to(tkwargs['device']), 
                        perm_dict = self.perm_dict[i], zeta = self.zeta[i])
                dimm=x_forward_raw.size()[0]
                if self.manifold_type=='probabilistic': 
                    # Convert to list of tuples
                    x_raw=torch.zeros(temp.size(0),2)
                    # Find unique rows
                    unique_rows, indices = torch.unique(temp, dim=0, return_inverse=True)
                    # if self.unique_indd is None:
                    #     self.unique_indd= unique_rows
                    self.unique_indd= unique_rows
                    # if self.unique_indd is not None and torch.isnan(self.unique_indd).any():
                    #     self.unique_indd = unique_rows
                    temp= unique_rows
                    dimm=unique_rows.size()[0]
                    if self.training:
                        epsilon=torch.normal(mean=0,std=1,size=[dimm,2])## use np instead of torch 
                        embeddings.append(getattr(self,'A_matrix')(x=temp.float().to(**tkwargs),epsilon=epsilon))
                    else:
                        if x.size()[0]==self.count:
                            epsilon=torch.normal(mean=0,std=1,size=[dimm,2])
                            embeddings.append(getattr(self,'A_matrix')(x=temp.float().to(**tkwargs),epsilon=epsilon))
                            self.embeddings_Dtrain.append(embeddings[0])
                        else:
                            embeddings.append(self.embeddings_Dtrain[NP])
                    for i, index in enumerate(indices):
                        x_raw[i] = embeddings[0][index]
                    embeddings=x_raw
                    x_new= torch.cat([embeddings,x[...,self.quant_index.long()]],dim=-1)
                else:
                    embeddings.append(self.A_matrix[i](temp.float().to(**tkwargs)))
                    x_new= torch.cat([embeddings[0],x[...,self.quant_index.long()]],dim=-1)
                
                ## For Calibration
                if len(self.Calibration_ID)>0:
                    if self.training:
                        for n in self.Calibration_ID:
                            if self.Calibration_type=='probabilistic':
                                s=torch.ones_like(x_new[getattr(self,'calibration_element'+str(n)),embeddings[0].size(1)+n]).shape
                                epsilon=torch.normal(mean=0,std=1,size=[s[0],1])
                                Theta=(getattr(self,'Theta_'+str(n))(epsilon.clone().reshape(-1,1)))
                                x_new[getattr(self,'calibration_element'+str(n)),embeddings[0].size(1)+n]=\
                                    torch.ones_like(x_new[getattr(self,'calibration_element'+str(n)),embeddings[0].size(1)+n])*(Theta.reshape(-1))
                                x_new[getattr(self,'calibration_element'+str(n)),embeddings[0].size(1)+n]=\
                                    torch.ones_like(x_new[getattr(self,'calibration_element'+str(n)),embeddings[0].size(1)+n])*(getattr(self,'Theta_'+str(n))(x[i,-1].clone().flatten().reshape(-1,1)))
                    else:
                        for n in self.Calibration_ID:
                            if self.Calibration_type=='probabilistic':
                                epsilon=torch.normal(mean=0,std=1,size=[1,1])
                                calibration_element=torch.where(x[:, -1]==self.calibration_source_index)[0]
                                x_new[calibration_element,embeddings[0].size(1)+n]=\
                                        torch.ones_like(x_new[calibration_element,embeddings[0].size(1)+n])*(getattr(self,'Theta_'+str(n))(epsilon.clone().reshape(-1,1)))
                            else:
                                calibration_element=torch.where(x[:, -1]==self.calibration_source_index)[0]
                                x_new[calibration_element,embeddings[0].size(1)+n]=\
                                    torch.ones_like(x_new[calibration_element,embeddings[0].size(1)+n])*(getattr(self,'Theta_'+str(n))(x[i,-1].clone().reshape(-1,1)))               
            if nd_flag == 1:
                x_new = x_new.reshape(*xsize[:-1], -1)
        #################### Multiple baises (General Case) ####################################  
            def multi_mean(x,x_forward_raw):
                mean_x=torch.zeros_like(x[:,-1])
                if self.base=='single':
                    mean_x=self.mean_module(x)
                elif self.base=='multiple_constant':
                    for i in range(len(mean_x)):
                        qq=int(x_forward_raw[i,-1])                        
                        mean_x[i] = getattr(self, 'mean_module_' + str(qq))(x[i, -1].clone().reshape(-1, 1))
                elif self.base=='multiple_polynomial':
                    for i in range(len(mean_x)):
                        qq=int(x_forward_raw[i,-1])
                        # mean_x[i]=getattr(self,'mean_module_'+str(qq))(torch.cat((torch.tensor((x[i,-1].clone().double().reshape(-1,1).float())**2),torch.tensor(x[i,-1].clone().double()).reshape(-1,1).float()),1))
                        mean_x[i] = getattr(self, 'mean_module_' + str(qq))(torch.cat((x[i, -1].clone().detach().double().reshape(-1, 1).float() ** 2,
                                x[i, -1].clone().detach().double().reshape(-1, 1).float()),1))
                
                elif self.base=='neural_network':
                    mean_x = getattr(self, 'mean_module_NN_All')(x.clone()).reshape(-1)
                return mean_x 
            mean_x = multi_mean(x_new,x_forward_raw)
            covar_x = self.covar_module(x_new)
            mean_x_sum+=mean_x
            Sigma_sum += covar_x.evaluate()+ torch.outer(mean_x, mean_x)

        # End of the loop for forward pasess ----> Compute ensemble mean and covariance
        k = Numper_of_pass
        ensemble_mean = mean_x_sum/k
        ensemble_covar = torch.zeros_like(Sigma_sum) 
        ensemble_covar= Sigma_sum/k
        ensemble_covar -= torch.outer(ensemble_mean, ensemble_mean)
        ensemble_covar=gpytorch.lazy.NonLazyTensor(ensemble_covar)
        Sigma_sum=0
        return MultivariateNormal(ensemble_mean,ensemble_covar)
    
        

        ########### Fit ############
    def fit(sefl,add_prior:bool=True,num_restarts:int=64,theta0_list:Optional[List[np.ndarray]]=None,jac:bool=True, options:Dict={},n_jobs:int=-1,method = 'L-BFGS-B',constraint=False,bounds=False,regularization_parameter:List[int]=[0,0]):
        fit_model_scipy (sefl,add_prior,num_restarts,theta0_list,jac, options,n_jobs,method ,constraint,bounds,regularization_parameter)

    def fill_nan_with_mean(self,train_x,cal_ID):
        # Check if there are any NaNs in the tensor
        if torch.isnan(train_x).any():
            if len(cal_ID)==0:
                print("There are NaN values in the data, which will be filled with column-wise mean values.")
            else:
                print("There are NaN values in the data, which will be estimated in calibration process")

            # Compute the mean of non-NaN elements column-wise
            col_means = torch.nanmean(train_x, dim=0)

            # Find indices where NaNs are located
            nan_indices = torch.isnan(train_x)

            # Replace NaNs with the corresponding column-wise mean
            train_x[nan_indices] = col_means.repeat(train_x.shape[0], 1)[nan_indices]

        return train_x
    ############################  Prediction and Visualization  ###############################
    
    def predict(self, Xtest,return_std=True, include_noise = True):
        with torch.no_grad():
            return super().predict(Xtest, return_std = return_std, include_noise= include_noise)
    
    def predict_with_grad(self, Xtest,return_std=True, include_noise = True):
        return super().predict(Xtest, return_std = return_std, include_noise= include_noise)
    
    def noise_value(self):
        noise = self.likelihood.noise_covar.noise.detach() * self.y_std**2
        return noise

    def score(self, Xtest, ytest, plot_MSE = True, title = None, seperate_levels = False):
        plt.rcParams.update({'font.size': 14})
        ypred = self.predict(Xtest, return_std=False)
        mse = ((ytest.reshape(-1)-ypred)**2).mean()
        print('################MSE######################')
        print(f'MSE = {mse:.5f}')
        print('#########################################')
        print('################Noise####################')
        noise = self.likelihood.noise_covar.noise.detach() * self.y_std**2
        print(f'The estimated noise parameter (varaince) is {noise}')
        print(f'The estimated noise std is {np.sqrt(noise)}')
        print('#########################################')

        if plot_MSE:
            _ = plt.figure(figsize=(8,6))
            _ = plt.plot(ytest.cpu().numpy(), ypred.cpu().numpy(), 'ro', label = 'Data')
            _ = plt.plot(ytest.cpu().numpy(), ytest.cpu().numpy(), 'b', label = 'MSE = ' + str(np.round(mse.detach().item(),3)))
            _ = plt.xlabel(r'Y_True')
            _ = plt.ylabel(r'Y_predict')
            _ = plt.legend()
            if title is not None:
                _ = plt.title(title)

        if seperate_levels and len(self.qual_index) > 0:
            for i in range(self.num_levels_per_var[0]):
                index = torch.where(Xtest[:,self.qual_index] == i)[0]
                _ = self.score(Xtest[index,...], ytest[index], 
                    plot_MSE=True, title = 'results' + ' Only Source ' + str(i), seperate_levels=False)
        return ypred


    # def score_probabilistic(self, Xtest, ytest, num_Forward_pass=30, plot_MSE = True, title = None, seperate_levels = False):
    #     plt.rcParams.update({'font.size': 14})

    #     # for i in 
    #     ypred=torch.zeros_like(ytest)
    #     for i in range(num_Forward_pass):
    #         ypred += self.predict(Xtest, return_std=False)/num_Forward_pass
            
    #     mse = ((ytest.reshape(-1)-ypred)**2).mean()
    #     print('################MSE######################')
    #     print(f'MSE = {mse:.5f}')
    #     print('#########################################')
    #     print('################Noise####################')
    #     noise = self.likelihood.noise_covar.noise.detach() * self.y_std**2
    #     print(f'The estimated noise parameter (varaince) is {noise}')
    #     print(f'The estimated noise std is {np.sqrt(noise)}')
    #     print('#########################################')

    #     if plot_MSE:
    #         _ = plt.figure(figsize=(8,6))
    #         _ = plt.plot(ytest.cpu().numpy(), ypred.cpu().numpy(), 'ro', label = 'Data')
    #         _ = plt.plot(ytest.cpu().numpy(), ytest.cpu().numpy(), 'b', label = 'MSE = ' + str(np.round(mse.detach().item(),3)))
    #         _ = plt.xlabel(r'Y_True')
    #         _ = plt.ylabel(r'Y_predict')
    #         _ = plt.legend()
    #         if title is not None:
    #             _ = plt.title(title)

    #     if seperate_levels and len(self.qual_index) > 0:
    #         for i in range(self.num_levels_per_var[0]):
    #             index = torch.where(Xtest[:,self.qual_index] == i)[0]
    #             _ = self.score(Xtest[index,...], ytest[index], 
    #                 plot_MSE=True, title = 'results' + ' Only Source ' + str(i), seperate_levels=False)
    #     return ypred

    # def Metric_probabilistic(self,Xtest,ytest):
    #     self.eval()
    #     likelihood=self.likelihood
    #     likelihood.fidel_indices=self.train_inputs[0][:,-1]
    #     output=self(Xtest)
    #     likelihood.fidel_indices=Xtest[:,-1]
    #     ytest_sc = (ytest-self.y_mean)/self.y_std
    #     with torch.no_grad():
    #         trained_pred_dist = likelihood(output)
    #     final_mse =gpytorch.metrics.mean_squared_error(trained_pred_dist, ytest_sc, squared=True)
    #     def interval_score(y_true, trained_pred_dist, alpha = 0.05):
    #         mu_low, mu_up = trained_pred_dist.confidence_region()
    #         out = mu_up - mu_low
    #         out += (y_true > mu_up)* 2/alpha * (y_true - mu_up)
    #         out += (y_true <mu_low)* 2/alpha * (mu_low - y_true)
    #         return out
    #     IS=interval_score(ytest_sc, trained_pred_dist, alpha = 0.05).mean()
    #     ## bake to original scale:
    #     return IS*torch.abs(self.y_std), final_mse*(self.y_std)**2
    
    def evaluation_2(self,Xtest,ytest,n_FP=1):
        self.eval()
        likelihood=self.likelihood
        likelihood.fidel_indices=self.train_inputs[0][:,-1]
        output=self(Xtest)
        likelihood.fidel_indices=Xtest[:,-1]
        ytest_sc = (ytest-self.y_mean)/self.y_std
        mean_temp=[]
        var_temp=[]
        for i in range (n_FP):
            with torch.no_grad():
                trained_pred_dist = likelihood(output)
                mean_temp.append(trained_pred_dist.mean)
                var_temp.append(trained_pred_dist.variance)
            
        sum_list = [mean**2 + var for mean, var in zip(mean_temp, var_temp)]
        sum_tensors = sum(sum_list)/n_FP
        mean_ensamble=sum(mean_temp)/n_FP
        var_ensamble=sum_tensors -mean_ensamble**2
        std_ensamble=var_ensamble.sqrt()
        mu_low, mu_up=mean_ensamble-1.96*std_ensamble, mean_ensamble+1.96*std_ensamble
        final_mse=((ytest_sc.reshape(-1)-mean_ensamble)**2).mean()
        def interval_score(y_true,mu_low, mu_up, alpha = 0.05):
            out = mu_up - mu_low
            out += (y_true > mu_up)* 2/alpha * (y_true - mu_up)
            out += (y_true <mu_low)* 2/alpha * (mu_low - y_true)
            return out
        IS=interval_score(ytest_sc,mu_low, mu_up, alpha = 0.05).mean()

        NIS=IS*torch.abs(self.y_std)/ytest.std()
        NRMSE=torch.sqrt((final_mse*(self.y_std)**2)/ytest.std()**2)
        return NIS, NRMSE
    

    # def Metric_3(self,Xtest,ytest,n_FP=1):
    #     self.eval()
    #     likelihood=self.likelihood
    #     y_per,ystd=self.predict(Xtest,return_std=True, include_noise = True)
    #     print(f'mmmssseee {((ytest.reshape(-1)-y_per)**2).mean()}')
    #     likelihood.fidel_indices=self.train_inputs[0][:,-1]
    #     output=self(Xtest)
    #     likelihood.fidel_indices=Xtest[:,-1]
    #     ytest_sc = (ytest-self.y_mean)/self.y_std
    #     mean_temp=[]
    #     var_temp=[]
    #     for i in range (n_FP):
    #         with torch.no_grad():
    #             trained_pred_dist = likelihood(output)
    #             mean_temp.append(trained_pred_dist.mean)
    #             var_temp.append(trained_pred_dist.variance)
            
    #     sum_list = [mean**2 + var for mean, var in zip(mean_temp, var_temp)]
    #     sum_tensors = sum(sum_list)/n_FP
    #     mean_ensamble=sum(mean_temp)/n_FP
    #     var_ensamble=sum_tensors -mean_ensamble**2
    #     std_ensamble=var_ensamble.sqrt()
    #     mu_low, mu_up=mean_ensamble-1.96*std_ensamble, mean_ensamble+1.96*std_ensamble
    #     final_mse=((ytest_sc.reshape(-1)-mean_ensamble)**2).mean()
    #     def interval_score(y_true,mu_low, mu_up, alpha = 0.05):
    #         out = mu_up - mu_low
    #         out += (y_true > mu_up)* 2/alpha * (y_true - mu_up)
    #         out += (y_true <mu_low)* 2/alpha * (mu_low - y_true)
    #         return out
        
        # IS=interval_score(ytest_sc,mu_low, mu_up, alpha = 0.05).mean()
        # NIS=IS*torch.abs(self.y_std)/ytest.std()
        # NRMSE=torch.sqrt((final_mse*(self.y_std)**2)/ytest.std())

        # IS=interval_score(ytest_sc,mu_low, mu_up, alpha = 0.05).mean()
        # NIS=IS*torch.abs(self.y_std)
        # NRMSE=torch.sqrt((final_mse*(self.y_std)**2))
        
        # To make sure
        # from GP_Plus_functions.utils.interval_score import interval_score as IS_old        
        # IS, accuracy =IS_old((output.mean*self.y_std + self.y_mean)+1.96* ystd, (output.mean*self.y_std + self.y_mean) - 1.96 * ystd, ytest)
        # NRMSE=torch.sqrt(((ytest.reshape(-1)-y_per)**2).mean()/ytest.std())
        # NIS=IS/self.y_std
        # # To make sure
        # return NIS, NRMSE
    
    def evaluation(self,Xtest,ytest):
        ytest_sc = (ytest-self.y_mean)/self.y_std
        self.eval()
        likelihood=self.likelihood
        output=self(Xtest)
        likelihood.x=Xtest
        with torch.no_grad():
            trained_pred_dist = likelihood(output)
        # Negative Log Predictive Density (NLPD)
        final_nlpd = gpytorch.metrics.negative_log_predictive_density(trained_pred_dist,ytest_sc)
        # Mean Squared Error (MSE)
        final_mse = gpytorch.metrics.mean_squared_error(trained_pred_dist, ytest_sc, squared=True)
        # Mean Absolute Error (MAE)
        final_mae = gpytorch.metrics.mean_absolute_error(trained_pred_dist, ytest_sc)
        import tensorflow as tf
        def interval_score(y_true, trained_pred_dist, alpha = 0.05):
            mu_low, mu_up = trained_pred_dist.confidence_region()
            out = mu_up - mu_low
            out += (y_true > mu_up).float()* 2/alpha * (y_true - mu_up)
            out += (y_true <mu_low).float()* 2/alpha * (mu_low - y_true)
            return out
        IS=interval_score(ytest_sc, trained_pred_dist, alpha = 0.05).mean()
        ## bake to original scale:
        final_nlpd=final_nlpd
        # final_msll=final_msll
        final_mse=final_mse*(self.y_std)**2
        final_mae=final_mae*torch.abs(self.y_std)
        IS=IS*torch.abs(self.y_std)
        ###
        RRMSE=torch.sqrt(final_mse/torch.var(ytest))
        table_data = [
            ['Negative Log-Likelihood (NLL)', final_nlpd],
            ['Mean Squared Error (MSE)', final_mse],
            ['Mean Absolute Error  (MAE)', final_mae],
            ['Relative Root Mean Square Error (RRMSE)', RRMSE],
            ['Interval Score (IS)', IS]
        ]
        table = tabulate(table_data, headers=['Metric', 'Value'], tablefmt='fancy_grid', colalign=("left", "left"))
        print(table)

    def rearrange_one_hot(self,tensor):
        # Find the indices that sort each row
        sorted_indices = torch.argsort(tensor, dim=1, descending=True)

        # Generate a new tensor of zeros with the same shape
        new_tensor = torch.zeros_like(tensor)

        # Place '1's in the appropriate positions based on the sorted indices
        for i in range(tensor.size(0)):
            new_tensor[i, sorted_indices[i, 0]] = 1

        return torch.flip(new_tensor, dims=[0])
    def visualize_latent(self,rpearts=500):
        if self.manifold_type=='deterministic':
            if len(self.qual_kernel_columns) > 0:
                for i in range(len(self.qual_kernel_columns)):
                    zeta = self.zeta[i]
                    dimm=zeta.size()[0]
                    zeta_epsilon=torch.normal(mean=0,std=1,size=[dimm,2])

                    A = getattr(self,'A_matrix')
                    positions = A[i](x=zeta.float().to(**tkwargs))
                    level = torch.max(self.perm[i], axis = 0)[0].tolist()
                    perm = self.perm[i]
                    plot_sep(positions = positions, levels = level, perm = perm, constraints_flag=False)
        elif self.manifold_type=='probabilistic':
            for i in range(len(self.qual_kernel_columns)):
                temp= self.transform_categorical(x=self.train_inputs[0][:,self.qual_kernel_columns[i]].clone().type(torch.int64).to(tkwargs['device']), perm_dict = self.perm_dict[i], zeta = self.zeta[i])
            unique_rows, indices = torch.unique(temp, dim=0, return_inverse=True)
            xp=self.rearrange_one_hot(unique_rows)
            z_p_list =[]
            label=[]
            epsilon=torch.normal(mean=0,std=1,size=[rpearts,2])
            for i in range(self.num_levels_per_var[0]):
                x_0=xp[i]
                x_0=x_0.repeat(rpearts, 1)
                z_p = getattr(self, 'A_matrix')(x=x_0, epsilon=epsilon)
                z_p_list.append(z_p)
                label.append(i*torch.ones_like(z_p))
            z_p_all = torch.cat(z_p_list, dim=0)
            label_ground_truth=torch.cat(label, dim=0)
            #########################
            plt.rcParams['font.family'] = 'Times New Roman'
            plt.rcParams['font.size'] = 25
            # plt.rcParams['figure.dpi']=150
            tab20 = plt.get_cmap('tab10')
            colors = tab20.colors
            colors=['deeppink','gold','darkorange','gray','orangered']
            plt.figure(figsize=(8,6))

            # Assuming z_p_all is a torch.Tensor
            z_p_all_np = z_p_all.detach().numpy()
            unique_labels = np.unique(label_ground_truth)
            markers = ['X','o','s',"v", 'p']

            for idx, label in enumerate(unique_labels):
                mask = (label_ground_truth == label)
                plt.scatter(z_p_all_np[mask[:,0], 0], z_p_all_np[mask[:,0], 1], 
                            c=colors[idx], 
                            marker=markers[idx], 
                            alpha=.6,
                            s=250, 
                            label=f'Label {label}')

            # Create the legend and get the legend handles and labels
            legend=[ 'HF', 'LF1','LF2','LF3']
            plt.xlabel(r'$z_1$',labelpad=0,rotation=0,usetex=True)
            plt.ylabel(r'$z_2$',labelpad=14,rotation=0,usetex=True)
            plt.tight_layout()
            plt.show()

    
    def evaluation(self,Xtest,ytest):
        self.eval()
        likelihood=self.likelihood
        ytest_sc = (ytest-self.y_mean)/self.y_std

        with torch.no_grad():
            trained_pred_dist = likelihood(self(Xtest))
        # Negative Log Predictive Density (NLPD)
        final_nlpd = gpytorch.metrics.negative_log_predictive_density(trained_pred_dist,ytest_sc)
        # Mean Squared Error (MSE)
        final_mse = gpytorch.metrics.mean_squared_error(trained_pred_dist, ytest_sc, squared=True)
        # Mean Absolute Error (MAE)
        final_mae = gpytorch.metrics.mean_absolute_error(trained_pred_dist, ytest_sc)
        import tensorflow as tf
        def interval_score(y_true, trained_pred_dist, alpha = 0.05):
            mu_low, mu_up = trained_pred_dist.confidence_region()
            out = mu_up - mu_low
            out += (y_true > mu_up)* 2/alpha * (y_true - mu_up)
            out += (y_true <mu_low)* 2/alpha * (mu_low - y_true)
            return out
        IS=interval_score(ytest_sc, trained_pred_dist, alpha = 0.05).mean()

        ## back to the original scale:
        final_mse=final_mse*(self.y_std)**2
        final_mae=final_mae*torch.abs(self.y_std)
        IS=IS*torch.abs(self.y_std)
        ###    
        RRMSE=torch.sqrt(final_mse/torch.var(ytest))
        table_data = [
            ['Negative Log-Likelihood (NLL)', final_nlpd],
            ['Mean Squared Error (MSE)', final_mse],
            ['Mean Absolute Error  (MAE)', final_mae],
            ['Relative Root Mean Square Error (RRMSE)', RRMSE],
            ['Interval Score (IS)', IS]
        ]
        # Print the table
        table = tabulate(table_data, headers=['Metric', 'Value'], tablefmt='fancy_grid', colalign=("left", "left"))
        print(table)

    def calibration_result(self,mean_train,std_train):
        self.Calibration_ID
        for n in self.Calibration_ID:
            if self.Calibration_type=='probabilistic':
                mean= (getattr(self,'Theta_'+str(n)).weights*std_train[n] +mean_train[n])[0].detach().numpy()
                STD= torch.abs((getattr(self,'Theta_'+str(n)).bias)*std_train[n])[0].detach().numpy()
                # print("Estimated Mean of Calibration parameter (Theta_"+str(n)+ "): "  + str(mean))
                print("For Calibration parameter Theta_"+str(n)+ " Estimated Mean is "  + str(mean)+" and Estimated STD is "  + str(STD))
                x = np.linspace(mean-5*STD,mean+5*STD, 1000)
                pdf_values = norm.pdf(x, mean, STD).squeeze()
                plt.figure()
                plt.plot(x, pdf_values, label='Theta_'+str(n))
                # plt.title('Theta_'+str(n))
                plt.title(r'$\mathit{\hat{\Theta}}_{' + str(n) + '}$')  # Italic LaTeX styled title with hat only on Theta
                plt.xlabel('Value')
                plt.ylabel('Density')
                plt.grid(True) 

            else:

                GT=self.train_inputs[0][0,n]*std_train[n] +mean_train[n]
                print("=================Ground Truth===================")
                print("The Ground Truth value of Theta_"+str(n)+ " is "+  str(GT))

                print("=================GP + Results===================")
                print("Error of Estimated Calibration parameter for Theta_"+str(n)+ " is "  + str(torch.abs(GT-(getattr(self,'Theta_'+str(n)).constant.detach()*std_train[n] +mean_train[n]))))


    @classmethod
    def show(cls):
        plt.show()
        
    def get_params(self, name = None):
        params = {}
        print('###################Parameters###########################')
        for n, value in self.named_parameters():
             params[n] = value
        if name is None:
            print(params)
            return params
        else:
            if name == 'Mean':
                key = 'mean_module.constant'
            elif name == 'Sigma':
                key = 'covar_module.raw_outputscale'
            elif name == 'Noise':
                key = 'likelihood.noise_covar.raw_noise'
            elif name == 'Omega':
                for n in params.keys():
                    if 'raw_lengthscale' in n and params[n].numel() > 1:
                        key = n
            print(params[key])
            return params[key]
    

    def sample_y(self, size = 1, X = None, plot = False):
        if X == None:
            X = self.train_inputs[0]
        
        self.eval()
        out = self.likelihood(self(X))
        draws = out.sample(sample_shape = torch.Size([size]))
        index = np.argsort(out.loc.detach().numpy())
        if plot:
            _ = plt.figure(figsize=(12,6))
            _ = plt.scatter(list(range(len(X))), out.loc.detach().numpy()[index], color = 'red', s = 20, marker = 'o')
            _ = plt.scatter(np.repeat(np.arange(len(X)).reshape(1,-1), size, axis = 0), 
                draws.detach().numpy()[:,index], color = 'blue', s = 1, alpha = 0.5, marker = '.')
        return draws

    def get_latent_space(self):
        if len(self.qual_index) > 0:
            zeta = torch.tensor(self.zeta, dtype = torch.float64).to(**tkwargs)
            positions = self.nn_model(zeta)
            return positions.detach()
        else:
            print('No categorical Variable, No latent positions')
            return None



    def LMMAPPING(self, num_features:int, type = 'Linear',lv_dim = 2):

        if type == 'Linear':
            in_feature = num_features
            out_feature = lv_dim
            lm = torch.nn.Linear(in_feature, out_feature, bias = False)
            return lm

        else:
            raise ValueError('Only Linear type for now')    

    def zeta_matrix(self,
        num_levels:int,
        lv_dim:int,
        batch_shape=torch.Size()
    ) -> None:

        if any([i == 1 for i in num_levels]):
            raise ValueError('Categorical variable has only one level!')

        if lv_dim == 1:
            raise RuntimeWarning('1D latent variables are difficult to optimize!')
        
        for level in num_levels:
            if lv_dim > level - 0:
                lv_dim = min(lv_dim, level-1)
                raise RuntimeWarning(
                    'The LV dimension can atmost be num_levels-1. '
                    'Setting it to %s in place of %s' %(level-1,lv_dim)
                )
    
        from itertools import product
        levels = []
        for l in num_levels:
            levels.append(torch.arange(l))

        perm = list(product(*levels))
        perm = torch.tensor(perm, dtype=torch.int64)

        #-------------Mapping-------------------------
        perm_dic = {}
        for i, row in enumerate(perm):
            temp = str(row.tolist())
            if temp not in perm_dic.keys():
                perm_dic[temp] = i

        #-------------One_hot_encoding------------------
        for ii in range(perm.shape[-1]):
            if perm[...,ii].min() != 0:
                perm[...,ii] -= perm[...,ii].min()
            
        perm_one_hot = []
        for i in range(perm.size()[1]):
            perm_one_hot.append( torch.nn.functional.one_hot(perm[:,i]) )

        perm_one_hot = torch.concat(perm_one_hot, axis=1)

        return perm_one_hot, perm, perm_dic

    #################################### transformation functions####################################

    def transform_categorical(self, x:torch.Tensor,perm_dict = [], zeta = []) -> None:
        if x.dim() == 1:
            x = x.reshape(-1,1)
        # categorical should start from 0
        if self.training == False:
            x = setlevels(x)
        if self.encoding_type == 'one-hot':
            index = [perm_dict[str(row.tolist())] for row in x]

            if x.dim() == 1:
                x = x.reshape(len(x),)

            return zeta[index,:]  

    def transform_categorical_random_varible_for_latent(self,x_raw, x:torch.Tensor,perm_dict = [], zeta = []) -> None:
        
        dimm=zeta.size()[0]
        # zeta=torch.normal(0,1,size=[dimm,2])

        self.random_zeta.append(torch.normal(mean=0,std=1,size=[dimm,2]))
        
        if x_raw.requires_grad:
            random_zeta_appy=self.random_zeta[-1]
        else:
            random_zeta_appy=self.random_zeta[0]
        
        if x.dim() == 1:
            x = x.reshape(-1,1)
        # categorical should start from 0
        if self.training == False:
            x = setlevels(x)
        if self.encoding_type == 'one-hot':
            index = [perm_dict[str(row.tolist())] for row in x]

            if x.dim() == 1:
                x = x.reshape(len(x),)

            return random_zeta_appy[index,:]  
        
    def final_transform_categorical_random_varible_for_latent(self,x_raw, x:torch.Tensor,perm_dict = [], zeta = []) -> None:
        
        if x.dim() == 1:
            x = x.reshape(-1,1)
        if self.training == False:
            x = setlevels(x)
        if self.encoding_type == 'one-hot':
            index = [perm_dict[str(row.tolist())] for row in x]

            if x.dim() == 1:
                x = x.reshape(len(x),)

        dimm=zeta.size()[0]

        self.random_zeta.append(torch.normal(mean=0,std=1,size=[dimm,2]))
        
        if x_raw.requires_grad:
            random_zeta_appy=self.random_zeta[-1]
        else:
            if x_raw.size()[0]==400:
                random_zeta_appy=0*self.random_zeta[0]
            else: 
                random_zeta_appy=self.random_zeta[0]
        return random_zeta_appy[index,:]       
    

    def Sobol(self, N=10000):
        """

        
        This function calculates the sensitivity indecies for a function
        Inputs:
            self (GP_Model): The GP model (fitted by GP+) with p inputs and dy outputs.

            N: is the size of the Sobol sequence used for evaluating the indecies. Should be larger than 1e5 for accuracy.

        Outputs:
            S: Matrix of size dy-by-p of main sensitivity indecies. 
            ST: Matrix of size dy-by-p of total sensitivity indecies.
        """
        if N<1e5:
            warnings.warn('Increase N for accuracy!')

        p = self.train_inputs[0].shape[1] 
        dy = 1# self.train_targets.shape[1] 

        self.qual_index
        self.num_levels_per_var
        # sequence = torch.from_numpy( sobol_seq.i4_sobol_generate(2*p, N)).to(**self.tkwargs)
        sequence = torch.from_numpy( sobol_seq.i4_sobol_generate(2*p, N))
        def normalize_sobol_sequence(sequence, train_inputs,p):
            
            temp_1 = sequence[:,p:]
            temp_2 = sequence[:,:p]
            
            # Normalize the sequence
            mins = train_inputs.min(dim=0)[0]
            maxs = train_inputs.max(dim=0)[0]

            sequence_1= mins + (maxs - mins) * temp_1
            sequence_2= mins + (maxs - mins) * temp_2
            # Take care of categotrical inputes
            j=0
            for i in self.qual_index:
                temp_1[:,i]= temp_1[:,i]*(self.num_levels_per_var[j]-1)
                temp_2[:,i]=temp_2[:,i]*(self.num_levels_per_var[j]-1)
                sequence_1[:,i]=temp_1[:,i].round()
                sequence_2[:,i]=temp_2[:,i].round()
                j+=1
                return sequence_1,sequence_2
        A,B = normalize_sobol_sequence(sequence, self.train_inputs[0],p)

        # # A = A * (self.Y.max(axis=0) - self.Y.min(axis=0)) + self.Y.min(axis=0) ## Normalize genrated data

        # B = A[:,p:]
        # A = A[:,:p]
        
        AB = torch.zeros((N,p,p))
        for i in range(p):
            AB[:,:,i] = A
            AB[:,i,i] = B[:,i]
            
        FA = self.predict(A,return_std=False).detach().cpu().numpy().reshape(-1,1)

        FB = self.predict(B,return_std=False).detach().cpu().numpy().reshape(-1,1)

        FAB = np.zeros((N, p, dy))
        for i in range(p):
            temp = self.predict(AB[:, :, i],return_std=False).detach().cpu().numpy()
            FAB[:, i, :] = temp.reshape(-1,1)

        S = np.zeros((p, dy))
        ST = np.zeros((p, dy))

        for i in range(p):
            temp = FAB[:, i, :]
            S[i, :] = np.sum(FB * (temp - FA), axis=0) / N
            ST[i, :] = np.sum((FA - temp)**2, axis=0) / (2 * N)
            
        varY = np.var(np.concatenate([FA,FB]), axis=0)
        S = (S / varY).T
        ST = (ST / varY).T

        return S, ST

######################################################################## Other Classes Used in GP_Pluse  #####################################################
class FFNN(nn.Module):
    def __init__(self, GP_Plus, input_size, num_classes, layers,name):
        super(FFNN, self).__init__()
        self.hidden_num = len(layers)
        if self.hidden_num > 0:
            self.fci = nn.Linear(input_size, layers[0], bias=False) 
            GP_Plus.register_parameter(str(name)+'fci', self.fci.weight)
            GP_Plus.register_prior(name = 'latent_prior_fci', prior=gpytorch.priors.NormalPrior(0.,1), param_or_closure=str(name)+'fci')

            for i in range(1,self.hidden_num):
                setattr(self, 'h' + str(i), nn.Linear(layers[i-1], layers[i], bias=False))
                GP_Plus.register_parameter(str(name)+'h'+str(i), getattr(self, 'h' + str(i)).weight )
                GP_Plus.register_prior(name = 'latent_prior'+str(i), prior=gpytorch.priors.NormalPrior(0.,1), param_or_closure=str(name)+'h'+str(i))
            
            self.fce = nn.Linear(layers[-1], num_classes, bias= False)
            GP_Plus.register_parameter(str(name)+'fce', self.fce.weight)
            GP_Plus.register_prior(name = 'latent_prior_fce', prior=gpytorch.priors.NormalPrior(0.,1), param_or_closure=str(name)+'fce')
        else:
            self.fci = Linear_MAP(input_size, num_classes, bias = False)
            GP_Plus.register_parameter(name, self.fci.weight)
            GP_Plus.register_prior(name = 'latent_prior_'+name, prior=gpytorch.priors.NormalPrior(0,1) , param_or_closure=name)

    def forward(self, x, transform = lambda x: x):
        """
        x here is the mnist images and we run it through fc1, fc2 that we created above.
        we also add a ReLU activation function in between and for that (since it has no parameters)
        I recommend using nn.functional (F)
        """
        if self.hidden_num > 0:
            x = torch.tanh(self.fci(x))
            for i in range(1,self.hidden_num):
                #x = F.relu(self.h(x))
                x = torch.tanh( getattr(self, 'h' + str(i))(x) )
            
            x = self.fce(x)
        else:
            #self.fci.weight.data = torch.sinh(self.fci.weight.data)
            x = self.fci(x, transform)
        return x
    
############################################
class FFNN_as_Mean(gpytorch.Module):
    def __init__(self, GP_Plus, input_size, num_classes, layers,name):
        super(FFNN_as_Mean, self).__init__()
        self.dropout = nn.Dropout(0.2)
        self.hidden_num = len(layers)
        if self.hidden_num > 0:
            self.fci = Linear_class(input_size, layers[0], bias=True, name='fci') 
            for i in range(1,self.hidden_num):
                setattr(self, 'h' + str(i), Linear_class(layers[i-1], layers[i], bias=True,name='h' + str(i)))
            
            self.fce = Linear_class(layers[-1], num_classes, bias=True,name='fce')
        else:
            self.fci = Linear_class(input_size, num_classes, bias=True, dtype = torch.float32,name='fci') #Linear_MAP(input_size, num_classes, bias = True)

    def forward(self, x, transform = lambda x: x):

        if self.hidden_num > 0:
            
            x = torch.tanh(self.fci(x))
            # x = self.dropout(x)
            # x = self.fci(x)
            for i in range(1,self.hidden_num):
                # x = torch.sigmoid( getattr(self, 'h' + str(i))(x) )
                # x =  getattr(self, 'h' + str(i))(x) 
                x = torch.tanh( getattr(self, 'h' + str(i))(x) )
                x = self.dropout(x)
            x = self.fce(x)
        else:
            #self.fci.weight.data = torch.sinh(self.fci.weight.data)
            x = self.fci(x)

        return x
    
############################################
class Linear_VAE(Mean):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor
    def __init__(self, in_features: int, out_features: int, bias: bool = True, name=None,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Linear_VAE, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.name=str(name)
        self.register_parameter(name=str(self.name)+'weight',  parameter= Parameter(torch.empty((out_features, in_features), **factory_kwargs)))
        self.register_prior(name =str(self.name)+ 'prior_m_weight_fci', prior=gpytorch.priors.NormalPrior(0.,.2), param_or_closure=str(self.name)+'weight')

        if bias:

            self.register_parameter(name=str(self.name)+'bias',  parameter=Parameter(torch.empty(out_features, **factory_kwargs)))
            self.register_prior(name= str(self.name)+'prior_m_bias_fci', prior=gpytorch.priors.NormalPrior(0.,.05), param_or_closure=str(self.name)+'bias')
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:                                             

        init.kaiming_uniform_( getattr(self,str(self.name)+'weight'), a=math.sqrt(5))
        if getattr(self,str(self.name)+'bias') is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(getattr(self,str(self.name)+'weight'))
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(getattr(self,str(self.name)+'bias'), -bound, bound)

    def forward(self, input) -> Tensor:

        return F.linear(input.double(), getattr(self,str(self.name)+'weight').double(), getattr(self,str(self.name)+'bias').double())      ### Forced to Add .double() for NN in mean function

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )

############################################
class LinearVariational(Mean):
    def __init__(self, batch_shape=torch.Size(),mean_prior=1,std_prior=1):
        super().__init__()
        self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape,1,1)))
        self.register_prior(name = 'weights_prior', prior=gpytorch.priors.NormalPrior(mean_prior,1), param_or_closure='weights')
        self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1,1)))
        self.register_prior(name = 'bias_prior', prior=gpytorch.priors.NormalPrior(std_prior,1), param_or_closure='bias')

    def forward(self, epsilon):
        res = self.weights + (torch.abs(self.bias)) *epsilon
        return res

##########################################
class LinearMean_with_prior(Mean):
    def __init__(self, input_size, batch_shape=torch.Size(), bias=True):
        super().__init__()
        self.register_parameter(name="weights", parameter=torch.nn.Parameter(torch.randn(*batch_shape, input_size, 1)))
        self.register_prior(name = 'weights_prior', prior=gpytorch.priors.NormalPrior(0.,.5), param_or_closure='weights')
        if bias:
            self.register_parameter(name="bias", parameter=torch.nn.Parameter(torch.randn(*batch_shape, 1)))
            self.register_prior(name = 'bias_prior', prior=gpytorch.priors.NormalPrior(0.,.5), param_or_closure='bias')
        else:
            self.bias = None
    def forward(self, x):
        res = x.matmul(self.weights).squeeze(-1)
        if self.bias is not None:
            res = res + self.bias
        return res
    
############################################
class Variational_Encoder(gpytorch.Module):
    def __init__(self, GP_Plus, input_size, num_classes, layers,name):
        super(Variational_Encoder, self).__init__() 
        self.dropout = nn.Dropout(0.2)
        self.hidden_num = len(layers)
        if self.hidden_num > 0:
            self.fci = Linear_VAE(input_size, layers[0], bias=True, name='fci') 
            for i in range(1,self.hidden_num):
                #self.h = nn.Linear(neuran[i-1], neuran[i])
                setattr(self, 'h' + str(i), Linear_VAE(layers[i-1], layers[i], bias=True,name='h' + str(i)))
            self.fce = Linear_VAE(layers[-1], num_classes, bias=True,name='fce')
        else:
            self.fci = Linear_VAE(input_size, num_classes, bias=True, dtype = torch.float32,name='fci') 

    def forward(self, x,epsilon, transform = lambda x: x):
        if self.hidden_num > 0:
            # x = torch.tanh(self.fci(x))
            x =self.fci(x)
            for i in range(1,self.hidden_num):
                # x = F.relu(self.h(x))
                x = torch.tanh( getattr(self, 'h' + str(i))(x) )
                # x = self.dropout(x)
            output = self.fce(x)

            epsilon_1, epsilon_2 = epsilon[:, 0:1], epsilon[:, 1:2]
            L22, L21, L11, Mu_2, Mu_1 = output[:, 0:1], output[:, 1:2], output[:, 2:3], output[:, 3:4], output[:, 4:5]
            # Optimized calculation using matrix operations
            X_1 = Mu_1 + 1*torch.abs(L11) * epsilon_1
            X_2 = Mu_2 + 1*L21 * epsilon_1 + 1*torch.abs(L22) * epsilon_2
            x = torch.cat((X_1,X_2),1)

        else:  
            output = self.fci(x)
            epsilon_1, epsilon_2 = epsilon[:, 0:1], epsilon[:, 1:2]
            L22, L21, L11, Mu_2, Mu_1 = output[:, 0:1], output[:, 1:2], output[:, 2:3], output[:, 3:4], output[:, 4:5]
            # calculation using matrix operations
            X_1 = Mu_1 + 1*torch.abs(L11) * epsilon_1
            X_2 = Mu_2 + 1*L21 * epsilon_1 + 1*torch.abs(L22) * epsilon_2
            x = torch.cat((X_1,X_2),1)
        return x 
    
############################################
class Linear_class(Mean):
    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True, name=None,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(Linear_class, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.name=str(name)
        self.register_parameter(name=str(self.name)+'weight',  parameter= Parameter(torch.empty((out_features, in_features), **factory_kwargs)))
        self.register_prior(name =str(self.name)+ 'prior_m_weight_fci', prior=gpytorch.priors.NormalPrior(0.,0.01), param_or_closure=str(self.name)+'weight')
        if bias:
            self.register_parameter(name=str(self.name)+'bias',  parameter=Parameter(torch.empty(out_features, **factory_kwargs)))
            self.register_prior(name= str(self.name)+'prior_m_bias_fci', prior=gpytorch.priors.NormalPrior(0.,.001), param_or_closure=str(self.name)+'bias')
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:                                         
        init.kaiming_uniform_( getattr(self,str(self.name)+'weight'), a=math.sqrt(5))
        if getattr(self,str(self.name)+'bias') is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(getattr(self,str(self.name)+'weight'))
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(getattr(self,str(self.name)+'bias'), -bound, bound)

    def forward(self, input) -> Tensor:

        # return F.linear(input, getattr(self,str(self.name)+'weight').double(), getattr(self,str(self.name)+'bias').double())      ### Forced to Add .double() for NN in mean function
        return F.linear(input.reshape(1,-1), getattr(self,str(self.name)+'weight'), getattr(self,str(self.name)+'bias'))      ### Forced to Add .double() for NN in mean function

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None
        )
    
############################################
class Linear_MAP(nn.Linear):
    def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
        super().__init__(in_features, out_features, bias, device, dtype)
        

    def forward(self, input, transform = lambda x: x):
        return F.linear(input,transform(self.weight), self.bias)