#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import division
import sys
import argparse
import os.path
import itertools
from qmpy import *
from django.db.models import get_app, get_models

def main():
    '''Handles command line args, parses them to understand what sub-scripts to
    run and handles displaying the results.'''

    parser = argparse.ArgumentParser(prog='oqmd',
            formatter_class=argparse.RawTextHelpFormatter,
            description='''
Command line access to OQMD functionality. Allows a user to obtain information
about database entries, do thermodynamic analyisis or request calculations.
''')

    parser.add_argument('task', default='query',
            nargs='*',
            help='''What do you want to do?

Valid options:
    
    add_host
    start_host
    stop_host
    add_user
    add_project
    add_allocation
    utilization
    progress
    failed
    
    describe_database

    add_structure
    add_structures


''')

    parser.add_argument('--formula','-F',
            nargs='*', default='',
            help='''A representation of a composition or compositions. Used
when the task requires a composition. 
Examples:
    Fe2O3 -> {'Fe':2, 'O':3}
    FeO -> {'Fe':1, 'O':1}
    {Fe,Ni}O -> [{'Fe':1, 'O':1}, {'Ni':1, 'O':1}]''')

    parser.add_argument('--data',
            nargs='*', default='',
            help='Supply your own data for thermodynamic analyses.')

    runner, args = parser.parse_known_args(sys.argv[1:])
    if not runner.task:
        return parser.parse_args(['--help'])
    runner.args = args

#==============================================================================#
# module: config
#==============================================================================#
    if runner.task[0] == 'add_host':
        Host.interactive_create()
    if runner.task[0] == 'add_user':
        User.interactive_create()
    if runner.task[0] == 'add_allocation':
        Allocation.interactive_create()
    if runner.task[0] == 'add_project':
        Project.interactive_create()

#==============================================================================#
# module: database
#==============================================================================#

    if runner.task[0] == 'describe_database':
        app = get_app('qmpy')
        for model in get_models(app):
            print 'Model: %s' % model.__name__
            for field in model._meta.fields:
                print '   -', field.name
            print 

    #======================================================================#
    if runner.task[0] == 'add_structure':
        if not kwargs.get('project', False):
            runner.kwargs['project'] = raw_input('project: ')
        if not kwargs.get('keywords', False):
            runner.kwargs['keywords'] = raw_input('keywords: ').split()
        path = os.path.abspath(runner.task[1])
        try:
            entry = Entry.create(path,
                    keywords=runner.kwargs['keywords'],
                    project=runner.kwargs['project'])
            entry.save()
        except:
            print 'Failed to add structure file: %s' % path
            print 'Please verify that it is a valid structure file'

    #======================================================================#
    if runner.task[0] == 'add_structures':
        if 'project' not in runner.kwargs:
            runner.kwargs['project'] = raw_input('project: ')
        if 'keywords' not in runner.kwargs:
            runner.kwargs['keywords'] = raw_input('keywords: ')
        path = os.path.abspath(runner.task[1])
        for file in os.listdir(path):
            try:
                entry = Entry.create(path,
                    keywords=runner.kwargs['project'],
                    project=runner.kwargs['keywords'])
                entry.save()
            except:
                print 'Failed to add structure file: %s' % path
                print 'Please verify that it is a valid structure file'

