blob: 768e46e79e0941d8c338c742fe2dcad9b0a3b48b [file] [log] [blame]
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
"""Hook for asynchronous checkpointing.
This hook dispatches checkpoint writing operations in a separate thread to
allow execution to continue on the main thread.
"""
import os
import threading
import time
from typing import Any, List, Optional, Text
from tensorflow.core.util import event_pb2
from tensorflow.python.client import session as session_lib
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model.pywrap_saved_model import metrics
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.summary_io import SummaryWriterCache
# Captures the timestamp of the first Saver object instantiation or end of a
# save operation. Can be accessed by multiple Saver instances.
_END_TIME_OF_LAST_WRITE = None
_END_TIME_OF_LAST_WRITE_LOCK = threading.Lock()
# API label for cell names used in TF1 async checkpoint metrics.
_ASYNC_CHECKPOINT_V1 = "async_checkpoint_v1"
def _get_duration_microseconds(start_time_seconds, end_time_seconds) -> int:
"""Returns the duration between start and end time in microseconds."""
return max(int((end_time_seconds - start_time_seconds) * 1000000), 0)
class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
checkpoint_dir: Text,
save_secs: Optional[int] = None,
save_steps: Optional[int] = None,
saver: Optional[saver_lib.Saver] = None,
checkpoint_basename: Text = "model.ckpt",
scaffold: Optional[monitored_session.Scaffold] = None,
listeners: Optional[List[
basic_session_run_hooks.CheckpointSaverListener]] = None):
"""Initializes a `CheckpointSaverHook`.
Args:
checkpoint_dir: `str`, base directory for the checkpoint files.
save_secs: `int`, save every N secs.
save_steps: `int`, save every N steps.
saver: `Saver` object, used for saving.
checkpoint_basename: `str`, base name for the checkpoint files.
scaffold: `Scaffold`, use to get saver object.
listeners: List of `CheckpointSaverListener` subclass instances. Used for
callbacks that run immediately before or after this hook saves the
checkpoint.
Raises:
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of `saver` or `scaffold` should be set.
"""
save_path = os.path.join(checkpoint_dir, checkpoint_basename)
logging.info("Create AsyncCheckpointSaverHook saving to path\n%s",
save_path)
if listeners:
logging.info(" with %d listener(s).", len(listeners))
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
self._save_thread = None
self._write_graph_thread = None
self._checkpoint_dir = checkpoint_dir
self._save_path = save_path
self._scaffold = scaffold
self._timer = basic_session_run_hooks.SecondOrStepTimer(
every_secs=save_secs, every_steps=save_steps)
self._listeners = listeners or []
self._steps_per_run = 1
self._summary_writer = None
self._global_step_tensor = None
self._last_checkpoint_step = None
# Initialize the first timestamp for _END_TIME_OF_LAST_WRITE.
global _END_TIME_OF_LAST_WRITE
with _END_TIME_OF_LAST_WRITE_LOCK:
if _END_TIME_OF_LAST_WRITE is None:
_END_TIME_OF_LAST_WRITE = time.time()
def _set_steps_per_run(self, steps_per_run):
self._steps_per_run = steps_per_run
def begin(self):
self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
self._global_step_tensor = training_util._get_or_create_global_step_read() # pylint: disable=protected-access
if self._global_step_tensor is None:
raise RuntimeError(
"Global step should be created to use CheckpointSaverHook.")
for l in self._listeners:
l.begin()
def after_create_session(self, session: session_lib.Session, coord: Any):
global_step = session.run(self._global_step_tensor)
# We do write graph and saver_def at the first call of before_run.
# We cannot do this in begin, since we let other hooks to change graph and
# add variables in begin. Graph is finalized after all begin calls.
def _write_graph_fn(self):
training_util.write_graph(
ops.get_default_graph().as_graph_def(add_shapes=True),
self._checkpoint_dir, "graph.pbtxt")
self._write_graph_thread = threading.Thread(target=_write_graph_fn,
args=[self])
self._write_graph_thread.start()
saver_def = self._get_saver().saver_def if self._get_saver() else None
graph = ops.get_default_graph()
meta_graph_def = meta_graph.create_meta_graph_def(
graph_def=graph.as_graph_def(add_shapes=True), saver_def=saver_def)
if self._summary_writer is None:
raise ValueError("Summary writer is not initialised")
self._summary_writer.add_graph(graph)
self._summary_writer.add_meta_graph(meta_graph_def)
# The checkpoint saved here is the state at step "global_step".
self._save(session, global_step)
self._timer.update_last_triggered_step(global_step)
def before_run(self, run_context: Any): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context: session_run_hook.SessionRunContext,
run_values: Any):
global_step = run_context.session.run(self._global_step_tensor)
if self._timer.should_trigger_for_step(global_step):
self._timer.update_last_triggered_step(global_step)
logging.info("Triggering checkpoint. %s", global_step)
if self._save(run_context.session, global_step):
run_context.request_stop()
def end(self, session: session_lib.Session):
if self._save_thread:
logging.info("Waiting for any pending checkpoints to finish.")
self._save_thread.join()
if self._write_graph_thread:
logging.info("Waiting for any pending write_graph to finish.")
self._write_graph_thread.join()
last_step = session.run(self._global_step_tensor)
if self._last_checkpoint_step != last_step:
self._save(session, last_step, asynchronous=False)
for l in self._listeners:
l.end(session, last_step)
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
start_time = time.time()
for l in self._listeners:
l.before_save(session, step)
self._get_saver().save(session, self._save_path, global_step=step)
if self._summary_writer is None:
raise ValueError("Summary writer is not initialised")
self._summary_writer.add_session_log(
event_pb2.SessionLog(
status=event_pb2.SessionLog.CHECKPOINT,
checkpoint_path=self._save_path), step)
for l in self._listeners:
l.after_save(session, step)
# Measure the async checkpoint write duration, i.e., non-blocking time.
end_time = time.time()
metrics.AddAsyncCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT_V1,
microseconds=_get_duration_microseconds(start_time, end_time))
# Measure the elapsed time since the last checkpoint.
# Due to the nature of async checkpoint, here it actually captures the
# duration between the start_time of the previous checkpoint and the start
# time of this checkpoint. As a result, the duration of the final async
# checkpoint is excluded, which is fine since it does not take much time.
global _END_TIME_OF_LAST_WRITE
with _END_TIME_OF_LAST_WRITE_LOCK:
metrics.AddTrainingTimeSaved(
api_label=_ASYNC_CHECKPOINT_V1,
microseconds=_get_duration_microseconds(_END_TIME_OF_LAST_WRITE,
start_time))
_END_TIME_OF_LAST_WRITE = start_time
logging.info("Checkpoint actual writing time: (%.3f sec)",
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
# Measure the checkpoint write duration that is blocking the main thread.
blocking_start_time = time.time()
def end_of_blocking_time():
blocking_end_time = time.time()
metrics.AddCheckpointWriteDuration(
api_label=_ASYNC_CHECKPOINT_V1,
microseconds=_get_duration_microseconds(blocking_start_time,
blocking_end_time))
if not asynchronous:
self._last_checkpoint_step = step
_save_fn()
end_of_blocking_time()
return
if self._save_thread is not None:
self._save_thread.join(timeout=0.1)
if self._save_thread.is_alive():
logging.info("Saver thread still in progress, skipping checkpoint.")
end_of_blocking_time()
return
self._last_checkpoint_step = step
self._save_thread = threading.Thread(target=_save_fn)
self._save_thread.start()
end_of_blocking_time()
def _get_saver(self):
if self._saver is not None:
return self._saver
elif self._scaffold is not None:
return self._scaffold.saver
# Get saver from the SAVERS collection if present.
collection_key = ops.GraphKeys.SAVERS
savers = ops.get_collection(collection_key)
if not savers:
raise RuntimeError(
"No items in collection {}. Please add a saver to the collection "
"or provide a saver or scaffold.".format(collection_key))
elif len(savers) > 1:
raise RuntimeError(
"More than one item in collection {}. "
"Please indicate which one to use by passing it to the constructor."
.format(collection_key))
self._saver = savers[0]
return savers[0]