"""
This module contains four napari widgets declared in
different ways:

- a pure Python function flagged with `autogenerate: true`
    in the plugin manifest. Type annotations are used by
    magicgui to generate widgets for each parameter. Best
    suited for simple processing tasks - usually taking
    in and/or returning a layer.
- a `magic_factory` decorated function. The `magic_factory`
    decorator allows us to customize aspects of the resulting
    GUI, including the widgets associated with each parameter.
    Best used when you have a very simple processing task,
    but want some control over the autogenerated widgets. If you
    find yourself needing to define lots of nested functions to achieve
    your functionality, maybe look at the `Container` widget!
- a `magicgui.widgets.Container` subclass. This provides lots
    of flexibility and customization options while still supporting
    `magicgui` widgets and convenience methods for creating widgets
    from type annotations. If you want to customize your widgets and
    connect callbacks, this is the best widget option for you.
- a `QWidget` subclass. This provides maximal flexibility but requires
    full specification of widget layouts, callbacks, events, etc.

References:
- Widget specification: https://napari.org/stable/plugins/guides.html?#widgets
- magicgui docs: https://pyapp-kit.github.io/magicgui/

Replace code below according to your needs.
"""
import napari
import numpy as np
import pyqtgraph as pg
from pyqtgraph.exporters import ImageExporter  # 追加
from qtpy.QtWidgets import QWidget, QVBoxLayout, QSpinBox, QLabel, QHBoxLayout, QLineEdit, QFileDialog, QCheckBox, QPushButton, QFrame
from napari.utils.notifications import show_info
import os
import csv

