xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/tpu_function.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Helper library for functions used during TPU compilation."""
17
18import contextlib
19import threading
20
21
22class TpuContext(threading.local):
23  """A context object holding state about the TPU computation being built."""
24
25  def __init__(self):
26    """Creates a new TpuContext."""
27    self._number_of_shards = None
28
29  @property
30  def number_of_shards(self):
31    return self._number_of_shards
32
33  def set_number_of_shards(self, number_of_shards):
34    self._number_of_shards = number_of_shards
35
36
37# The Tpu context holds the number of shards when a sharded computation is
38# being built, or None if no computation is being built.
39_current_tpu_context = TpuContext()
40
41
42@contextlib.contextmanager
43def tpu_shard_context(number_of_shards):
44  """A context manager setting current number of shards."""
45  if _current_tpu_context.number_of_shards is not None:
46    raise NotImplementedError(
47        "tpu_shard_context cannot be nested."
48        "If you're using TPUEstimator with inference_on_tpu, "
49        "make sure you have set "
50        "export_saved_model_api_version=ExportSavedModelApiVersion.V2 in "
51        "the creation of TPUEstimator.")
52  try:
53    _current_tpu_context.set_number_of_shards(number_of_shards)
54    yield
55  finally:
56    _current_tpu_context.set_number_of_shards(None)
57
58
59def get_tpu_context():
60  return _current_tpu_context
61
62
63# Decorator function for tpu computation func that was passed to tpu.rewrite()
64# if there is an embedded training loop in this func, trace tools will generate
65# step markers for each iteration.
66def on_device_training_loop(func):
67  # Value for this attribute is from xla.DebugOptions.StepMarkerLocation.
68  setattr(func, "step_marker_location", "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP")
69  return func
70