# Copyright 2021 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for serialization and deserialization of GDA."""

import math
import pathlib

from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src import config as jax_config
from jax.config import config
from jax._src import array
from jax._src.sharding import NamedSharding, GSPMDSharding
from jax.sharding import PartitionSpec as P
from jax.experimental.gda_serialization import serialization
import numpy as np
import tensorstore as ts

config.parse_flags_with_absl()


class CheckpointTest(jtu.JaxTestCase):

  @jax_config.jax_array(True)
  def test_checkpointing_jax_array(self):
    global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
    inp_shape = (8, 2)
    pspec = P('x', 'y')
    num = math.prod(inp_shape)

    # First Array
    global_input_data1 = np.arange(num).reshape(inp_shape)
    a1 = array.make_array_from_callback(
        inp_shape, NamedSharding(global_mesh, pspec),
        lambda idx: global_input_data1[idx])
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    # Second Array
    global_input_data2 = np.arange(num, num + num).reshape(inp_shape)
    a2 = array.make_array_from_callback(
        inp_shape, NamedSharding(global_mesh, pspec),
        lambda idx: global_input_data2[idx])
    ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

    # Third Array
    def cb3(_):
      return np.array([])
    global_mesh1d = jtu.create_global_mesh((8,), ('x',))
    a3 = array.make_array_from_callback(
        (0,), NamedSharding(global_mesh1d, P(None)), cb3)
    ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

    ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
    tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([a1, a2, a3], tspecs)

    m1, m2, m3 = serialization.run_deserialization(
        [NamedSharding(global_mesh, pspec),
         NamedSharding(global_mesh, P('x')),
         NamedSharding(global_mesh1d, P(None))],
        tspecs)

    self.assertIsInstance(m1, array.ArrayImpl)
    self.assertArraysEqual(np.asarray(m1.addressable_shards[0].data),
                           np.array([[0], [2]]))
    self.assertArraysEqual(np.asarray(m1.addressable_shards[1].data),
                           np.array([[1], [3]]))
    self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
    self.assertEqual(m1.dtype, np.int32)

    self.assertIsInstance(m2, array.ArrayImpl)
    self.assertArraysEqual(np.asarray(m2.addressable_shards[0].data),
                           np.array([[16, 17], [18, 19]]))
    self.assertArraysEqual(np.asarray(m2.addressable_shards[1].data),
                           np.array([[16, 17], [18, 19]]))
    self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
    self.assertEqual(m2.dtype, np.int32)

    self.assertIsInstance(m3, array.ArrayImpl)
    for i, s in enumerate(m3.addressable_shards):
      self.assertEqual(s.index, (slice(None),))
      self.assertEqual(s.replica_id, i)
      self.assertArraysEqual(np.asarray(s.data), np.array([]))
    self.assertEqual(m3.dtype, np.float32)

  @jax_config.jax_array(True)
  def test_checkpointing_with_bigger_shape_jax_array(self):
    global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
    global_input_shape = (8, 2)
    num = math.prod(global_input_shape)

    global_input_data1 = np.arange(num, dtype=np.int32).reshape(global_input_shape)
    def cb1(index):
      return global_input_data1[index]
    arr = array.make_array_from_callback(
        global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb1)
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    ckpt_paths = [str(ckpt_dir1)]
    tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([arr], tspecs)

    ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y'))

    m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)],
                                            [np.float32])

    expected_data = {
        0: np.array([[0], [2], [4]], dtype=np.float32),
        1: np.array([[1], [3], [5]], dtype=np.float32),
        2: np.array([[6], [8], [10]], dtype=np.float32),
        3: np.array([[7], [9], [11]], dtype=np.float32),
        4: np.array([[12], [14], [0]], dtype=np.float32),
        5: np.array([[13], [15], [0]], dtype=np.float32),
        6: np.array([[0], [0], [0]], dtype=np.float32),
        7: np.array([[0], [0], [0]], dtype=np.float32),
    }

    for l in m1.addressable_shards:
      self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id])

    new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat))
    m2, = serialization.run_deserialization([new_ds], tspecs, [(8, 2)], [np.float32])
    for l in m2.addressable_shards:
      self.assertArraysEqual(l.data, global_input_data1.astype('float32'))

  @jax_config.jax_array(True)
  def test_checkpointing_scalar_jax_array(self):
    global_mesh = jtu.create_global_mesh((2,), ('x'))
    global_input_shape = ()
    data = np.array(4)
    s = NamedSharding(global_mesh, P(None))
    array1 = array.make_array_from_callback(
        global_input_shape, s, lambda idx: data[idx])
    ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

    ckpt_paths = [str(ckpt_dir1)]
    tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

    serialization.run_serialization([array1], tspecs)
    ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None))

    m1, = serialization.run_deserialization(
        [ds],
        tspecs,
        [()],
        [np.float32]
    )

    for l in m1.addressable_shards:
      self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32))

  @jax_config.jax_array(True)
  def test_deserialize_tensorstore_array_jax_array(self):
    global_mesh = jtu.create_global_mesh((2,), ('x'))
    data = np.arange(1024)
    tspec = ts.array(data).spec()
    m1, = serialization.run_deserialization(
        [NamedSharding(global_mesh, P(None))],
        [tspec]
    )
    for l in m1.addressable_shards:
      self.assertArraysEqual(np.asarray(l.data), data)

  def test_spec_has_metadata(self):
    spec = {
        'a': {
            'b': 1,
            'c': 2,
        },
        'd': 3,
        'e': {
            'a': 2,
            'metadata': 3
        },
        'f': 4
    }
    self.assertTrue(serialization._spec_has_metadata(spec))
    self.assertTrue(
        serialization._spec_has_metadata({
            'driver': 'zarr',
            'kvstore': 'gfile',
            'metadata': {
                'chunks': 4,
                'shape': (32, 64)
            },
            'one_more': 'thing'
        }))

  def test_spec_has_no_metadata(self):
    spec = {
        'a': {
            'b': 1,
            'c': 2,
        },
        'd': 3,
        'e': {
            'a': 2,
        },
        'f': 4
    }
    self.assertFalse(serialization._spec_has_metadata(spec))

  def test_empty_spec_has_no_metadata(self):
    spec = {}
    self.assertFalse(serialization._spec_has_metadata(spec))

if __name__ == '__main__':
  absltest.main(testLoader=jtu.JaxTestLoader())