class IntensityPlotControlWidget(QWidget):
    def __init__(self, viewer: napari.viewer.Viewer):
        super().__init__()
        self.viewer = viewer
        self.square_size = 3
        self.save_csv = True
        self.save_png = False
        
        # ラベルとスピンボックスを作成
        layout = QVBoxLayout()
        self.setLayout(layout)
        self.add_separator(layout)  # 区切り線を追加
        layout.addWidget(QLabel("Square Size"))

        self.square_spinbox = QSpinBox()
        self.square_spinbox.setValue(self.square_size)
        self.square_spinbox.setRange(1, 511)
        self.square_spinbox.setSingleStep(2)  # ステップを2にして奇数のみを選択可能に
        self.square_spinbox.valueChanged.connect(self.update_square_size)
        layout.addWidget(self.square_spinbox)
        
        # ディレクトリ入力フィールドとボタン
        self.add_separator(layout)  # 区切り線を追加
        self.save_path_input = QLineEdit()
        layout.addWidget(QLabel("Save Directory:"))
        layout.addWidget(self.save_path_input)

        select_dir_button = QPushButton("Select Directory")
        select_dir_button.clicked.connect(self.select_directory)
        layout.addWidget(select_dir_button)

        # 初期ディレクトリを設定（ユーザーが未設定の場合）
        self.save_path_input.setText(os.path.expanduser("~/Desktop"))
        
        # CSVとPNG保存のチェックボックス
        self.csv_checkbox = QCheckBox("Save as CSV")
        self.csv_checkbox.setChecked(self.save_csv)  # デフォルトでオン
        self.csv_checkbox.stateChanged.connect(self.update_save_csv)
        layout.addWidget(self.csv_checkbox)

        self.png_checkbox = QCheckBox("Save as PNG")
        self.png_checkbox.setChecked(self.save_png)  # デフォルトでオフ
        self.png_checkbox.stateChanged.connect(self.update_save_png)
        layout.addWidget(self.png_checkbox)
        
        # Save to CSV/PNG ボタンとショートカット表示
        save_button_layout = QHBoxLayout()
        save_button = QPushButton("Save to CSV/PNG")
        save_button.clicked.connect(self.save_to_csv)
        save_button_layout.addWidget(save_button)
        save_button_layout.addWidget(QLabel("Ctrl+S"))
        layout.addLayout(save_button_layout)
        
        # Hide All Layers ボタンとショートカット表示
        self.add_separator(layout)  # 区切り線を追加
        hide_button_layout = QHBoxLayout()
        hide_button = QPushButton("Hide All Layers")
        hide_button.clicked.connect(self.hide_all_layers)
        hide_button_layout.addWidget(hide_button)
        hide_button_layout.addWidget(QLabel("Ctrl+D"))
        layout.addLayout(hide_button_layout)
        
        # Focus on Visible Layer ボタンとショートカット表示
        focus_button_layout = QHBoxLayout()
        focus_button = QPushButton("Auto Scale")
        focus_button.clicked.connect(self.focus_on_visible_layer)
        focus_button_layout.addWidget(focus_button)
        focus_button_layout.addWidget(QLabel("Ctrl+A"))
        layout.addLayout(focus_button_layout)
        self.add_separator(layout)  # 区切り線を追加
        
    def add_separator(self, layout):
        """レイアウトに区切り線を追加する"""
        separator = QWidget()
        separator.setFixedHeight(2)
        separator.setStyleSheet("background-color: #414851;")
        layout.addWidget(separator)
        
    def select_directory(self):
        """保存先のディレクトリを選択"""
        dir_path = QFileDialog.getExistingDirectory(self, "Select Directory")
        if dir_path:
            self.save_path_input.setText(dir_path)
            self.update_save_directory()

    def update_square_size(self, value):
        """squareのサイズを更新し、IntensityPlotWidgetに反映"""
        # 偶数が入力された場合は修正して警告を表示
        if value % 2 == 0:
            value -= 1  # 1小さい奇数に修正
            self.square_spinbox.setValue(value)  # スピンボックスの値を修正
            show_info("Square size must be odd. Adjusted to the nearest odd number.")
            
        self.square_size = value
        # IntensityPlotWidget内のsquareのサイズを更新
        for widget in self.viewer.window._dock_widgets.values():
            if isinstance(widget.widget(), IntensityPlotWidget):
                widget.widget().update_square(value)

    def update_save_directory(self):
        """保存ディレクトリを更新し、IntensityPlotWidgetに反映"""
        save_dir = self.get_save_directory()
        for widget in self.viewer.window._dock_widgets.values():
            if isinstance(widget.widget(), IntensityPlotWidget):
                widget.widget().update_save_directory(save_dir)
            
    def get_save_directory(self):
        """保存先のディレクトリを取得"""
        return self.save_path_input.text()
    
    def update_save_csv(self, state):
        """CSV保存の状態を更新し、IntensityPlotWidgetに反映"""
        self.save_csv = state == 2
        for widget in self.viewer.window._dock_widgets.values():
            if isinstance(widget.widget(), IntensityPlotWidget):
                widget.widget().update_save_csv(self.save_csv)

    def update_save_png(self, state):
        """PNG保存の状態を更新し、IntensityPlotWidgetに反映"""
        self.save_png = state == 2
        for widget in self.viewer.window._dock_widgets.values():
            if isinstance(widget.widget(), IntensityPlotWidget):
                widget.widget().update_save_png(self.save_png)
                
    def hide_all_layers(self):
        """すべてのレイヤーを非表示にする"""
        for widget in self.viewer.window._dock_widgets.values():
            if isinstance(widget.widget(), IntensityPlotWidget):
                widget.widget().hide_all_layers()

    def focus_on_visible_layer(self):
        """表示中のレイヤーにフォーカスを当てる"""
        for widget in self.viewer.window._dock_widgets.values():
            if isinstance(widget.widget(), IntensityPlotWidget):
                widget.widget().focus_on_visible_layer()
                
    def save_to_csv(self):
        """CSV/PNG保存を実行する"""
        for widget in self.viewer.window._dock_widgets.values():
            if isinstance(widget.widget(), IntensityPlotWidget):
                widget.widget().save_to_csv()
                
