# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/views.error_bars.ipynb (unless otherwise specified).

__all__ = ['ErrorBarsView', 'StaticErrorBars', 'ErrorBars']

# Cell
from ..widgets import Output
from ..mvc import Model
from ..items import flatIndex, arrayIndex, Item
from ..models.error_bars import ErrorBarsModel
from ..widget_containers import Box

# Cell
import matplotlib.pyplot as plt
import matplotlib.colors as mpl_colors
from matplotlib.lines import Line2D
from matplotlib.patches import Rectangle
import numpy as np

# Cell
class ErrorBarsView(Output):
    def __init__(self, model, width=6, height=3, **kwargs):
        super().__init__(**kwargs)

        self.model = model

        self.bar_width = 0.8
        self.whisker_width = 0.6
        self.color_choices = {'red': mpl_colors.to_rgba('red'),
           'lightblue': mpl_colors.to_rgba('lightblue'),
           'green': mpl_colors.to_rgba('green')}

        with self:
            self.fig, self.ax = plt.subplots(figsize=[width, height])
            self.fig.canvas.toolbar_visible = False
            self.fig.canvas.header_visible = False # Hide the Figure name at the top of the figure
            self.fig.canvas.footer_visible = False
            self.fig.canvas.resizable = False
            self.fig.canvas.capture_scroll = False

        # store important values
        self.xticks = [] # locations of ticks (one for each group)
        self.xlocs = [] # one for each bar
        for i in range(self.model.nGroups):
            xlocs = [i * (self.model.nBars + 1) + j for j in range(self.model.nBars)]
            self.xlocs.extend(xlocs)
            self.xticks.append(float(sum(xlocs))/len(xlocs))

        # locations of xticks depend on the number of groups
        if self.model.nGroups == 1:
            self.ax.set_xticks(self.xlocs)
        else:
            self.ax.set_xticks(self.xticks)
        self.ax.set_xticklabels(self.model.labels)
        self.ax.yaxis.grid()
        if self.model.title:
            self.ax.set_title(self.model.title)

        # set margins (where the proportion is the space the artists take up over the xlim and ylim)
        self.ax.set_ymargin(0.4)
        if self.model.total == 1:
            self.ax.set_xmargin(0.5) # https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_xmargin.html

        self.bars = self.ax.bar(self.xlocs, self.model.heights, self.bar_width, color=self.model.colors)
        self.errbars = self.ax.errorbar(self.xlocs, self.model.heights, self.model.errors / 2,
                                        fmt='none', color='black', capsize=1)
        _, (self.top_markers, self.bot_markers) , (self.verticals,) = self.errbars
        self.top_marker_data = self.top_markers.get_ydata()
        self.bot_marker_data = self.bot_markers.get_ydata()
        self.segments = self.verticals.get_segments()

        self.lines = []
        self.tops = []
        self.bots = []
        for i in range(self.model.total):
            start = self.xlocs[i] - (self.bar_width / 2)
            end = start + self.bar_width
            line = Line2D([start, end], [self.model.heights[i]]*2, color='black', solid_capstyle='butt', linewidth=6)
            start = self.xlocs[i] - (self.whisker_width / 2)
            end = start + self.whisker_width
            top = Line2D([start, end], [self.model.tops[i]]*2, color='black', solid_capstyle='butt', linewidth=3)
            bot = Line2D([start, end], [self.model.bots[i]]*2, color='black', solid_capstyle='butt', linewidth=3)
            self.ax.add_artist(line)
            self.ax.add_artist(top)
            self.ax.add_artist(bot)
            self.lines.append(line)
            self.tops.append(top)
            self.bots.append(bot)


