# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/custom_widgets.ratio.ipynb (unless otherwise specified).

__all__ = ['ColorIntWidget', 'ColorIntArray', 'RatioLandingModel', 'Ratios', 'RatioArray', 'RatioLandingView',
           'RatioEstimatesModel', 'RatioEstimatesView']

# Cell
import ipywidgets as ipyw
import numpy as np
import traitlets
from traitlets import Unicode, Bool, validate, TraitError, observe, HasTraits
from fastcore.test import test_eq
from traittypes import Array
import copy

# Cell
from ..view.widgets import Button, Grid, Label, FloatText, IntText, Text, AlignedLabel, IntLabel, Dropdown
from ..view.widget_collections import WidgetList, VBox, HBox, Box, WidgetContainer
from ..model.model import Model
from ..model.model_collections import CollectionList, CollectionArray
import superpower_gui.rpy as rpy
from ..rpy import Int, IntArray, FloatArray, Float, DTYPES

# Cell
class ColorIntWidget(Button):

    value = Int(0)
    color = Unicode('none')

    def __init__(self, width='auto', height='auto', **kwargs):
        super().__init__(**kwargs)
        self.description = str(self.value)
        self.on_click(self._on_click)

    @observe('value')
    def _on_value_change(self, proposal):
        self.description = str(self.value)

    @validate('color')
    def _valid_value(self, proposal):
        if proposal['value'] not in ('red', 'green', 'none'):
            raise TraitError('Invalid color: Valid colors are red, green, and none')
        return proposal['value']

    @observe('color')
    def _observe_color(self, b):
        if self.color == 'red':
            self.layout.border='3px solid red'
        elif self.color == 'green':
            self.layout.border='3px solid green'
        elif self.color == 'none':
            self.layout.border='none'

    def _on_click(self, button):
        if button.color == 'red':
            button.color = 'green'
        elif button.color == 'green':
            button.color = 'none'
        elif button.color == 'none':
            button.color = 'red'

# Cell
class ColorIntArray(CollectionArray):
    def __init__(self, n_row, n_col, width='auto', height='auto', **kwargs):
        children = [ColorIntWidget(width, height) for j in range(n_col) for i in range(n_row)]
        super().__init__(n_row, n_col, children, array_map={'value': 'ns_cells', 'color': 'colors'}, **kwargs)

# Cell
from ..globals import AddStringsMixin
class RatioLandingModel(Model, AddStringsMixin):

    rowName = Unicode('A')
    colName = Unicode('B')
    rowLevels = Array()
    colLevels = Array()
    rowLevelNames = Array()
    colLevelNames = Array()

    def __init__(self, row_max=6, col_max=6, **kwargs):
        super().__init__(**kwargs)
        self.add_traits(
            row_max = Int(row_max),
            col_max = Int(col_max),
            n_row = Int(row_max),
            n_col = Int(col_max),
            _rowLevels = Array(np.array([str(x+1) for x in range(0,row_max)], dtype=np.object)),
            _colLevels = Array(np.array([str(x+1) for x in range(0,col_max)]), dtype=np.object),
        )
        self.observe(self._observe_rowLevels, 'rowLevels')
        self.observe(self._observe_colLevels, 'colLevels')
        self.rowLevels = self._rowLevels
        self.colLevels = self._colLevels
        # we use the observe method instead of the magic for two reasons:
        # 1. we don't want the methods to be observing values as they are initialized
        # 2. we want the observations to be inherited by any child class
        self.observe(self._observe_n_row, 'n_row')
        self.observe(self._observe_n_col, 'n_col')
        self.observe(self._observe_rowName, 'rowName')
        self.observe(self._observe_colName, 'colName')

    def _observe_n_row(self, change):
        diff = change['new'] - change['old']
        smaller = diff < 0
        larger = diff > 0
        if smaller:
            start = self.row_max + diff
            # populate protected row levels with disappearing values
            for i in range(start, len(self.rowLevels)):
                self._rowLevels[i] = self.rowLevels[i]
            self.rowLevels = self.rowLevels[:change['new']]
        elif larger:
            self.rowLevels = self._rowLevels[:change['new']]
        return change['new']

    def _observe_n_col(self, change):
        diff = change['new'] - change['old']
        smaller = diff < 0
        larger = diff > 0
        if smaller:
            start = self._col_max + diff
            # populate protected row levels with disappearing values
            for i in range(start, len(self.colLevels)):
                self._colLevels[i] = self.colLevels[i]
            self.colLevels = self.colLevels[:change['new']]
        elif larger:
            self.colLevels = self._colLevels[:change['new']]
        return change['new']

    def _observe_rowLevels(self, change):
        self.rowLevelNames = self.add_strings(self.rowName, change['new'])
        return change['new']

    def _observe_colLevels(self, change):
        self.colLevelNames = self.add_strings(self.colName, change['new'])
        return change['new']

    def _observe_rowName(self, change):
        self.rowLevelNames = self.add_strings(change['new'], self.rowLevels)
        return change['new']

    def _observe_colName(self, change):
        self.colLevelNames = self.add_strings([change['new'], self.colLevels])
        return change['new']

