#!/usr/bin/env python
# coding: utf-8

# In[1]:


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import gym
from gym import error, spaces, utils
from gym.utils import seeding
#import statsmodels.api as sm
#import statsmodels.formula.api as smf
import pandas.util.testing as tm
from sklearn.linear_model import LogisticRegression
from scipy.stats import truncnorm
import math

#Gym environment

class DiscreteEnv(gym.Env):
  def __init__(self):
    self.size = 2000
    #get initial values for theta's
    #fit logit model to data
    self.df = pd.DataFrame(dict(
            Xs=np.random.normal(0,1,size=self.size),
            Xa=np.random.normal(0,1,size=self.size),
            Y=np.random.binomial(1, 0.05, self.size)))
    self.model = LogisticRegression().fit(self.df[["Xs", "Xa"]], np.ravel(self.df[["Y"]].astype(int)))

    #extract theta parameters from the fitted logistic

    self.thetas = np.array([self.model.coef_[0,0] , self.model.coef_[0,1], self.model.intercept_[0]]) #thetas[1] coef for Xs, thetas[2] coef for Xa

    #set range for obs space
    self.minXa1 = pd.to_numeric(min(self.df[["Xa"]].values.flatten()))
    self.minXs1 = pd.to_numeric(min(self.df[["Xs"]].values.flatten()))
    
    self.maxXa1 = pd.to_numeric(max(self.df[["Xa"]].values.flatten()))
    self.maxXs1 = pd.to_numeric(max(self.df[["Xs"]].values.flatten()))
    
    self.min_Xas=np.array([np.float32(self.minXa1), np.float32(self.minXs1)])
    self.max_Xas=np.array([np.float32(self.maxXa1), np.float32(self.maxXs1)])
    
    
    #set ACTION SPACE
    self.row = np.linspace(-2, 2, 10).tolist()
    self.action_space = spaces.Tuple((
      spaces.Discrete(10),
      spaces.Discrete(10),
      spaces.Discrete(10)))
    
    #set OBSERVATION SPACE
    #it is made of values for Xa, Xs for each observation
    self.observation_space = spaces.Box(low=self.min_Xas, 
                                   high=self.max_Xas, 
                                   dtype=np.float32)
    
    #set an initial state
    #the step def will update self.state according to some value
    self.state=None    #self.df.sample(n=1, random_state=1).values.reshape(3,)

    #introduce some (short) length (time steps)
    self.horizon=200 
    
    

  def seed(self, seed=None):
    self.np_random, seed = seeding.np_random(seed)
    return [seed]    


  
  #take an action with the environment
  #it returns the next observation, the immediate reward, whether the episode is over (done) and additional information    
  #"action" argument is one value in the range of the action space (logit transform)
  def step(self, action): 
    
    
    #recover action value using indexes from onerow
    action0, action1, action2 = action[0], action[1], action[2]
    action_0, action_1, action_2 = self.row[action0], self.row[action1], self.row[action2]
    actions = [action_0, action_1, action_2]

    patients1= np.hstack([np.ones((self.size, 1)), self.patients]) #shape (50, 3), 1st column of 1's, 2nd columns Xa, 3rd column Xs
    rho1 = (1/(1+np.exp(-(np.matmul(patients1, actions[:, None])))))  #prob of Y=1  # (sizex3) x (3x1) = (size, 1)
    rho1 = rho1.squeeze() # shape: size, individual risk
    Xa = patients1[:, 1] # shape: size
    g2 = ((Xa) + 0.5*((Xa)+np.sqrt(1+(Xa)**2)))*(1-rho1**2) + ((Xa) - 0.5*((Xa)+np.sqrt(1+(Xa)**2)))*(rho1**2)
    Xa = g2 # 50
    
    #calculate reward
    #get new coefficients given the covariate Xa has changed by running logit 
    Y = np.random.binomial(1, 0.05, (self.size, 1))
    patients2 = np.hstack([Y, np.reshape(Xa, (self.size,1)), np.reshape(patients1[:, 2], (self.size,1))]) 
    #run logit model to get coefficients, because their risk has changed (or use acitons to get risk using just new Xa??)
    model2 = LogisticRegression().fit(patients2[:, 1:3], np.ravel(patients2[:, 0].astype(int)))
    thetas2 = np.array([model2.intercept_[0], model2.coef_[0,0] , model2.coef_[0,1]]) #thetas2[0]: intercept; thetas2[1]: coef for Xa, thetas2[2] coef for Xs
    
    patients3 = np.hstack([np.ones((self.size, 1)), patients2[:, 1:3]])
    rho3 = (1/(1+np.exp(-(np.matmul(patients3, thetas2[:, None])))))  #prob of Y=1 # (sizex3) x (3x1) = (size, 1)
    rho3 = rho3.squeeze() # shape: size, individual risk
    
    #transform rho3 in a list to print individual risk (and not just mean risk of the hospitak's patient)
    rho3_list = rho3.tolist()
    self.mean_r = np.mean(rho3)
    
    #check if horizon is over, otherwise keep on going
    if self.horizon <= 0:
      done = True
    else:
      done = False
    #set the reward equal to the mean hospitalization rate
    reward = self.mean_r 
        
    self.state = self.patients[self.random_indices, :].reshape(2,) #not sure if with or without reshape
    
    #without action - simple logit on inital (non-intervened) dataset with Y, old Xa, Xs
    patients4 = np.hstack([Y, self.patients]) #shape (50, 3), 1st column of Y, 2nd columns Xa, 3rd column Xs
    model4 = LogisticRegression().fit(patients4[:, 1:3], np.ravel(patients4[:, 0].astype(int)))
    thetas4 = np.array([model4.intercept_[0], model4.coef_[0,0] , model4.coef_[0,1]]) #thetas4[0]: intercept; thetas4[1]: coef for Xa, thetas4[2] coef for Xs
    rho4 = (1/(1+np.exp(-(np.matmul(patients1, thetas4[:, None])))))  #prob of Y=1  #(sizex3) x (3x1) = (size, 1) #use patients1 because it's fine, it has self.patients
    rho4 = rho4.squeeze() # shape: size, individual risk
    rho4_list = rho4.tolist()
    reward4 = np.mean(rho4)
 
    
    #reduce the horizon
    self.horizon -= 1
    
    #set placeholder for infos
    info ={}    
    return self.state, reward,  reward4, rho3, rho4, done, {}    
    
    
    
#reset state and horizon    
  def reset(self):
    self.horizon = 200
        
    #define dataset of patients with actionable covariate Xa and non-actionable covariate Xs
    self.patients = np.random.normal(0,1,size=(self.size,2)) #shape (size, 2), 1st columns is Xa, second is Xs
    
       
    self.random_indices = np.random.choice(self.size, size=1, replace=False)
    self.state = self.patients[self.random_indices, :].reshape(2,) #not sure if with or without reshape

    return self.state



# In[ ]:




