import typing
from pathlib import Path

import pandas as pd

import mzcn
from mzcn.engine.base_task import BaseTask


def load_data(
    stage: str = 'train',
    task: typing.Union[str, BaseTask] = 'ranking',
    return_classes: bool = False
) -> typing.Union[mzcn.DataPack, typing.Tuple[mzcn.DataPack, list]]:
    """
    Load toy data.

    :param stage: One of `train`, `dev`, and `test`.
    :param task: Could be one of `ranking`, `classification` or a
        :class:`mzcn.engine.BaseTask` instance.
    :param return_classes: `True` to return classes for classification task,
        `False` otherwise.

    :return: A DataPack unless `task` is `classificiation` and `return_classes`
        is `True`: a tuple of `(DataPack, classes)` in that case.

    Example:
        >>> import mzcn as mz
        >>> stages = 'train', 'dev', 'test'
        >>> tasks = 'ranking', 'classification'
        >>> for stage in stages:
        ...     for task in tasks:
        ...         _ = mz.datasets.toy.load_data(stage, task)
    """
    if stage not in ('train', 'dev', 'test'):
        raise ValueError(f"{stage} is not a valid stage."
                         f"Must be one of `train`, `dev`, and `test`.")

    path = Path(__file__).parent.joinpath(f'{stage}.csv')
    data_pack = mzcn.pack(pd.read_csv(path, index_col=0), task)

    if task == 'ranking' or isinstance(task, mzcn.tasks.Ranking):
        return data_pack
    elif task == 'classification' or isinstance(
            task, mzcn.tasks.Classification):
        if return_classes:
            return data_pack, [False, True]
        else:
            return data_pack
    else:
        raise ValueError(f"{task} is not a valid task."
                         f"Must be one of `Ranking` and `Classification`.")


def load_embedding():
    path = Path(__file__).parent.joinpath('embedding.2d.txt')
    return mzcn.embedding.load_from_file(path, mode='glove')
