# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbdev_nbs/03_protein_intensity_estimation.ipynb.

# %% auto 0
__all__ = ['estimate_protein_intensities', 'get_list_of_tuple_w_protein_profiles_and_shifted_peptides',
           'get_list_with_sequential_processing', 'get_list_with_multiprocessing',
           'get_configured_multiprocessing_pool',
           'get_input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan', 'get_normed_dfs',
           'get_ion_intensity_dataframe_from_list_of_shifted_peptides', 'add_protein_names_to_ion_ints',
           'add_protein_name_to_ion_df', 'get_protein_dataframe_from_list_of_protein_profiles',
           'calculate_peptide_and_protein_intensities', 'get_protein_profile_from_shifted_peptides',
           'get_list_with_protein_value_for_each_sample', 'ProtvalCutter', 'OrphanIonRemover',
           'OrphanIonsForDeletionSelector', 'IonCheckedForOrphan']

# %% ../nbdev_nbs/03_protein_intensity_estimation.ipynb 3
import pandas as pd
import numpy as np
import directlfq.normalization as lfqnorm
import multiprocess
import itertools

def estimate_protein_intensities(normed_df, min_nonan, num_samples_quadratic, num_cores):
    "derives protein pseudointensities from between-sample normalized data"
    
    allprots = list(normed_df.index.get_level_values(0).unique())
    print(f"{len(allprots)} prots total")
    
    list_of_tuple_w_protein_profiles_and_shifted_peptides = get_list_of_tuple_w_protein_profiles_and_shifted_peptides(allprots, normed_df, num_samples_quadratic, min_nonan, num_cores)
    protein_df = get_protein_dataframe_from_list_of_protein_profiles(allprots=allprots, list_of_tuple_w_protein_profiles_and_shifted_peptides=list_of_tuple_w_protein_profiles_and_shifted_peptides, normed_df= normed_df)
    ion_df = get_ion_intensity_dataframe_from_list_of_shifted_peptides(list_of_tuple_w_protein_profiles_and_shifted_peptides, allprots)

    return protein_df, ion_df


def get_list_of_tuple_w_protein_profiles_and_shifted_peptides(allprots, normed_df, num_samples_quadratic, min_nonan, num_cores):
    if num_cores is not None and num_cores <=1:
        list_of_tuple_w_protein_profiles_and_shifted_peptides = get_list_with_sequential_processing(allprots, normed_df, num_samples_quadratic, min_nonan)
    else:
        list_of_tuple_w_protein_profiles_and_shifted_peptides = get_list_with_multiprocessing(allprots, normed_df, num_samples_quadratic, min_nonan, num_cores)
    return list_of_tuple_w_protein_profiles_and_shifted_peptides

def get_list_with_sequential_processing(allprots, normed_df, num_samples_quadratic, min_nonan):
    input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan = get_input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan(normed_df, allprots, num_samples_quadratic, min_nonan)
    list_of_tuple_w_protein_profiles_and_shifted_peptides = list(map(lambda x : calculate_peptide_and_protein_intensities(*x), input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan))
    return list_of_tuple_w_protein_profiles_and_shifted_peptides
    
def get_list_with_multiprocessing(allprots, normed_df, num_samples_quadratic, min_nonan, num_cores):
    pool = get_configured_multiprocessing_pool(num_cores)
    input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan = get_input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan(normed_df, allprots, num_samples_quadratic, min_nonan)
    list_of_tuple_w_protein_profiles_and_shifted_peptides = pool.starmap(calculate_peptide_and_protein_intensities, input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan)
    pool.close()
    return list_of_tuple_w_protein_profiles_and_shifted_peptides


def get_configured_multiprocessing_pool(num_cores):
    multiprocess.freeze_support()
    if num_cores is None:
        num_cores = multiprocess.cpu_count() if multiprocess.cpu_count() < 60 else 60 #windows upper thread limit
    pool = multiprocess.Pool(num_cores)
    print(f"using {pool._processes} processes")
    return pool


def get_input_specification_tuplelist_idx__df__num_samples_quadratic__min_nonan(normed_df, allprots, num_samples_quadratic, min_nonan):
    list_of_normed_dfs = get_normed_dfs(normed_df, allprots)
    return zip(range(len(list_of_normed_dfs)),list_of_normed_dfs, itertools.repeat(num_samples_quadratic), itertools.repeat(min_nonan))