#==============================================================================#
# queue
#==============================================================================#

    if 'jobserver' in runner.task:
        js = JobManager('/tmp/oqmd-jobserver.pid')
        getattr(js, runner.task[1])()

    #======================================================================#
    if 'taskserver' in runner.task:
        ts = TaskManager('/tmp/oqmd-taskserver.pid')
        method = getattr(ts, runner.task[1])
        if len(runner.task) > 2:
            kwargs = dict(arg.split('=') for arg in runner.task[3:])
        else:
            kwargs = {}
        method(**kwargs)

    #======================================================================#
    if runner.task[0] == 'start_host':
        h = Host.objects.filter(name=runner.task[1])
        if h.exists():
            h = h[0]
            h.state = 1
            h.save()
            print 'Starting host: %s' % h
        else:
            print 'Host', runner.task[1], 'doesn\'t exist!'

    #======================================================================#
    if runner.task[0] == 'stop_host':
        h = Host.objects.filter(name=runner.task[1])
        if h.exists():
            h = h[0]
            h.state = -2
            h.save()
            print 'Disabling host: %s' % h
        else:
            print 'Host', runner.task[1], 'doesn\'t exist!'

    #======================================================================#
    if runner.task[0] == 'running':
        if len(runner.task) == 1:
            hosts = Host.objects.all().values_list('name', flat=True)
        else:
            hosts = runner.task[1:]

        print 'Currently running:'
        for h in hosts:
            print '!'+''.ljust(78, '=')+'!'
            print h.center(80)
            print '!'+''.ljust(78, '=')+'!'
            print 'Entry'.rjust(60), 'NCPUS'.ljust(20)
            host = Host.objects.get(name=h)
            for j in host.jobs:
                print str(j.entry).rjust(60),
                print str(j.ncpus).ljust(20)

    #======================================================================#
    if runner.task[0] == 'utilization':
        print 'Resource utilization:'
        for h in Host.objects.all():
            print '   - %s : %0.1f %%' % (h.name, h.utilization/h.ncpus*100)

    #======================================================================#
    if runner.task[0] == 'progress':
        if len(runner.task) == 1:
            projects = Project.objects.all().values_list('name', flat=True)
        else:
            projects = runner.task[1:]
        
        for project in projects:
            print 'Project: %s' % project
            tasks = Task.objects.filter(project_set=project)
            rel = Calculation.objects.filter(converged=True,
                        label='relaxation', entry__task__project_set=project)
            sta = Calculation.objects.filter(converged=True, 
                        label='static', entry__task__project_set=project)

            print '  %s tasks for this project' % tasks.count()
            print '     - FAILED:   %s' % tasks.filter(state=-1).count()
            print '     - RUNNING:  %s' % tasks.filter(state=1).count()
            print '     - COMPLETE: %s' % tasks.filter(state=2).count()
            num = float(tasks.count())
            print '  Progress:'
            print '     - Relaxed: {0:.3f}%'.format(rel.count()/num*100)
            print '     - Completed: {0:.3f}%'.format(sta.count()/num*100)
            print 

    #======================================================================#
    if runner.task[0] == 'failed':
        if len(runner.task) == 1:
            tasks = Task.objects.filter(state=-1)
        else:
            projects = runner.task[1:]

        for project in projects:
            print 'Project: %s' % project
            tasks = Task.objects.filter(project_set=project, state=-1)
            for t in tasks:
                print t.entry.name.rjust(10),
                errs = set()
                for err in t.entry.errors.values():
                    errs |= set(err)
                print t.entry.path.ljust(30),
                print ', '.join(errs)

    #======================================================================#
    if runner.task[0] == 'cancel':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'check_queue':
        raise NotImplementedError

#==============================================================================#
#                       module: adhoc / calculate
#==============================================================================#
    if runner.task[0] in VASP_SETTINGS:
        for struct in runner.task[1:]:
            structure = io.read(struct)
            path = os.path.abspath(struct)
            path = os.path.dirname(path)
            calc = Calculation.setup(
                    input=structure,
                    type=runner.task[0],
                    path=path+'/'+runner.task[0])
            if calc.converged:
                print 'Calculation of %s is done' % struct
            else:
                if calc.instructions:
                    calc.write()
                    print 'Wrote calculation of %s' % struct 

    #======================================================================#
    if runner.task[0] == 'kpoints':
        c = Calculation()
        c.input = io.read(runner.task[1])
        c.settings['kppra'] = int(runner.task[2])
        print c.KPOINTS

