#!/usr/bin/python3

import sys
from PyQt5.QtWidgets import (QWidget, QPushButton, QGroupBox, QSpinBox, QDoubleSpinBox,
    QGridLayout, QHBoxLayout, QVBoxLayout, QApplication, QMainWindow, QShortcut, QLabel,
    QComboBox, QCheckBox, QProgressBar, QAction, QStackedWidget)

from PyQt5.QtGui import QKeySequence, QColor, QPalette
from PyQt5.QtCore import Qt

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.backends.qt_compat import QtCore, QtWidgets, is_pyqt5
from matplotlib.backends.backend_qt5agg import (
    FigureCanvas, NavigationToolbar2QT as NavigationToolbar)
from matplotlib.figure import Figure
from functools import partial
from itertools import cycle

import miepy

import signal
signal.signal(signal.SIGINT, signal.SIG_DFL)

mpl.rc('mathtext', default='regular')

nm = 1e-9
tab = ' '*6

class dynamic_labels:
    def __init__(self, group_names):
        self.layout = QGridLayout()
        self.row = 0
        self.col = 0
        self.col_max = 2

        self.static_labels = []
        self.group_names = group_names
        self.groups = {name: [] for name in group_names}

    def add_static(self, label, widget, same_row=False):
        qlabel = QLabel(label)
        if same_row:
            self._add_column(qlabel, widget)
        else:
            self._add(qlabel, widget)

    def add_dynamic(self, label, widget, group, hide=True, same_row=False):
        qlabel = QLabel(label)
        self.groups[group].extend([qlabel, widget])

        if same_row:
            self._add_column(qlabel, widget)
        else:
            self._add(qlabel, widget)

        if hide:
            qlabel.hide()
            widget.hide()

    def show_group(self, group):
        for widget in self.groups[group]:
            widget.show()

    def hide_group(self, group):
        for widget in self.groups[group]:
            widget.hide()

    def show_only(self, group):
        for wlist in self.groups.values():
            for widget in wlist:
                widget.hide()

        for widget in self.groups[group]:
            widget.show()

    def add_stretch(self):
        self.layout.setColumnStretch(self.col_max, 1)
        self.layout.setRowStretch(self.row, 1)

    def _add(self, label, widget):
        self.col = 0
        self.layout.addWidget(label, self.row, self.col)
        self.layout.addWidget(widget, self.row, self.col + 1)

        self.row += 1
        self.col += 2

    def _add_column(self, label, widget):
        self.layout.addWidget(label, self.row-1, self.col)
        self.layout.addWidget(widget, self.row-1, self.col + 1)
        self.col += 2
        self.col_max = max(self.col, self.col_max)

class spheroid_box:
    def __init__(self, parent):
        self.Q = QWidget()
        self.box = QGridLayout(self.Q)

        self.axis_xy = QDoubleSpinBox(particle_box, minimum=1, maximum=2000)
        self.axis_xy.setValue(150)
        self.axis_xy.setSingleStep(10)

        self.axis_z = QDoubleSpinBox(particle_box, minimum=1, maximum=2000)
        self.axis_z.setValue(100)
        self.axis_z.setSingleStep(10)

        self.box.addWidget(QLabel('axis a'), 0, 0)
        self.box.addWidget(self.axis_xy, 0, 1)
        self.box.addWidget(QLabel('axis b'), 1, 0)
        self.box.addWidget(self.axis_z, 1, 1)