def get_normed_dfs(normed_df, allprots):
    list_of_normed_dfs = []
    for protein in allprots:
        peptide_intensity_df = pd.DataFrame(normed_df.loc[protein])#DataFrame definition to avoid pandas Series objects
        if len(peptide_intensity_df.index) > 1:
            peptide_intensity_df = ProtvalCutter(peptide_intensity_df, maximum_df_length=100).get_dataframe()
            peptide_intensity_df = OrphanIonRemover(peptide_intensity_df).orphan_removed_df
        list_of_normed_dfs.append(peptide_intensity_df)

    return list_of_normed_dfs


def get_ion_intensity_dataframe_from_list_of_shifted_peptides(list_of_tuple_w_protein_profiles_and_shifted_peptides, allprots):
    ion_ints = [x[1] for x in list_of_tuple_w_protein_profiles_and_shifted_peptides]
    ion_ints = add_protein_names_to_ion_ints(ion_ints, allprots)
    ion_df = 2**pd.concat(ion_ints)
    ion_df = ion_df.replace(np.nan, 0)
    return ion_df

def add_protein_names_to_ion_ints(ion_ints, allprots):
    ion_ints = [add_protein_name_to_ion_df(ion_ints[idx], allprots[idx]) for idx in range(len(ion_ints))]
    return ion_ints

def add_protein_name_to_ion_df(ion_df, protein):
    ion_df["protein"] = protein
    ion_df = ion_df.reset_index().set_index(["protein", "ion"])
    return ion_df


def get_protein_dataframe_from_list_of_protein_profiles(allprots, list_of_tuple_w_protein_profiles_and_shifted_peptides, normed_df):
    index_list = []
    profile_list = []

    list_of_protein_profiles = [x[0] for x in list_of_tuple_w_protein_profiles_and_shifted_peptides]
    
    for idx in range(len(allprots)):
        if list_of_protein_profiles[idx] is None:
            continue
        index_list.append(allprots[idx])
        profile_list.append(list_of_protein_profiles[idx])
    
    index_for_protein_df = pd.Index(data=index_list, name="protein")
    protein_df = 2**pd.DataFrame(profile_list, index = index_for_protein_df, columns = normed_df.columns)
    protein_df = protein_df.replace(np.nan, 0)
    protein_df = protein_df.reset_index()
    return protein_df


def calculate_peptide_and_protein_intensities(idx,peptide_intensity_df , num_samples_quadratic, min_nonan):
    if(idx%100 ==0):
        print(f"prot {idx}")
    summed_pepint = np.nansum(2**peptide_intensity_df)
    
    if(peptide_intensity_df.shape[1]<2):
        shifted_peptides = peptide_intensity_df
    else:
        shifted_peptides = lfqnorm.NormalizationManagerProtein(peptide_intensity_df, num_samples_quadratic = num_samples_quadratic).complete_dataframe
    
    protein_profile = get_protein_profile_from_shifted_peptides(shifted_peptides, summed_pepint, min_nonan)
    
    return protein_profile, shifted_peptides


def get_protein_profile_from_shifted_peptides(normalized_peptide_profile_df, summed_pepints, min_nonan):
    intens_vec = get_list_with_protein_value_for_each_sample(normalized_peptide_profile_df, min_nonan)
    intens_vec = np.array(intens_vec)
    summed_intensity = np.nansum(2**intens_vec)
    if summed_intensity == 0: #this means all elements in intens vec are nans
        return None
    intens_conversion_factor = summed_pepints/summed_intensity
    scaled_vec = intens_vec+np.log2(intens_conversion_factor)
    return scaled_vec

def get_list_with_protein_value_for_each_sample(normalized_peptide_profile_df, min_nonan):
    intens_vec = []
    for sample in normalized_peptide_profile_df.columns:
        reps = normalized_peptide_profile_df.loc[:,sample].to_numpy()
        nonan_elems = sum(~np.isnan(reps))
        if(nonan_elems>=min_nonan):
            intens_vec.append(np.nanmedian(reps))
        else:
            intens_vec.append(np.nan)
    return intens_vec


# %% ../nbdev_nbs/03_protein_intensity_estimation.ipynb 5
import pandas as pd
from numba import njit

class ProtvalCutter():
    def __init__(self, protvals_df, maximum_df_length = 100):
        self._protvals_df = protvals_df
        self._maximum_df_length = maximum_df_length
        self._dataframe_too_long = None
        self._sorted_idx = None
        self._check_if_df_too_long_and_sort_index_if_so()


    def _check_if_df_too_long_and_sort_index_if_so(self):
        self._dataframe_too_long =len(self._protvals_df.index)>self._maximum_df_length
        if self._dataframe_too_long:
            self._determine_nansorted_df_index()

    def _determine_nansorted_df_index(self):
        idxs = self._protvals_df.index
        self._sorted_idx =  sorted(idxs, key= lambda idx : self._get_num_nas_in_row(self._protvals_df.loc[idx].to_numpy()))
        
    @staticmethod
    @njit
    def _get_num_nas_in_row(row):
        return sum(np.isnan(row))


    def get_dataframe(self):
        if self._dataframe_too_long:
            return self._get_shortened_dataframe()
        else:
            return self._protvals_df

    def _get_shortened_dataframe(self):
        shortened_index = self._sorted_idx[:self._maximum_df_length]
        return self._protvals_df.loc[shortened_index]


