###############################################################################
#                                                                             #
#    This program is free software: you can redistribute it and/or modify     #
#    it under the terms of the GNU General Public License as published by     #
#    the Free Software Foundation, either version 3 of the License, or        #
#    (at your option) any later version.                                      #
#                                                                             #
#    This program is distributed in the hope that it will be useful,          #
#    but WITHOUT ANY WARRANTY; without even the implied warranty of           #
#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the            #
#    GNU General Public License for more details.                             #
#                                                                             #
#    You should have received a copy of the GNU General Public License        #
#    along with this program. If not, see <http://www.gnu.org/licenses/>.     #
#                                                                             #
###############################################################################

import os
import sys
import logging
import operator
import shutil
import tempfile
import ntpath
import pickle
from itertools import combinations
from collections import defaultdict, namedtuple

from biolib.taxonomy import Taxonomy
from biolib.external.execute import check_dependencies

from numpy import (mean as np_mean)

from genometreetk.common import (parse_genome_path,
                                    binomial_species,
                                    canonical_species_name,
                                    read_gtdb_metadata,
                                    read_gtdb_taxonomy,
                                    read_gtdb_ncbi_taxonomy)
                                    
from genometreetk.type_genome_utils import (GenomeRadius,
                                            read_qc_file,
                                            symmetric_ani,
                                            write_clusters,
                                            write_type_radius)
                                    
from genometreetk.ani_cache import ANI_Cache
from genometreetk.mash import Mash

