# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/views.anova.ipynb (unless otherwise specified).

__all__ = ['ANOVALandingView', 'ANOVALanding', 'ANOVAUnit', 'ANOVAEstimatesView', 'ANOVAEstimates']

# Cell
from ..models.error_bars import ErrorBarsModel
from .error_bars import ErrorBars
from ..items import ItemList, ItemArray
from ..widget_containers import Box, VBox, Tab, HBox
from ..widgets import Dropdown, Text, BoundedIntText, BoundedFloatText, Checkbox, AlignedLabel,Textarea, Button, \
RadioButtons, GridBox, Grid, FloatText, IntText

# Cell
import ipywidgets as ipyw
from traitlets import observe
import traitlets
from IPython.display import display, Javascript

# Cell
class ANOVALandingView(Box):

    def __init__(self, **kwargs):
        layout = {'flex_flow': 'row wrap', 'justify_content': 'space-between'}
        super().__init__(layout = ipyw.Layout(**layout), **kwargs)

        boxLayout = {'display': 'flex', 'flex_flow': 'column nowrap', 'align_items': 'flex-start'}

        style={'description_width': 'initial'}

        factorsOptions = [i for i in range(1, 6)]
        self.num_factors = Dropdown(value=None, description='Factors', options=factorsOptions, layout=ipyw.Layout(width='100px', margin='10px'), style=style)

        self.leftBox = Box(children=(self.num_factors,),layout=ipyw.Layout(**boxLayout))
        self.levelsBox = Box(layout=ipyw.Layout(**boxLayout, margin='10px'))

        self.factorLabels = ItemList([Text(value=None, description='Factor', layout=ipyw.Layout(width='100px'), style=style) for i in range(5)])
        self.numLevels = ItemList([Dropdown(value=None, description='Levels',style=style,options=[i for i in range(2,7)], layout=ipyw.Layout(width='95px')) for i in range(5)])
        self.levelLabels = ItemArray(5, 6, children=[Text(value=None, layout=ipyw.Layout(width='62px')) for i in range(30)])
        self.factor_types = ItemList([RadioButtons(value=None, options=(('Between', 'b'), ('Within','w')),layout=ipyw.Layout(width='100px')) for i in range(5)])
        self.factorBoxes = []

        for i in range(0, 5):
            self.numLevels[i].observe(self.numLevelsChanged, names='value')
            self.factor_types[i].observe(self.factor_typesChanged, names='value')
            box = Box(children=(self.factorLabels[i], self.numLevels[i]) + tuple(self.levelLabels[i, :]) +  (self.factor_types[i],), layout=ipyw.Layout(display='align_items'))
            self.factorBoxes.append(box)
            self.levelsBox.children += (box, )

        self.r = BoundedIntText(value=None, description=r'$r_{within-subject factors}$', min=0, max=10, style=style, layout=ipyw.Layout(width='170px', margin='10px'))
        self.num_covariates = BoundedIntText(value=None, description='Covariates', min=0, max=10, style=style, layout=ipyw.Layout(width='135px', margin='10px'))
        self.powerCovariates = Checkbox(value=None, description='Power Covariates', style=style, margin='10px')
        self.leftBox.children += (self.levelsBox, self.r, self.num_covariates, self.powerCovariates)

        label = AlignedLabel(value=r'Adjusted $\beta$')
        self.covariatesBox = Box(children=(label,), layout=ipyw.Layout(**boxLayout))
        covariateLabels = [Text(value=None, description='Covariate', layout=ipyw.Layout(width='135px'), style=style) for i in range(0, 10)]
        self.covariateLabels = ItemList(covariateLabels)
        covariates = [BoundedFloatText(value=None, step=0.01, layout=ipyw.Layout(width='60px'), style=style) for i in range(0, 10)]
        self.covariates = ItemList(covariates)
        self.covariateBoxes = []

        #self.covariateBoxes = []
        for i in range(0, 10):
            box = Box(children=(self.covariateLabels[i], self.covariates[i]), layout=ipyw.Layout(display='align_items'))
            self.covariateBoxes.append(box)
            self.covariatesBox.children += (box, )

        self.children = (self.leftBox, self.covariatesBox)

        ''' observes '''
        self.num_factors.observe(self.num_factorsChanged, names='value')
        self.powerCovariates.observe(self.powerCovariatesChanged, names='value')
        self.num_covariates.observe(self.num_covariatesChanged, names='value')

    def num_factorsChanged(self, change):
        for i in range(1, change['new']):
            self.factorBoxes[i].visible()
        for i in range(change['new'], 5):
            self.factorBoxes[i].invisible()

    def powerCovariatesChanged(self, change):
        if change['new'] == False:
            display(Javascript('alert("If you do not power your covariates, you must enter adjusted means for each factor.");'))

    def num_covariatesChanged(self, change):
        if change['new'] == 0:
            self.powerCovariates.invisible()
            self.covariatesBox.hide()
        else:
            self.powerCovariates.visible()
            self.covariatesBox.show()
            for i in range(0, change['new']):
                self.covariateBoxes[i].visible()
            for i in range(change['new'], 10):
                self.covariateBoxes[i].invisible()

    def numLevelsChanged(self, change):
        for i in range(2, change['new']):
            self.levelLabels[change['index'],i].visible()
        for i in range(change['new'], 6):
            self.levelLabels[change['index'],i].invisible()

    def factor_typesChanged(self, change):
        if change['new'] == 'w':
            self.r.visible()
        else:
            if 'w' not in self.factor_types.get_list('value')[:self.num_factors.value]:
                self.r.invisible()