# Cell
class Ratios(np.ndarray):

    '''    Note that this doesn't support view casting and new-from-template. It just
    allows us to have access
    '''
    def __new__(subtype, size=1, decimals=3, **kwargs):
        # Create the ndarray instance of our type, given the usual
        # ndarray input arguments.  This will call the standard
        # ndarray constructor, but return an object of our type.
        # It also triggers a call to InfoArray.__array_finalize__
        precision = 10**decimals
        fill_value = round(int(1/size * precision)/precision, decimals)
        obj = np.full(size, fill_value, dtype=DTYPES['float']).view(subtype)
        obj.precision = precision
        obj.decimals = decimals
        obj.total = np.sum(obj) * precision
        obj.step = 1/obj.precision
        obj.n = size
        obj.j = 0
        obj.k = 0
        # Finally, we must return the newly created object:
        return obj

    def __array_finalize__(self, obj):
        # ``self`` is a new object resulting from
        # ndarray.__new__(InfoArray, ...), therefore it only has
        # attributes that the ndarray.__new__ constructor gave it -
        # i.e. those of a standard ndarray.
        #
        # We could have got to the ndarray.__new__ call in 3 ways:
        # From an explicit constructor - e.g. InfoArray():
        #    obj is None
        #    (we're in the middle of the InfoArray.__new__
        #    constructor, and self.info will be set when we return to
        #    InfoArray.__new__)
        if obj is None: return
        # From view casting - e.g arr.view(InfoArray):
        #    obj is arr
        #    (type(obj) can be InfoArray)
        # From new-from-template - e.g infoarr[:3]
        #    type(obj) is InfoArray
        #
        # Note that it is here, rather than in the __new__ method,
        # that we set the default value for 'info', because this
        # method sees all creation of default objects - with the
        # InfoArray.__new__ constructor, but also with
        # arr.view(InfoArray).
        self.precision = getattr(obj, 'precision', None)
        self.decimals = getattr(obj, 'decimals', None)
        self.total = getattr(obj, 'total', None)
        self.step = getattr(obj, 'step', None)
        self.n = getattr(obj, 'n', None)
        self.j = getattr(obj, 'j', None)
        self.k = getattr(obj, 'k', None)
        # We do not need to return anything

    def __setitem__(self, i, value):
        diff = round(value - self[i], 3) * self.precision
        if diff > 0:
            self.inc(i, diff)
        elif diff < 0:
            self.dec(i, abs(diff))

    def inc(self, i, diff):
        change = 0
        nums = np.copy(self) * self.precision
        nums = nums.astype(int)
        if diff + nums[i] > self.total:
            diff = self.total - nums[i]
        while diff > 0:
            if self.j != i:
                nums[self.j] = nums[self.j] - 1
                nums[i] = nums[i] + 1
                diff -= 1
                change += 1
            self._inc_j()
        for i, val in enumerate(np.around(nums/self.precision, 3)):
            self.itemset(i, val)

    def dec(self, i, diff):
        change = 0
        nums = np.copy(self) * self.precision
        nums = nums.astype(int)
        if nums[i] - diff < 0:
            diff = nums[i]
        while diff > 0:
            if self.k != i:
                nums[self.k] = nums[self.k] + 1
                nums[i] = nums[i] - 1
                diff -= 1
                change += 1
            self._inc_k()
        for i, val in enumerate(np.around(nums/self.precision, 3)):
            self.itemset(i, val)

    def _inc_j(self):
        self.j = (self.j + 1) % self.n

    def _inc_k(self):
        self.k = (self.k + 1) % self.n

