#!/usr/bin/env pyrocko-python

import sys
import logging
import tempfile
import math
import os.path as op
import shutil
try:
    from urllib.error import HTTPError
except ImportError:
    from urllib2 import HTTPError
import glob
import pipes
from optparse import OptionParser
from collections import defaultdict

import numpy as num

from pyrocko import trace, util, io, cake, catalog, automap, pile, model
from pyrocko import orthodrome, weeding
from pyrocko.client import fdsn
from pyrocko.io import resp, enhanced_sacpz as epz, stationxml

km = 1000.

g_sites_available = sorted(fdsn.g_site_abbr.keys())

geofon = catalog.Geofon()
usgs = catalog.USGS(catalog=None)

tfade_factor = 1.0
ffade_factors = 0.5, 1.5

fdsn.g_timeout = 60.


class starfill(object):
    def __getitem__(self, k):
        return '*'


def nice_seconds_floor(s):
    nice = [1., 10., 60., 600., 3600., 3.*3600., 12*3600., 24*3600., 48*3600.]
    p = s
    for x in nice:
        if s < x:
            return p

        p = x

    return s


def get_events(time_range, region=None, catalog=geofon, **kwargs):
    if not region:
        return catalog.get_events(time_range, **kwargs)

    events = []
    for (west, east, south, north) in automap.split_region(region):
        events.extend(
            catalog.get_events(
                time_range=time_range,
                lonmin=west,
                lonmax=east,
                latmin=south,
                latmax=north, **kwargs))

    return events


def cut_n_dump(traces, win, out_path):
    otraces = []
    for tr in traces:
        try:
            otr = tr.chop(win[0], win[1], inplace=False)
            otraces.append(otr)
        except trace.NoData:
            pass

    return io.save(otraces, out_path)


def get_events_by_name_or_date(event_names_or_dates, catalog=geofon):
    stimes = []
    for sev in event_names_or_dates:
        stimes.append(sev)

    events_out = []
    for stime in stimes:
        if op.isfile(stime):
            events_out.extend(model.Event.load_catalog(stime))
        elif stime.startswith('gfz'):
            event = geofon.get_event(stime)
            events_out.append(event)
        else:
            t = util.str_to_time(stime)
            events = get_events(time_range=(t-60., t+60.), catalog=catalog)
            events.sort(key=lambda ev: abs(ev.time - t))
            event = events[0]
            events_out.append(event)

    return events_out


class NoArrival(Exception):
    pass


class PhaseWindow(object):

    def __init__(self, model, phases, omin, omax):
        self.model = model
        self.phases = phases
        self.omin = omin
        self.omax = omax

    def __call__(self, time, distance, depth):
        tts = []
        for ray in self.model.arrivals(
                phases=self.phases,
                zstart=depth,
                distances=[distance*cake.m2d]):

            tts.append(ray.t)

        if tts:
            return time + min(tts) + self.omin, time + max(tts) + self.omax

        raise NoArrival()


class VelocityWindow(object):
    def __init__(self, vmin, vmax=None, tpad=0.0):
        self.vmin = vmin
        self.vmax = vmax
        self.tpad = tpad

    def __call__(self, time, distance, depth):
        ttmax = (depth + distance) / self.vmin
        if self.vmax is not None:
            ttmin = (depth + distance) / self.vmax
        else:
            ttmin = 0.0

        return time + ttmin - self.tpad, time + ttmax + self.tpad


class FixedWindow(object):
    def __init__(self, tmin, tmax):
        self.tmin = tmin
        self.tmax = tmax

    def __call__(self, time, distance, depth):
        return self.tmin, self.tmax


def dump_commandline(argv, fn):
    s = ' '.join([pipes.quote(x) for x in argv])
    with open(fn, 'w') as f:
        f.write(s)
        f.write('\n')


g_user_credentials = {}
g_auth_tokens = {}


def get_user_credentials(site):
    user, passwd = g_user_credentials.get(site, (None, None))
    token = g_auth_tokens.get(site, None)
    return dict(user=user, passwd=passwd, token=token)