class Example(QMainWindow):
    def __init__(self):
        super().__init__()
        
        self.materials = {'Silver': miepy.materials.Ag(),
                          'Gold': miepy.materials.Au(),
                          'Aluminum': miepy.materials.Al(),
                          'Nickel': miepy.materials.Ni(),
                          'Copper': miepy.materials.Cu(),
                          'Cobalt': miepy.materials.Co(),
                          'Platinum': miepy.materials.Pt(),
                          'TiO2': miepy.materials.TiO2(),
                          'Silica': miepy.materials.silica(),}

        self.particles = ['Sphere', 'Spheroid', 'Cylinder'] # Core-shell

        ### central widget
        self._main = QtWidgets.QWidget()
        self.setCentralWidget(self._main)

        ### main layout
        self.vbox = QtWidgets.QVBoxLayout(self._main)
        self.vbox.setSpacing(10)
        self.vbox.setContentsMargins(*[20]*4)
        self.fig = plot_region()

        self.Nwav = 300
        wavelengths = np.linspace(300*nm, 1000*nm, self.Nwav)
        source = miepy.sources.plane_wave([1, 0])
        medium = miepy.materials.water()
        self.mie = mie_sphere(75*nm, self.materials['Silver'], wavelengths, source, medium)
        sol = self.mie.solve()
        self.fig.initialize_plot(self.mie.wavelengths, sol)

        QApplication.setStyle("Fusion")

        self.initUI()
        self.initWidgets()

    def initUI(self):
        # self.setStyleSheet("background-color: white;")
        QShortcut(QKeySequence("Ctrl+W"), self, QApplication.instance().quit)
        QShortcut(QKeySequence("Ctrl+Q"), self, QApplication.instance().quit)