# Cell
class StaticErrorBars(ErrorBarsView):

    def __init__(self, model, **kwargs):
        super().__init__(model, **kwargs)

    def setColor(self, i , color):
        self.bars[i].set_facecolor(self.color_choices[color])
        self.model.colors[i] = color
        change = {'new': color, 'index': i, 'name': 'colors', 'type': 'child_change'}
        self.notify_change(change)

    def _setHeight(self, i, height):
        self.model.setHeight(i, height)
        self.bars[i].set_height(height)
        self.lines[i].set_ydata([height, height])
        self._setHeightError(i, height, self.model.errors[i])
        change = {'new': height, 'index': i, 'name': 'heights', 'type': 'child_change'}
        self.notify_change(change)

    def setHeight(self, i, height):
        self._setHeight(i, height)
        self.redrawCanvas()

    def setError(self, i, error):
        self._setError(i, error)
        self.redrawCanvas()

    def _setError(self, i, error):
        self.model.setError(i, error)
        self._setHeightError(i, self.model.heights[i], error)
        change = {'new': error, 'index': i, 'name': 'errors', 'type': 'child_change'}
        self.notify_change(change)

    def _setHeightError(self, i, height, error):

        # whiskers
        self.tops[i].set_ydata(self.model.tops[i])
        self.bots[i].set_ydata(self.model.bots[i])

        # markers (which are covered up by whiskers but are important)
        self.top_marker_data[i] = self.model.tops[i]
        self.bot_marker_data[i] = self.model.bots[i]
        self.top_markers.set_ydata(self.top_marker_data) # update view components
        self.bot_markers.set_ydata(self.bot_marker_data)

        # vertical lines
        self.segments[i][0][1] = self.model.tops[i]
        self.segments[i][1][1] = self.model.bots[i]
        self.verticals.set_segments(self.segments)

    def redrawCanvas(self):
        self.ax.relim()
        self.ax.autoscale()
        self.fig.canvas.flush_events()
        self.fig.canvas.draw()

# Cell
class ErrorBars(StaticErrorBars):
    def __init__(self, model, **kwargs):
        super().__init__(model, **kwargs)

        self.cidpress = self.fig.canvas.mpl_connect(
           'button_press_event', self.onPress)
        self.cidrelease = self.fig.canvas.mpl_connect(
            'button_release_event', self.onRelease)
        self.cidmotion = None
        self.picked = None
        self.event_start = None

    def onPress(self, event):
            if event.inaxes != self.ax: return
            for i in range(self.model.total):
                self._onPress(i, event)

    def _onPress(self, i, event):
        ''' returns true for anything that needs motion'''
        x0 = self.bars[i].get_x()
        x1 = x0 + self.bar_width
        if event.xdata > x0 and event.xdata < x1:
            if self.lines[i].contains(event)[0]:
                self.event_start = event.ydata
                self.height_start = self.model.heights[i]
                self.picked = i
                self.cidmotion = self.fig.canvas.mpl_connect('motion_notify_event', self.onLineMotion)
            elif self.tops[i].contains(event)[0]:
                self.event_start = event.ydata
                self.error_start = self.model.errors[i]
                self.picked = i
                self.cidmotion = self.fig.canvas.mpl_connect('motion_notify_event', self.onTopMotion)
            elif self.bots[i].contains(event)[0]:
                self.event_start = event.ydata
                self.error_start = self.model.errors[i]
                self.picked = i
                self.cidmotion = self.fig.canvas.mpl_connect('motion_notify_event', self.onBotMotion)
            elif event.ydata > 0 and event.ydata < self.bars[i].get_height():
                current = self.bars[i].get_facecolor()
                if current == self.color_choices['lightblue']:
                    self.setColor(i, 'red')
                elif current == self.color_choices['red']:
                    self.setColor(i, 'green')
                else:
                    self.setColor(i, 'lightblue')

    def onLineMotion(self, event):
        if self.picked is not None:
            dy = event.ydata - self.event_start
            height = self.height_start + dy
            self._setHeight(self.picked, height)

    def onTopMotion(self, event):
        if event.inaxes != self.ax: return
        if self.picked is not None and self.event_start is not None:
            print(self.picked, event)
            dy = event.ydata - self.event_start
            error = self.error_start + dy * 2
            self._setError(self.picked, error)

    def onBotMotion(self, event):
        if event.inaxes != self.ax: return
        if self.picked is not None:
            print(self.picked, event)
            dy = self.event_start - event.ydata
            error = self.error_start + dy * 2
            self._setError(self.picked, error)

    def onRelease(self, event):
        if event.inaxes != self.ax: return
        if self.picked is not None:
            self.cidmotion = self.fig.canvas.mpl_disconnect(self.cidmotion)
            self.picked = None
            self.redrawCanvas()