#=============================================================================
# thermodynamic analysis
#==============================================================================#
    #======================================================================#
    if runner.task[0] == 'reaction':
        for formula in runner.formula:
            s = PhaseSpace(formula+'-'+runner.task[1])
            for reaction in s.get_reactions(runner.task[1], 
                    electrons=kwargs.get('electrons', 1.0)):
                print reaction

    #======================================================================#
    if runner.task[0] == 'reaction_plot':
        for formula in runner.formula:
            print formula, '+', runner.task[1]
            s = PhaseSpace(formula+'-'+runner.task[1])
            s.plot_reactions(runner.task[1], 
                    electrons=kwargs.get('electrons', 1.0))
            plt.show()

    #======================================================================#
    if runner.task[0] == 'meta_stability':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'compounds':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'single_point':
        for rformula in runner.formula:
            formulae = parse_formula_regex(rformula)
            for formula in formulae:
                formula = unit_comp(formula)
                print format_comp(formula), ':'
                bstring = '-'.join(formula.keys())
                space = PhaseSpace(bstring)
                en, phases = space.gclp(formula)
                keys = sorted( phases.keys(), key=lambda x: -phases[x])
                print '   Hull formation energy: %.3f eV/atom' % en
                print '   Hull phases:', 
                pstr = ''
                for k in keys:
                    if phases[k] == 1:
                        pstr += k.name 
                    else:
                        pstr += '%s %s + ' % (phases[k], k.name)
                pstr = pstr.rstrip('+ ')
                print pstr,

                print '(',
                for k in keys:
                    print k.description,
                print ')'

                print '   Hull phases (LaTeX):',
                pstr = ''
                for k in keys:
                    if phases[k] == 1:
                        pstr += k.latex
                    else:
                        pstr += '%s %s + ' % (phases[k], k.latex)
                pstr = pstr.rstrip('+ ')
                print pstr,

                print '(',
                for k in keys:
                    print k.description,
                print ')'

    #======================================================================#
    if runner.task[0] == 'graph':
        for formula in runner.formula:
            bounds = formula.split('-')
            bounds = [ parse_formula_regex(b) for b in bounds ]
            for region in itertools.product(*bounds):
                s = PhaseSpace(region)
                graph_plot(s)
            plt.show()

    #======================================================================#
    if runner.task[0] == 'phase_diagram':
        unstable = ( 'unstable' in args )
        for formula in runner.formula:
            bounds = formula.split('-')
            bounds = [ parse_formula_regex(b) for b in bounds ]
            for region in itertools.product(*bounds):
                s = PhaseSpace(region)
                r = s.phase_diagram
                r.plot_in_matplotlib()
            plt.show()

    #======================================================================#
    if runner.task[0] == 'phase_diagram_script':
        unstable = ( 'unstable' in args )
        for formula in runner.formula:
            bounds = formula.split('-')
            bounds = [ parse_formula_regex(b) for b in bounds ]
            for region in itertools.product(*bounds):
                s = PhaseSpace(region)
                print_script(s, unstable=unstable)

#==============================================================================#
# structure manipulation
#==============================================================================#
    if runner.task[0] == 'defects':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'symmetry':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'surface':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'vacancy':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'substitution':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'interstitial':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'replace':
        raise NotImplementedError

    #======================================================================#
    if runner.task[0] == 'reshape':
        s = io.read(runner.task[1])
        remains = map(float, runner.task[2:])
        s.transform(remains)
        print io.write(s)

#==============================================================================#
# composition search
#==============================================================================#

    if runner.task[0] in ['query', 'search']:
        for formula in runner.formula:
            if runner.task == ['summary']:
                if '-' in formula:
                    comps = Composition.get_space(formula)
                    for comp in comps:
                        comp.get_distinct(calculable=False)
                        print comp.summary
                else:
                    comp = Composition.get(formula)
                    print comp.summary
            elif runner.task == ['count']:
                phases = set()
                a = PhaseSpace(load=[])
                a._data.load_oqmd(fit='standard')
                for rformula in runner.formula:
                    formulae = parse_formula_regex(rformula)
                    for formula in formulae:
                        tot = sum(formula.values())
                        formula = dict( (k, v/tot) for k, v in formula.items() )
                        bstring = '-'.join(formula.keys())
                        space = PhaseSpace(bstring, pdata=a._data)
                        phases |= set(space.phases)
                print '\n'.join([ p.name for p in phases])

            else:
                result = query(type=runner.task[0], 
                        formula=runner.formula, 
                        search=runner.search,
                        columns=runner.task[1:])
                print result

if __name__ == '__main__':
    main()
    if os.path.exists('gurobi.log'):
        os.unlink('gurobi.log')