# Cell
class RatioArray(traitlets.TraitType):

    info_text = 'an ndarray of dependent ratios'

    def __init__(self, size=1, decimals=3, **kwargs):
        self.default_value = Ratios(size, decimals)
        super().__init__(**kwargs)

    def validate(self, obj, value):
        if isinstance(value, Ratios):
            return value
        self.error(obj, value)

# Cell
class RatioLandingView(RatioLandingModel, VBox):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.children = (HBox(name='rowBox'), HBox(name='colBox'))

        ''' row '''
        rowNameWidget = Text(
            name = 'rowNameWidget',
            layout = ipyw.Layout(description_width = 'initial', max_width = '160px'),
            value=self.rowName,
            description='Factor:'
        )
        rowNameWidget.observe(self._observe_rowNameWidget, 'value')
        n_row_widget = Dropdown(
            name = 'n_row_widget',
            layout = ipyw.Layout(description_width='initial', max_width='135px'),#, margin='20px 0px 0px 0px 0px'),
            options=[i for i in range(2, self.row_max+1)],
            value=self.n_row,
            description='Levels:'
        )
        n_row_widget.observe(self._observe_n_row_widget, names='value')
        self['rowBox'].children += (rowNameWidget, n_row_widget)

        rowLevelsWidgets = [Text(
            value=self.rowLevels[0],
            description = 'Names:',
            layout = ipyw.Layout (
                description_width= 'initial',
                max_width= '150px' )
        )]

        rowLevelsWidgets += [ Text(
            value = self._rowLevels[i],
            layout=ipyw.Layout(description_width = 'initial', max_width = '62px')
        ) for i in range(1, self.n_row)]

        rowLevelsWidgets += [ Text(
            value = self._rowLevels[i],
            layout=ipyw.Layout(description_width = 'initial', max_width = '62px')
        ) for i in range(self.n_row, self.row_max)]
        self.rowLevelsWidgets = CollectionList(rowLevelsWidgets)
        self.rowLevelsWidgets.observe_all(self._observe_rowLevelsWidgets, 'value')
        self['rowBox'].children += self.rowLevelsWidgets.children

        ''' column '''
        colNameWidget = Text (
            name = 'colNameWidget',
            layout = ipyw.Layout(description_width = 'initial', max_width = '160px'),
            value=self.colName,
            description='Factor:'
        )
        colNameWidget.observe(self._observe_colNameWidget, 'value')
        n_col_widget = Dropdown(
            name = 'n_row_widget',
            layout = ipyw.Layout(description_width='initial', max_width='135px'),#, margin='20px 0px 0px 0px 0px'),
            options=[i for i in range(2, self.col_max+1)],
            value=self.n_col,
            description='Levels:'
        )
        n_col_widget.observe(self._observe_n_col_widget, names='value')
        self['colBox'].children += (colNameWidget, n_col_widget)

        colLevelsWidgets = [Text(
            value=self.colLevels[0],
            description = 'Names:',
            layout = ipyw.Layout (
                description_width= 'initial',
                max_width= '150px' )
        )]

        colLevelsWidgets += [ Text(
            value = self._colLevels[i],
            layout=ipyw.Layout(description_width = 'initial', max_width = '62px')
        ) for i in range(1, self.n_col)]

        colLevelsWidgets += [ Text(
            value = self._colLevels[i],
            layout=ipyw.Layout(description_width = 'initial', max_width = '62px')
        ) for i in range(self.n_col, self.col_max)]
        self.colLevelsWidgets = CollectionList(colLevelsWidgets)
        self.colLevelsWidgets.observe_all(self._observe_colLevelsWidgets, 'value')
        self['colBox'].children += self.colLevelsWidgets.children


    def _observe_n_row_widget(self, change):
        for i in range(2, change['new']):
            self.rowLevelsWidgets[i].visible()
        for i in range(change['new'], 6):
            self.rowLevelsWidgets[i].invisible()
        self.n_row = super()._observe_n_row(change)

    def _observe_n_col_widget(self, change):
        for i in range(2, change['new']):
            self.colLevelsWidgets[i].visible()
        for i in range(change['new'], 6):
            self.colLevelsWidgets[i].invisible()
        self.n_col = super()._observe_n_col(change)

    def _observe_rowLevelsWidgets(self, change):
        self.rowLevels = self.rowLevelsWidgets.values
        return change['new']

    def _observe_colLevelsWidgets(self, change):
        self.colLevels = self.colLevelsWidgets.values
        return change['new']

    def _observe_rowNameWidget(self, change):
        self.rowName = change['new']
        return change['new']

    def _observe_colNameWidget(self, change):
        self.colName = change['new']
        return change['new']

