#!/usr/bin/env python3

from restoreEpics import writeChannels, restoreEpics
import traceback
import argparse
import yaml


def grabInputArgs():
    parser = argparse.ArgumentParser(
        description='This script is a version of nominal caput command with '
                    'added functionality of\nsetting a tramp and reading a set'
                    ' of channels from yaml files if they need to\nbe set all '
                    'together or ramped all together to new values.',
        epilog='Example use:\n'
               'caputt C1:FIRST_CH 8.0\n'
               'caputt C1:FIRST_CH 8.0 -t 5\n'
               'caputt C1:FIRST_CH 8.0 C1:SECOND_CH 5.6 -t 6 -w -o 60\n'
               'caputt channelsFile.yml\n',
        formatter_class=argparse.RawDescriptionHelpFormatter)
    parser.add_argument('chanInfo', nargs='*',
                        help='Channel name and value pairs sepaated by space '
                             'or name of the yaml file containing them.')
    parser.add_argument('-t', '--tramp', type=float,
                        help='Global ramping time in seconds.', default=None)
    parser.add_argument('-w', '--wait', action='store_true',
                        help='Whether to wait until the processing has '
                             'completed. Global and would override any other '
                             'value from paramter file.')
    parser.add_argument('-o', '--timeout', type=float, default=None,
                        help='how long to wait (in seconds) for put to '
                             'complete before giving up. Global and would '
                             'override any other value from paramter file.')

    return parser.parse_args()


if __name__ == '__main__':
    args = grabInputArgs()
    if len(args.chanInfo) == 1:
        if (args.chanInfo[0].find('.yml') == -1
                and args.chanInfo[0].find('.yaml') == -1):
            raise RuntimeError('Please provide a value to write on channel '
                               + args.chanInfo[0])
        else:
            with open(args.chanInfo[0], 'r') as p:
                chanInfo = yaml.full_load(p)
    elif len(args.chanInfo) % 2 == 1:
        raise RuntimeError('You missed a value for one of the channels')
    else:
        chanInfo = {'channels': {}}
        for ii in range(0, len(args.chanInfo), 2):
            pvname = args.chanInfo[ii]
            value = float(args.chanInfo[ii + 1])
            chanInfo['channels'][pvname] = {'value': value}
    if args.tramp is not None:
        chanInfo['tramp'] = args.tramp
    if args.wait:
        chanInfo['wait'] = True
    if args.timeout is not None:
        chanInfo['timeout'] = args.timeout

    try:
        writeChannels(chanInfo)
    except BaseException:
        restoreEpics()
        traceback.print_exc()
