import collections
import numpy as np
import random
import json


class DocData:
    def __init__(self, filename, min_count=5, ns_rate=2, with_generator=False,column_name='clean_text'):
        self.docs, self.cntr = self.read_docs_file(filename,column_name)
        self.n_docs = len(self.docs)
        self.n_words = sum([len(x) for x in self.docs])
        self.vocab_size = np.nan
        self.ns_rate = ns_rate
        self.min_count = min_count
        self.prepare(with_generator=with_generator)
        
    def read_docs_file(self, filename,column_name, lowercase=True):
        data = []
        cntr = collections.defaultdict(lambda: 0)
        f = filename
        for i in range(len(f)):
            line =f.iloc[i][column_name]
#             line=line.to_string()
#             print(type(line))
            line=line.lower()
            data.append(line.strip().split())
            for token in data[-1]:
                cntr[token] += 1
        
#         print(data,cntr)
        return data, cntr

    def prepare(self, replace=False, with_generator=False):
        self.token2idx = collections.defaultdict(lambda: len(self.token2idx))
        self.token2idx.update({token: i for i, token in enumerate(
            [token for token, cnt in self.cntr.items() if cnt > self.min_count]
        )})
        self.idx2token = {i: token for token, i in self.token2idx.items()}
        self.vocab_size = len(self.token2idx)
        if with_generator:
            return
        
        self.input_docs, self.input_tokens, self.outputs = [], [], []
        for doc_id, tokens in enumerate(self.docs):
            if doc_id % 100 == 0:
#                 print("\rPreparing data: %d%%" % ((doc_id+1)/len(self.docs)*100+1), end='', flush=True)
                token_ids = [self.token2idx[token] for token in tokens if self.cntr[token] > self.min_count]
            for i, idx in enumerate(token_ids):
                self.input_tokens.append(idx)
                self.input_tokens += [random.randint(1, self.vocab_size-1) for x in range(self.ns_rate)]
                self.input_docs += [doc_id]*(self.ns_rate+1)
                self.outputs += [1]+[0]*self.ns_rate

# 		print()
        self.input_docs = np.array(self.input_docs, dtype="int32")
        self.input_tokens = np.array(self.input_tokens, dtype="int32")
        self.outputs = np.array(self.outputs)
        
        if replace:
            del self.docs
            
    
    def count_cooccs(self, save_to=None):
        self.cocntr = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
        for tokens in self.docs:
            for i, token1 in enumerate(tokens[:-1]):
                for token2 in tokens[i+1:min(i+110,len(tokens))]:
                    t1, t2 = sorted([token1, token2])
                    self.cocntr[t1][t2] += 1
                    
                    
        if save_to:
            json.dump([self.cntr, self.cocntr], open(save_to, 'w'))


    def load_cooccs(self, filename):
        cntr, self.cocntr = json.load(open(filename))
        self.cntr.update(cntr)
        
        
        
        


class DocDatalstm:
    def __init__(self, filename, min_count=5, ns_rate=2, with_generator=False,max_len=100,column_name='clean_text'):
        self.docs, self.cntr = self.read_docs_file(filename,column_name)
        self.n_docs = len(self.docs)
        self.n_words = sum([len(x) for x in self.docs])
        self.vocab_size = np.nan
        self.ns_rate = ns_rate
        self.min_count = min_count
        self.max_len_text=max_len
        self.prepare(with_generator=with_generator)

    
    def read_docs_file(self, filename,column_name, lowercase=True):
        data = []
        cntr = collections.defaultdict(lambda: 0)
        f = filename
        for i in range(len(f)):
            line =f.iloc[i][column_name]
#             line=line.to_string()
#             print(type(line))
            line=line.lower()
            data.append(line.strip().split())
            for token in data[-1]:
                cntr[token] += 1
        
        return data, cntr
    
    def prepare(self, replace=False, with_generator=False):
        """ Prepare training data and vocabulary mappings from documents """
        self.token2idx = collections.defaultdict(lambda: len(self.token2idx))
        self.token2idx.update({token: i for i, token in enumerate(
        [ token for token, cnt in self.cntr.items() if cnt > self.min_count]
        )})
    
        self.idx2token = {i: token for token, i in self.token2idx.items()}
        self.vocab_size = len(self.token2idx)
#         print("Vocabulary size: %d" % self.vocab_size)
    
        if with_generator:
            return
      
        self.input_docs, self.input_tokens, self.outputs = [], [], []
        for doc_id, tokens in enumerate(self.docs):
            if doc_id % 100 == 0:
#                 print("\rPreparing data: %d%%" % ((doc_id+1)/len(self.docs)*100+1), end='', flush=True)
				# Filter tokens by frequency and map them to IDs (creates mapping table on the fly)
                token_ids = [self.token2idx[token] for token in tokens if self.cntr[token] > self.min_count]
            for i, idx in enumerate(token_ids):
                self.input_tokens.append(idx)
                self.input_tokens += [random.randint(1, self.vocab_size-1) for x in range(self.ns_rate)]
                self.input_docs += [doc_id]*(self.ns_rate+1)
                self.outputs += [1]+[0]*self.ns_rate
        

        mainln=len(self.input_docs)
        extr=mainln%self.max_len_text
        for i in range(extr):
            self.input_docs.pop(0)
        for i in range(extr):
             self.input_tokens.pop(0)
        for i in range(extr):
            self.outputs.pop(0)

        self.outdata=[]
        ottk=0
        while ottk<len(self.outputs):
            tmp=0
            ik=ottk
            while ik<min(ottk+self.max_len_text,len(self.outputs)):
                tmp+=self.outputs[ik]
                ik=ik+1
            tmp/=self.max_len_text
            self.outdata.append(tmp)
            ottk=ottk+self.max_len_text

        # self.outdata=np.array(self.outdata)
        self.outputs=self.outdata
        self.input_docs = np.array(self.input_docs, dtype="int32")
        self.input_tokens = np.array(self.input_tokens, dtype="int32")
        self.outputs = np.array(self.outputs)
#         print(self.outputs.shape)
#         print(self.input_docs.shape)
#         print(self.input_tokens.shape)
    
    #self.idx2token = dict([(i,t) for t,i in self.token2idx.items()])
        if replace:
            del self.docs
      
      
    def count_cooccs(self, save_to=None):
        """ Count word co-occurrences for PMI coherence evaluation """
        self.cocntr = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
# 		print("Counting word co-occurrences...")
        for tokens in self.docs:
            for i, token1 in enumerate(tokens[:-1]):
                for token2 in tokens[i+1:min(i+110,len(tokens))]:
                    t1, t2 = sorted([token1, token2])
                    self.cocntr[t1][t2] += 1
      
        if save_to:
            json.dump([self.cntr, self.cocntr], open(save_to, 'w'))
      
    def load_cooccs(self, filename):
        """ Load word co-occurrence counts for PMI coherence evaluation """
# 		print("Loading word co-occurrence data...")
        cntr, self.cocntr = json.load(open(filename))
        self.cntr.update(cntr)