# %% ../nbdev_nbs/03_protein_intensity_estimation.ipynb 6
import numpy as np
import pandas as pd

class OrphanIonRemover(): #removes ions that do not have any overlap with any of the other ions
    def __init__(self, protvals_df : pd.DataFrame):
        self._protvals_df = protvals_df
        
        self._provals_is_not_na_df = None
        self._count_of_nonans_per_position = None
        
        self._orphan_ions = []
        self._non_orphan_ions = []

        self.orphan_removed_df = None

        self._define_protvals_is_not_na_df()
        self._define_count_of_nonans_per_position()
        self._define_orphan_ions_and_non_orphan_ions()
        self._define_orphan_removed_df()

    def _define_protvals_is_not_na_df(self):
        self._provals_is_not_na_df = self._protvals_df.notna()

    def _define_count_of_nonans_per_position(self):
        self._count_of_nonans_per_position = self._provals_is_not_na_df.sum(axis=0)
    
    def _define_orphan_ions_and_non_orphan_ions(self):
        for ion in self._provals_is_not_na_df.index:
            is_nonan_per_position_for_ion = self._provals_is_not_na_df.loc[ion].to_numpy()
            orphan_checked_ion = IonCheckedForOrphan(ion,self._count_of_nonans_per_position, is_nonan_per_position_for_ion)
            self._append_to_orphan_or_non_orphan_list(orphan_checked_ion)

    def _append_to_orphan_or_non_orphan_list(self, orphan_checked_ion):
            if orphan_checked_ion.is_orphan:
                self._orphan_ions.append(orphan_checked_ion)
            else:
                self._non_orphan_ions.append(orphan_checked_ion)
    
    def _define_orphan_removed_df(self):
        ions_to_delete = OrphanIonsForDeletionSelector(self._orphan_ions, self._non_orphan_ions).ion_accessions_for_deletion
        self.orphan_removed_df = self._protvals_df.drop(ions_to_delete, axis='index')



class OrphanIonsForDeletionSelector():
    def __init__(self, orphan_ions : list, non_orphan_ions : list):
        self._orphan_ions = orphan_ions
        self._non_orphan_ions = non_orphan_ions
        
        self.ion_accessions_for_deletion = None

        self._define_orphan_ions_for_deletion()
    
    def _define_orphan_ions_for_deletion(self):
        if len(self._non_orphan_ions)>0:
            self.ion_accessions_for_deletion = self._get_accessions_of_list_of_ions(self._orphan_ions)
        else:
            if len(self._orphan_ions)>1:
                self._sort_list_of_ions_by_num_nonans_descending(self._orphan_ions)
                orphan_ions_to_delete = self._orphan_ions[1:]
                self.ion_accessions_for_deletion = self._get_accessions_of_list_of_ions(orphan_ions_to_delete)
    
    def _get_accessions_of_list_of_ions(self, ions_checked_for_orphan : list):
        return [ion_checked_for_orphan.ion_accession for ion_checked_for_orphan in ions_checked_for_orphan]

    def _sort_list_of_ions_by_num_nonans_descending(self, ions : list):
        ions.sort(key=lambda x: x.num_nonans, reverse=True)
    




class IonCheckedForOrphan():
    def __init__(self, ion_accession, count_of_nonans_per_position : np.array, is_nonan_per_position_for_ion : np.array):
        self.ion_accession = ion_accession
        
        self._count_of_nonans_per_position = count_of_nonans_per_position
        self._is_nonan_per_position_for_ion = is_nonan_per_position_for_ion

        self._count_of_nonans_per_position_for_ion = None

        self.is_orphan = None
        self.num_nonans = None

        self._define_count_of_nonans_per_position_for_ion()
        self._check_if_is_orphan()
        self._define_num_nonans()

    def _define_count_of_nonans_per_position_for_ion(self):
        self._count_of_nonans_per_position_for_ion = self._count_of_nonans_per_position[self._is_nonan_per_position_for_ion]

    def _check_if_is_orphan(self):
        self.is_orphan = np.max(self._count_of_nonans_per_position_for_ion) == 1
    
    def _define_num_nonans(self):
        self.num_nonans = np.sum(self._count_of_nonans_per_position_for_ion)