# Qt.Key_Return
        # self.setGeometry(0,0, 640, 440)

        bottom_layout = QHBoxLayout()

        ### Particle properties widgets
        particle_box = QGroupBox('Particle properties')

        self.particle_layout = dynamic_labels(self.particles + ['Nonsphere', 'Dielectric'])

        self.geometry_box = QComboBox(particle_box)
        for particle in self.particles:
            self.geometry_box.addItem(particle)
        self.particle_layout.add_static('particle type', self.geometry_box)
        self.current_geometry = 'Sphere'

        self.sphere_diameter_box = QDoubleSpinBox(particle_box, minimum=1, maximum=2000)
        self.sphere_diameter_box.setValue(150)
        self.sphere_diameter_box.setSingleStep(10)
        self.particle_layout.add_dynamic(tab + 'diameter (nm)', self.sphere_diameter_box, group='Sphere', hide=False)

        self.cylinder_diameter_box = QDoubleSpinBox(particle_box, minimum=1, maximum=2000)
        self.cylinder_diameter_box.setValue(150)
        self.cylinder_diameter_box.setSingleStep(10)
        self.particle_layout.add_dynamic(tab + 'diameter (nm)', self.cylinder_diameter_box, group='Cylinder', hide=True)

        self.cylinder_height_box = QDoubleSpinBox(particle_box, minimum=1, maximum=2000)
        self.cylinder_height_box.setValue(100)
        self.cylinder_height_box.setSingleStep(10)
        self.particle_layout.add_dynamic(tab + 'height (nm)', self.cylinder_height_box, group='Cylinder', hide=True)

        self.spheroid_axis_a_box = QDoubleSpinBox(particle_box, minimum=1, maximum=2000)
        self.spheroid_axis_a_box.setValue(150)
        self.spheroid_axis_a_box.setSingleStep(10)
        self.particle_layout.add_dynamic(tab + 'axis 2a (nm)', self.spheroid_axis_a_box, group='Spheroid', hide=True)

        self.spheroid_axis_b_box = QDoubleSpinBox(particle_box, minimum=1, maximum=2000)
        self.spheroid_axis_b_box.setValue(100)
        self.spheroid_axis_b_box.setSingleStep(10)
        self.particle_layout.add_dynamic(tab + 'axis 2b (nm)', self.spheroid_axis_b_box, group='Spheroid', hide=True)

        self.orientation_box = QComboBox(particle_box)
        for comp in ['X', 'Y', 'Z']:
            self.orientation_box.addItem(comp)
        self.particle_layout.add_dynamic(tab + 'orientation', self.orientation_box, group='Nonsphere', hide=True)

        self.material_box = QComboBox(particle_box)
        self.material_box.addItem('Dielectric')
        for name in self.materials.keys():
            self.material_box.addItem(name)
        self.material_box.setCurrentIndex(1)
        self.particle_layout.add_static(tab + 'material', self.material_box)

        self.dielectric_box = QDoubleSpinBox(particle_box, minimum=.01, maximum=100)
        self.dielectric_box.setValue(3.7)
        self.dielectric_box.setSingleStep(.1)
        self.particle_layout.add_dynamic('index', self.dielectric_box, group='Dielectric', same_row=True)

        self.medium_box = QDoubleSpinBox(particle_box, minimum=.01, maximum=100)
        self.medium_box.setValue(1.33)
        self.medium_box.setSingleStep(.1)
        self.particle_layout.add_static('medium index', self.medium_box)

        self.particle_layout.add_stretch()

        particle_box.setLayout(self.particle_layout.layout)
        bottom_layout.addWidget(particle_box)

        ### Source widgets
        source_box = QGroupBox('Source properties')

        self.sources_box = QComboBox(source_box)
        self.sources_box.addItem('Plane wave')
        self.sources_box.addItem('Gaussian beam')
        self.sources_box.addItem('Radial beam')
        self.sources_box.addItem('Azimuthal beam')

        self.wavelength_min = QDoubleSpinBox(source_box, minimum=1, maximum=2000)
        self.wavelength_min.setValue(300)
        self.wavelength_min.setSingleStep(50)
        self.wavelength_max = QDoubleSpinBox(source_box, minimum=1, maximum=2000)
        self.wavelength_max.setValue(1000)
        self.wavelength_max.setSingleStep(50)

        self.beam_width_box = QDoubleSpinBox(minimum=1, maximum=2000)
        self.beam_width_box.setValue(1000)
        self.beam_width_box.setSingleStep(50)

        # layout = QGridLayout()
        self.source_layout = QVBoxLayout()

        layout1 = QHBoxLayout()
        layout1.addWidget(QLabel('wavelength (nm)'))
        layout1.addWidget(self.wavelength_min)
        layout1.addWidget(QLabel('to'))
        layout1.addWidget(self.wavelength_max)
        layout1.addStretch(1)

        layout2 = QHBoxLayout()
        layout2.addWidget(QLabel('source type'))
        layout2.addWidget(self.sources_box)
        layout2.addStretch(1)

        layout3 = QHBoxLayout()
        self.beam_width_label = QLabel(tab + 'beam width (nm)')
        layout3.addWidget(self.beam_width_label)
        layout3.addWidget(self.beam_width_box)
        layout3.addStretch(1)

        self.beam_width_label.hide()
        self.beam_width_box.hide()

        self.source_layout.addLayout(layout1)
        self.source_layout.addLayout(layout2)
        self.source_layout.addLayout(layout3)
        self.source_layout.addStretch(1)

        source_box.setLayout(self.source_layout)
        bottom_layout.addWidget(source_box)

        ### Options widgets
        options_box = QGroupBox('Plotting options')

        self.extinction_box = QCheckBox('Extinction')
        self.extinction_box.setChecked(True)
        self.scattering_box = QCheckBox('Scattering')
        self.scattering_box.setChecked(True)
        self.absorption_box = QCheckBox('Absorption')
        self.absorption_box.setChecked(True)
        self.multipole_box =  QCheckBox('Multipoles')
        self.autocompute_box =  QCheckBox('Auto compute')

        layout = QGridLayout()
        layout.addWidget(self.extinction_box, 0, 0)
        layout.addWidget(self.scattering_box, 1, 0)
        layout.addWidget(self.absorption_box, 2, 0)
        layout.addWidget(self.multipole_box,  3, 0)
        layout.addWidget(self.autocompute_box,  4, 0)
        layout.setColumnStretch(1, 1)
        layout.setRowStretch(5, 1)

        options_box.setLayout(layout)
        bottom_layout.addWidget(options_box)

        compute_layout = QHBoxLayout()
        self.note = QLabel('')
        compute_layout.addWidget(self.note)
        compute_layout.addStretch(1)

        self.progress = QProgressBar(minimumWidth=200)
        self.progress.setRange(0, self.Nwav-1)
        self.progress.setValue(self.Nwav-1)
        compute_layout.addWidget(self.progress)

        self.compute_button = QPushButton('Compute')
        self.compute_button.setDefault(True)
        self.compute_button.clicked.connect(self.compute)
        compute_layout.addWidget(self.compute_button)

        self.autocompute_box.toggled.connect(self.compute_button.setDisabled)
        self.progress.hide()
        self.autocompute_box.setChecked(True)

        menubar = self.menuBar()
        theme_menu = menubar.addMenu('Theme')
        dark_theme =  QAction('Dark theme', self, checkable=True)
        dark_theme.triggered.connect(self.set_theme)
        theme_menu.addAction(dark_theme)

        ### Put it all together
        self.addToolBar(NavigationToolbar(self.fig.canvas, self))
        self.vbox.addWidget(self.fig.canvas)
        self.vbox.addStretch(1)
        self.vbox.addLayout(bottom_layout)
        self.vbox.addLayout(compute_layout)
        self.setWindowTitle('Mie Calculator')
        QShortcut(Qt.Key_Return, self, lambda: self.compute() if not self.autocompute_box.isChecked() else None)
        self.set_light_theme()
        self.show()

    def set_theme(self, dark):
        if dark:
            self.set_dark_theme()
        else:
            self.set_light_theme()

    def set_light_theme(self):
        QApplication.setPalette(QApplication.style().standardPalette())

        # mpl.rcdefaults()
        mpl.rc('text', color='black')
        mpl.rc('legend', facecolor='white', edgecolor=[0.8]*3)
        self.fig.ax.set_facecolor([240/255]*3)
        self.fig.ax.figure.patch.set_facecolor([240/255]*3)
        self.fig.ax.tick_params(axis='both', labelcolor='black')
        self.fig.ax.set_xlabel(self.fig.ax.get_xlabel(), color='black')
        self.fig.ax.set_ylabel(self.fig.ax.get_ylabel(), color='black')
        self.fig.draw()

    def set_dark_theme(self):
        dark_palette = QPalette()
        dark_palette.setColor(QPalette.Window, QColor(53, 53, 53))
        dark_palette.setColor(QPalette.WindowText, QColor(230, 230, 230))
        dark_palette.setColor(QPalette.Base, QColor(25, 25, 25))
        dark_palette.setColor(QPalette.AlternateBase, QColor(53, 53, 53))
        dark_palette.setColor(QPalette.ToolTipBase, QColor(53, 53, 53))
        dark_palette.setColor(QPalette.ToolTipText, QColor(180, 180, 180))
        dark_palette.setColor(QPalette.Text, QColor(230, 230, 230))
        dark_palette.setColor(QPalette.Button, QColor(53, 53, 53))
        dark_palette.setColor(QPalette.ButtonText, QColor(230, 230, 230))
        dark_palette.setColor(QPalette.BrightText, Qt.red)
        dark_palette.setColor(QPalette.Link, QColor(42, 130, 218))
        dark_palette.setColor(QPalette.Highlight, QColor(42, 130, 218))
        dark_palette.setColor(QPalette.HighlightedText, Qt.black)

        d = 100
        dark_palette.setColor(QPalette.Disabled, QPalette.WindowText, QColor(d, d, d))
        dark_palette.setColor(QPalette.Disabled, QPalette.Text, QColor(d, d, d))
        dark_palette.setColor(QPalette.Disabled, QPalette.ButtonText, QColor(d, d, d))
        dark_palette.setColor(QPalette.Disabled, QPalette.Highlight, QColor(80, 80, 80))
        dark_palette.setColor(QPalette.Disabled, QPalette.HighlightedText, QColor(d, d, d))
        QApplication.setPalette(dark_palette)

        self.fig.ax.set_facecolor([53/255]*3)
        mpl.rc('text', color=[230/255]*3)
        mpl.rc('legend', facecolor=[53/255]*3, edgecolor=[100/255]*3)
        for label in ['top', 'bottom', 'left', 'right']:
            self.fig.ax.spines[label].set_color([100/255]*3)
        self.fig.ax.tick_params(axis='both', labelcolor=[230/255]*3)
        self.fig.ax.figure.patch.set_facecolor([53/255]*3)
        self.fig.ax.set_xlabel(self.fig.ax.get_xlabel(), color=[230/255]*3)
        self.fig.ax.set_ylabel(self.fig.ax.get_ylabel(), color=[230/255]*3)
        self.fig.draw()

    def initWidgets(self):
        self.scattering_box.toggled.connect(self.fig.set_scattering_visible)
        self.absorption_box.toggled.connect(self.fig.set_absorption_visible)
        self.extinction_box.toggled.connect(self.fig.set_extinction_visible)
        self.multipole_box.toggled.connect(partial(self.fig.set_multipoles_visible,
                                  boxes=[self.absorption_box, self.extinction_box]))

        self.sphere_diameter_box.valueChanged.connect(self.set_diameter)
        self.material_box.currentTextChanged.connect(self.set_material)
        self.dielectric_box.valueChanged.connect(self.set_dielectric_index)
        self.medium_box.valueChanged.connect(self.set_medium)

        self.wavelength_min.valueChanged.connect(self.set_min_wavelength)
        self.wavelength_max.valueChanged.connect(self.set_max_wavelength)

        self.sources_box.currentTextChanged.connect(self.set_source)
        self.beam_width_box.valueChanged.connect(self.set_beam_width)

        self.geometry_box.currentTextChanged.connect(self.set_geometry)
        self.orientation_box.currentTextChanged.connect(self.set_orientation)

    def set_diameter(self, diameter):
        self.mie.radius = diameter/2*nm

        if self.autocompute_box.isChecked():
            self.compute()

    def set_dielectric_index(self, index):
        self.mie.material = miepy.constant_material(index**2)

        if self.autocompute_box.isChecked():
            self.compute()

    def set_material(self, name):
        if name != 'Dielectric':
            self.mie.material = self.materials[name]
            self.particle_layout.hide_group('Dielectric')
        else:
            self.particle_layout.show_group('Dielectric')
            eps = self.dielectric_box.value()**2
            self.mie.material = miepy.constant_material(eps)

        if self.autocompute_box.isChecked():
            self.compute()

    def set_medium(self, index):
        self.mie.medium = miepy.constant_material(index**2)

        if self.autocompute_box.isChecked():
            self.compute()

    def set_min_wavelength(self, wmin):
        self.mie.wavelengths[...] = np.linspace(wmin*nm, self.mie.wavelengths[-1], self.Nwav)

        if self.autocompute_box.isChecked():
            self.compute()

    def set_max_wavelength(self, wmax):
        self.mie.wavelengths[...] = np.linspace(self.mie.wavelengths[0], wmax*nm, self.Nwav)

        if self.autocompute_box.isChecked():
            self.compute()

    def set_beam_width(self, width):
        self.mie.source.width = width*nm

        if self.autocompute_box.isChecked():
            self.compute()

    def set_geometry(self, name):
        self.particle_layout.hide_group(self.current_geometry)
        self.particle_layout.show_group(name)
        self.current_geometry = name

        if name == 'Sphere':
            if self.should_autocompute():
                self.enable_autocompute()
            self.particle_layout.hide_group('Nonsphere')
        else:
            self.disable_autocompute()
            self.particle_layout.show_group('Nonsphere')

        if name == 'Sphere':
            self.note.setText("")
            radius = self.sphere_diameter_box.value()*nm/2
            self.mie = mie_sphere(radius, self.mie.material, self.mie.wavelengths, self.mie.source, self.mie.medium)

        elif name == 'Cylinder':
            self.note.setText("The light is Z-propagating. The cylinder's orientation is its axis of symmetry (height)")
            height = self.cylinder_height_box.value()*nm
            diameter = self.cylinder_diameter_box.value()*nm
            particle = miepy.cylinder(radius=diameter/2, height=height, material=self.mie.material,
                                   position=[0,0,0], orientation=miepy.quaternion.from_spherical_coords(np.pi/2, 0))
            self.mie = non_sphere(particle, self.mie.wavelengths, self.mie.source, self.mie.medium)

        elif name == 'Spheroid':
            self.note.setText("The light is Z-propagating. The spheroid's orientation is its axis of symmetry (axis b)")
            axis_xy = self.spheroid_axis_a_box.value()*nm
            axis_z = self.spheroid_axis_b_box.value()*nm
            particle = miepy.spheroid(axis_xy=axis_xy, axis_z=axis_z, material=self.mie.material,
                                   position=[0,0,0], orientation=miepy.quaternion.from_spherical_coords(np.pi/2, 0))
            self.mie = non_sphere(particle, self.mie.wavelengths, self.mie.source, self.mie.medium)

        if self.autocompute_box.isChecked():
            self.compute()

        if name != 'Sphere':
            self.set_orientation(self.orientation_box.currentText())


    def set_orientation(self, name):
        if name == 'X':
            orientation = miepy.quaternion.from_spherical_coords(np.pi/2, 0)
        elif name == 'Y':
            orientation = miepy.quaternion.from_spherical_coords(np.pi/2, np.pi/2)
        elif name == 'Z':
            orientation = miepy.quaternion.from_spherical_coords(0, 0)

        self.mie.particle.orientation = orientation

    def set_source(self, name):
        if name != 'Plane wave':
            self.beam_width_label.show()
            self.beam_width_box.show()
            self.disable_autocompute()
        else:
            self.beam_width_label.hide()
            self.beam_width_box.hide()
            if self.should_autocompute():
                self.enable_autocompute()

        width = self.beam_width_box.value()

        if name == 'Plane wave':
            source = miepy.sources.plane_wave([1, 0])
        elif name == 'Gaussian beam':
            source = miepy.sources.gaussian_beam(polarization=[1, 0], width=width*nm, amplitude=1)
        elif name == 'Azimuthal beam':
            source = miepy.sources.azimuthal_beam(width=width*nm, amplitude=1)
        elif name == 'Radial beam':
            source = miepy.sources.radial_beam(width=width*nm, amplitude=1)

        self.mie.source = source

        if self.autocompute_box.isChecked():
            self.compute()

    def compute(self):
        if not self.autocompute_box.isChecked():
            self.progress.setValue(0)
            self.progress.show()
            progress = self.progress
        else:
            progress = None

        sol = self.mie.solve(progress=progress)
        self.fig.update_xdata(self.mie.wavelengths, draw=False)
        self.fig.update_ydata(sol)

        self.progress.hide()

    def should_autocompute(self):
        return ((self.sources_box.currentText() == 'Plane wave') and (self.geometry_box.currentText() == 'Sphere'))

    def disable_autocompute(self):
        self.autocompute_box.setChecked(False)
        self.autocompute_box.setDisabled(True)

    def enable_autocompute(self):
        self.autocompute_box.setChecked(True)
        self.autocompute_box.setDisabled(False)


