#!/usr/bin/env python3

import argparse
import os.path
from mnist import MNIST
from mnist import img_packer

if __name__ == '__main__':

    parser = argparse.ArgumentParser()
    parser.add_argument("--data", default="./emnist_data",
                        help="Path to MNIST data dir")
    parser.add_argument("--output", default=None,
                        help="Where to save result")

    args = parser.parse_args()

    DATASETS = ["balanced", "byclass", "bymerge",
                "digits", "letters", "mnist"]

    mn = MNIST(args.data)

    if not args.output:
        dest = args.data
        train_img_fname = 'rf_' + mn.train_img_fname
        test_img_fname = 'rf_' + mn.test_img_fname
    else:
        dest = args.output
        train_img_fname = mn.train_img_fname
        test_img_fname = mn.test_img_fname

    for dt_name in DATASETS:
        mn.select_emnist(dt_name)

        print("========procesing {} dataset========".format(dt_name))

        tra_img, _ = mn.load_training()
        img_packer(dest, train_img_fname,
                   tra_img, gzip=True)

        tes_img, _ = mn.load_testing()
        img_packer(dest, test_img_fname,
                   tes_img, gzip=True)