# Cell
class RatioEstimatesModel(RatioLandingModel):

    total = Int()
    N = Int()
    colors = Array()
    ns_cells = IntArray()
    percents = Float()
    rowRatios = RatioArray()
    colRatios = RatioArray()
    rowTotals = IntArray()
    colTotals = IntArray()

    def __init__(self, n_row=6, n_col=6, precision=3, **kwargs):
        super().__init__(n_row, n_col, **kwargs)
        self.rowRatios = Ratios(n_row, precision)
        self.colRatios = Ratios(n_col, precision)
        self.precision = precision
        self.colors = np.array([['none'] * self.n_col] * self.n_row, dtype=np.object)
        self.total = self.n_row * self.n_col * 10
        self._ratios_changed()
        self.observe(self._observe_total, 'total')

    def _observe_total(self, change):
        self._on_change(change['new'])
        return change['new']

    def _ratios_changed(self):
        percents = np.outer(self.rowRatios, self.colRatios)
        self.percents = np.around(percents, self.precision)
        self._on_change(self.total)

    def _on_change(self, total):
        ns_cells = self.percents * total
        self.ns_cells = ns_cells.astype(DTYPES['int'])
        self.rowTotals = np.sum(self.ns_cells, axis=1, dtype=DTYPES['int'])
        self.colTotals = np.sum(self.ns_cells, axis=0, dtype=DTYPES['int'])
        self.N = np.sum(self.ns_cells)

    def change_row_ratio(self, i, diff):
        self.rowRatios[i] += diff
        self._ratios_changed()

    def change_col_ratio(self, i, diff):
        self.colRatios[i] += diff
        self._ratios_changed()

