from nltk.stem import PorterStemmer
import numpy as np
import pandas as pd
from nltk.corpus import stopwords
from nltk import download
from gensim.models import KeyedVectors

class word_mover_distance():
    def __init__(self, source_names, target_names, topn, model=None):
        if model is None:
            self.model = KeyedVectors.load_word2vec_format('BioWordVec_PubMed_MIMICIII_d200.vec.bin', binary=True)
        self.source_names = source_names
        self.target_names = target_names
        self.model = model
        self.topn = topn

    def match(self):
        '''
        Main match function. return only the top candidate for every source string.
        '''

        self.top_wmd_distance()

        match_output = self._make_matchdf()

        return match_output


    # def clean_data(self):
    #     download('stopwords') 
    #     self.source_names = [self.preprocess(sent) for sent in self.source_names]
    #     self.target_names = [self.preprocess(sent) for sent in self.target_names]


    # def preprocess(self, sentence):
    #     stop_words = stopwords.words('english')

    #     sentence = sentence.translate(str.maketrans('', '', punctuation))

    #     return [w for w in sentence.lower().split() if w not in stop_words]


    def min_wmd_distance(self, input):
        wmd_results = np.array([self.model.wmdistance(input, target) for target in self.target_names])
        
        wmd_sorted = np.sort(wmd_results)
        results = []
        scores = []
        for x in wmd_sorted:
            if len(results) == self.topn:
                break
            # convert distance to score
            scores.append(1 - x)
            for y in np.where(wmd_results == x):
                results.append(y[0])
                if len(results) == self.topn:
                    break
            
        targets = [self.target_names[idx] for idx in results]
        
        while len(targets) < self.topn:
            targets.append(None)
            scores.append(None)

        return targets, scores
    

    def top_wmd_distance(self):
        results = np.array([self.min_wmd_distance(input) for input in self.source_names])
        self.targets = results[:, 0]
        self.top_scores = results[:, 1]

    def _make_matchdf(self):
        ''' Build dataframe for result return '''

        match_list = []
        for source, target, top_score in zip(self.source_names, self.targets, self.top_scores):
            row = []
            row.append(source)
            if target is not None:
                for x, y in zip(target, top_score):
                    row.append(x)
                    row.append(y) 
            match_list.append(tuple(row))

        # List of tuples to dataframe
        colnames = ['source', 'prediction', 'score']
        
        for i in range(2, self.topn+1):
            colnames.append(f'prediction_{i}')
            colnames.append(f'score_{i}')

        match_df = pd.DataFrame(match_list, columns=colnames)
        
        return match_df