#!/usr/bin/env python3

# Copyright 2020 University of Groningen
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Perform the parsing and input redirect for the
different subcommands. This is the main executable.
"""

import logging
import argparse
from pathlib import Path
from vermouth.log_helpers import (StyleAdapter, BipolarFormatter,
                                  CountingHandler, TypeAdapter,
                                  ignore_warnings_and_count,)

import martini_sour
from martini_sour import conv_itp, conv_coords, analyze, titrate

# Implement Logger
LOGGER = TypeAdapter(logging.getLogger('martini_sour'))
PRETTY_FORMATTER = logging.Formatter(fmt='{levelname:} - {type} - {message}',
                                     style='{')
DETAILED_FORMATTER = logging.Formatter(fmt='{levelname:} - {type} - {name} - {message}',
                                       style='{')
COUNTER = CountingHandler()

# Control above what level message we want to count
COUNTER.setLevel(logging.WARNING)

CONSOLE_HANDLER = logging.StreamHandler()
FORMATTER = BipolarFormatter(DETAILED_FORMATTER,
                             PRETTY_FORMATTER,
                             logging.DEBUG,
                             logger=LOGGER)

CONSOLE_HANDLER.setFormatter(FORMATTER)
LOGGER.addHandler(CONSOLE_HANDLER)
LOGGER.addHandler(COUNTER)

LOGGER = StyleAdapter(LOGGER)

VERSION = 'martini_sour version {}'.format(martini_sour.__version__)


def main():
    """
    Parses commandline arguments and call relevant sub_programs.
    """
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument('-V', '--version', action='version', version=VERSION)
    subparsers = parser.add_subparsers()

    # List of Subparsers for the different tools
    parser_conv_itp = subparsers.add_parser('conv_itp')
    parser_conv_coords = subparsers.add_parser('conv_coords')
    parser_analyze = subparsers.add_parser('analyze')
    parser_titrate = subparsers.add_parser('titrate')

    # =============================================================================
    # Input Arguments for the itp generation tool
    # =============================================================================

    parser_conv_itp.add_argument('-v', dest='verbosity', action='count',
                                 help='Enable debug logging output. Can be given '
                                 'multiple times.', default=0)

    file_group = parser_conv_itp.add_argument_group('Input and output options')
    file_group.add_argument('-f', dest='inpath', required=False, type=Path,
                            help='Input file (ITP)', nargs="*")
    file_group.add_argument('-o', dest='outpath', type=Path,
                            help='Output ITP (ITP)')

    titrate_group = parser_conv_itp.add_argument_group('Titratable Bead definitions.')
    titrate_group.add_argument('-bases', dest='bases', type=str, nargs='+',
                               help='An enumeration of residues to convert to bases.')
    titrate_group.add_argument('-acids', dest='bases', type=str, nargs='+',
                               help='An enumeration of residues to convert to acids.')
    titrate_group.add_argument('-auto', dest='auto', type=bool, default=False, help='Identify '
                               'acids/bases automatically from known building blocks.')

    parser_conv_itp.set_defaults(func=conv_itp)

    # ============================================================================
    #           Input Arguments for the coordinate generation tool
    # ============================================================================

    parser_conv_coords.add_argument('-v', dest='verbosity', action='count',
                                    help='Enable debug logging output. Can be given '
                                    'multiple times.', default=0)

    file_group = parser_conv_coords.add_argument_group('Convert regular to titratable beads.')
    file_group.add_argument('-f', dest='input_file', type=Path,
                            help='input coordinate file (.gro, .pdb, etc.)')
    file_group.add_argument('-o', dest='out_file', type=Path,
                            help='output titratable coordinate file (.gro, .pdb, etc.)')

    titrate_group = parser_conv_coords.add_argument_group('Titratable Bead Selections '
                                                          '(can enter multiple of each)')
    titrate_group.add_argument('-bead', dest='bead_type', action='append',
                               choices=martini_sour.src.conv_coords.BEAD_TYPES.keys(),
                               help='type of the selected beads')
    titrate_group.add_argument('-sel', dest='sel', action='append', help='selection command '
                               '(see MDAnalysis) for the chosen bead type(s)')
    titrate_group.add_argument('-prot', dest='prot', type=str, default='POS',
                               help='name of the proton bead (default: POS)')
#    titrate_group.add_argument('-bases', dest='bases', type=str, nargs='+',
#                           help='An enumeration of residues to convert to bases.')
#    titrate_group.add_argument('-acids', dest='bases', type=str, nargs='+',
#                           help='An enumeration of residues to convert to acids.')
#    titrate_group.add_argument('-auto', dest='auto', type=bool, default=False,
#                           help='Identify acids/bases automatically from known building blocks.')

    parser_conv_coords.set_defaults(func=conv_coords)

    # ============================================================================
    #           Input Arguments for initializing the titration
    # ============================================================================

    parser_titrate.add_argument('-v', dest='verbosity', action='count',
                                help='Enable debug logging output. Can be given '
                                'multiple times.', default=0)

    parser_titrate.add_argument('-pH', dest='ph', type=str, required=True, nargs='*',
                                help='Either individual values (e.g., -pH 6 6.5 7) or '
                                'a range with format start:end:increment (e.g., -pH 3:8:0.5)')
    parser_titrate.add_argument('-prefix', dest='prefix', type=str, default='',
                                help='Prefix for the pH directory (e.g., pre_6.50')
    parser_titrate.add_argument('-mdp', dest='mdp_file', action='append', type=Path, required=True,
                                help='Name(s) of mdp file(s)). Can be given multiple times.')
    parser_titrate.add_argument('-p', dest='top_file', type=Path, default='system.top',
                                help='Name of the titratable topology (.top) file')
    parser_titrate.add_argument('-c', dest='gro_file', type=Path, default='start.gro',
                                help='Name of the titratable coordinate (.gro) file')
    parser_titrate.add_argument('-o', dest='out_file', type=Path, default='run.sh',
                                help='Name of the output script')

    parser_titrate.set_defaults(func=titrate)

    # ============================================================================
    #           Input Arguments for the analysis tool
    # ============================================================================

    parser_analyze.add_argument('-v', dest='verbosity', action='count',
                                help='Enable debug logging output. Can be given '
                                'multiple times.', default=0)
    parser_analyze.add_argument('-count', dest='scheme', type=str, default="legacy",
                                help='counting scheme')

    multidir_group = parser_analyze.add_argument_group('Multi-dir analysis (titration)')
    multidir_group.add_argument('-pH', dest='ph', type=str, nargs='*', default=[],
                                help='Either individual values (e.g., -pH 6 6.5 7) or '
                                'a range with format start:end:increment (e.g., -pH 3:8:0.5)')
    multidir_group.add_argument('-dir', dest='dir', type=str, default='',
                                help='Name of the subdirectory with the trajectory')
    multidir_group.add_argument('-prefix', dest='prefix', type=str, default='',
                                help='Prefix for the pH directory (e.g., pre_6.50)')
    multidir_group.add_argument('-fit', dest='fit', action='store_true', default=False,
                                help='Enables the fitting of the titration curve. Sufficient pH '
                                'data points are needed for meaningful results.')

    file_group = parser_analyze.add_argument_group('Input and output files')
    file_group.add_argument('-f', dest='traj', type=Path, default='traj_comp.xtc',
                            help='GROMACS trajectory file (.xtc)')
    file_group.add_argument('-s', dest='tpr', type=Path, default='topol.tpr',
                            help='GROMACS topology file (.tpr)')
    file_group.add_argument('-o', dest='out', type=Path, default='degree_of_deprot.xvg',
                            help='output file name for average degree of deprotonation values')
    file_group.add_argument('-os', dest='std_out', type=Path,
                            help='output file name for STD of degree of deprotonation values')
    file_group.add_argument('-of', dest='fit_out', type=Path, default='pka_fit.xvg', help='output'
                            'file name for titration curve fit values. -fit must be enabled.')
    file_group.add_argument('-op', dest='plot_out', type=Path, default='titr_curve.pdf', help='out'
                            'put file name for titration analysis plot. -fit must be enabled.')

    bead_group = parser_analyze.add_argument_group('Titratable Bead Selections')
    bead_group.add_argument('-sel', dest='sel', action='append', help='selection command (see '
                            'MDAnalysis) for the chosen bead type(s). Can enter multiple times.',
                            required=True)
    bead_group.add_argument('-ref', dest='ref', type=str, help='selection command (see MDAnalysis)'
                            'for the remaining titratable beads not selected by -sel command.',
                            required=True)
    bead_group.add_argument('-prot', dest='prot', type=str, default='POS',
                            help='name of the proton bead (default: POS)')

    traj_group = parser_analyze.add_argument_group('Trajectory Options')
    traj_group.add_argument('-b', dest='start', type=int, help='first frame (int)', default=0)
    traj_group.add_argument('-e', dest='end', type=int, help='last frame (int)', default=-1)

    parser_analyze.set_defaults(func=analyze)

    args = parser.parse_args()

    loglevels = {0: logging.INFO, 1: logging.DEBUG, 2: 5}
    LOGGER.setLevel(loglevels[args.verbosity])

    args.func(args)


if __name__ == '__main__':
    main()
