# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/models.anova.ipynb (unless otherwise specified).

__all__ = ['ANOVALandingModel', 'ANOVAUnitModel', 'ANOVAModel']

# Cell
from ..mvc import Model
from .error_bars import ErrorBarsModel
from ..rpy import PyFunction
from powerpy.anova import ANOVA_Model

# Cell
from traitlets import Int, List, Float, Bool, observe
import copy
import numpy as np
from traittypes import Array

# Cell
class ANOVALandingModel(Model):

    ''' var names from powerANCOVA_multiway.R'''
    num_factors = Int()
    num_levels = List() # Number of groups being tested. Should be a vector if more than one factor
    labels_factors = List()
    labels_levels = List()
    num_covariates = Int(1) # that 1 is there for a reason
    labels_covariates = List(allow_none=True)
    factorTypes = List()

    factor_types = List()
    w_covariates = Array(dtype=np.dtype(float))
    covariates = Array(allow_none=True, dtype=np.dtype(float))
    numLevels = List()
    factorLabels = List()
    levelLabels = List()
    covariateLabels = List()
    r = Float()
    powerCovariates = Bool()

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def setLanding(self):
        with self.hold_trait_notifications():
            self.numLevels = [2] * 5
            self.num_levels = [2] * 5
            self.factorLabels = ['A', 'B', 'C', 'D', 'E']
            self.levelLabels = [[str(i) for i in range(1,7)] for j in range(5)]
            self.labels_levels = copy.deepcopy(self.levelLabels)
            self.factor_types = ['b']*5
            self.covariateLabels = ['CV'+str(i) for i in range(1, 11)]
            self.w_covariates = np.full(10, 0.0)
            self.covariates = np.full(10, 0.0)
            self.r = 0.0
            self.powerCovariates = True
            self.num_covariates = 0
        for i in range(5):
            self.changeNumLevels(i, 2)
        self.num_factors = 1


    @observe('num_factors')
    def observe_num_factors(self, change):
        self.labels_factors = self.factorLabels[:change['new']]
        self.factorTypes = self.factor_types[:change['new']]
        self.num_levels = self.numLevels[:change['new']]
        self.labels_levels = [self.levelLabels[i][:self.num_levels[i]] for i in range(change['new'])]

    @observe('num_covariates')
    def changeNumCovariates(self, change):
        if change['new'] == 0:
            self.labels_covariates = None
            self.covariates = None
        else:
            self.labels_covariates = self.covariateLabels[:change['new']]
            self.covariates = self.w_covariates[:change['new']]

    def changeCovariates(self, i, value):
        self.w_covariates[i] = value
        if self.num_covariates > i:
            self.covariates[i] = value

    def changeNumLevels(self, i, num):
        self.numLevels[i] = num
        if self.num_factors > i:
            self.num_levels[i] = num
            self.labels_levels[i] = self.levelLabels[i][:num]

    def changeLevelLabel(self, i, j, value):
        self.levelLabels[i][j] = value
        if self.num_factors > i:
            if self.num_levels[i] > j:
                self.labels_levels[i][j] = value

    def changeCovariateLabel(self, i, value):
        self.covariateLabels[i] = value
        if self.num_covariates > i:
            self.labels_covariates[i] = value

    def changeFactorLabel(self, i, value):
        self.factorLabels[i] = value
        if self.num_factors > i:
            self.labels_factors[i] = value

    def changeFactorType(self, i , value):
        self.factor_types[i] = value
        if self.num_factors > i:
            self.factorTypes[i] = value

# Cell
class ANOVAUnitModel(ErrorBarsModel):

    Ns = List()

    def flatIndex(self, i, j):
        return i * self.nBars + j

    def setEstimates(self):
        super().setEstimates()
        for j in range(self.nGroups):
            for k in range(self.nBars):
                self.setHeight(self.flatIndex(j, k), 1 + 0.1 * k)
                self.setError(self.flatIndex(j, k), 1)
                self.Ns.append(10)

    def getMeans(self):
        ret = []
        for j in range(self.nGroups):
            start = self.nBars*j
            ret.append(self.heights[start:start+self.nBars])
        return ret

    def getSDs(self):
        ret = []
        for j in range(self.nGroups):
            start = self.nBars*j
            ret.append(self.errors[start:start+self.nBars])
        return ret

    def getNs(self):
        ret = []
        for j in range(self.nGroups):
            start = self.nBars*j
            ret.append(self.Ns[start:start+self.nBars])
        return ret

# Cell
class ANOVAModel(ANOVALandingModel):

    nBars = Int()
    nGroups = Int()
    nUnits = Int()
    labels = List()
    means = List()
    SDs = List()
    N = List()
    sims = Int()

    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.statmodel = ANOVA_Model()
        self.function = PyFunction(self.statmodel, ['pVal'], ['N', 'num_levels', 'means', 'num_factors', 'SDs',
                                                   'factorTypes', 'corr'])

    def run(self, **kwargs):
        if self.nUnits > 1:
            self.means = [unit.getMeans() for unit in self.units]
            self.SDs = [unit.getSDs() for unit in self.units]
            self.N = [unit.getNs() for unit in self.units]
        elif self.nUnits == 1:
            self.means = self.units[0].getMeans()
            self.SDs = self.units[0].getSDs()
            self.N = self.units[0].getNs()
        return self.function.run(**self.getTraits())

    def flatIndex(self, i, j):
        return i * self.nBars + j

    def setEstimates(self):

        self.labels = []
        for i, f in enumerate(self.labels_factors):
            self.labels.append([f + ':' + l for l in self.labels_levels[i]])

        self.sims = 1
        self.nBars = self.num_levels[-1]
        self.nGroups = 1
        self.units = []
        if self.num_factors >= 2:
            self.nGroups = self.num_levels[-2]
        elif self.num_factors < 3:
            self.units = [ANOVAUnitModel(nBars=self.nBars, nGroups=self.nGroups)]
            self.units[0].setEstimates()
            self.units[0].title = None
            self.units[0].labels = self.labels[0]

        if self.num_factors == 3:
            self.nUnits = len(self.labels[0])
            for a in self.labels[0]:
                unit = ANOVAUnitModel(nBars=self.nBars, nGroups=self.nGroups)
                unit.setEstimates()
                unit.title = a
                unit.labels = self.labels[1]
                self.units.append(unit)
        elif self.num_factors == 4:
            self.nUnits = len(self.labels[0]) * len(self.labels[1])
            for a in self.labels[0]:
                for b in self.labels[1]:
                    unit = ANOVAUnitModel(nBars=self.nBars, nGroups=self.nGroups)
                    unit.setEstimates()
                    unit.title = a + ' X ' + b
                    unit.labels = self.labels[2]
                    self.units.append(unit)
        elif self.num_factors == 5:
            self.nUnits = len(self.labels[0]) * len(self.labels[1]) * len(self.labels[2])
            for a in self.labels[0]:
                for b in self.labels[1]:
                    for c in self.labels[2]:
                        unit = ANOVAUnitModel(nBars=self.nBars, nGroups=self.nGroups)
                        unit.setEstimates()
                        unit.title = a + ' X ' + b + ' X ' + c
                        unit.labels = self.labels[3]
                        self.units.append(unit)

        self.nUnits = len(self.units)