# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TPUStrategy."""

import os

from absl.testing import parameterized

from tensorflow.python.checkpoint import checkpoint as util
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.compiler.xla.experimental import xla_sharding
from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import packed_distributed_variable as packed
from tensorflow.python.distribute import strategy_test_lib
from tensorflow.python.distribute import tpu_replicated_variable
from tensorflow.python.distribute import tpu_strategy as tpu_lib
from tensorflow.python.distribute import tpu_values
from tensorflow.python.distribute.cluster_resolver import tpu_cluster_resolver
from tensorflow.python.eager import def_function
from tensorflow.python.eager import remote
from tensorflow.python.eager import test
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import summary_test_util
from tensorflow.python.module import module
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables
from tensorflow.python.platform import flags
from tensorflow.python.tpu import device_assignment as device_assignment_lib
from tensorflow.python.tpu import tpu_replication
from tensorflow.python.tpu import tpu_strategy_util

FLAGS = flags.FLAGS
flags.DEFINE_string("tpu", "", "Name of TPU to connect to.")
flags.DEFINE_string("project", None, "Name of GCP project with TPU.")
flags.DEFINE_string("zone", None, "Name of GCP zone with TPU.")


def get_tpu_cluster_resolver():
  resolver = tpu_cluster_resolver.TPUClusterResolver(
      tpu=FLAGS.tpu,
      zone=FLAGS.zone,
      project=FLAGS.project,
  )
  return resolver


def get_tpu_strategy(enable_spmd=False):
  resolver = get_tpu_cluster_resolver()
  remote.connect_to_cluster(resolver)
  topology = tpu_strategy_util.initialize_tpu_system(resolver)
  num_replicas = resolver.get_tpu_system_metadata().num_cores // 2
  device_assignment = device_assignment_lib.DeviceAssignment.build(
      topology, num_replicas=num_replicas, computation_shape=[1, 1, 1, 2])
  strategy = tpu_lib.TPUStrategyV2(
      resolver,
      experimental_device_assignment=device_assignment,
      experimental_spmd_xla_partitioning=enable_spmd)
  return strategy, num_replicas


