#!/usr/bin/env python

import argparse
import csv
import json
import logging

from os.path import isfile
from samrand.sampler import sample
from samrand.reader import read_csv, read_json
from sys import exit, stdout

parser = argparse.ArgumentParser(
    description='Sample your dataset randomly')
parser.add_argument(
    '--dataset', metavar='d', type=str, required=True,
    help='The file containing your dataset.')
parser.add_argument(
    '--size', metavar='s', type=int, required=True,
    help='The required sample size.')
parser.add_argument(
    '--header', required=False, action='store_true',
    help='When using a CSV dataset file, use this flag to indicate whether the first row is a header.')
parser.add_argument(
    '--replacement', required=False, action='store_true',
    help='Extract samples with replacement. Not including this flag means without replacement.')
parser.add_argument(
    '--stratify', required=False, action='store_true',
    help="Balance the extracted sample so that it reflects the population's distribution.")
parser.add_argument(
    '--strata', metavar='s', type=str, required=False,
    help='When using stratification, use this parameter to indicate which fields should be used as a basis for stratification. Accepts JSON arrays of column indices starting with 0.')
parser.add_argument(
    '--output', metavar='o', type=str, required=False,
    help='The output format of the samples. Default is JSON. Can be one of [CSV|JSON].')

# Set up logging
logger = logging.getLogger('SamRand')
logging.basicConfig(level=logging.INFO)
logger.info('Logger initialized.')

args = vars(parser.parse_args())

# Check dataset parameters are valid
if args['dataset'] is None:
    logger.fatal('You did not provide a valid dataset path.')
    exit(1)
if not isfile(args['dataset']):
    logger.fatal('The path you specified does not point to an existing dataset.')
    exit(1)

# Read input file and convert to something usable
logger.info('Reading dataset...')
if args['dataset'].endswith('.csv'):
    read_result = read_csv(args['dataset'], args['header'])
elif args['dataset'].endswith('.json'):
    read_result = read_json(args['dataset'])

# Check remaining parameters
if args['size'] <= 0:
    logger.fatal('The sample size must be greater than 0.')
    exit(1)

strata = args['strata']
if strata is None:
    strata = []
else:
    strata = json.loads(strata)

# Sample the dataset
result = sample(read_result, args['size'], args['stratify'], strata, args['replacement'])

# Output the dataset based on the user's desired format
if args['output'] and args['output'].upper() == 'CSV':
    csvwriter = csv.writer(stdout)
    for row in result:
        csvwriter.writerow(row)
else:
    json_output = []
    if args['header']:
        json_header = read_result[1]
        result = result[1:]  # Skip the existing header
    else:
        json_header = list(range(len(read_result[0][1])))
    for row in result:
        entry = dict()
        for index, header in enumerate(json_header):
            entry[header] = row[index]
        json_output.append(entry)
    print(json.dumps(json_output))
