"""Functions for querying device data from the database."""
import pandas as pd
from .. import session, settings, ParentCell
from ..api.device_data import get_data_by_id
from ..models import Cell, Project, DeviceData, Device, Die, Wafer
from sqlalchemy.sql import ColumnElement
from sqlmodel import select
from sqlmodel.sql.expression import SelectOfScalar
from tqdm.auto import tqdm

import requests
from concurrent.futures import ProcessPoolExecutor


def _get_device_data_joined_query() -> SelectOfScalar[DeviceData]:
    return (
        select(DeviceData)
        .join(Device)
        .join(Die, isouter=True)
        .join(Cell, Device.cell_id == Cell.id)  # type: ignore[arg-type]
        .join(ParentCell, Device.parent_cell_id == ParentCell.id, isouter=True)  # type: ignore[arg-type]
        .join(Project, Project.id == Cell.project_id)  # type: ignore[arg-type]
        .join(Wafer, Wafer.id == Die.wafer_id, isouter=True)  # type: ignore[arg-type]
    )


def _get_device_data_and_frame(idx: int) -> tuple[int, pd.DataFrame]:
    return (idx, get_data_by_id(idx))


def get_data_by_query(
    clauses: list[ColumnElement[bool]] = [],
    multi_processing: bool = True,
    progress_bar: bool = False,
) -> list[tuple[DeviceData, pd.DataFrame]]:
    """Query the database for device data and return DeviceData and its raw data.

    Args:
        clauses: A list of sql expressions such as `dd.Cell.name == "RibLoss"`.
        multi_processing: Use multiple processes to download data from the API
            endpoint.
        progress_bar: Show a progress bar.
    """
    statement = _get_device_data_joined_query()

    for clause in clauses:
        statement = statement.where(clause)

    _dd = session.exec(statement).all()

    device_data = {dd.id: dd for dd in _dd}

    if multi_processing:
        with ProcessPoolExecutor(max_workers=settings.n_cores) as executor:
            try:
                mp_data = executor.map(_get_device_data_and_frame, device_data.keys())
                results: list[tuple[DeviceData, pd.DataFrame]] = []
                if progress_bar:
                    for result in tqdm(mp_data, total=len(device_data)):
                        results.append((device_data[result[0]], result[1]))
                else:
                    for result in mp_data:
                        results.append((device_data[result[0]], result[1]))
                return results
            except requests.HTTPError:
                executor.shutdown(wait=False, cancel_futures=True)
                raise
    else:
        if progress_bar:
            data = [
                _get_device_data_and_frame(idx)  # type:ignore[arg-type]
                for idx in tqdm(device_data.keys(), total=len(device_data))
            ]
            return [(device_data[_id], _data) for _id, _data in data]
        else:
            data = [_get_device_data_and_frame(idx) for idx in device_data.keys()]  # type:ignore[arg-type]
            return [(device_data[_id], _data) for _id, _data in data]