class TPUStrategyModelParallelismTest(
    strategy_test_lib.DistributionTestBase,
    strategy_test_lib.TwoDeviceDistributionTestBase,
    parameterized.TestCase):

  @parameterized.named_parameters([("packed", True), ("unpacked", False)])
  def test_spmd_variable_structure(self, enable_packing):
    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)

    # pylint: disable=protected-access
    if enable_packing:
      self.assertTrue(strategy._enable_packed_variable_in_eager_mode,
                      "packed variables should be enabled by default")
    else:
      strategy._enable_packed_variable_in_eager_mode = False
    # pylint: enable=protected-access

    tensor = constant_op.constant([[0., 1.], [2., 3.]])

    # Test TPUMirroredVariable and TPUSyncOnReadVariable
    with strategy.scope():
      v = variables.Variable(
          tensor, name="v", synchronization=vs.VariableSynchronization.ON_READ)
      w = variables.Variable(
          tensor, name="w", synchronization=vs.VariableSynchronization.ON_WRITE)

    def test_read(x):
      @def_function.function
      def fn():
        return x.read_value()

      results = strategy.run(fn)
      results = strategy.experimental_local_results(results)

      for i in range(num_replicas):
        self.assertAllClose(results[i], tensor)

    def test_structure(values):
      for i, value in enumerate(values):
        self.assertIsInstance(
            value, tpu_replicated_variable.TPUReplicatedVariable)
        packed_var = getattr(value, "_packed_var", None)
        if enable_packing:
          if i == 0:
            self.assertIsInstance(packed_var, packed.PackedDistributedVariable)
          else:
            self.assertIs(packed_var, values[0]._packed_var,  # pylint: disable=protected-access
                          "all vals should share the same packed var instance")
        else:
          self.assertIsNone(packed_var)

      if enable_packing:
        # pylint: disable=protected-access
        resources = sum((value._vars for value in values), [])
        dist_vars = packed_var._distributed_variables
        # pylint: enable=protected-access
        self.assertLen(resources, len(dist_vars))
        for dist_var, resource in zip(dist_vars, resources):
          self.assertIs(dist_var, resource)

    test_read(v)
    test_structure(v.values)
    test_read(w)
    test_structure(w.values)

  def test_logical_device_assignment(self):
    strategy, num_replicas = get_tpu_strategy()
    with strategy.scope():
      v = variables.Variable(2.)
      with strategy.extended.experimental_logical_device(1):
        w = variables.Variable(3.)

    self.assertLen(strategy.experimental_local_results(v), num_replicas)
    self.assertLen(strategy.experimental_local_results(w), num_replicas)
    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:0",
                     strategy.experimental_local_results(v)[0].device)
    self.assertEqual("/job:localhost/replica:0/task:0/device:TPU:1",
                     strategy.experimental_local_results(w)[0].device)

    logical_devices = []

    @def_function.function
    def f(x):
      replica_ctx = distribute_lib.get_replica_context()
      with replica_ctx.experimental_logical_device(0):
        y = v * x
      with replica_ctx.experimental_logical_device(1):
        z = w * y
      logical_devices.append((y.device, z.device))
      return z

    result = strategy.run(f, args=(5.,))

    self.assertEqual(
        [("/device:TPU_REPLICATED_CORE:0", "/device:TPU_REPLICATED_CORE:1")],
        logical_devices)

    with self.cached_session():
      self.evaluate(variables.global_variables_initializer())
      self.assertEqual(30. * num_replicas,
                       self.evaluate(strategy.reduce("SUM", result, axis=None)))

  def test_paritioned_model_checkpointing(self):

    class PartitionedModel(module.Module):

      def __init__(self, v, w):
        super(PartitionedModel, self).__init__()

        assert distribute_lib.has_strategy()
        strategy = distribute_lib.get_strategy()

        with strategy.extended.experimental_logical_device(0):
          self.v = variables.Variable(v)
        with strategy.extended.experimental_logical_device(1):
          self.w = variables.Variable(w)

      def __call__(self, x):
        replica_ctx = distribute_lib.get_replica_context()
        with replica_ctx.experimental_logical_device(0):
          y = self.v * x
        with replica_ctx.experimental_logical_device(1):
          z = self.w * y
        return z

      def change_weights_op(self, v_new, w_new):
        return control_flow_ops.group(
            [self.v.assign(v_new), self.w.assign(w_new)])

    strategy, num_replicas = get_tpu_strategy()
    with strategy.scope():
      model = PartitionedModel(2., 3.)

    checkpoint_dir = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = util.Checkpoint(model=model)

    with self.cached_session() as sess:
      self.evaluate(variables.global_variables_initializer())
      checkpoint.save(file_prefix=checkpoint_prefix)

      self.evaluate(model.change_weights_op(1., 4.))
      result = strategy.run(def_function.function(model), args=(5.0,))
      self.assertEqual(20. * num_replicas,
                       self.evaluate(strategy.reduce("SUM", result, axis=None)))

      status = checkpoint.restore(
          checkpoint_management.latest_checkpoint(checkpoint_dir))
      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
      status.assert_consumed()
      status.assert_existing_objects_matched()
      result = strategy.run(def_function.function(model), args=(5.0,))
      self.assertEqual(30. * num_replicas,
                       self.evaluate(strategy.reduce("SUM", result, axis=None)))

  def test_spmd_cannot_assign_tensor_to_logical_device(self):
    strategy, _ = get_tpu_strategy(enable_spmd=True)
    x = constant_op.constant([0, 1])
    with self.assertRaises(ValueError):
      strategy.experimental_assign_to_logical_device(x, 0)

  def test_spmd_variable_created_from_callable(self):
    initilizer = lambda: random_ops.random_normal(shape=(16, 16))
    strategy, _ = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      w = variables.Variable(initilizer)
    value0 = w.values[0]
    for v in value0.variables:
      self.assertAllEqual(v, value0.variables[0])

  def test_spmd_variable_read(self):
    batch_size = 32
    num_feature_in = 16
    num_feature_out = 8

    x = random_ops.random_uniform((batch_size, num_feature_in),
                                  dtype=dtypes.float32)
    w_init = random_ops.random_uniform((num_feature_in, num_feature_out),
                                       dtype=dtypes.float32)

    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      w = variables.Variable(w_init, dtype=dtypes.float32)

    self.assertEqual(w.values[0].variables[0].shape.as_list(),
                     [num_feature_in, num_feature_out])

    self.assertEqual(w.shape.as_list(), [num_feature_in, num_feature_out])

    def step_fn(batch_features):
      predict = math_ops.matmul(batch_features, w)
      return predict

    @def_function.function
    def train_fn(batch_features):
      return strategy.run(step_fn, args=(batch_features,))

    result = train_fn(x)
    self.assertAllClose(
        strategy.reduce("SUM", result, axis=None),
        math_ops.matmul(x, w_init) * num_replicas,
        rtol=5e-03,
        atol=5e-03)

  def test_spmd_variable_read_init_scope(self):
    strategy, _ = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      v = variables.Variable(array_ops.ones((4, 4), dtype=dtypes.float32))

    @def_function.function
    def read_v():
      with ops.init_scope():
        return v.read_value()

    result = strategy.reduce("MEAN", strategy.run(read_v), axis=None)
    self.assertAllClose(result, v.read_value())

  def test_spmd_variable_update(self):
    batch_size = 1024
    num_feature_in = 256

    x = random_ops.random_uniform((batch_size, num_feature_in),
                                  dtype=dtypes.float32)
    w_init = random_ops.random_uniform((batch_size, num_feature_in),
                                       dtype=dtypes.float32)

    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      w = variables.Variable(w_init, dtype=dtypes.float32)

    self.assertIsInstance(w, tpu_values.TPUMirroredVariable)
    self.assertTrue(w._is_replicated_or_sharded_to_logical_cores())

    def make_strategy_run(fn):

      def run(value):
        return strategy.run(fn, args=(value,))

      return def_function.function(run)

    result = make_strategy_run(w.assign)(x)
    self.assertAllClose(
        strategy.reduce("SUM", result, axis=None), x * num_replicas)

    delta = random_ops.random_uniform((batch_size, num_feature_in),
                                      dtype=dtypes.float32)
    result = make_strategy_run(w.assign_sub)(delta)
    x -= delta
    self.assertAllClose(
        strategy.reduce("SUM", result, axis=None), x * num_replicas)

    delta = random_ops.random_uniform((batch_size, num_feature_in),
                                      dtype=dtypes.float32)
    result = make_strategy_run(w.assign_add)(delta)
    x += delta
    self.assertAllClose(
        strategy.reduce("SUM", result, axis=None), x * num_replicas)

  def test_spmd_variable_eager_update(self):
    batch_size = 32
    num_feature_in = 16

    x = random_ops.random_uniform((batch_size, num_feature_in),
                                  dtype=dtypes.float32)
    w_init = random_ops.random_uniform((batch_size, num_feature_in),
                                       dtype=dtypes.float32)

    strategy, _ = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      w = variables.Variable(w_init, dtype=dtypes.float32)

    w.assign(x)
    result = w.numpy()
    self.assertAllClose(result, x)

    x1 = random_ops.random_uniform((batch_size, num_feature_in),
                                   dtype=dtypes.float32)
    w.assign_sub(x1)
    result = w.numpy()
    self.assertAllClose(result, x - x1)

    x2 = random_ops.random_uniform((batch_size, num_feature_in),
                                   dtype=dtypes.float32)
    w.assign(x)
    w.assign_add(x2)
    result = w.numpy()
    self.assertAllClose(result, x + x2)

  def test_spmd_model_checkpointing(self):

    class LinearModel(module.Module):

      def __init__(self, w):
        super(LinearModel, self).__init__()
        self.w = variables.Variable(w)

      def __call__(self, x):
        return math_ops.matmul(x, self.w)

      def change_weights_op(self, w_new):
        return self.w.assign(w_new)

    batch_size = 32
    num_feature_in = 16
    num_feature_out = 8
    w1 = random_ops.random_uniform((num_feature_in, num_feature_out),
                                   dtype=dtypes.float32)
    w2 = random_ops.random_uniform((num_feature_in, num_feature_out),
                                   dtype=dtypes.float32)
    x = random_ops.random_uniform((batch_size, num_feature_in),
                                  dtype=dtypes.float32)

    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)
    with strategy.scope():
      model = LinearModel(w1)

    checkpoint_dir = self.get_temp_dir()
    checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
    checkpoint = util.Checkpoint(model=model)

    @def_function.function
    def step_fn(x):
      x = strategy.experimental_split_to_logical_devices(x, [1, 2])
      return model(x)

    with self.cached_session() as sess:
      self.evaluate(variables.global_variables_initializer())
      checkpoint.save(file_prefix=checkpoint_prefix)

      self.evaluate(model.change_weights_op(w2))
      result = strategy.run(step_fn, args=(x,))
      self.assertAllClose(
          math_ops.matmul(x, w2) * num_replicas,
          self.evaluate(strategy.reduce("SUM", result, axis=None)),
          rtol=5e-3,
          atol=5e-3)

      status = checkpoint.restore(
          checkpoint_management.latest_checkpoint(checkpoint_dir))
      status.run_restore_ops(sess)  # must run restore op in non-eager mode.
      status.assert_consumed()
      status.assert_existing_objects_matched()
      result = strategy.run(step_fn, args=(x,))
      self.assertAllClose(
          math_ops.matmul(x, w1) * num_replicas,
          self.evaluate(strategy.reduce("SUM", result, axis=None)),
          rtol=5e-3,
          atol=5e-3)

  def test_spmd_with_summary(self):
    original_device_placement = config.get_soft_device_placement()
    config.set_soft_device_placement(True)

    strategy, _ = get_tpu_strategy(enable_spmd=True)
    summary_dir = self.get_temp_dir()
    writer = summary_ops.create_file_writer_v2(summary_dir)
    const_multiple = 2
    num_iters = 10
    expected_event_count = num_iters + 1

    with strategy.scope():
      step = variables.Variable(1, dtype=dtypes.int64)

    @def_function.function
    def run():
      with writer.as_default():
        with summary_ops.record_if(True):
          summary_ops.scalar("result", step * const_multiple, step=step)
          step.assign_add(1)

    for _ in range(num_iters):
      strategy.run(run, args=())

    for val in step.values:
      for var in val.variables:
        self.assertAllEqual(expected_event_count, var)

    events = summary_test_util.events_from_logdir(summary_dir)
    self.assertLen(events, expected_event_count)

    # Event[0] is generic metadata and summary_ops data starts at event[1].
    for logged_step in range(1, expected_event_count):
      self.assertEqual(events[logged_step].summary.value[0].simple_value,
                       logged_step * const_multiple)

    config.set_soft_device_placement(original_device_placement)

  # Tests SPMD with outside compilation. One test case is for replicated
  # sharding of the input tensor and one case is for split sharding of the input
  # tensor.
  @parameterized.parameters([False, True])
  def test_spmd_with_outside_comp(self, split):
    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)

    def host_inc(x):
      return x + 1

    @def_function.function
    def fn(x):
      if split:
        x = strategy.experimental_split_to_logical_devices(x, [1, 2])
      y = x + 1
      z = tpu_replication.outside_compilation(host_inc, y)
      a = z + 1
      return a

    arg = constant_op.constant(0, shape=(2, 2), dtype=dtypes.int64)
    result = strategy.run(fn, args=(arg,))
    self.assertAllEqual(
        (arg + 3) * num_replicas,
        self.evaluate(strategy.reduce("SUM", result, axis=None)))

  # Tests auto_to_manual_spmd_partition and manual_to_auto_spmd_partition.
  # The internal versions of these ops are XlaSpmdFullToShardShape and
  # XlaSpmdShardToFullShape.
  def test_manual_sharding_ops(self):
    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)

    @def_function.function
    def fn(x):
      x_split = strategy.experimental_split_to_logical_devices(x, [1, 2])
      split_sharding = xla_sharding.get_op_sharding(x_split.op)
      x_manual = xla_sharding.auto_to_manual_spmd_partition(
          x_split, split_sharding
      )
      y_manual = x_manual + 1
      y_split = xla_sharding.manual_to_auto_spmd_partition(
          y_manual, split_sharding, (2, 2)
      )
      return y_split

    arg = constant_op.constant(0, shape=(2, 2), dtype=dtypes.int64)
    result = strategy.run(fn, args=(arg,))
    self.assertAllEqual(
        (arg + 1) * num_replicas,
        self.evaluate(strategy.reduce("SUM", result, axis=None)),
    )

  def test_spmd_with_map_outside_comp(self):
    strategy, num_replicas = get_tpu_strategy(enable_spmd=True)

    def host_inc(x):
      return x + 1

    @def_function.function
    def fn(a):
      b = strategy.experimental_split_to_logical_devices(a, [2, 1])
      c = tpu_replication.experimental_map_outside_compilation(host_inc, b)
      d = strategy.experimental_split_to_logical_devices(c, [2, 1])
      return d

    arg = constant_op.constant(
        [[0, 1], [2, 3]], shape=(2, 2), dtype=dtypes.int64
    )
    result = strategy.run(fn, args=(arg,))
    expected = (arg + 1) * num_replicas
    self.assertAllEqual(
        expected, self.evaluate(strategy.reduce("SUM", result, axis=None))
    )


if __name__ == "__main__":
  test.main()
