# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/model.model_collections.ipynb (unless otherwise specified).

__all__ = ['Collection', 'Selection', 'CollectionList', 'CollectionArray']

# Cell
from superpower_gui import rpy
from ..rpy import Int, IntArray, Float, FloatArray
from .model import Item

# Cell
from traitlets import HasTraits, Tuple, Unicode, Any, observe, List
import numpy as np
from traittypes import Array

# Cell
class Collection(Item):
    '''Supports 2D and 3D arrays. 3D arrays can be accessed with tuples but are stored in
    a list.'''

    children = Tuple()

    def __init__(self, children=None, **kwargs):
        super().__init__(**kwargs)
        if children:
            self.children += tuple(children)
        self._collections = {}

    def __getitem__(self, key):
        if isinstance(key, int):
            return self._children[key]
        elif isinstance(key, str):
            if key in self._collections:
                return self._collections[key]
            else:
                if hasattr(self, '_indices') and key in self._indices:
                    return self._children[self._indices[key]]
                else:
                    if hasattr(self, 'name'):
                        name = self.name
                    else:
                        name = self.__class__.__name__
                    raise KeyError('Key \'%s\' not found in %s' % (key, name))
        elif isinstance(key, slice):
            return [self._children[i] for i in range(*key.indices(self.size))]
        else:
            raise KeyError('Key must be an int or a str')

    def __iter__(self):
        return self._children.__iter__()

    @property
    def size(self):
        return self._size

    def get_child_index(self):
        return self._size

    @observe('children')
    def _observe_collection_children(self, change):
        self._children = ()
        self._names = []
        self._indices = {} # name: index
        self._size = 0
        for child in change['new']:
            self._add_child(child)
        return tuple(self._children)

    def add_collection(self, collection):
        if not hasattr(collection, 'name'):
            collection.name = collection.__class__.__name__ + str(self._size)
        self._collections[collection.name] = collection
        self.children += collection.children

    def _add_child(self, child):
        '''Modify child and add child to list self._children'''
        child.item_index = self.get_child_index()
        if not hasattr(child, 'name'):
            child.name = child.__class__.__name__ + str(self._size)
        self._names.append(child.name)
        self._indices[child.name] = child.item_index
        self._size += 1
        self._children += (child, )

    def get_links(self, attr):
        return [(child, attr) for child in self._children]

# Cell
class Selection(Item):
    '''A collection with selectable item.'''

    children = Tuple()
    _children = Tuple()
    selected_index = Int(default_value=None, allow_none=True)
    selected_name = Unicode(default_value=None, allow_none=True)

    def __init__(self, children, **kwargs):
        super().__init__(**kwargs)
        self._size = 0
        self._children = children

    def __getitem__(self, key):
        if isinstance(key, int):
            return self._children[key]
        elif isinstance(key, str):
            if key in self._indices:
                return self._children[self._indices[key]]
        elif isinstance(key, slice):
            return [self._children[i] for i in range(*key._indices(self._size))]
        else:
            raise KeyError('Key %s not found in children or collections' % str(key))

    def __iter__(self):
        return self._children.__iter__()

    #@validate('children')
    #TODO: children can only be changed to match selected name and index

    @observe('_children')
    def _observe_selection_children(self, change):
        self._names = []
        self._indices = {} # name: index
        ret = tuple([self._add_child(child, i) for i, child in enumerate(change['new'])])
        if ret:
            self.selected_index = 0
            self.selected_name = ret[0].name
            self.children = (ret[0], )
        self._size = len(ret)
        return ret

    def _add_child(self, child, index):
        '''Modify child and add child to list self._children'''
        child.item_index = index
        if not hasattr(child, 'name'):
            child.name = child.__class__.__name__ + str(self._size)
        self._names.append(child.name)
        self._indices[child.name] = child.item_index
        return child

    @observe('selected_index')
    def _observe_selected_index(self, change):
        selected_index = change['new']
        if self._size:
            self.children = (self._children[selected_index], )
            self.selected_name = self._names[selected_index]
        return selected_index

    @observe('selected_name')
    def _observe_selected_name(self, change):
        selected_name = change['new']
        if self._size:
            self.selected_index = self._indices[selected_name]
        return selected_name

