# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/0.5_mgmnt.prep.traceability.ipynb (unless otherwise specified).

__all__ = ['get_gt_links', 'get_non_gt', 'add_wmd']

# Cell
# Imports
import pandas as pd
import random
import sentencepiece as sp

from fastprogress.fastprogress import master_bar
from functools import partial
from pathlib import Path
from tqdm.notebook import tqdm

# Cell
def get_gt_links(path, language):
    links_df = pd.DataFrame([], columns = [
        'sys', 'src_type', 'trgt_type', 'src_file', 'trgt_file'
    ])
    for fn in path.glob('*.txt'):
        content = str(fn.name).split('.')[0][1:-1]
        content = content.split('-')
        sys, src_type, trgt_type = content[0], content[2], content[4]

        with open(fn) as f:
            links = f.readlines()[:-1]
        all_srcs, all_trgts = [], []
        for link in links:
            link = link.split(' ')
            src, trgts = link[0], link[1:]
            all_srcs.extend([src] * len(trgts))
            all_trgts.extend(trgts)

        syses = [sys] * len(all_srcs)
        src_types = [src_type] * len(all_srcs)
        trgt_types = [trgt_type] * len(all_srcs)
        all_links = set(zip(syses, src_types, trgt_types, all_srcs, all_trgts))
        df = pd.DataFrame(all_links, columns = [
            'sys', 'src_type', 'trgt_type', 'src_file', 'trgt_file'
        ])
        links_df = pd.concat([links_df, df])

    links_df['src_file'] = links_df['src_file'].apply(lambda x: x.strip())
    links_df['trgt_file'] = links_df['trgt_file'].apply(lambda x: x.strip())
    links_df = links_df.drop_duplicates(keep = False)
    return links_df

# Cell
def get_non_gt(path, language, gt, n = 1):
    non_links_df = pd.DataFrame([], columns = [
        'sys', 'src_type', 'trgt_type', 'src_file', 'trgt_file'
    ])

    srcs = gt[['sys', 'src_type', 'src_file']]
    trgts = gt[['sys', 'trgt_type', 'trgt_file']]

    # Sample from the dataframe to introduce randomness and reduce number of links
    srcs = srcs.sample(frac = n)
    trgts = trgts.sample(frac = n)
    for (src_id, src_row), (trgt_id, trgt_row) in zip(srcs.iterrows(), trgts.iterrows()):
#     for src_id, src_row in tqdm(list(srcs.iterrows())[:200]):
#         for trgt_id, trgt_row in tqdm(list(trgts.iterrows())[:200]):
        if src_id == trgt_id or src_row['sys'] != trgt_row['sys']: continue

        row = pd.DataFrame([[
            src_row['sys'], src_row['src_type'],
            trgt_row['trgt_type'], src_row['src_file'],
            trgt_row['trgt_file']]], columns = [
            'sys', 'src_type', 'trgt_type', 'src_file',
            'trgt_file'
        ])
        # Check this row is not a ground truth link
        if len(gt) < len(pd.concat([gt, row]).drop_duplicates(keep = False)):
            non_links_df = pd.concat([non_links_df, row])

    non_links_df = non_links_df.drop_duplicates(keep = False)
    return non_links_df

# Cell
def add_wmd(links_df, wmd_df):
    new_links_df = pd.DataFrame([], columns = [
        'sys', 'src_type', 'trgt_type', 'src_file', 'trgt_file', 'wmd'
    ])
    for _, row in wmd_df.iterrows():
        new_row = links_df.loc[(links_df.src_file == row.src_file) & (links_df.trgt_file == row.trgt_file)].copy()
        if len(new_row) > 0:
            new_row['wmd'] = [row.wmd]
            new_links_df = pd.concat([new_links_df, new_row])

    new_links_df = new_links_df.sort_values(['src_file', 'trgt_file'])
    return new_links_df