# coding=utf-8
# Copyright 2020 The TensorFlow Datasets Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python3
"""Defined Reader and ReadInstruction to read tfrecord files."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import copy
import functools
import math
import os
import re

import attr

import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow_datasets.core import _sharded_files
from tensorflow_datasets.core import api_utils
from tensorflow_datasets.core import example_parser
from tensorflow_datasets.core import naming
from tensorflow_datasets.core import utils


_BUFFER_SIZE = 8<<20  # 8 MiB per file.

_SUB_SPEC_RE = re.compile(r'''
^
 (?P<split>\w+)
 (\[
  ((?P<from>-?\d+)
   (?P<from_pct>%)?)?
  :
  ((?P<to>-?\d+)
   (?P<to_pct>%)?)?
 \])?
$
''', re.X)

_ADDITION_SEP_RE = re.compile(r'\s*\+\s*')


def _default_options():
  """Returns optimization options to given dataset."""
  options = tf.data.Options()
  options.experimental_threading.max_intra_op_parallelism = 1
  options.experimental_threading.private_threadpool_size = 16
  options.experimental_optimization.apply_default_optimizations = True
  options.experimental_optimization.map_fusion = True
  options.experimental_optimization.map_parallelization = True
  return options


def _get_dataset_from_filename(filename_skip_take, do_skip, do_take):
  """Returns a tf.data.Dataset instance from given (filename, skip, take)."""
  filename, skip, take = (filename_skip_take['filename'],
                          filename_skip_take['skip'],
                          filename_skip_take['take'],)

  # Explictly use DatasetV1 for backward compatibility:
  # * isinstance(ds, tf.data.Dataset)
  # * ds.make_one_shot_iterator()
  ds = tf.data.TFRecordDataset(
      filename,
      buffer_size=_BUFFER_SIZE,
      num_parallel_reads=1,
      )
  if do_skip:
    ds = ds.skip(skip)
  if do_take:
    ds = ds.take(take)
  return ds


@attr.s(frozen=True)
class FileInstructions(object):
  """The file instructions associated with a split ReadInstruction.

  Attributes:
    num_examples: `int`, The total number of examples
    file_instructions: List[dict(filename, skip, take)], the files information.
      The filenames contains the relative path, not absolute.
      skip/take indicates which example read in the shard: `ds.skip().take()`
  """
  num_examples = attr.ib()
  file_instructions = attr.ib()


def make_file_instructions(name, split_infos, instruction):
  """Returns instructions of the split dict.

  Args:
    name: Name of the dataset.
    split_infos: `List[SplitInfo]`, Dataset splits information
    instruction: `ReadInstruction` or `str`

  Returns:
    file_intructions: FileInstructions instance
  """
  name2shard_lengths = {
      info.name: info.shard_lengths for info in split_infos
  }
  name2len = {
      name: sum(lengths) for name, lengths in name2shard_lengths.items()
  }
  if not isinstance(instruction, ReadInstruction):
    instruction = ReadInstruction.from_spec(instruction)
  # Create the absolute instruction (per split)
  absolute_instructions = instruction.to_absolute(name2len)

  return _make_file_instructions_from_absolutes(
      name=name,
      name2shard_lengths=name2shard_lengths,
      absolute_instructions=absolute_instructions,
  )


def _make_file_instructions_from_absolutes(
    name,
    name2shard_lengths,
    absolute_instructions,
):
  """Returns the files instructions from the absolute instructions list."""
  # For each split, return the files instruction (skip/take)
  file_instructions = []
  num_examples = 0
  for abs_instr in absolute_instructions:
    shard_lengths = name2shard_lengths[abs_instr.splitname]
    if not shard_lengths:
      raise ValueError(
          'Shard empty. This might means that dataset hasn\'t been generated '
          'yet and info not restored from GCS, or that legacy dataset is used.')
    filenames = naming.filenames_for_dataset_split(
        dataset_name=name,
        split=abs_instr.splitname,
        num_shards=len(shard_lengths),
        filetype_suffix='tfrecord')
    from_ = 0 if abs_instr.from_ is None else abs_instr.from_
    to = sum(shard_lengths) if abs_instr.to is None else abs_instr.to
    num_examples += to - from_
    single_file_instructions = _sharded_files.get_read_instructions(
        from_, to, filenames, shard_lengths)
    file_instructions.extend(single_file_instructions)
  return FileInstructions(
      num_examples=int(num_examples),  # int() due to proto shard_length `long`
      file_instructions=file_instructions,
  )


def _read_files(
    files,
    parse_fn,
    read_config,
    shuffle_files,
    num_examples):
  """Returns tf.data.Dataset for given file instructions.

  Args:
    files: List[dict(filename, skip, take)], the files information.
      The filenames contain the absolute path, not relative.
      skip/take indicates which example read in the shard: `ds.skip().take()`
    parse_fn (callable): function used to parse each record.
    read_config: `tfds.ReadConfig`, Additional options to configure the
      input pipeline (e.g. seed, num parallel reads,...).
    shuffle_files (bool): Defaults to False. True to shuffle input files.
    num_examples: `int`, if defined, set the cardinality on the
      tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.
  """
  # Eventually apply a transformation to the instruction function.
  # This allow the user to have direct control over the interleave order.
  if read_config.experimental_interleave_sort_fn is not None:
    files = read_config.experimental_interleave_sort_fn(files)

  do_skip = any(f['skip'] > 0 for f in files)
  do_take = any(f['take'] > -1 for f in files)

  # Transpose the list[dict] into dict[list]
  tensor_inputs = {
      # skip/take need to be converted to int64 explicitly
      k: list(vals) if k == 'filename' else np.array(vals, dtype=np.int64)
      for k, vals in utils.zip_dict(*files)
  }

  parallel_reads = read_config.interleave_parallel_reads
  block_length = read_config.interleave_block_length

  instruction_ds = tf.data.Dataset.from_tensor_slices(tensor_inputs)

  # If shuffle is True, we shuffle the instructions/shards
  if shuffle_files:
    instruction_ds = instruction_ds.shuffle(
        len(tensor_inputs['filename']),
        seed=read_config.shuffle_seed,
        reshuffle_each_iteration=read_config.shuffle_reshuffle_each_iteration,
    )

  ds = instruction_ds.interleave(
      functools.partial(_get_dataset_from_filename,
                        do_skip=do_skip, do_take=do_take),
      cycle_length=parallel_reads,
      block_length=block_length,
      num_parallel_calls=tf.data.experimental.AUTOTUNE,
  )

  # If the number of examples read in the tf-record is known, we forward
  # the information to the tf.data.Dataset object.
  # Check the `tf.data.experimental` for backward compatibility with TF <= 2.1
  if num_examples and hasattr(tf.data.experimental, 'assert_cardinality'):
    ds = ds.apply(tf.data.experimental.assert_cardinality(num_examples))

  # TODO(tfds): Should merge the default options with read_config to allow users
  # to overwrite the default options.
  ds = ds.with_options(_default_options())  # Default performance options
  ds = ds.with_options(read_config.options)  # Additional users options

  # TODO(pierrot): `parse_example` uses
  # `tf.io.parse_single_example`. It might be faster to use `parse_example`,
  # after batching.
  # https://www.tensorflow.org/api_docs/python/tf/io/parse_example
  return ds.map(parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)


class Reader(object):
  """Build a tf.data.Dataset object out of Instruction instance(s).

  This class should not typically be exposed to the TFDS user.
  """

  def __init__(self, path, example_specs):
    """Initializes Reader.

    Args:
      path (str): path where tfrecords are stored.
      example_specs: spec to build ExampleParser.
    """
    self._path = path
    self._parser = example_parser.ExampleParser(example_specs)

  def read(
      self,
      name,
      instructions,
      split_infos,
      read_config,
      shuffle_files,
  ):
    """Returns tf.data.Dataset instance(s).

    Args:
      name (str): name of the dataset.
      instructions (ReadInstruction, List[], Dict[]): instruction(s) to read.
        Instructions can be string and will then be passed to the Instruction
        constructor as it.
      split_infos (list of SplitInfo proto): the available splits for dataset.
      read_config: `tfds.ReadConfig`, the input pipeline options
      shuffle_files (bool): If True, input files are shuffled before being read.

    Returns:
       a single tf.data.Dataset instance if instruction is a single
       ReadInstruction instance. Otherwise a dict/list of tf.data.Dataset
       corresponding to given instructions param shape.
    """
    def _read_instruction_to_ds(instruction):
      file_instructions = make_file_instructions(name, split_infos, instruction)
      files = file_instructions.file_instructions
      if not files:
        msg = 'Instruction "%s" corresponds to no data!' % instruction
        raise AssertionError(msg)
      return self.read_files(
          files=tuple(files),
          read_config=read_config,
          shuffle_files=shuffle_files,
          num_examples=file_instructions.num_examples,
      )

    return tf.nest.map_structure(_read_instruction_to_ds, instructions)

  def read_files(
      self,
      files,
      read_config,
      shuffle_files,
      num_examples=None,
  ):
    """Returns single tf.data.Dataset instance for the set of file instructions.

    Args:
      files: List[dict(filename, skip, take)], the files information.
        The filenames contains the relative path, not absolute.
        skip/take indicates which example read in the shard: `ds.skip().take()`
      read_config: `tfds.ReadConfig`, the input pipeline options
      shuffle_files (bool): If True, input files are shuffled before being read.
      num_examples: `int`, if defined, set the cardinality on the
        tf.data.Dataset instance with `tf.data.experimental.with_cardinality`.

    Returns:
       a tf.data.Dataset instance.
    """
    # Prepend path to filename
    files = copy.deepcopy(files)
    for f in files:
      f.update(filename=os.path.join(self._path, f['filename']))
    dataset = _read_files(
        files=files,
        read_config=read_config,
        parse_fn=self._parser.parse_example,
        shuffle_files=shuffle_files,
        num_examples=num_examples,
    )
    return dataset


@attr.s(frozen=True)
class _AbsoluteInstruction(object):
  """A machine friendly slice: defined absolute positive boundaries."""
  splitname = attr.ib()  # type: Text
  from_ = attr.ib()  # uint (starting index).
  to = attr.ib()  # uint (ending index).


@attr.s(frozen=True)
class _RelativeInstruction(object):
  """Represents a single parsed slicing instruction, can use % and negatives."""
  splitname = attr.ib()  # type: Text
  from_ = attr.ib()  # int (starting index) or None if no lower boundary.
  to = attr.ib()  # int (ending index) or None if no upper boundary.
  unit = attr.ib(validator=attr.validators.in_(['%', 'abs']))  # str
  rounding = attr.ib(validator=attr.validators.in_([
      'closest', 'pct1_dropremainder']))  # str

  @from_.validator
  @to.validator
  def check_boundary_pct(self, unused_attribute, value):
    if self.unit == '%' and value is not None and abs(value) > 100:
      raise AssertionError('Percent slice boundaries must be > -100 and < 100.')


def _str_to_relative_instruction(spec):
  """Returns ReadInstruction for given string."""
  res = _SUB_SPEC_RE.match(spec)
  if not res:
    raise AssertionError('Unrecognized instruction format: %s' % spec)
  unit = '%' if res.group('from_pct') or res.group('to_pct') else 'abs'
  return ReadInstruction(
      split_name=res.group('split'),
      rounding='closest',
      from_=int(res.group('from')) if res.group('from') else None,
      to=int(res.group('to')) if res.group('to') else None,
      unit=unit,
      )


def _pct_to_abs_pct1(boundary, num_examples):
  # Using math.trunc here, since -99.5% should give -99%, not -100%.
  if num_examples < 100:
    msg = ('Using "pct1_dropremainder" rounding on a split with less than 100 '
           'elements is forbidden: it always results in an empty dataset.')
    raise AssertionError(msg)
  return boundary * math.trunc(num_examples / 100.)


def _pct_to_abs_closest(boundary, num_examples):
  return int(round(boundary * num_examples / 100.))


def _rel_to_abs_instr(rel_instr, name2len):
  """Returns _AbsoluteInstruction instance for given RelativeInstruction.

  Args:
    rel_instr: RelativeInstruction instance.
    name2len: dict {split_name: num_examples}.
  """
  pct_to_abs = (_pct_to_abs_closest if rel_instr.rounding == 'closest'
                else _pct_to_abs_pct1)
  split = rel_instr.splitname
  if split not in name2len:
    raise ValueError('Unknown split "{}". Should be one of {}.'.format(
        split, list(name2len)))
  num_examples = name2len[split]
  from_ = rel_instr.from_
  to = rel_instr.to
  if rel_instr.unit == '%':
    from_ = 0 if from_ is None else pct_to_abs(from_, num_examples)
    to = num_examples if to is None else pct_to_abs(to, num_examples)
  else:
    from_ = 0 if from_ is None else from_
    to = num_examples if to is None else to
  if abs(from_) > num_examples or abs(to) > num_examples:
    msg = 'Requested slice [%s:%s] incompatible with %s examples.' % (
        from_ or '', to or '', num_examples)
    raise AssertionError(msg)
  if from_ < 0:
    from_ = num_examples + from_
  elif from_ == 0:
    from_ = None
  if to < 0:
    to = num_examples + to
  elif to == num_examples:
    to = None
  return _AbsoluteInstruction(split, from_, to)


class ReadInstruction(object):
  """Reading instruction for a dataset.

  Examples of usage:

  ```
  # The following lines are equivalent:
  ds = tfds.load('mnist', split='test[:33%]')
  ds = tfds.load('mnist', split=tfds.core.ReadInstruction.from_spec(
      'test[:33%]'))
  ds = tfds.load('mnist', split=tfds.core.ReadInstruction(
      'test', to=33, unit='%'))
  ds = tfds.load('mnist', split=tfds.core.ReadInstruction(
      'test', from_=0, to=33, unit='%'))

  # The following lines are equivalent:
  ds = tfds.load('mnist', split='test[:33%]+train[1:-1]')
  ds = tfds.load('mnist', split=tfds.core.ReadInstruction.from_spec(
      'test[:33%]+train[1:-1]'))
  ds = tfds.load('mnist', split=(
      tfds.core.ReadInstruction.('test', to=33, unit='%') +
      tfds.core.ReadInstruction.('train', from_=1, to=-1, unit='abs')))

  # 10-fold validation:
  tests = tfds.load(
      'mnist',
      [tfds.core.ReadInstruction('train', from_=k, to=k+10, unit='%')
       for k in range(0, 100, 10)])
  trains = tfds.load(
      'mnist',
      [tfds.core.ReadInstruction('train', to=k, unit='%') +
       tfds.core.ReadInstruction('train', from_=k+10, unit='%')
       for k in range(0, 100, 10)])
  ```

  """

  def _init(self, relative_instructions):
    # Private initializer.
    self._relative_instructions = relative_instructions

  @classmethod
  def _read_instruction_from_relative_instructions(cls, relative_instructions):
    """Returns ReadInstruction obj initialized with relative_instructions."""
    # Use __new__ to bypass __init__ used by public API and not conveniant here.
    result = cls.__new__(cls)
    result._init(relative_instructions)  # pylint: disable=protected-access
    return result

  @api_utils.disallow_positional_args(allowed=['split_name'])
  def __init__(
      self,
      split_name,
      rounding='closest',
      from_=None,
      to=None,
      unit=None,
  ):
    """Initialize ReadInstruction.

    Args:
      split_name (str): name of the split to read. Eg: 'train'.
      rounding (str): The rounding behaviour to use when percent slicing is
        used. Ignored when slicing with absolute indices.
        Possible values:
         - 'closest' (default): The specified percentages are rounded to the
           closest value. Use this if you want specified percents to be as
           much exact as possible.
         - 'pct1_dropremainder': the specified percentages are treated as
           multiple of 1%. Use this option if you want consistency. Eg:
             len(5%) == 5 * len(1%).
           Using this option, one might not be able to use the full set of
           examples, if the number of those is not a multiple of 100.
      from_ (int):
      to (int): alternative way of specifying slicing boundaries. If any of
        {from_, to, unit} argument is used, slicing cannot be specified as
        string.
      unit (str): optional, one of:
        '%': to set the slicing unit as percents of the split size.
        'abs': to set the slicing unit as absolute numbers.
    """
    # Unit is optional only if the full dataset is read, otherwise, will
    # `_RelativeInstruction` validator will fail.
    if from_ is None and to is None and unit is None:
      unit = '%'
    # This constructor is not always called. See factory method
    # `_read_instruction_from_relative_instructions`. Common init instructions
    # MUST be placed in the _init method.
    self._init(
        [_RelativeInstruction(split_name, from_, to, unit, rounding)])

  @classmethod
  def from_spec(cls, spec):
    """Creates a ReadInstruction instance out of a string spec.

    Args:
      spec (str): split(s) + optional slice(s) to read. A slice can be
            specified, using absolute numbers (int) or percentages (int). E.g.
              `test`: test split.
              `test + validation`: test split + validation split.
              `test[10:]`: test split, minus its first 10 records.
              `test[:10%]`: first 10% records of test split.
              `test[:-5%]+train[40%:60%]`: first 95% of test + middle 20% of
                                           train.

    Returns:
      ReadInstruction instance.
    """
    spec = str(spec)  # Need to convert to str in case of NamedSplit instance.
    subs = _ADDITION_SEP_RE.split(spec)
    if not subs:
      raise AssertionError('No instructions could be built out of %s' % spec)
    instruction = _str_to_relative_instruction(subs[0])
    return sum([_str_to_relative_instruction(sub) for sub in subs[1:]],
               instruction)

  def __add__(self, other):
    """Returns a new ReadInstruction obj, result of appending other to self."""
    if not isinstance(other, ReadInstruction):
      msg = 'ReadInstruction can only be added to another ReadInstruction obj.'
      raise AssertionError(msg)
    other_ris = other._relative_instructions  # pylint: disable=protected-access
    if self._relative_instructions[0].rounding != other_ris[0].rounding:
      raise AssertionError('It is forbidden to sum ReadInstruction instances '
                           'with different rounding values.')
    return self._read_instruction_from_relative_instructions(
        self._relative_instructions + other_ris)

  def __str__(self):
    return 'ReadInstruction(%s)' % self._relative_instructions

  def to_absolute(self, name2len):
    """Translate instruction into a list of absolute instructions.

    Those absolute instructions are then to be added together.

    Args:
      name2len: dict associating split names to number of examples.

    Returns:
      list of _AbsoluteInstruction instances (corresponds to the + in spec).
    """
    return [_rel_to_abs_instr(rel_instr, name2len)
            for rel_instr in self._relative_instructions]