class ClusterNamedTypes(object):
    """Cluster genomes to selected GTDB type genomes."""

    def __init__(self, ani_sp, af_sp, ani_cache_file, cpus, output_dir):
        """Initialization."""
        
        check_dependencies(['fastANI', 'mash'])
        
        self.cpus = cpus
        self.output_dir = output_dir

        self.logger = logging.getLogger('timestamp')

        self.ani_sp = ani_sp
        self.af_sp = af_sp

        self.max_ani_neighbour = 97.0
        self.min_mash_ani = 90.0
        
        self.ClusteredGenome = namedtuple('ClusteredGenome', 'ani af gid')
        
        self.ani_cache = ANI_Cache(ani_cache_file, cpus)
        
    def _type_genome_radius(self, type_gids, type_genome_ani_file):
        """Calculate circumscription radius for type genomes."""
        
        # set type radius for all type genomes to default values
        type_radius = {}
        for gid in type_gids:
            type_radius[gid] = GenomeRadius(ani = self.ani_sp, 
                                                 af = None,
                                                 neighbour_gid = None)
        
        # determine closest ANI neighbour and restrict ANI radius as necessary
        with open(type_genome_ani_file) as f:
            header = f.readline().strip().split('\t')
            
            type_gid1_index = header.index('Type genome 1')
            type_gid2_index = header.index('Type genome 2')
            ani_index = header.index('ANI')
            af_index = header.index('AF')

            for line in f:
                line_split = line.strip().split('\t')
                
                type_gid1 = line_split[type_gid1_index]
                type_gid2 = line_split[type_gid2_index]

                if type_gid1 not in type_gids or type_gid2 not in type_gids:
                    continue

                ani = float(line_split[ani_index])
                af = float(line_split[af_index])

                if ani > type_radius[type_gid1].ani:
                    if af < self.af_sp:
                        if ani >= self.ani_sp:
                            self.logger.warning('ANI for %s and %s is >%.2f, but AF <%.2f [pair skipped].' % (
                                                    type_gid1,
                                                    type_gid2,
                                                    ani, af))
                        continue
                    
                    if ani > self.max_ani_neighbour:
                        self.logger.error('ANI neighbour %s is >%.2f for %s.' % (type_gid2, ani, type_gid1))
 
                    type_radius[type_gid1] = GenomeRadius(ani = ani, 
                                                                 af = af,
                                                                 neighbour_gid = type_gid2)
                    
        self.logger.info('ANI circumscription radius: min=%.2f, mean=%.2f, max=%.2f' % (
                                min([d.ani for d in type_radius.values()]), 
                                np_mean([d.ani for d in type_radius.values()]), 
                                max([d.ani for d in type_radius.values()])))
                        
        return type_radius
        
    def _calculate_ani(self, type_gids, genome_files, ncbi_taxonomy, type_genome_sketch_file):
        """Calculate ANI between type and non-type genomes."""
        
        mash = Mash(self.cpus)
        
        # create Mash sketch for type genomes
        if not type_genome_sketch_file or not os.path.exists(type_genome_sketch_file):
            type_genome_list_file = os.path.join(self.output_dir, 'gtdb_type_genomes.lst')
            type_genome_sketch_file = os.path.join(self.output_dir, 'gtdb_type_genomes.msh')
            mash.sketch(type_gids, genome_files, type_genome_list_file, type_genome_sketch_file)
            
        # create Mash sketch for non-type genomes
        nontype_gids = set()
        for gid in genome_files:
            if gid not in type_gids:
                nontype_gids.add(gid)
                
        nontype_genome_list_file = os.path.join(self.output_dir, 'gtdb_nontype_genomes.lst')
        nontype_genome_sketch_file = os.path.join(self.output_dir, 'gtdb_nontype_genomes.msh')
        mash.sketch(nontype_gids, genome_files, nontype_genome_list_file, nontype_genome_sketch_file)

        # get Mash distances
        mash_dist_file = os.path.join(self.output_dir, 'gtdb_type_vs_nontype_genomes.dst')
        mash.dist(float(100 - self.min_mash_ani)/100, 
                                type_genome_sketch_file, 
                                nontype_genome_sketch_file, 
                                mash_dist_file)

        # read Mash distances
        mash_ani = mash.read_ani(mash_dist_file)

        # get pairs above Mash threshold
        mash_ani_pairs = []
        for qid in mash_ani:
            for rid in mash_ani[qid]:
                if mash_ani[qid][rid] >= self.min_mash_ani:
                    if qid != rid:
                        mash_ani_pairs.append((qid, rid))
                        mash_ani_pairs.append((rid, qid))
                
        self.logger.info('Identified %d genome pairs with a Mash ANI >= %.1f%%.' % (len(mash_ani_pairs), self.min_mash_ani))
        
        # calculate ANI between pairs
        self.logger.info('Calculating ANI between %d genome pairs:' % len(mash_ani_pairs))
        if True: #***
            ani_af = self.ani_cache.fastani_pairs(mash_ani_pairs, genome_files)
            pickle.dump(ani_af, open(os.path.join(self.output_dir, 'ani_af_type_vs_nontype.pkl'), 'wb'))
        else:
            ani_af = pickle.load(open(os.path.join(self.output_dir, 'ani_af_type_vs_nontype.pkl'), 'rb'))

        return ani_af

    def _cluster(self, ani_af, nontype_gids, type_radius):
        """Cluster non-type genomes to type genomes using species specific ANI thresholds."""
        
        clusters = {}
        for rep_id in type_radius:
            clusters[rep_id] = []
            
        transitive_cases = set()
        num_outside_closest_rep_radius = 0
        num_no_rep = 0
        multi_reps = set()
        num_single_rep = 0
        for idx, nontype_gid in enumerate(nontype_gids):
            if idx % 100 == 0:
                sys.stdout.write('==> Processed %d of %d genomes.\r' % (idx+1, len(nontype_gids)))
                sys.stdout.flush()
                
            if nontype_gid not in ani_af:
                num_no_rep += 1
                continue

            closest_type_gid = None
            closest_ani = 0
            closest_af = 0
            num_rep_radii = 0
            for type_gid in type_radius:
                if type_gid not in ani_af[nontype_gid]:
                    continue

                ani, af = symmetric_ani(ani_af, type_gid, nontype_gid)
                
                if af >= self.af_sp:
                    if ani > type_radius[type_gid].ani:
                        num_rep_radii += 1
                        
                    if ani > closest_ani or (ani == closest_ani and af > closest_af):
                        closest_type_gid = type_gid
                        closest_ani = ani
                        closest_af = af
                
            if closest_type_gid:
                if closest_ani > type_radius[closest_type_gid].ani:
                    assert(num_rep_radii >= 1)
                    if num_rep_radii > 1:
                        multi_reps.add(nontype_gid)
                    else:
                        num_single_rep += 1
                        
                    clusters[closest_type_gid].append(self.ClusteredGenome(gid=nontype_gid, 
                                                                            ani=closest_ani, 
                                                                            af=closest_af))
                else:
                    num_outside_closest_rep_radius += 1
                    if num_rep_radii >= 1:
                        transitive_cases.add(nontype_gid)
            else:
                num_no_rep += 1
                        
                
        sys.stdout.write('==> Processed %d of %d genomes.\r' % (idx, 
                                                                len(nontype_gids)))
        sys.stdout.flush()
        sys.stdout.write('\n')

        num_clustered = sum([len(clusters[type_gid]) for type_gid in clusters])
        assert(num_clustered == num_single_rep + len(multi_reps))
        self.logger.info('Assigned %d genomes to representatives.' % num_clustered)
        self.logger.info(' ... %d genomes satisfied the clustering criteria of only 1 representative.' % num_single_rep)
        self.logger.info(' ... %d genomes satisfied the clustering criteria of >1 representative.' % len(multi_reps))

        num_unclustered = len(nontype_gids) - num_clustered
        assert(num_unclustered == num_no_rep + num_outside_closest_rep_radius)
        self.logger.info('There were %d genomes that could not be assigned to a representatives.' % num_unclustered)
        self.logger.info(' ... %d genomes were not assigned as they did not meeting the clustering criteria of any representatives.' % num_no_rep)
        self.logger.info(' ... %d genomes were not assigned as they did not meeting the clustering criteria of the closest genome.' % num_outside_closest_rep_radius)
        self.logger.info(' ..... %d genomes were within the ANI radius of >=1 other representatives (i.e. transitive case).' % len(transitive_cases))
        
        fout = open(os.path.join(self.output_dir, 'multi_reps_gids.lst'), 'w')
        for gid in multi_reps:
            fout.write(gid + '\n')
        fout.close()
        
        fout = open(os.path.join(self.output_dir, 'transitive_nontype_gids.lst'), 'w')
        for gid in transitive_cases:
            fout.write(gid + '\n')
        fout.close()
        
        return clusters

    def run(self, qc_file,
                    metadata_file,
                    genome_path_file,
                    named_type_genome_file,
                    type_genome_ani_file,
                    mash_sketch_file,
                    species_exception_file):
        """Cluster genomes to selected GTDB type genomes."""
        
        # identify genomes failing quality criteria
        self.logger.info('Reading QC file.')
        passed_qc = read_qc_file(qc_file)
        self.logger.info('Identified %d genomes passing QC.' % len(passed_qc))

        # get type genomes
        type_gids = set()
        species_type_gid = {}
        with open(named_type_genome_file) as f:
            header = f.readline().strip().split('\t')
            type_gid_index = header.index('Type genome')
            sp_index = header.index('NCBI species')
            
            for line in f:
                line_split = line.strip().split('\t')
                type_gids.add(line_split[type_gid_index])
                species_type_gid[line_split[type_gid_index]] = line_split[sp_index]
        self.logger.info('Identified type genomes for %d species.' % len(species_type_gid))

        # calculate circumscription radius for type genomes
        self.logger.info('Determining ANI species circumscription for %d type genomes.' % len(type_gids))
        type_radius = self._type_genome_radius(type_gids, type_genome_ani_file)
        assert(len(type_radius) == len(species_type_gid))
        
        write_type_radius(type_radius, species_type_gid, os.path.join(self.output_dir, 'gtdb_type_genome_ani_radius.tsv'))
        
        # get path to genome FASTA files
        self.logger.info('Reading path to genome FASTA files.')
        genome_files = parse_genome_path(genome_path_file)
        self.logger.info('Read path for %d genomes.' % len(genome_files))
        for gid in set(genome_files):
            if gid not in passed_qc:
                genome_files.pop(gid)
        self.logger.info('Considering %d genomes after removing unwanted User genomes.' % len(genome_files))
        assert(len(genome_files) == len(passed_qc))
        
        # get GTDB and NCBI taxonomy strings for each genome
        self.logger.info('Reading NCBI taxonomy from GTDB metadata file.')
        ncbi_taxonomy, ncbi_update_count = read_gtdb_ncbi_taxonomy(metadata_file, species_exception_file)
        self.logger.info('Read NCBI taxonomy for %d genomes with %d manually defined updates.' % (len(ncbi_taxonomy), ncbi_update_count))
        
        # calculate ANI between type and non-type genomes
        self.logger.info('Calculating ANI between type and non-type genomes.')
        ani_af = self._calculate_ani(type_gids, genome_files, ncbi_taxonomy, mash_sketch_file)

        # cluster remaining genomes to type genomes
        nontype_gids = set(genome_files) - set(type_radius)
        self.logger.info('Clustering %d non-type genomes to type genomes using species specific ANI radii.' % len(nontype_gids))
        clusters = self._cluster(ani_af, nontype_gids, type_radius)
        
        # write out clusters
        write_clusters(clusters, type_radius, species_type_gid, os.path.join(self.output_dir, 'gtdb_type_genome_clusters.tsv'))