def get_multipole_label(i, j):
    label = 'e' if i == 0 else 'm'
    if j < 4:
        label += ['D', 'Q', 'O', 'H'][j] 
    else:
        label += str(j+1)

    return label

class plot_region:
    def __init__(self):
        self.canvas = FigureCanvas(Figure(figsize=(8, 6), constrained_layout=True))
        self.canvas.setMinimumHeight(200)
        self.ax = self.canvas.figure.subplots()
        self.ax.set(xlabel='wavelength (nm)', ylabel=r'cross-section ($\mu m^2$)')
        self.ax.grid(lw=.2, color='gray')
        self.lines = {}

    def update_xdata(self, wavelengths, draw=True):
        for line in self.lines.values():
            line.set_xdata(wavelengths/nm)

        if draw:
            self.draw()

    def update_ydata(self, scattering, draw=True):
        for name, scat in scattering.items():
            if name != 'multipoles':
                self.lines[name].set_ydata(scat*1e12)

        scat = scattering['multipoles']
        for i in range(scat.shape[1]):
            for j in range(scat.shape[2]):
                label = get_multipole_label(i,j)
                self.lines[label].set_ydata(scat[:,i,j]*1e12)

        if draw:
            self.draw()

    def initialize_plot(self, wavelengths, scattering):
        for name, scat in scattering.items():
            if name != 'multipoles':
                self.lines[name] = self.ax.plot(wavelengths/nm, scat*1e12, label=name)[0]
        self.legend = self.ax.legend()

        scat = scattering['multipoles']
        colors = cycle([f'C{i}' for i in range(1,9)])
        for i in range(scat.shape[1]):
            for j in range(scat.shape[2]):
                label = get_multipole_label(i,j)
                self.lines[label] = self.ax.plot(wavelengths/nm, scat[:,i,j]*1e12,
                                         label=label, visible=False, color=next(colors))[0]
        
        self.ax.set_ylim(bottom=0)
        self.canvas.draw()

    def set_scattering_visible(self, value, draw=True):
        self.lines['scattering'].set_visible(value)
        if draw:
            self.draw()

    def set_absorption_visible(self, value, draw=True):
        self.lines['absorption'].set_visible(value)
        if draw:
            self.draw()

    def set_extinction_visible(self, value, draw=True):
        self.lines['extinction'].set_visible(value)
        if draw:
            self.draw()

    def set_multipoles_visible(self, value, boxes):
        for name, line in self.lines.items():
            if name not in ('scattering', 'absorption', 'extinction'):
                line.set_visible(value)

        if value:
            self.set_absorption_visible(False, False)
            self.set_extinction_visible(False, False)
            [box.setDisabled(True) for box in boxes]
        else:
            [box.setDisabled(False) for box in boxes]
            if boxes[0].isChecked():
                self.set_absorption_visible(True, False)
            if boxes[1].isChecked():
                self.set_extinction_visible(True, False)

        self.draw()

    def draw(self):
        visible = list(zip(*list(filter(lambda l: l[1].get_visible(), self.lines.items()))))
        if visible:
            self.ax.legend(visible[1], visible[0])
        else:
            leg = self.ax.legend()
            leg.set_visible(False)

        self.ax.relim()
        self.ax.autoscale()
        self.ax.set_ylim(bottom=0)
        self.canvas.draw()