# Cell
class ANOVALanding(ANOVALandingView):

    def __init__(self, model, **kwargs):
        super().__init__()

        self.model = model

        self.numLevels.observe(self.observeNumLevels, names='value',type='child_change')
        self.levelLabels.observe(self.observeLevelLabels, names='value',type='child_change')
        self.factorLabels.observe(self.observeFactorLabels, names='value',type='child_change')
        self.covariateLabels.observe(self.observeCovariateLabels, names='value',type='child_change')
        self.factor_types.observe(self.observeFactorTypes, names='value',type='child_change')
        self.covariates.observe(self.observeCovariates, names='value', type='child_change')

        traitlets.link((self.model, 'num_factors'), (self.num_factors, 'value'))
        traitlets.link((self.model, 'r'), (self.r, 'value'))
        traitlets.link((self.model, 'num_covariates'), (self.num_covariates, 'value'))
        traitlets.link((self.model, 'powerCovariates'), (self.powerCovariates, 'value'))

        self.setLanding()

        # force observe method even if the value hasn't changed
        self.num_covariates._notify_trait('value', None, self.model.num_covariates)
        self.powerCovariates._notify_trait('value', None, self.model.powerCovariates)

    def setLanding(self):
        self.levelLabels.set_array(self.model.levelLabels)
        self.factor_types.set_list(self.model.factor_types)
        self.covariateLabels.set_list(self.model.covariateLabels)
        self.factorLabels.set_list(self.model.factorLabels)
        self.covariates.set_list(self.model.w_covariates)
        self.numLevels.set_list(self.model.numLevels)

        self.num_factors.value = self.model.num_factors
        self.r.value = self.model.r
        self.num_covariates.value = self.model.num_covariates
        self.powerCovariates.value = self.model.powerCovariates


    def observeCovariates(self, change):
        self.model.changeCovariates(change['index'], change['new'])

    def observeNumLevels(self, change):
        self.model.changeNumLevels(change['index'], change['new'])

    def observeLevelLabels(self, change):
        i, j = change['index']
        self.model.changeLevelLabel(i, j, change['new'])

    def observeFactorLabels(self, change):
        self.model.changeFactorLabel(change['index'], change['new'])

    def observeFactorTypes(self, change):
        self.model.changeFactorType(change['index'], change['new'])

    def observeCovariateLabels(self, change):
        self.model.changeCovariateLabel(change['index'], change['new'])