class IntensityPlotWidget(QWidget):
    def __init__(self, viewer: napari.viewer.Viewer):
        super().__init__()
        self.viewer = viewer
        self.square = 3  # 初期値
        self.intensity_data = None
        self.layer_name = None
        self.clicked_coords = (0, 0)
        self.save_directory = os.path.expanduser("~/Desktop")  # 初期の保存先
        self.save_csv = True  # CSV保存を有効化
        self.save_png = False  # PNG保存を無効化
        
        # PyQtGraphのプロットウィジェットを作成
        self.plot = pg.PlotWidget()
        self.plot.setBackground("#FFFFFF00")
        self.plot.getAxis('left').setPen(pg.mkPen(color='w', width=2))
        self.plot.getAxis('bottom').setPen(pg.mkPen(color='w', width=2))
        self.plot.showAxis('top')
        self.plot.showAxis('right')
        self.plot.getAxis('top').setPen(pg.mkPen(color='w', width=2))
        self.plot.getAxis('right').setPen(pg.mkPen(color='w', width=2))
        font = pg.QtGui.QFont('Arial', 18)
        self.plot.getAxis('bottom').setTickFont(font)
        self.plot.getAxis('left').setTickFont(font)
        self.plot.getAxis('left').setTextPen('w')
        self.plot.getAxis('bottom').setTextPen('w')
        self.plot.getAxis('top').setTextPen("#FFFFFF00")
        self.plot.getAxis('right').setTextPen("#FFFFFF00")
        self.plot.getAxis('bottom').setLabel('Slice', color='white', font='Arial', **{'font-size': '18pt'})
        self.plot.getAxis('left').setLabel('Intensity', color='white', font='Arial', **{'font-size': '18pt'})
    
        # レイアウトにプロットを追加
        layout = QVBoxLayout()
        self.setLayout(layout)
        layout.addWidget(self.plot)

        # コールバックを設定（重複登録防止）
        if self.update_plot not in self.viewer.mouse_drag_callbacks:
            self.viewer.mouse_drag_callbacks.append(self.update_plot)
        if self.on_click not in self.viewer.mouse_drag_callbacks:
            self.viewer.mouse_drag_callbacks.append(self.on_click)
            
        # イメージレイヤーが追加されるたびにClicked Pointレイヤーを再作成
        self.viewer.layers.events.inserted.connect(self.on_new_layer)

        # 初期のClicked Pointレイヤーを作成
        self.create_clicked_point_layer()
        
        # ショートカットキーにCSV保存機能をバインド
        self.viewer.bind_key('Ctrl+S', self.save_to_csv)
        # ショートカットキーに機能をバインド
        self.viewer.bind_key('Ctrl+D', self.hide_all_layers)
        self.viewer.bind_key('Ctrl+A', self.focus_on_visible_layer)
        
    def update_square(self, new_size):
        """Squareのサイズを更新"""
        self.square = new_size
        
    def update_save_directory(self, directory):
        """保存先ディレクトリを更新"""
        self.save_directory = directory
        
    def update_save_csv(self, enabled):
        """CSV保存の有効/無効を更新"""
        self.save_csv = enabled

    def update_save_png(self, enabled):
        """PNG保存の有効/無効を更新"""
        self.save_png = enabled
        
    def create_clicked_point_layer(self):
        """Clicked Pointレイヤーを削除して再作成する"""
        if 'Clicked Point' in self.viewer.layers:
            self.viewer.layers.remove('Clicked Point')

        self.shapes_layer = self.viewer.add_shapes(name='Clicked Point', visible=False)

    def on_new_layer(self, event):
        """新たなレイヤーが追加されたときにClicked Pointレイヤーを再作成"""
        new_layer = event.value
        if isinstance(new_layer, napari.layers.Image):
            self.create_clicked_point_layer()

    def update_plot(self, viewer, event):
        """マウスポインタの位置で強度プロファイルを更新"""
        coordinates = self.viewer.cursor.position
        t, y, x = [round(coord) for coord in coordinates]

        visible_image_layers = [
            layer for layer in self.viewer.layers if isinstance(layer, napari.layers.Image) and layer.visible
        ]
        if len(visible_image_layers) == 1:
            visible_image_layer = visible_image_layers[0]
            self.layer_name = visible_image_layer.name  # プロットされたデータのレイヤー名を保存
            self.clicked_coords = (x, y)  # クリックされた座標を保存
            intensity0 = np.zeros(len(visible_image_layer.data))

            # ピクセル四方の平均強度を取得
            for i in range(-self.square // 2 + 1, self.square // 2 + 1):
                for j in range(-self.square // 2 + 1, self.square // 2 + 1):
                    intensity0 += visible_image_layer.data[:, y + i, x + j]
            intensity0 = intensity0 / self.square ** 2
            self.intensity_data = intensity0  # データを保持
            self.plot.clear()
            self.plot.plot(intensity0, pen=pg.mkPen(color='w', width=1.5), antialias=True)
        else:
            show_info("More than one image layer is visible.")

    def create_rectangle(self, center):
        """指定された中心座標に基づいて四角形を生成します。"""
        t, y, x = [round(coord) for coord in center]

        top_left = [y - self.square / 2 - 0.5, x - self.square / 2 - 0.5]
        bottom_right = [y + self.square / 2 + 0.5, x + self.square / 2 + 0.5]

        return [top_left, [top_left[0], bottom_right[1]], bottom_right, [bottom_right[0], top_left[1]]]

    def on_click(self, viewer, event):
        """クリック位置に四角形を描画し、プロットを更新する"""
        if event.type == 'mouse_press':  # クリックイベントのみに限定
            self.shapes_layer.visible = False
            self.shapes_layer.data = []

            # マウスポインタの座標を取得
            coordinates = self.viewer.cursor.position
            rectangle = self.create_rectangle(coordinates)
            self.shapes_layer.add_rectangles([rectangle],
                                            edge_color='lime', face_color='transparent')
            self.shapes_layer.visible = True

            # プロットも更新
            self.update_plot(viewer, event)

    def save_to_csv(self, event=None):
        """プロットデータをCSVとして保存"""
        if self.intensity_data is None:
            show_info("No data to save.")
            return

        if not self.layer_name:
            show_info("Layer name not available.")
            return

        if not os.path.isdir(self.save_directory):
            show_info("Invalid save directory. Please set a valid directory in the square control widget.")
            return

        # ファイル名にイメージレイヤーの名前とクリックした座標を追加
        x, y = self.clicked_coords
        base_name = f"{self.layer_name}_y{y}_x{x}"
        csv_path = os.path.join(self.save_directory, base_name + ".csv")
        png_path = os.path.join(self.save_directory, base_name + ".png")

        # CSVを保存する場合
        if self.save_csv:
            with open(csv_path, 'w', newline='') as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow(['Slice', 'Intensity'])
                for idx, value in enumerate(self.intensity_data):
                    writer.writerow([idx, value])
            show_info(f"Data saved to {csv_path}")

        # PNGを保存する場合
        if self.save_png:
            # 現在の設定を保存
            original_left_axis_pen = self.plot.getAxis('left').pen()
            original_bottom_axis_pen = self.plot.getAxis('bottom').pen()
            original_top_axis_pen = self.plot.getAxis('top').pen()
            original_right_axis_pen = self.plot.getAxis('right').pen()
            original_left_text_pen = self.plot.getAxis('left').textPen()
            original_bottom_text_pen = self.plot.getAxis('bottom').textPen()
            original_plot_data = self.plot.plotItem.listDataItems()

            # 黒のペンと白の背景を設定
            black_pen = pg.mkPen(color='k', width=2)
            self.plot.getAxis('left').setPen(black_pen)
            self.plot.getAxis('bottom').setPen(black_pen)
            self.plot.getAxis('top').setPen(black_pen)
            self.plot.getAxis('right').setPen(black_pen)
            self.plot.getAxis('left').setTextPen('k')
            self.plot.getAxis('bottom').setTextPen('k')
            self.plot.getAxis('bottom').setLabel('Slice', color='black', font='Arial', **{'font-size': '18pt'})
            self.plot.getAxis('left').setLabel('Intensity', color='black', font='Arial', **{'font-size': '18pt'})

            # プロット線を黒に変更
            for item in original_plot_data:
                item.setPen(pg.mkPen('k', width=1.5))
                
            # プロットを保存
            exporter = ImageExporter(self.plot.plotItem)
            exporter.export(png_path)

            # 元の設定に戻す
            self.plot.getAxis('left').setPen(original_left_axis_pen)
            self.plot.getAxis('bottom').setPen(original_bottom_axis_pen)
            self.plot.getAxis('top').setPen(original_top_axis_pen)
            self.plot.getAxis('right').setPen(original_right_axis_pen)
            self.plot.getAxis('left').setTextPen(original_left_text_pen)
            self.plot.getAxis('bottom').setTextPen(original_bottom_text_pen)
            self.plot.getAxis('bottom').setLabel('Slice', color='white', font='Arial', **{'font-size': '18pt'})
            self.plot.getAxis('left').setLabel('Intensity', color='white', font='Arial', **{'font-size': '18pt'})
            
            # プロット線を白に変更
            for item in original_plot_data:
                item.setPen(pg.mkPen('w', width=1.5))
                
            show_info(f"Plot saved to {png_path}")
            
    def hide_all_layers(self, event=None):
        """すべてのレイヤーを非表示にする"""
        for layer in self.viewer.layers:
            layer.visible = False

    def focus_on_visible_layer(self, event=None):
        """表示中のレイヤーにフォーカスを当てる"""
        visible_image_layers = [layer for layer in self.viewer.layers if isinstance(
            layer, napari.layers.Image) and layer.visible]
        if len(visible_image_layers) > 1:
            show_info("More than one image layer is visible.")
        elif visible_image_layers:
            visible_image_layer = visible_image_layers[0]
            center = np.mean(visible_image_layer.extent.data, axis=0)
            self.viewer.camera.center = center[1:3]
            range = np.ptp(visible_image_layer.extent.data, axis=0)
            scale_factor = min(
                self.viewer.window.qt_viewer.canvas.size) / max(range[1:3])
            self.viewer.camera.zoom = scale_factor
            self.viewer.layers.selection.active = visible_image_layer