# Cell
class CollectionList(Collection):
    '''A WidgetCollectin in which every child is the same type of Widget.
    Any child trait name listed in `trait_names` will be turned into an
    observable list, for example if 'value' is listed, a List trait 'value_list'
    will give the user access to a list of child values so that when the list
    is updated, so is the value of the child'''
    def __init__(self, children=None, list_map={'value': 'values'}, **kwargs):
        super().__init__(**kwargs)
        for name, list_name in list_map.items():
            assert hasattr(children[0], name), "Child does not have attribute %s provided in list_map." % name
        self._list_map = list_map
        self._list_handlers = {}
        if children:
            self.children = tuple(children)

    def _add_child(self, child):
        child._child_handlers = {}
        super()._add_child(child)

    @observe('children')
    def _observe_collection_children(self, change):
        super()._observe_collection_children(change)
        if self._children:
            self.add_trait_lists()

    def apply_all(self, method_name, *args, **kwargs):
        for child in self._children:
            method = getattr(child, method_name)
            method(*args, **kwargs)

    def observe_all(self, handler, names=None, type='change'):
        for child in self._children:
            if names:
                child.observe(handler, names, type)
            else:
                child.observe(handler)

    def unobserve_all(self, handler, names=None, type='change'):
        for child in self._children:
            if names:
                child.unobserve(handler, names, type)
            else:
                child.unobserve(handler)

    def add_trait_lists(self):
        # check that trait exists in child
        child_traits = self._children[0].traits()
        for name, list_name in self._list_map.items():
            assert name in child_traits, 'Name %s doesn\'t exist in widget.' % name
            trait_instance = getattr(self._children[0], name)
            # if the trait is a numeric type, create the corresponding array and append '_array' to the name
            # else create a list and append '_list' to the name
            if isinstance(trait_instance, (int, np.integer)):
                trait_list = IntArray()
            elif isinstance(trait_instance, (int, np.float)):
                trait_list = FloatArray()
            elif isinstance(trait_instance, bool):
                trait_list = Array(dtype = np.dtype(bool))
            elif isinstance(trait_instance, str): # if we dont use objects, the string length can't grow
                trait_list = Array(dtype = np.dtype(np.object))
            else:
                trait_list = Array()
            # add the new trait
            self.add_traits(**{list_name: trait_list})
            setattr(self, list_name, self._init_list(name))
            # create, save, and observe child handlers
            for child in self._children:
                child._child_handlers[name] = observe(name, type='change')
                child._child_handlers[name]._init_call(self._child_handler_factory(name))
                child.observe(child._child_handlers[name], name)
            # create, save, and observe list handler
            self._list_handlers[list_name] = observe(name, type='change')
            self._list_handlers[list_name]._init_call(self._list_handler_factory(name))
            self.observe(self._list_handlers[list_name], list_name)

    def _init_list(self, name):
        return [getattr(child, name) for child in self._children]

    def observe_list(self, name):
        self.observe(self._list_handlers[name], self._list_map[name], type='change')

    def unobserve_list(self, name):
        self.unobserve(self._list_handlers[name], self._list_map[name], type='change')

    def observe_children(self, name):
        for child in self._children:
            child.observe(child._child_handlers[name], name, type='change')

    def unobserve_children(self, name):
        for child in self._children:
            if name in child._child_handlers:
                child.unobserve(child._child_handlers[name], name, type='change')

    def _list_handler_factory(self, name):
        def wrapper(change):
            return self._list_handler(name, change)
        return wrapper

    def _list_handler(self, name, change):
        self.unobserve_children(name)
        # set child values to match the trait_list
        for i, child in enumerate(self._children):
            setattr(child, name, change['new'][i])
        self.observe_children(name)
        return change['new']

    def _child_handler_factory(self, name):
        def wrapper(change):
            return self._child_handler(name, change)
        return wrapper

    def _child_handler(self, name, change):
        # get corresponding list
        list_name = self._list_map[name]
        _list = getattr(self, list_name)
        # set appropriate element of list to new value
        child_index = change['owner'].item_index
        _list[child_index] = change['new']
        # add to change dict and notify change to list
        list_change = {'name': list_name}
        list_change['list'] = _list
        list_change['old'] = change['old']
        list_change['new'] = change['new']
        list_change['index'] = child_index
        list_change['owner'] = self
        list_change['type'] = 'child_change'
        self.notify_change(list_change)
        return change['new']

# Cell
class CollectionArray(CollectionList):
    '''A CollectionList with two dimensions. Children can be accessed
    with tuple notation similar to the numpy library. Any trait_name listed
    will cause the creation of an array of child traits, for example
    if 'value' is listed, the trait 'value_array' will become available'''
    def __init__(self, n_rows, n_columns, children=None, array_map={'value': 'values'}, **kwargs):
        self.n_rows = n_rows
        self.n_columns = n_columns
        kwargs['list_map'] = array_map
        super().__init__(children, **kwargs)

    def get_child_index(self):
        return self._2D_index(self._size)

    def __getitem__(self, index):
        if not isinstance(index, tuple):
            raise IndexError('Array index must be a tuple')
        x, y = index
        if isinstance(x, int) and isinstance(y, int):
            assert x <= self.n_rows and y <= self.n_columns, "index (%i, %i) not in range" % (i, j)
            return self._children[self._flat_index(x, y)]
        else:
            ret = []
            for x in self._to_range(x, self.n_rows):
                for y in self._to_range(y, self.n_columns):
                    ret.append(self._children[self._flat_index(x, y)])
            return ret

    def _to_range(self, s, end):
        if isinstance(s, int):
            return range(s, s+1)
        else:
            if hasattr(s, 'stop'):
                assert s.stop <= end, "slice not in range"
                end = s.stop
            return range(*s.indices(end))

    def _flat_index(self, i, j):
        return i * self.n_columns + j

    def _2D_index(self, index):
        row = int(index / self.n_columns)
        col = int(index % self.n_columns)
        return (row, col)

    def _list_handler(self, name, change):
        self.unobserve_children(name)
        for i in range(self.n_rows):
            for j in range(self.n_columns):
                setattr(self[i, j], name, change['new'][i][j])
        self.observe_children(name)
        return change['new']

    def _child_handler(self, name, change):
        # get corresponding list
        ary_name = self._list_map[name]
        ary = getattr(self, ary_name)
        # set appropriate element of list to new value
        child_index = change['owner'].item_index
        ary[child_index[0]][child_index[1]] = change['new']
        # add to change dict and notify change to list
        ary_change = {'name': ary_name}
        ary_change['array'] = ary
        ary_change['old'] = change['old']
        ary_change['new'] = change['new']
        ary_change['index'] = child_index
        ary_change['owner'] = self
        ary_change['type'] = 'child_change'
        self.notify_change(ary_change)
        return change['new']

    def _init_list(self, trait):
        return [[getattr(self[i, j], trait) for j in range(self.n_columns)] for i in range(self.n_rows)]