# Cell
class ANOVAUnit(Box):

    def __init__(self, model, style='single', **kwargs):
        super().__init__(**kwargs)
        assert style =='single' or style =='multi', 'Style choices are "single" or "multi"'
        self.model = model

        self.M = ItemList([FloatText(step=0.1,layout=ipyw.Layout(width='50px')) for i in range(self.model.total)])
        self.SD = ItemList([FloatText(step=0.1,layout=ipyw.Layout(width='50px')) for i in range(self.model.total)])
        self.N = ItemList([IntText(layout=ipyw.Layout(width='50px')) for i in range(self.model.total)])

        self.grids = []
        if self.model.nGroups == 1:
            for i in range(self.model.nBars):
                grid = Grid(3, 2, layout=ipyw.Layout(justify_items='flex-end'), name=self.model.labels[i])
                grid.add_widget_list([self.M[i]], 0, 1)
                grid.add_widget_list([self.SD[i]], 1, 1)
                grid.add_widget_list([self.N[i]], 2, 1)
                grid[0, 0] = AlignedLabel(value='M')
                grid[1, 0] = AlignedLabel(value = 'SD')
                grid[2, 0] = AlignedLabel(value = 'N')
                self.grids.append(grid)

        else:
            for i in range(self.model.nGroups):
                grid = Grid(3, self.model.nBars + 1, layout=ipyw.Layout(justify_items='flex-end'), name=self.model.labels[i])
                start = i * self.model.nBars
                end = start + self.model.nBars
                grid.add_widget_list(self.M[start:end], 0, 1)
                grid.add_widget_list(self.SD[start:end], 1, 1)
                grid.add_widget_list(self.N[start:end], 2, 1)
                grid[0, 0] = AlignedLabel(value='M')
                grid[1, 0] = AlignedLabel(value = 'SD')
                grid[2, 0] = AlignedLabel(value = 'N')
                self.grids.append(grid)

        self.errorBars = ErrorBars(self.model, layout=ipyw.Layout(display = 'inline-flex'))

        self.layout.flex_flow = 'column nowrap'
        self.layout.width= 'min-content'
        # tabbed or no?
        if style == 'multi' or self.model.total > 6:
            multi={'width': '100%', 'padding': '0 66px 0 80px'}
            self.tab = Tab(children = self.grids, layout=ipyw.Layout(**multi))
            self.tabBox = Box(children=(self.tab,))
            self.children = (self.errorBars, self.tabBox)
            if style == 'multi':
                self.tabBox.layout.padding='38px 0 38px 0'
                self.layout.flex_flow = 'row nowrap'
                self.layout.display = 'flex'
            else:
                self.tab

        else:
            layout = {'justify_content': 'space-around', 'padding': '0 66px 0 42px'}
            self.box = HBox(children=self.grids, layout=ipyw.Layout(**layout))
            #self.layout.flex_flow = 'col nowrap'
            self.children = (self.errorBars, self.box)


         # observe grid
        self.M.observe(self.observeM, names='value', type='child_change')
        self.SD.observe(self.observeSD, names='value', type='child_change')
        self.N.observe(self.observeN, names='value', type='child_change')
        self.errorBars.observe(self.observeError, names='errors', type='child_change')
        self.errorBars.observe(self.observeHeight, names='heights', type='child_change')

        self.setEstimates()

    def flatIndex(self, i, j):
        ret = i * self.model.nBars + j
        return ret

    def gridIndex(self, i):
        grid = i % self.model.nGroups
        index = grid + 1
        return grid, index

    def setEstimates(self):
        for i in range(self.model.nGroups):
            for j in range(self.model.nBars):
                flat = self.flatIndex(i,j)
                self.M[flat].value = self.model.heights[flat]
                self.SD[flat].value = self.model.errors[flat]
                self.N[flat].value = self.model.Ns[flat]

    def observeHeight(self, change):
        grid, index = self.gridIndex(change['index'])
        self.grids[grid][0, index] = change['new']

    def observeM(self, change):
        self.errorBars.setHeight(change['index'], round(change['new'], 1))

    def observeSD(self, change):
        self.errorBars.setError(change['index'], round(change['new'], 1))

    def observeN(self, change):
        self.model.Ns[change['index']] = change['new']

    def observeHeight(self, change):
        self.M[change['index']].value = round(change['new'], 1)

    def observeError(self, change):
        self.SD[change['index']].value = round(change['new'], 1)

# Cell
class ANOVAEstimatesView(VBox):

    def __init__(self, nBars, nGroups, nGraphs, **kwargs):
        lo = {'justify_content': 'space-between', 'flex_flow': 'row wrap', 'width': '100%'}
        super().__init__(layout = ipyw.Layout(**lo), **kwargs)
        style={'description_width': '20px'}

        ''' left '''
        self.left = VBox(name='left', layout=ipyw.Layout(min_width='250px', width='30%'))
        self.textarea = Textarea(layout=ipyw.Layout(width="94%", height="150px"), disabled=True)
        self.effectButton = Button(description='Show effect sizes', button_style='success',
            layout=ipyw.Layout(min_width='220px', max_width='220px',margin = '2px 2px 2px 2px'))
        self.left.children = (self.textarea, self.effectButton)

        ''' right '''
        self.right = VBox(name='right', layout=ipyw.Layout(width='auto'))

        self.children =  (self.right, )#(self.left, self.right) <--- uncomment to include effect size

# Cell
class ANOVAEstimates(ANOVAEstimatesView):

    def __init__(self, model, **kwargs):
        super().__init__(nBars=model.nBars, nGroups=model.nGroups, nGraphs=model.nUnits, **kwargs)
        self.model = model

        if model.nUnits == 1:
            style = 'single'
        else:
            style = 'multi'

        self.units = []
        for i in range(self.model.nUnits):
            self.units.append(ANOVAUnit(self.model.units[i], style=style))
            self.right.children += (self.units[i], )

    def run(self):
        return self.model.run()