xref: /aosp_15_r20/external/tensorflow/tensorflow/python/tpu/training_loop.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15
16"""Library for constructing a training loop, suitable for TPUs."""
17
18from typing import Any, Callable, Iterable, List, Optional, Union
19
20from tensorflow.python.compiler.xla import xla
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.tpu import tensor_tracer
25from tensorflow.python.tpu import tpu_feed
26from tensorflow.python.tpu import tpu_function
27from tensorflow.python.types import core as core_types
28
29
30def while_loop(condition: Callable[..., Any],
31               body: Callable[..., Any],
32               inputs: Optional[List[Any]] = None,
33               infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
34               name: Any = None) -> Any:
35  """Builds a training loop for TPUs.
36
37  The set of loop-carried tensors corresponds to `inputs`.  Both
38  `condition` and `body` take the current value of the loop-carried
39  tensors. 'body' additionally takes a tuple of infeed from
40  infeed_queue if infeed_queue is not None. `condition` must return a
41  single boolean value that determines whether iteration
42  continues. `body` must return an updated list of values for the
43  loop-carried tensors.
44
45  Args:
46    condition: a Python function that builds the loop condition.
47    body: a Python function that builds the loop body.
48    inputs: a list of initial values passed into the training loop, or None
49      (equivalent to an empty list).
50    infeed_queue: if not None, the infeed queue from which to append a tuple of
51      arguments as inputs to condition.
52    name: (Deprecated) Does nothing.
53
54  Returns:
55    The final values of the loop-carried tensors.
56
57  Raises:
58    TypeError: if body or condition has the wrong signature.
59  """
60  del name
61  # Converts inputs to Tensors.
62  inputs = [] if inputs is None else [ops.convert_to_tensor(x) for
63                                      x in inputs]
64  input_types = [x.dtype for x in inputs]
65  input_arity = len(inputs)
66
67  body_arg_error = xla.check_function_argument_count(
68      body, input_arity, infeed_queue)
69  if body_arg_error is not None:
70    if infeed_queue is None:
71      raise TypeError(
72          f"Supplied loop body function cannot be called with the specified "
73          f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop body needs {body_arg_error}"
74      )
75    else:
76      raise TypeError(
77          f"Supplied loop body function cannot be called with the specified "
78          f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]} and {infeed_queue.number_of_tuple_elements} additional inputs from "
79          f"infeed, but the computation needs {body_arg_error}")
80  condition_arg_error = xla.check_function_argument_count(
81      condition, input_arity, None)
82  if condition_arg_error is not None:
83    if infeed_queue is None:
84      raise TypeError(
85          f"Supplied loop condition function cannot be called with the "
86          f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop "
87          f"condition needs {condition_arg_error}")
88    else:
89      raise TypeError(
90          f"Supplied loop condition function cannot be called with the "
91          f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop "
92          f"condition needs {condition_arg_error}. Note that infeed is not passed to the loop condition."
93      )
94
95  def condition_wrapper(*inputs):
96    # Discards the dummy output added for arity-0 loops.
97    if input_arity == 0:
98      inputs = []
99    return condition(*inputs)
100
101  def body_wrapper(*inputs):
102    """Wrapper around `body` that handles infeed queues and control deps."""
103    inputs = list(inputs)
104
105    # Discards the dummy output added for arity-0 loops.
106    if input_arity == 0:
107      inputs = []
108
109    # Runs `body` with the dequeue_ops appended.
110    if infeed_queue:
111      number_of_shards = tpu_function.get_tpu_context().number_of_shards
112      if number_of_shards is None:
113        raise ValueError("Can't build training loop with infeed when there is "
114                         "no tpu_shard_context. Are you building a loop or "
115                         "graph directly rather than from inside tpu.rewrite, "
116                         "tpu.batch_parallel, tpu.shard, or tpu.replicate?")
117      infeed_queue.set_number_of_shards(number_of_shards)
118      dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()]
119    else:
120      dequeue_ops = []
121    outputs = body(*(inputs + dequeue_ops))
122
123    # If the computation only returned one value, make it a tuple.
124    if not isinstance(outputs, (list, tuple)):
125      outputs = (outputs,)
126
127    outputs = [
128        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
129        for o in outputs
130    ]
131
132    # Separates the returned Operations and Tensors.
133    output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
134    output_tensors = [o for o in outputs
135                      if not isinstance(o, ops.Operation)]
136
137    if outputs != output_tensors + output_operations:
138      raise ValueError(
139          "TPU training loop body must return zero or more Tensor values "
140          "followed by zero or more Operations.")
141
142    output_types = [op.dtype for op in output_tensors]
143    if input_types != output_types:
144      raise TypeError(
145          "Mismatch between input types and output types for training loop "
146          "body: {} vs {}".format(input_types, output_types))
147
148    # Add the dequeue operations to output_operations to ensure they are run
149    # by the loop, even if the programmer's loop body does not use them.
150    output_operations += dequeue_ops
151
152    # Add a dummy output, if needed.
153    if not output_tensors:
154      output_tensors = array_ops.constant(0)
155
156    if output_operations:
157      # TODO(phawkins): in principle this is too restrictive since it serializes
158      # the training loop steps. In practice it does not matter since this loop
159      # will be compiled by XLA.
160      output_tensors = control_flow_ops.tuple(output_tensors,
161                                              control_inputs=output_operations)
162
163    if tensor_tracer.TensorTracer.is_enabled():
164      num_replicas = tpu_function.get_tpu_context().number_of_shards
165      if num_replicas is None:
166        num_replicas = 1
167      tt = tensor_tracer.TensorTracer()
168      output_tensors = tt.trace_tpu(ops.get_default_graph(),
169                                    output_tensors, None,
170                                    num_replicas)
171    return output_tensors
172
173  # If the body has arity 0, add a dummy loop-carried value to which we can add
174  # control dependencies from any side-effecting operations.
175  if input_arity == 0:
176    inputs = [array_ops.constant(0)]
177  return control_flow_ops.while_loop(
178      condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1)
179
180
181def repeat(
182    n: int,
183    body: Callable[..., Union[core_types.TensorLike, Iterable]],  # pylint:disable=g-bare-generic
184    inputs: Optional[List[core_types.TensorLike]] = None,
185    infeed_queue: Optional[tpu_feed.InfeedQueue] = None,
186    name: Any = None) -> List[core_types.TensorLike]:
187  """Builds a training loop that executes a fixed number of iterations.
188
189  The set of loop-carried tensors correspond to `inputs`.
190  `body` must be a function that takes and returns the values of the
191  loop-carried tensors.
192
193  Args:
194    n: the number of loop iterations
195    body: a Python function that builds the loop body.
196    inputs: a list of initial values passed into the training loop or None
197      (equivalent to an empty list).
198    infeed_queue: if not None, the infeed queue from which to append a tuple of
199      arguments as inputs to condition.
200    name: (Deprecated) Does nothing.
201
202  Returns:
203    The final values of the loop-carried tensors.
204  Raises:
205    ValueError: if there is a type error.
206  """
207  def _convert_to_list(xs):
208    if not isinstance(xs, (list, tuple)):
209      return [xs]
210    else:
211      return list(xs)
212
213  def cond(i, *args):
214    del args
215    return i < n
216
217  def body_wrapper(i, *args):
218    return [i + 1] + _convert_to_list(body(*args))
219
220  inputs = [0] if inputs is None else [0] + _convert_to_list(inputs)
221  outputs = while_loop(
222      cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name)
223  outputs = _convert_to_list(outputs)
224  if len(outputs) == 1:
225    # Returns the Op rather than an empty list.
226    return outputs[0].op
227  else:
228    return outputs[1:]
229