class mie_sphere:
    def __init__(self, radius, material, wavelengths, source, medium=None):
        self.radius = radius
        self.material = material
        self.wavelengths = wavelengths
        self.source = source
        self.medium = medium

    def solve(self, progress=None):
        C, A, E = [np.full(self.wavelengths.shape, np.nan, dtype=float) for i in range(3)]
        multi = np.full(self.wavelengths.shape + (2, 2), np.nan, dtype=float)

        for i, wavelength in enumerate(self.wavelengths):
            if type(self.material) is miepy.materials.create.data_material:
                wav_range = self.material.wavelength
                if wavelength < wav_range[0] or wavelength > wav_range[-1]:
                    continue

            sol = miepy.sphere_cluster(position=[0,0,0],
                                       radius=self.radius,
                                       material=self.material,
                                       source=self.source,
                                       wavelength=wavelength,
                                       lmax=2,
                                       medium=self.medium)

            C[i], A[i], E[i] = sol.cross_sections()
            multi[i] = sol.cross_sections_per_multipole().scattering

            if progress is not None:
                progress.setValue(i)

        return dict(scattering=C, absorption=A, extinction=E, multipoles=multi)

class non_sphere:
    def __init__(self, particle, wavelengths, source, medium=None):
        self.particle = particle
        self.wavelengths = wavelengths
        self.source = source
        self.medium = medium
        self.material = particle.material

    def solve(self, progress=None):
        self.particle.material = self.material
        C, A, E = [np.full(self.wavelengths.shape, np.nan, dtype=float) for i in range(3)]
        multi = np.full(self.wavelengths.shape + (2, 2), np.nan, dtype=float)

        for i, wavelength in enumerate(self.wavelengths):
            if type(self.material) is miepy.materials.create.data_material:
                wav_range = self.material.wavelength
                if wavelength < wav_range[0] or wavelength > wav_range[-1]:
                    continue

            sol = miepy.cluster(particles=self.particle,
                                source=self.source,
                                wavelength=wavelength,
                                lmax=2,
                                medium=self.medium)

            C[i], A[i], E[i] = sol.cross_sections()
            multi[i] = sol.cross_sections_per_multipole().scattering

            if progress is not None:
                progress.setValue(i)

        return dict(scattering=C, absorption=A, extinction=E, multipoles=multi)
        
if __name__ == '__main__':
    app = QApplication(sys.argv)
    ex = Example()
    sys.exit(app.exec_())
