__version__ = "1.0.4"
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/factor_reader_core.ipynb.

# %% auto 0
__all__ = ['TokenUnavailableError', 'do_on_dfs_class', 'concat_nosqls', 'FactorReader']

# %% ../nbs/factor_reader_core.ipynb 2
import requests
import pandas as pd
from typing import Iterable,Union
from functools import wraps,reduce
import json
import datetime
from tenacity import retry, stop_after_attempt,wait_fixed
import tqdm

# %% ../nbs/factor_reader_core.ipynb 4
try:
    response = requests.get("https://pypi.org/pypi/factor-reader/json", timeout=2)
    latest_version = response.json()["info"]["version"]
    now_version=get_config(cfg_name='settings.ini')['version']
    if latest_version!=now_version:
        print(f'''您当前使用的是{now_version}，最新版本为{latest_version}
              建议您使用`pip install factor_reader --upgrade`
              或`pip install -i https://pypi.tuna.tsinghua.edu.cn/simple factor_reader --upgrade`命令进行更新
              ''')
except Exception:
    try:
        response = requests.get("https://pypi.org/pypi/factor-reader/json", timeout=2)
        latest_version = response.json()["info"]["version"]
        now_version=__version__
        if latest_version!=now_version:
            print(f'''您当前使用的是{now_version}，最新版本为{latest_version}
              建议您使用`pip install factor_reader --upgrade`
              或`pip install -i https://pypi.tuna.tsinghua.edu.cn/simple factor_reader --upgrade`命令进行更新
              ''')
    except Exception:
        ...

# %% ../nbs/factor_reader_core.ipynb 5
class TokenUnavailableError(PermissionError):
    def __init__(self,error_info):
        super().__init__(self) 
        self.error_info=error_info
        
    def __str__(self):
        return self.error_info


def do_on_dfs_class(func):
    """一个装饰器，使类的方法的参数可以输入为列表或元组，返回值为分别对各个参数执行此函数后依次得到的结果"""

    @wraps(func)
    def wrapper(self, df=None, *args, **kwargs):
        if isinstance(df, list) or isinstance(df, tuple):
            dfs = [func(self, i, *args, **kwargs) for num, i in enumerate(df)]
            return dfs
        else:
            return func(self, df, *args, **kwargs)

    return wrapper

def concat_nosqls(func):
    @wraps(func)
    def wrapper(self,fac_key,*args,**kwargs):
        dfs=func(self,fac_key,*args,**kwargs)
        def double_index(res,key):
            res=res.reset_index().assign(name=self._deal_key_name(key)[1])
            res=res.set_index(['name','date'])
            return res
        if (isinstance(dfs, list) or isinstance(dfs, tuple)) and ("date" not in dfs[0].columns):
            dfs=[double_index(i,self._deal_key_name(key)[1]) for i,key in zip(dfs,fac_key)]
            dfs=pd.concat(dfs)
        elif (isinstance(dfs, list) or isinstance(dfs, tuple)) and ("date" in dfs[0].columns):
            dfs = reduce(
                lambda x, y: pd.merge(x, y, on=["date", "code"], how="outer"), dfs
            )
        elif "date" not in dfs.columns:
            dfs=double_index(dfs,fac_key)
        return dfs
    return wrapper



