#!/usr/bin/env python
"""
# Author: Xiong Lei
# Created Time : Wed 26 Dec 2018 03:46:19 PM CST
# File Name: batch.py
# Description:
"""

import os
import numpy as np
import pandas as pd
import scipy
from tqdm import tqdm

from torch.utils.data import Dataset
from anndata import AnnData
import scanpy as sc
from sklearn.preprocessing import maxabs_scale, MaxAbsScaler
from multiprocessing import Pool, cpu_count

from glob import glob

np.warnings.filterwarnings('ignore')
DATA_PATH = os.path.expanduser("~")+'/.scalex/'


def read_mtx(path):
    for filename in glob(path+'/*'):
        if ('count' in filename or 'matrix' in filename or 'data' in filename) and ('mtx' in filename):
            adata = sc.read_mtx(filename).T
    for filename in glob(path+'/*'):
        if 'barcode' in filename:
            barcode = pd.read_csv(filename, sep='\t', header=None).iloc[:, -1].values
            adata.obs = pd.DataFrame(index=barcode)
        if 'gene' in filename or 'peaks' in filename or 'feature' in filename:
            gene = pd.read_csv(filename, sep='\t', header=None).iloc[:, -1].values
            adata.var = pd.DataFrame(index=gene)
    return adata


def load_file(path):  
    if os.path.isdir(path): # mtx format
        adata = read_mtx(path)
    elif os.path.isfile(path):
        if path.endswith(('.csv', '.csv.gz')):
            adata = sc.read_csv(path).T
        elif path.endswith(('.txt', '.txt.gz', '.tsv', '.tsv.gz')):
            df = pd.read_csv(path, sep='\t', index_col=0).T
            adata = AnnData(df.values, dict(obs_names=df.index.values), dict(var_names=df.columns.values))
        elif path.endswith('.h5ad'):
            adata = sc.read_h5ad(path)
    else:
        raise ValueError("File {} not exists".format(path))
        
    if type(adata.X) == np.ndarray:
        adata.X = scipy.sparse.csr_matrix(adata.X)
    adata.var_names_make_unique()
    return adata


def load_files(root):
    if root.split('/')[-1] == '*':
        adata = []
        for root in sorted(glob(root)):
            adata.append(load_file(root))
        return AnnData.concatenate(*adata, batch_key='sub_batch', index_unique=None)
    else:
        return load_file(root)
    
    
def concat_data(data_list, batch_categories=None, join='inner', 
                batch_key='batch', index_unique=None, save=None):
    adata_list = []
    for root in data_list:
        adata = load_files(root)
        adata_list.append(adata)
        
    if batch_categories is None:
        batch_categories = list(map(str, range(len(adata_list))))
    else:
        assert len(adata_list) == len(batch_categories)
    concat = AnnData.concatenate(*adata_list, join=join, batch_key=batch_key,
                                batch_categories=batch_categories, index_unique=index_unique)  
    if save:
        concat.write(save, compression='gzip')
    return concat
        
    
def load_dataset(name, path=DATA_PATH):
    if os.path.exists(path+name+'.h5ad'):
        adata = sc.read_h5ad(path+name+'.h5ad')
        return adata
    else:
        raise ValueError('No such dataset named {} under {}'.format(name, path))
    
        
class SingleCellDataset(Dataset):
    """
    Dataset for dataloader
    """
    def __init__(self, adata):
        self.adata = adata
        self.shape = adata.shape
        
    def __len__(self):
        return self.adata.X.shape[0]
    
    def __getitem__(self, idx):
        x = self.adata.X[idx].toarray().squeeze()
        domain_id = self.adata.obs['batch'].cat.codes[idx]
        return x, domain_id, idx
    

def standard_scale(X, vmax=5):
    if type(X) != np.ndarray:
        X = X.toarray()
    X = scale(X)
    X[X>vmax]=vmax
    X = minmax_scale(X)
    return scipy.sparse.csr_matrix(X)


def filter_data(adata, min_genes=600, min_cells=3):
    sc.pp.filter_cells(adata, min_genes=min_genes)
    sc.pp.filter_genes(adata, min_cells=min_cells)

    
def preprocessing(adata, 
        min_genes=600, 
        min_cells=3, 
        target_sum=1e4, 
        n_top_genes=None, # or gene list
        transform=maxabs_scale,
        split=True,
    ):
    """
    preprocessing
    """
    if type(adata.X) == np.ndarray:
        adata.X = scipy.sparse.csr_matrix(adata.X)
    
    adata = adata[:, [gene for gene in adata.var_names 
                  if not str(gene).startswith(tuple(['ERCC', 'MT-', 'mt-']))]].copy()
    
    sc.pp.filter_cells(adata, min_genes=min_genes)
    sc.pp.filter_genes(adata, min_cells=min_cells)
    sc.pp.normalize_total(adata, target_sum=target_sum, exclude_highly_expressed=True)
    sc.pp.log1p(adata)
    
    if n_top_genes is not None:
        if type(n_top_genes) == int:
            if n_top_genes>0:
                print('Find highly variable genes')
                df = sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, batch_key='batch', inplace=False) #, subset=True)
                adata = transform_data(adata, transform, split=split)
                raw = adata.copy()
                adata._inplace_subset_var(df['highly_variable'].values)
        else:
            adata = transform_data(adata, transform, split=split)
            raw = adata.copy()
            if type(n_top_genes) == str:
                n_top_genes = np.loadtxt(n_top_genes, dtype=str)
            adata = reindex(adata, n_top_genes)

    adata.raw = raw 
            
    return adata
    

def transform_data(adata, transform, split=True):
    if split and len(adata.obs['batch'].unique())>1:
        for b in adata.obs['batch'].unique():
            idx = np.where(adata.obs['batch']==b)[0]
            adata.var['max_'+str(b)] = adata.X[idx].toarray().max(0)
#             chunk = 20000
#             for i in range(len(idx)//chunk+1):
#                 adata.X[idx[i*chunk:(i+1)*chunk]] = transform(adata.X[idx[i*chunk:(i+1)*chunk]])
            adata.X[idx] = transform(adata.X[idx])
    else:
        adata.var['max'] = adata.X.toarray().max(0)
        adata.X = transform(adata.X)
    return adata
        

def reindex(adata, genes):
    idx = [i for i, g in enumerate(genes) if g in adata.var_names]
    print('There are {} gene in selected genes'.format(len(idx)))
    new_X = scipy.sparse.csr_matrix((adata.shape[0], len(genes)))
    new_X[:, idx] = adata[:, genes[idx]].X
    adata = AnnData(new_X, obs=adata.obs, var={'var_names':genes}) 
    return adata
   
    
def down_sample(adata, cat='celltype', size=500):
    indices = []
    for c in adata.obs[cat].cat.categories:
        index = adata[adata.obs[cat]==c].obs_names
        idx = list(np.random.choice(index, size=min(size, len(index)), replace=False))
        indices+=idx
    return adata[indices]