# Cell
class RatioEstimatesView(Grid, RatioEstimatesModel):

    def __init__(self, n_row=6, n_col=6, precision=3, width='auto', height='auto', **kwargs):
        super().__init__(n_rows=n_row+3, n_columns=n_col+3, n_row=n_row, n_col=n_col, **kwargs)

        widget_list = [FloatText(step=1/10**precision, layout=ipyw.Layout(width=width, height=height)) for i in range(self.n_row)]
        self.rowRatioWidgets = WidgetList(widget_list)
        self.add_widget_list(self.rowRatioWidgets, 2, 0, direction='col')
        self.rowRatioWidgets.values = np.copy(self.rowRatios)
        self.rowRatioWidgets.observe_all(self._observe_rowRatioWidgets, 'value')

        widget_list = [FloatText(step=1/10**precision, layout=ipyw.Layout(width=width, height=height)) for i in range(self.n_col)]
        self.colRatioWidgets = WidgetList(widget_list)
        self.add_widget_list(self.colRatioWidgets, 0, 2, direction='row')
        self.colRatioWidgets.values = np.copy(self.colRatios)
        self.colRatioWidgets.observe_all(self._observe_colRatioWidgets, 'value')

        self.colorIntArray = ColorIntArray(n_row, n_col, width, height, name='colorInts')
        self.add_widget_array(self.colorIntArray, 2, 2)
        self.colorIntArray.ns_cells = self.ns_cells
        traitlets.link((self, 'ns_cells'), (self.colorIntArray, 'ns_cells'))
        traitlets.link((self, 'colors'), (self.colorIntArray, 'colors'))

        widget_list = [AlignedLabel(layout=ipyw.Layout(width=width, height=height)) for i in range(self.n_row)]
        self.rowLevelNamesWidgets = WidgetList(widget_list)
        self.add_widget_list(self.rowLevelNamesWidgets, 2, 1, direction='col')
        traitlets.link( (self, 'rowLevelNames'), (self.rowLevelNamesWidgets, 'values'))

        widget_list = [AlignedLabel(layout=ipyw.Layout(width=width, height=height)) for i in range(self.n_col)]
        self.colLevelNamesWidgets = WidgetList(widget_list)
        self.add_widget_list(self.colLevelNamesWidgets, 1, 2, direction='row')
        traitlets.link( (self, 'colLevelNames'), (self.colLevelNamesWidgets, 'values'))

        widget_list = [IntLabel(layout=ipyw.Layout(width=width, height=height)) for i in range(self.n_row)]
        self.rowTotalsWidgets = WidgetList(widget_list)
        self.add_widget_list(self.rowTotalsWidgets, 2, self.n_col + 2, direction='col')
        traitlets.link((self, 'rowTotals'), (self.rowTotalsWidgets, 'values'))

        widget_list = [IntLabel(layout=ipyw.Layout(width=width, height=height)) for i in range(self.n_row)]
        self.colTotalsWidgets = WidgetList(widget_list)
        self.add_widget_list(self.colTotalsWidgets, self.n_row + 2, 2, direction='row')
        traitlets.link((self, 'colTotals'), (self.colTotalsWidgets, 'values'))

        self.totalWidget = IntText(layout=ipyw.Layout(width=width, height=height))
        self.add_widget_list([self.totalWidget], self.n_row + 2, self.n_col+2)
        traitlets.link((self, 'total'), (self.totalWidget, 'value'))

    def _observe_rowRatioWidgets(self, change):
        diff = change['new'] - change['old']
        self.rowRatioWidgets.unobserve_all(self._observe_rowRatioWidgets, 'value')
        self.change_row_ratio(change['owner'].item_index, diff)
        self.rowRatioWidgets.values = np.copy(self.rowRatios)
        self.rowRatioWidgets.observe_all(self._observe_rowRatioWidgets, 'value')
        self._on_widget_change()

    def _observe_colRatioWidgets(self, change):
        diff = change['new'] - change['old']
        self.colRatioWidgets.unobserve_all(self._observe_colRatioWidgets, 'value')
        self.change_col_ratio(change['owner'].item_index, diff)
        self.colRatioWidgets.values = np.copy(self.colRatios)
        self.colRatioWidgets.observe_all(self._observe_colRatioWidgets, 'value')
        self._on_widget_change()

    def _on_widget_change(self):
        self.colorIntArray.ns_cells = self.ns_cells