program_name = 'grondown'
description = '''
Download waveforms from FDSN web services and prepare for Grond
'''.strip()

logger = logging.getLogger('')

usage = '''
usage: grondown [options] [--] <YYYY-MM-DD> <HH:MM:SS> <lat> <lon> \\
                               <depth_km> ( <radius_km> | <stations> )
                               <fmin_hz> <sampling_rate_hz> <eventname>

       grondown [options] [--] <YYYY-MM-DD> <HH:MM:SS> \\
                               ( <radius_km> | <stations> ) \\
                               <fmin_hz> <sampling_rate_hz> <eventname>

       grondown [options] [--] <catalog-eventname> \\
                               ( <radius_km> | <stations> ) \\
                               <fmin_hz> <sampling_rate_hz> <eventname>

       grondown [options] --window="<YYYY-MM-DD HH:MM:SS, YYYY-MM-DD HH:MM:\
SS>" \\
                               [--] <lat> <lon>
                               ( <radius_km> | <stations> ) <fmin_hz> \\
                               <sampling_rate_hz> <eventname>
'''.strip()

if __name__ == '__main__':
    parser = OptionParser(
        usage=usage,
        description=description)

    parser.add_option(
        '--force',
        dest='force',
        action='store_true',
        default=False,
        help='allow recreation of output <directory>')

    parser.add_option(
        '--debug',
        dest='debug',
        action='store_true',
        default=False,
        help='print debugging information to stderr')

    parser.add_option(
        '--dry-run',
        dest='dry_run',
        action='store_true',
        default=False,
        help='show available stations/channels and exit '
             '(do not download waveforms)')

    parser.add_option(
        '--continue',
        dest='continue_',
        action='store_true',
        default=False,
        help='continue download after a accident')

    parser.add_option(
        '--local-catalog',
        dest='local_catalog',
        metavar='FILENAME',
        help='local catalog to use in combination with <catalog-eventname>')

    parser.add_option(
        '--local-data',
        dest='local_data',
        action='append',
        help='add file/directory with local data')

    parser.add_option(
        '--local-stations',
        dest='local_stations',
        action='append',
        help='add local stations file')

    parser.add_option(
        '--local-responses-resp',
        dest='local_responses_resp',
        action='append',
        help='add file/directory with local responses in RESP format')

    parser.add_option(
        '--local-responses-pz',
        dest='local_responses_pz',
        action='append',
        help='add file/directory with local pole-zero responses')

    parser.add_option(
        '--local-responses-stationxml',
        dest='local_responses_stationxml',
        help='add file with local response information in StationXML format')

    parser.add_option(
        '--window',
        dest='window',
        default='full',
        help='set time window to choose [full, p, body, '
             '"<time-start>,<time-end>"] (time format is YYYY-MM-DD HH:MM:SS)')

    parser.add_option(
        '--out-components',
        choices=['enu', 'rtu'],
        dest='out_components',
        default='rtu',
        help='set output component orientations to radial-transverse-up [rtu] '
             '(default) or east-north-up [enu]')

    parser.add_option(
        '--padding-factor',
        type=float,
        default=3.0,
        dest='padding_factor',
        help='extend time window on either side, in multiples of 1/<fmin_hz> '
             '(default: 5)')

    parser.add_option(
        '--credentials',
        dest='user_credentials',
        action='append',
        default=[],
        metavar='SITE,USER,PASSWD',
        help='user credentials for specific site to access restricted data '
             '(this option can be repeated)')

    parser.add_option(
        '--token',
        dest='auth_tokens',
        metavar='SITE,FILENAME',
        action='append',
        default=[],
        help='user authentication token for specific site to access '
             'restricted data (this option can be repeated)')

    parser.add_option(
        '--sites',
        dest='sites',
        metavar='SITE1,SITE2,...',
        default='geofon,iris,orfeus',
        help='sites to query (available: %s, default: "%%default"'
        % ', '.join(g_sites_available))

    parser.add_option(
        '--network',
        dest='network',
        metavar='NET1,NET2,...',
        help='networks to query in the FDSN station queries')

    parser.add_option(
        '--band-codes',
        dest='priority_band_code',
        metavar='V,L,M,B,H,S,E,...',
        default='V,L,M,B,H,S,E',
        help='select and prioritize band codes (default: %default)')

    parser.add_option(
        '--instrument-codes',
        dest='priority_instrument_code',
        metavar='H,L,G,...',
        default='H,L',
        help='select and prioritize instrument codes (default: %default)')

    parser.add_option(
        '--radius-min',
        dest='radius_min',
        metavar='VALUE',
        default=0.0,
        type=float,
        help='minimum radius [km]')

    parser.add_option(
        '--nstations-wanted',
        dest='nstations_wanted',
        metavar='N',
        type=int,
        help='number of stations to select initially')

    (options, args) = parser.parse_args(sys.argv[1:])

    if len(args) not in (9, 6, 5):
        parser.print_help()
        sys.exit(1)

    if options.debug:
        util.setup_logging(program_name, 'debug')
    else:
        util.setup_logging(program_name, 'info')

    if options.local_responses_pz and options.local_responses_resp:
        logger.critical('cannot use local responses in PZ and RESP '
                        'format at the same time')
        sys.exit(1)

    n_resp_opt = 0
    for resp_opt in (
            options.local_responses_pz,
            options.local_responses_resp,
            options.local_responses_stationxml):

        if resp_opt:
            n_resp_opt += 1

    if n_resp_opt > 1:
        logger.critical('can only handle local responses from either PZ or '
                        'RESP or StationXML. Cannot yet merge different '
                        'response formats.')
        sys.exit(1)

    if options.local_responses_resp and not options.local_stations:
        logger.critical('--local-responses-resp can only be used '
                        'when --stations is also given.')
        sys.exit(1)

    try:
        ename = ''
        magnitude = None
        depth = None
        mt = None
        if len(args) == 9:
            time = util.str_to_time(args[0] + ' ' + args[1])
            lat = float(args[2])
            lon = float(args[3])
            depth = float(args[4])*km
            iarg = 5

        elif len(args) == 6:
            if args[1].find(':') == -1:
                sname_or_date = None
                lat = float(args[0])
                lon = float(args[1])
                event = None
                time = None
            else:
                sname_or_date = args[0] + ' ' + args[1]

            iarg = 2

        elif len(args) == 5:
            sname_or_date = args[0]
            iarg = 1

        if len(args) in (6, 5) and sname_or_date is not None:
            if options.local_catalog:
                events = [
                    ev for ev in model.load_events(options.local_catalog)
                    if ev.name == sname_or_date]
            else:
                events = get_events_by_name_or_date([sname_or_date],
                                                    catalog=geofon)
            if len(events) == 0:
                logger.critical('no event found')
                sys.exit(1)
            elif len(events) > 1:
                logger.critical('more than one event found')
                sys.exit(1)

            event = events[0]
            time = event.time
            lat = event.lat
            lon = event.lon
            depth = event.depth
            ename = event.name
            magnitude = event.magnitude
            mt = event.moment_tensor

        sarg = args[iarg]
        if op.exists(sarg):
            stations_want = model.load_stations(sarg)
            if event is None:
                radius_max = None
            else:
                radius_max = max(
                    orthodrome.distance_accurate50m(event, station)
                    for station in stations_want)
        else:
            stations_want = None
            radius_max = float(sarg)*km

        fmin = float(args[iarg+1])
        sample_rate = float(args[iarg+2])

        eventname = args[iarg+3]
        event_dir = op.join('data', 'events', eventname)
        output_dir = op.join(event_dir, 'waveforms')

    except Exception as e:
        logger.critical('an error occurred: %s' % e)
        sys.exit(1)

    if depth is None:
        logger.warning('No event depth given. Assuming zero.')
        depth = 0.0

    if options.force and op.isdir(event_dir):
        if not options.continue_:
            shutil.rmtree(event_dir)

    if op.exists(event_dir) and not options.continue_:
        logger.critical(
            'directory "%s" exists. Delete it first or use the --force option'
            % event_dir)
        sys.exit(1)

    util.ensuredir(output_dir)

    if time is not None:
        event = model.Event(
            time=time, lat=lat, lon=lon, depth=depth, name=ename,
            magnitude=magnitude, moment_tensor=mt)

    if options.window == 'full':
        if event is None:
            logger.critical('need event for --window=full')
            sys.exit(1)

        low_velocity = 1500.
        timewindow = VelocityWindow(
            low_velocity, tpad=options.padding_factor/fmin)

        tmin, tmax = timewindow(time, radius_max, depth)

    elif options.window in ('p', 'body'):
        if event is None:
            logger.critical('need event for --window=%s' % options.window)
            sys.exit(1)

        if options.window == 'p':
            phases = list(map(cake.PhaseDef, 'P p'.split()))

        elif options.window == 'body':
            phases = list(map(cake.PhaseDef, 'P p S s SS'.split()))

        emod = cake.load_model()

        tpad = options.padding_factor / fmin
        timewindow = PhaseWindow(emod, phases, -tpad, tpad)

        arrivaltimes = []
        for dist in num.linspace(options.radius_min, radius_max, 20):
            try:
                arrivaltimes.extend(timewindow(time, dist, depth))
            except NoArrival:
                pass

        if not arrivaltimes:
            logger.error('required phase arrival not found')
            sys.exit(1)

        tmin = min(arrivaltimes)
        tmax = max(arrivaltimes)

    else:
        try:
            stmin, stmax = options.window.split(',')
            tmin = util.str_to_time(stmin.strip())
            tmax = util.str_to_time(stmax.strip())

            timewindow = FixedWindow(tmin, tmax)

        except ValueError:
            logger.critical('invalid argument to --window: "%s"'
                            % options.window)
            sys.exit(1)

    if event is not None:
        event.name = eventname

    tlen = tmax - tmin
    tfade = tfade_factor / fmin

    tpad = tfade

    tmin -= tpad
    tmax += tpad

    tinc = None

    priority_band_code = options.priority_band_code.split(',')
    for s in priority_band_code:
        if len(s) != 1:
            logger.critical('invalid band code: %s' % s)

    priority_instrument_code = options.priority_instrument_code.split(',')
    for s in priority_instrument_code:
        if len(s) != 1:
            logger.critical('invalid instrument code: %s' % s)

    target_sample_rate = sample_rate

    fmax = target_sample_rate

    # target_sample_rate = None
    # priority_band_code = ['H', 'B', 'M', 'L', 'V', 'E', 'S']

    priority_units = ['M/S', 'M', 'M/S**2']

    output_units = 'M'

    sites = [x.strip() for x in options.sites.split(',') if x.strip()]

    for site in sites:
        if site not in g_sites_available:
            logger.critical('unknown FDSN site: %s' % site)
            sys.exit(1)

    for s in options.user_credentials:
        try:
            site, user, passwd = s.split(',')
            g_user_credentials[site] = user, passwd
        except ValueError:
            logger.critical('invalid format for user credentials: "%s"' % s)
            sys.exit(1)

    for s in options.auth_tokens:
        try:
            site, token_filename = s.split(',')
            with open(token_filename, 'r') as f:
                g_auth_tokens[site] = f.read()
        except (ValueError, OSError, IOError):
            logger.critical('cannot get token from file: %s' % token_filename)
            sys.exit(1)

    fn_template0 = \
        'data_%(network)s.%(station)s.%(location)s.%(channel)s_%(tmin)s.mseed'

    fn_template_raw = op.join(output_dir, 'raw',  fn_template0)
    fn_stations_raw = op.join(output_dir, 'stations.raw.txt')
    fn_template_rest = op.join(output_dir, 'rest',  fn_template0)
    fn_commandline = op.join(output_dir, 'grondown.command')

    ftap = (ffade_factors[0]*fmin, fmin, fmax, ffade_factors[1]*fmax)

    # chapter 1: download

    sxs = []
    for site in sites:
        try:
            extra_args = {
                'iris': dict(matchtimeseries=True),
            }.get(site, {})

            extra_args.update(
                includerestricted=(
                    site in g_user_credentials or site in g_auth_tokens))

            if not stations_want:
                extra_args.update(
                    latitude=lat,
                    longitude=lon,
                    minradius=options.radius_min*km*cake.m2d,
                    maxradius=radius_max*cake.m2d,
                    channel=','.join('%s??' % s for s in priority_band_code))

                if site == 'geonet':
                    extra_args.update(
                        starttime=tmin,
                        endtime=tmax)
                else:
                    extra_args.update(
                        startbefore=tmax,
                        endafter=tmin)

            else:
                selection = []
                for station in stations_want:
                    for channel in station.get_channels():
                        selection.append(
                            station.nsl() + (channel.name, tmin, tmax))

                extra_args.update(selection=selection)

            if options.network:
                extra_args['network'] = options.network

            logger.info('downloading channel information (%s)' % site)
            sx = fdsn.station(
                site=site,
                format='text',
                level='channel',
                **extra_args)

        except fdsn.EmptyResult:
            logger.error('No stations matching given criteria. (%s)' % site)
            sx = None

        sxs.append(sx)

    if all(sx is None for sx in sxs) and not options.local_data:
        sys.exit(1)

    nsl_to_sites = defaultdict(list)
    nsl_to_station = {}
    for sx, site in zip(sxs, sites):
        if sx is not None:
            site_stations = sx.get_pyrocko_stations()
        else:
            site_stations = []

        for s in site_stations:
            nsl = s.nsl()
            nsl_to_sites[nsl].append(site)
            if nsl not in nsl_to_station:
                nsl_to_station[nsl] = s  # using first site with this station

            logger.debug(str(s))

    logger.info('number of stations found: %i' % len(nsl_to_station))

    # station weeding

    nsls_selected = None
    if options.nstations_wanted:
        stations_all = [
            nsl_to_station[nsl_] for nsl_ in sorted(nsl_to_station.keys())]

        for s in stations_all:
            s.set_event_relative_data(event)

        stations_selected = weeding.weed_stations(
            stations_all, options.nstations_wanted)[0]

        nsls_selected = set(s.nsl() for s in stations_selected)
        logger.info('number of stations selected: %i' % len(nsls_selected))

    if tinc is None:
        tinc = 3600.

    have_data = set()

    if options.continue_:
        fns = glob.glob(fn_template_raw % starfill())
        p = pile.make_pile(fns)
    else:
        fns = []

    have_data_site = {}
    could_have_data_site = {}
    for site in sites:
        have_data_site[site] = set()
        could_have_data_site[site] = set()

    available_through = defaultdict(set)
    it = 0
    nt = int(math.ceil((tmax - tmin) / tinc))
    for it in range(nt):
        tmin_win = tmin + it * tinc
        tmax_win = min(tmin + (it + 1) * tinc, tmax)
        logger.info('time window %i/%i (%s - %s)' % (it+1, nt,
                                                     util.tts(tmin_win),
                                                     util.tts(tmax_win)))

        have_data_this_window = set()
        if options.continue_:
            trs_avail = p.all(tmin=tmin_win, tmax=tmax_win, load_data=False)
            for tr in trs_avail:
                have_data_this_window.add(tr.nslc_id)
                have_data.add(tr.nslc_id)

        for site, sx in zip(sites, sxs):
            if sx is None:
                continue

            selection = []
            channels = sx.choose_channels(
                target_sample_rate=target_sample_rate,
                priority_band_code=priority_band_code,
                priority_units=priority_units,
                priority_instrument_code=priority_instrument_code,
                timespan=(tmin_win, tmax_win))

            for nslc in sorted(channels.keys()):
                if nsls_selected is not None and nslc[:3] not in nsls_selected:
                    continue

                could_have_data_site[site].add(nslc)

                if nslc not in have_data_this_window:
                    channel = channels[nslc]
                    if event:
                        lat_, lon_ = event.lat, event.lon
                    else:
                        lat_, lon_ = lat, lon

                    dist = orthodrome.distance_accurate50m_numpy(
                        lat_, lon_,
                        channel.latitude.value, channel.longitude.value)

                    if event:
                        depth_ = event.depth or 0.0
                        time_ = event.time
                    else:
                        depth_ = None
                        time_ = None

                    try:
                        tmin_, tmax_ = timewindow(time_, dist, depth_)
                    except NoArrival:
                        continue

                    tmin_this = tmin_ - tpad
                    tmax_this = tmax_ + tpad

                    tmin_req = max(tmin_win, tmin_this)
                    tmax_req = min(tmax_win, tmax_this)

                    if channel.sample_rate:
                        deltat = 1.0 / channel.sample_rate.value
                    else:
                        deltat = 1.0

                    if tmin_req < tmax_req:
                        # extend time window by some samples because otherwise
                        # sometimes gaps are produced
                        selection.append(
                            nslc + (
                                tmin_req-deltat*10.0,
                                tmax_req+deltat*10.0))

            if options.dry_run:
                for (net, sta, loc, cha, tmin, tmax) in selection:
                    available_through[net, sta, loc, cha].add(site)

            else:
                neach = 100
                i = 0
                nbatches = ((len(selection)-1) // neach) + 1
                while i < len(selection):
                    selection_now = selection[i:i+neach]

                    f = tempfile.NamedTemporaryFile()
                    try:
                        sbatch = ''
                        if nbatches > 1:
                            sbatch = ' (batch %i/%i)' % (
                                (i//neach) + 1, nbatches)

                        logger.info('downloading data (%s)%s' % (site, sbatch))
                        data = fdsn.dataselect(
                            site=site, selection=selection_now,
                            **get_user_credentials(site))

                        while True:
                            buf = data.read(1024)
                            if not buf:
                                break
                            f.write(buf)

                        f.flush()

                        trs = io.load(f.name)
                        for tr in trs:
                            try:
                                tr.chop(tmin_win, tmax_win)
                                have_data.add(tr.nslc_id)
                                have_data_this_window.add(tr.nslc_id)
                                have_data_site[site].add(tr.nslc_id)
                            except trace.NoData:
                                pass

                        fns2 = io.save(trs, fn_template_raw)
                        for fn in fns2:
                            if fn in fns:
                                logger.warning('overwriting file %s', fn)
                        fns.extend(fns2)

                    except fdsn.EmptyResult:
                        pass

                    except HTTPError:
                        logger.warning(
                            'an error occurred while downloading data '
                            'for channels \n  %s' % '\n  '.join(
                                '.'.join(x[:4]) for x in selection_now))

                    f.close()
                    i += neach

    if options.dry_run:
        nslcs = sorted(available_through.keys())

        all_channels = defaultdict(set)
        all_stations = defaultdict(set)

        def plural_s(x):
            return '' if x == 1 else 's'

        for nslc in nslcs:
            sites = tuple(sorted(available_through[nslc]))
            logger.info('selected: %s.%s.%s.%s from site%s %s' % (
                nslc + (plural_s(len(sites)), '+'.join(sites))))

            all_channels[sites].add(nslc)
            all_stations[sites].add(nslc[:3])

        nchannels_all = 0
        nstations_all = 0
        for sites in sorted(
                all_channels.keys(),
                key=lambda sites: (-len(sites), sites)):

            nchannels = len(all_channels[sites])
            nstations = len(all_stations[sites])
            nchannels_all += nchannels
            nstations_all += nstations
            logger.info(
                'selected (%s): %i channel%s (%i station%s)' % (
                    '+'.join(sites),
                    nchannels,
                    plural_s(nchannels),
                    nstations,
                    plural_s(nstations)))

        logger.info(
            'selected total: %i channel%s (%i station%s)' % (
                nchannels_all,
                plural_s(nchannels_all),
                nstations_all,
                plural_s(nstations_all)))

        logger.info('dry run done.')
        sys.exit(0)

    for nslc in have_data:
        # if we are in continue mode, we have to guess where the data came from
        if not any(nslc in have_data_site[site] for site in sites):
            for site in sites:
                if nslc in could_have_data_site[site]:
                    have_data_site[site].add(nslc)

    sxs = {}
    for site in sites:
        selection = []
        for nslc in sorted(have_data_site[site]):
            selection.append(nslc + (tmin-tpad, tmax+tpad))

        if selection:
            neach = 100
            nbatches = ((len(selection)-1) // neach) + 1
            sxs_batches = []
            ibatch = 0
            for ibatch in range(nbatches):
                selection_batch = selection[ibatch*neach:(ibatch+1)*neach]

                logger.info(
                    'downloading response information for %i channels  '
                    '(%s) (batch %i/%i)' % (
                        len(selection_batch),
                        site, ibatch+1, nbatches))

                sxs_batch = fdsn.station(
                    site=site, level='response', selection=selection_batch)

                sxs_batch.dump_xml(
                    filename=op.join(
                        output_dir,
                        'stations.%s.%i.xml' % (site, ibatch)))

                sxs_batches.append(sxs_batch)

            sxs[site] = stationxml.primitive_merge(sxs_batches)
            sxs[site].dump_xml(
                filename=op.join(
                    output_dir,
                    'stations.%s.xml' % site))

    # chapter 1.5: inject local data

    if options.local_data:
        have_data_site['local'] = set()
        plocal = pile.make_pile(options.local_data, fileformat='detect')
        for traces in plocal.chopper_grouped(
                gather=lambda tr: tr.nslc_id,
                tmin=tmin,
                tmax=tmax,
                tinc=tinc):

            for tr in traces:
                if tr.nslc_id not in have_data:
                    fns.extend(io.save(traces, fn_template_raw))
                    have_data_site['local'].add(tr.nslc_id)
                    have_data.add(tr.nslc_id)

        sites.append('local')

    if options.local_responses_pz:
        sxs['local'] = epz.make_stationxml(
            epz.iload(options.local_responses_pz))

    if options.local_responses_resp:
        local_stations = []
        for fn in options.local_stations:
            local_stations.extend(
                model.load_stations(fn))

        sxs['local'] = resp.make_stationxml(
            local_stations, resp.iload(options.local_responses_resp))

    if options.local_responses_stationxml:
        sxs['local'] = stationxml.load_xml(
            filename=options.local_responses_stationxml)

    # chapter 1.6: dump raw data stations file

    nsl_to_station = {}
    for site in sites:
        if site in sxs:
            stations = sxs[site].get_pyrocko_stations(timespan=(tmin, tmax))
            for s in stations:
                nsl = s.nsl()
                if nsl not in nsl_to_station:
                    nsl_to_station[nsl] = s

    stations = [
        nsl_to_station[nsl_] for nsl_ in sorted(nsl_to_station.keys())]

    util.ensuredirs(fn_stations_raw)
    model.dump_stations(stations, fn_stations_raw)

    dump_commandline(sys.argv, fn_commandline)

    # chapter 2: restitution

    if not fns:
        logger.error('no data available')
        sys.exit(1)

    p = pile.make_pile(fns, show_progress=False)
    p.get_deltatmin()
    otinc = None
    if otinc is None:
        otinc = nice_seconds_floor(p.get_deltatmin() * 500000.)
    otinc = 3600.
    otmin = math.floor(p.tmin / otinc) * otinc
    otmax = math.ceil(p.tmax / otinc) * otinc
    otpad = tpad*2

    fns = []
    rest_traces_b = []
    win_b = None
    for traces_a in p.chopper_grouped(
            gather=lambda tr: tr.nslc_id,
            tmin=otmin,
            tmax=otmax,
            tinc=otinc,
            tpad=otpad):

        rest_traces_a = []
        win_a = None
        for tr in traces_a:
            win_a = tr.wmin, tr.wmax

            if win_b and win_b[0] >= win_a[0]:
                fns.extend(cut_n_dump(rest_traces_b, win_b, fn_template_rest))
                rest_traces_b = []
                win_b = None

            response = None
            failure = []
            for site in sites:
                try:
                    if site not in sxs:
                        continue
                    response = sxs[site].get_pyrocko_response(
                        tr.nslc_id,
                        timespan=(tr.tmin, tr.tmax),
                        fake_input_units=output_units)

                    break

                except stationxml.NoResponseInformation:
                    failure.append('%s: no response information' % site)

                except stationxml.MultipleResponseInformation:
                    failure.append('%s: multiple response information' % site)

            if response is None:
                failure = ', '.join(failure)

            else:
                failure = ''
                try:
                    rest_tr = tr.transfer(tfade, ftap, response, invert=True)
                    rest_traces_a.append(rest_tr)

                except (trace.TraceTooShort, trace.NoData):
                    failure = 'trace too short'

            if failure:
                logger.warning('failed to restitute trace %s.%s.%s.%s (%s)' %
                            (tr.nslc_id + (failure,)))

        if rest_traces_b:
            rest_traces = trace.degapper(rest_traces_b + rest_traces_a,
                                         deoverlap='crossfade_cos')

            fns.extend(cut_n_dump(rest_traces, win_b, fn_template_rest))
            rest_traces_a = []
            if win_a:
                for tr in rest_traces:
                    try:
                        rest_traces_a.append(
                            tr.chop(win_a[0], win_a[1]+otpad,
                                    inplace=False))
                    except trace.NoData:
                        pass

        rest_traces_b = rest_traces_a
        win_b = win_a

    fns.extend(cut_n_dump(rest_traces_b, win_b, fn_template_rest))

    # chapter 3: rotated restituted traces for inspection

    if not event:
        sys.exit(0)

    fn_template1 = \
        'DISPL.%(network)s.%(station)s.%(location)s.%(channel)s'

    fn_waveforms = op.join(output_dir, 'prepared',  fn_template1)
    fn_stations = op.join(output_dir, 'stations.prepared.txt')
    fn_event = op.join(event_dir, 'event.txt')

    nsl_to_station = {}
    for site in sites:
        if site in sxs:
            stations = sxs[site].get_pyrocko_stations(timespan=(tmin, tmax))
            for s in stations:
                nsl = s.nsl()
                if nsl not in nsl_to_station:
                    nsl_to_station[nsl] = s

    p = pile.make_pile(fns, show_progress=False)

    deltat = None
    if sample_rate is not None:
        deltat = 1.0 / sample_rate

    used_stations = []
    for nsl, s in nsl_to_station.items():
        s.set_event_relative_data(event)
        traces = p.all(trace_selector=lambda tr: tr.nslc_id[:3] == nsl)

        keep = []
        for tr in traces:

            if deltat is not None:
                if tr.data_len() * tr.deltat < 10*deltat:
                    continue

                try:
                    tr.downsample_to(deltat, snap=True, allow_upsample_max=5)
                    keep.append(tr)
                except util.UnavailableDecimation as e:
                    logger.warning('Cannot downsample %s.%s.%s.%s: %s'
                                % (tr.nslc_id + (e,)))
                    continue

        if options.out_components == 'rtu':
            pios = s.guess_projections_to_rtu(out_channels=('R', 'T', 'Z'))
        elif options.out_components == 'enu':
            pios = s.guess_projections_to_enu(out_channels=('E', 'N', 'Z'))
        else:
            assert False

        for (proj, in_channels, out_channels) in pios:

            proc = trace.project(traces, proj, in_channels, out_channels)
            for tr in proc:
                for ch in out_channels:
                    if ch.name == tr.channel:
                        s.add_channel(ch)

            if proc:
                io.save(proc, fn_waveforms)
                used_stations.append(s)

    stations = list(used_stations)
    util.ensuredirs(fn_stations)
    model.dump_stations(stations, fn_stations)
    model.dump_events([event], fn_event)

    logger.info('prepared waveforms from %i stations' % len(stations))