# %% ../nbs/factor_reader_core.ipynb 6
class FactorReader():
    def __init__(self,token:str) -> None:
        self.token=token
        self.host='http://43.143.223.158:1837/'
        infos=self.show_all_factors_information()
        self.keys=list(infos.数据键名)
        self.names=list(infos.因子名称)
        self.keys_names = {k: v for k, v in zip(self.keys, self.names)}
        self.names_keys = {k: v for k, v in zip(self.names, self.keys)}
        
    def _deal_key_name(self,input_key):
        names_related = [i for i in self.names if input_key in i]
        if input_key in self.keys:
            fac_key=input_key
            fac_name=self.keys_names[fac_key]
        elif len(names_related) > 0:
            fac_name = names_related[0]
            fac_key = self.names_keys[fac_name]
        else:
            raise ValueError(
                "输入的fac_key参数有误，请输入因子名称或因子键名，可通过`factor_reader.show_all_factors_information`函数来查看可用的因子名称和键名"
            )
        return fac_key,fac_name
    
    @retry(stop=stop_after_attempt(100),wait=wait_fixed(3))
    def show_all_factors_information(self) -> pd.DataFrame:
        url=self.host+self.token+'/info'
        df=requests.get(url).text
        if df=='您的token暂不可用，请检查核验后重试。或联系方正金工团队开通可用token~':
            raise TokenUnavailableError('您的token暂不可用，请检查核验后重试。或联系方正金工团队开通可用token~')
        df=pd.read_json(json.loads(df))
        return df
    
    @concat_nosqls
    @do_on_dfs_class
    def read_factor(
        self,
        fac_key: str,
        trade_date: Union[int,str,datetime.datetime] = None,
        start_date: Union[int,str,datetime.datetime] = None,
        end_date: Union[int,str,datetime.datetime] = None,
        sql_like: bool = False,
    ) -> pd.DataFrame:
        """通过表名，读取因子数据

        Parameters
        ----------
        fac_key : str
            表的名称或因子的名称
        trade_date : Union[int,str,datetime.datetime], optional
            读取单日因子值，形如20230106或'20230106'或'2023-01-13'或pd.Timestamp('2023-01-13')，指定此参数时，start_date和end_date两个参数将失效, by default None
        start_date : Union[int,str,datetime.datetime], optional
            读取因子值的起始日期，形如20230106或'20230106'或'2023-01-13'或pd.Timestamp('2023-01-13'), by default None
        end_date : Union[int,str,datetime.datetime], optional
            读取因子值的终止日期，形如20230106或'20230106'或'2023-01-13'或pd.Timestamp('2023-01-13'), by default None
        sql_like : bool, optional
            返回的数据为形如sql中的长表，包括日期、股票代码、因子值三列, by default False

        Returns
        -------
        pd.DataFrame
            因子值，index为每天的日期，columns为股票代码，values为因子值
        """
        def date_ok(x):
            if isinstance(x,int):
                date=pd.Timestamp(str(x))
            elif isinstance(x,float):
                date=pd.Timestamp(str(int(x)))
            else:
                date=pd.Timestamp(x)
            date=datetime.datetime.strftime(date,'%Y%m%d')
            return date
        
        '''名称部分处理'''
        fac_key,fac_name=self._deal_key_name(fac_key)
        print(f"正在读取{fac_name}的数据")
        
        '''日期部分处理'''
        if trade_date is not None:
            trade_date=date_ok(trade_date)
            start_date='None'
            end_date='None'
            cuts=None
        elif (start_date is not None) and (end_date is not None):
            start_date=date_ok(start_date)
            end_date=date_ok(end_date)
            dura=pd.Timestamp(end_date)-pd.Timestamp(start_date)
            if dura>pd.Timedelta(days=30):
                cuts=pd.date_range(pd.Timestamp(start_date),pd.Timestamp(end_date))
                cuts=list(pd.Series([start_date]+[datetime.datetime.strftime(i,'%Y%m%d') for i in list(cuts)]+[end_date]).sort_values().drop_duplicates())
            else:
                cuts=None
            trade_date='None'
        elif start_date is not None:
            start_date=date_ok(start_date)
            trade_date='None'
            end_date='None'
            if dura>pd.Timedelta(days=30):
                cuts=pd.date_range(pd.Timestamp(start_date),datetime.datetime.now())
                cuts=list(pd.Series([start_date]+[datetime.datetime.strftime(i,'%Y%m%d') for i in list(cuts)+[datetime.datetime.now()]]).sort_values().drop_duplicates())
            else:
                cuts=None
        elif end_date is not None:
            end_date=date_ok(end_date)
            trade_date='None'
            start_date='None'
            cuts=None
        else:
            raise ValueError('请至少指定trade_date、start_date、end_date参数中的一个')
        
        @retry(stop=stop_after_attempt(100),wait=wait_fixed(3))
        def read_in(trade_date):
            url=self.host+self.token+'/'+fac_key+'/'+trade_date+'/'+start_date+'/'+end_date+'/'+str(int(sql_like))
            res=requests.get(url).text
            self.res=res
            if res=='您的token暂不可用，请检查核验后重试。或联系方正金工团队开通可用token~':
                raise TokenUnavailableError('您的token暂不可用，请检查核验后重试。或联系方正金工团队开通可用token~')
            res=pd.read_json(json.loads(res))
            if 'timestamp' in res.columns:
                res=res.drop(columns=['timestamp'])
            if not sql_like:
                res=res.drop_duplicates(subset=['date','code']).pivot(index='date',columns='code',values='fac')
            else:
                res=res.rename(columns={'fac':self.keys_names[fac_key]})
            return res
        
        if cuts is None:
            return read_in(trade_date=trade_date)
        else:
            print('您读取的数据过长，正在分段读取，请稍候……')
            ress=[]
            for trade_in in tqdm.tqdm(cuts):
                ress.append(read_in(trade_in))
            print('读取完成，正在拼接，请稍等')
            res=pd.concat(ress)
            return res
