tree: b66aa76851a56d9eb725c48451f65138caf0e2da [path history] [tgz]
  1. cluster_resolver/
  2. coordinator/
  3. experimental/
  4. failure_handling/
  5. integration_test/
  6. parallel_device/
  7. v1/
  8. BUILD
  9. central_storage_strategy.py
  10. checkpoint_utils_test.py
  11. checkpointing_test.py
  12. collective_all_reduce_strategy.py
  13. collective_all_reduce_strategy_test.py
  14. collective_util.py
  15. collective_util_test.py
  16. combinations.py
  17. combinations_test.py
  18. cross_device_ops.py
  19. cross_device_ops_test.py
  20. cross_device_utils.py
  21. cross_device_utils_test.py
  22. custom_training_loop_gradient_test.py
  23. custom_training_loop_input_test.py
  24. device_util.py
  25. device_util_test.py
  26. distribute_config.py
  27. distribute_coordinator.py
  28. distribute_coordinator_context.py
  29. distribute_coordinator_test.py
  30. distribute_lib.py
  31. distribute_lib_test.py
  32. distribute_utils.py
  33. distribute_utils_test.py
  34. distributed_table_test.py
  35. distributed_variable_test.py
  36. estimator_training.py
  37. input_lib.py
  38. input_lib_test.py
  39. input_lib_type_spec_test.py
  40. input_ops.py
  41. input_ops_test.py
  42. input_util.py
  43. merge_call_interim.py
  44. metrics_v1_test.py
  45. mirrored_run.py
  46. mirrored_strategy.py
  47. mirrored_strategy_test.py
  48. mirrored_values_test.py
  49. mirrored_variable_test.py
  50. moving_averages_test.py
  51. multi_process_lib.py
  52. multi_process_runner.py
  53. multi_process_runner_no_init_test.py
  54. multi_process_runner_test.py
  55. multi_worker_continuous_run_test.py
  56. multi_worker_test_base.py
  57. multi_worker_test_base_test.py
  58. multi_worker_util.py
  59. multi_worker_util_test.py
  60. numpy_dataset.py
  61. numpy_dataset_test.py
  62. one_device_strategy.py
  63. one_device_strategy_test.py
  64. packed_distributed_variable.py
  65. packed_distributed_variable_test.py
  66. parameter_server_strategy.py
  67. parameter_server_strategy_test.py
  68. parameter_server_strategy_v2.py
  69. parameter_server_strategy_v2_test.py
  70. per_replica_test.py
  71. ps_values.py
  72. ps_values_test.py
  73. random_generator_test.py
  74. README.md
  75. reduce_util.py
  76. remote_mirrored_strategy_eager_test.py
  77. sharded_variable.py
  78. sharded_variable_test.py
  79. shared_variable_creator.py
  80. shared_variable_creator_test.py
  81. single_loss_example.py
  82. step_fn.py
  83. strategy_combinations.py
  84. strategy_combinations_test.py
  85. strategy_common_test.py
  86. strategy_gather_test.py
  87. strategy_test_lib.py
  88. summary_op_util.py
  89. template_mirrored_strategy_test.py
  90. test_util.py
  91. test_util_test.py
  92. tf_function_test.py
  93. tpu_replicated_variable.py
  94. tpu_replicated_variable_test.py
  95. tpu_strategy.py
  96. tpu_strategy_compilation_test.py
  97. tpu_strategy_model_parallelism_test.py
  98. tpu_strategy_test.py
  99. tpu_util.py
  100. tpu_values.py
  101. values.py
  102. values_test.py
  103. values_util.py
  104. values_v2.py
  105. values_v2_test.py
  106. vars_test.py
  107. warm_starting_util_test.py
  108. zero_batch_test.py
tensorflow/python/distribute/README.md

Tensorflow Distribute Libraries

Overview

tf.distribute.Strategy is a TensorFlow API to distribute training across multiple GPUs, multiple machines or TPUs. Using this API, users can distribute their existing models and training code with minimal code changes.

It can be used with TensorFlow's high level APIs, tf.keras and tf.estimator, with just a couple of lines of code change. It does so by changing the underlying components of TensorFlow to become strategy-aware. This includes variables, layers, models, optimizers, metrics, summaries, and checkpoints.

Documentation

Distributed Training Guide

Distributed Training With Keras Tutorial

Distributed Training With Custom Training Loops Tutorial

Multiworker Training With Keras Tutorial

Multiworker Training With Estimator Tutorial

Save and Load with Distribution Strategy

Simple Examples

Using compile fit with GPUs.

# Create the strategy instance. It will automatically detect all the GPUs.
mirrored_strategy = tf.distribute.MirroredStrategy()

# Create and compile the keras model under strategy.scope()
with mirrored_strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(1, input_shape=(1,))])
  model.compile(loss='mse', optimizer='sgd')

# Call model.fit and model.evaluate as before.
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10)
model.fit(dataset, epochs=2)
model.evaluate(dataset)

Custom training loop with TPUs.

# Create the strategy instance.
tpu_strategy = tf.distribute.TPUStrategy(resolver)


# Create the keras model under strategy.scope()
with tpu_strategy.scope():
  model = keras.layers.Dense(1, name="dense")

# Create custom training loop body as tf.function.
@tf.function
def train_step(iterator):
  def step_fn(inputs):
    images, targets = inputs
    with tf.GradientTape() as tape:
      outputs = model(images)
      loss = tf.reduce_sum(outputs - targets)
    grads = tape.gradient(loss, model.variables)
    return grads

  return tpu_strategy.run(
      step_fn, args=(next(iterator),))

# Run the loop body once on at dataset.
dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat(100).batch(10
input_iterator = iter(tpu_strategy.experimental_distribute_dataset(dataset))
train_step(input_iterator)

Testing

Tests here should cover all distribution strategies to ensure feature parity. This can be done using the test decorators in strategy_combinations.py.