"""数据加载器"""
import abc
import datetime
import polars as pl
from pathlib import Path
from typing import List, Union
from tqdm import tqdm

Date = Union[float, str, datetime.date, datetime.datetime]
PathType = Union[str, Path]


class vxDataLoader(abc.ABC):
    """
    vxDataLoader is designed for loading raw data from original data source.
    """

    __cache_dir__ = Path("~/.vxdataset/cache/")

    @classmethod
    def set_cache_dir(cls, cache_dir: PathType) -> None:
        if isinstance(cache_dir, str):
            cache_dir = Path(cache_dir)

        cache_dir.mkdir(parents=True, exist_ok=True)
        cls.__cache_dir__ = cache_dir

    @abc.abstractmethod
    def _load_basics(
        self,
        symbols: List[str],
        freq: str = "day",
        start_date: Date = None,
        end_date: Date = None,
    ) -> pl.DataFrame:
        """加载基础量价信息

        Arguments:
            symbols {list} -- 证券代码列表
            start_date {Date} -- 开始日期
            end_date {Date} -- 结束日期

        Returns:
            pl.Dateframe: ['date','symbol','freq','open','high','low','close','amount','volume','factor','instert']

        """
        pass

    @abc.abstractmethod
    def _load_indicators(
        self,
        symbols: List[str],
        freq: str = "day",
        start_date: Date = None,
        end_date: Date = None,
    ) -> pl.DataFrame:
        """加载不同证券类型的指标

        Arguments:
            symbols {list} -- 证券代码列表，此处所有的证券代码的类型必须一致
            start_date {Date} -- 开始日期
            end_date {Date} -- 结束日期
        """
        pass

    def load(
        self,
        symbols: List,
        freq: str = "day",
        start_date: Date = None,
        end_date: Date = None,
    ) -> pl.DataFrame:
        """
        load the data as pl.DataFrame.

        Example of the data :

            .. code-block:: python

                datetime    instrument  close      volume            high              ma3          oml    ret

                2010-01-04  SH600000    81.807068  17145150.0       83.737389        83.016739    2.741058  0.0032
                2010-01-04  SH600004    13.313329  11800983.0       13.313329        13.317701    0.183632  0.0042
                2010-01-04  SH600005    37.796539  12231662.0       38.258602        37.919757    0.970325  0.0289


        Parameters
        ----------
        instruments : str or dict
            it can either be the market name or the config file of instruments generated by InstrumentProvider.
        start_time : str
            start of the time range.
        end_time : str
            end of the time range.

        Returns
        -------
        pl.DataFrame:
            data load from the under layer source
        """
        basics_cache_file = Path(self.__cache_dir__, freq, "basics.parquet")
        if basics_cache_file.exists():
            basics_data = pl.read_parquet(basics_cache_file)
        else:
            basics_data = pl.DataFrame([], columns=["date", "symbol", "freq"])

        return pl.concat(
            [self.load_one_df(symbol, start_date, end_date) for symbol in tqdm(symbols)]
        )


if __name__ == "__main__":
    d = vxDataLoader()
