1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15 16"""Library for constructing a training loop, suitable for TPUs.""" 17 18from typing import Any, Callable, Iterable, List, Optional, Union 19 20from tensorflow.python.compiler.xla import xla 21from tensorflow.python.framework import ops 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.tpu import tensor_tracer 25from tensorflow.python.tpu import tpu_feed 26from tensorflow.python.tpu import tpu_function 27from tensorflow.python.types import core as core_types 28 29 30def while_loop(condition: Callable[..., Any], 31 body: Callable[..., Any], 32 inputs: Optional[List[Any]] = None, 33 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 34 name: Any = None) -> Any: 35 """Builds a training loop for TPUs. 36 37 The set of loop-carried tensors corresponds to `inputs`. Both 38 `condition` and `body` take the current value of the loop-carried 39 tensors. 'body' additionally takes a tuple of infeed from 40 infeed_queue if infeed_queue is not None. `condition` must return a 41 single boolean value that determines whether iteration 42 continues. `body` must return an updated list of values for the 43 loop-carried tensors. 44 45 Args: 46 condition: a Python function that builds the loop condition. 47 body: a Python function that builds the loop body. 48 inputs: a list of initial values passed into the training loop, or None 49 (equivalent to an empty list). 50 infeed_queue: if not None, the infeed queue from which to append a tuple of 51 arguments as inputs to condition. 52 name: (Deprecated) Does nothing. 53 54 Returns: 55 The final values of the loop-carried tensors. 56 57 Raises: 58 TypeError: if body or condition has the wrong signature. 59 """ 60 del name 61 # Converts inputs to Tensors. 62 inputs = [] if inputs is None else [ops.convert_to_tensor(x) for 63 x in inputs] 64 input_types = [x.dtype for x in inputs] 65 input_arity = len(inputs) 66 67 body_arg_error = xla.check_function_argument_count( 68 body, input_arity, infeed_queue) 69 if body_arg_error is not None: 70 if infeed_queue is None: 71 raise TypeError( 72 f"Supplied loop body function cannot be called with the specified " 73 f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop body needs {body_arg_error}" 74 ) 75 else: 76 raise TypeError( 77 f"Supplied loop body function cannot be called with the specified " 78 f"inputs. You specified {input_arity} inputs: {[i.name for i in inputs]} and {infeed_queue.number_of_tuple_elements} additional inputs from " 79 f"infeed, but the computation needs {body_arg_error}") 80 condition_arg_error = xla.check_function_argument_count( 81 condition, input_arity, None) 82 if condition_arg_error is not None: 83 if infeed_queue is None: 84 raise TypeError( 85 f"Supplied loop condition function cannot be called with the " 86 f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop " 87 f"condition needs {condition_arg_error}") 88 else: 89 raise TypeError( 90 f"Supplied loop condition function cannot be called with the " 91 f"specified inputs. You specified {input_arity} inputs: {[i.name for i in inputs]}, but the loop " 92 f"condition needs {condition_arg_error}. Note that infeed is not passed to the loop condition." 93 ) 94 95 def condition_wrapper(*inputs): 96 # Discards the dummy output added for arity-0 loops. 97 if input_arity == 0: 98 inputs = [] 99 return condition(*inputs) 100 101 def body_wrapper(*inputs): 102 """Wrapper around `body` that handles infeed queues and control deps.""" 103 inputs = list(inputs) 104 105 # Discards the dummy output added for arity-0 loops. 106 if input_arity == 0: 107 inputs = [] 108 109 # Runs `body` with the dequeue_ops appended. 110 if infeed_queue: 111 number_of_shards = tpu_function.get_tpu_context().number_of_shards 112 if number_of_shards is None: 113 raise ValueError("Can't build training loop with infeed when there is " 114 "no tpu_shard_context. Are you building a loop or " 115 "graph directly rather than from inside tpu.rewrite, " 116 "tpu.batch_parallel, tpu.shard, or tpu.replicate?") 117 infeed_queue.set_number_of_shards(number_of_shards) 118 dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] 119 else: 120 dequeue_ops = [] 121 outputs = body(*(inputs + dequeue_ops)) 122 123 # If the computation only returned one value, make it a tuple. 124 if not isinstance(outputs, (list, tuple)): 125 outputs = (outputs,) 126 127 outputs = [ 128 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 129 for o in outputs 130 ] 131 132 # Separates the returned Operations and Tensors. 133 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 134 output_tensors = [o for o in outputs 135 if not isinstance(o, ops.Operation)] 136 137 if outputs != output_tensors + output_operations: 138 raise ValueError( 139 "TPU training loop body must return zero or more Tensor values " 140 "followed by zero or more Operations.") 141 142 output_types = [op.dtype for op in output_tensors] 143 if input_types != output_types: 144 raise TypeError( 145 "Mismatch between input types and output types for training loop " 146 "body: {} vs {}".format(input_types, output_types)) 147 148 # Add the dequeue operations to output_operations to ensure they are run 149 # by the loop, even if the programmer's loop body does not use them. 150 output_operations += dequeue_ops 151 152 # Add a dummy output, if needed. 153 if not output_tensors: 154 output_tensors = array_ops.constant(0) 155 156 if output_operations: 157 # TODO(phawkins): in principle this is too restrictive since it serializes 158 # the training loop steps. In practice it does not matter since this loop 159 # will be compiled by XLA. 160 output_tensors = control_flow_ops.tuple(output_tensors, 161 control_inputs=output_operations) 162 163 if tensor_tracer.TensorTracer.is_enabled(): 164 num_replicas = tpu_function.get_tpu_context().number_of_shards 165 if num_replicas is None: 166 num_replicas = 1 167 tt = tensor_tracer.TensorTracer() 168 output_tensors = tt.trace_tpu(ops.get_default_graph(), 169 output_tensors, None, 170 num_replicas) 171 return output_tensors 172 173 # If the body has arity 0, add a dummy loop-carried value to which we can add 174 # control dependencies from any side-effecting operations. 175 if input_arity == 0: 176 inputs = [array_ops.constant(0)] 177 return control_flow_ops.while_loop( 178 condition_wrapper, body_wrapper, inputs, name="", parallel_iterations=1) 179 180 181def repeat( 182 n: int, 183 body: Callable[..., Union[core_types.TensorLike, Iterable]], # pylint:disable=g-bare-generic 184 inputs: Optional[List[core_types.TensorLike]] = None, 185 infeed_queue: Optional[tpu_feed.InfeedQueue] = None, 186 name: Any = None) -> List[core_types.TensorLike]: 187 """Builds a training loop that executes a fixed number of iterations. 188 189 The set of loop-carried tensors correspond to `inputs`. 190 `body` must be a function that takes and returns the values of the 191 loop-carried tensors. 192 193 Args: 194 n: the number of loop iterations 195 body: a Python function that builds the loop body. 196 inputs: a list of initial values passed into the training loop or None 197 (equivalent to an empty list). 198 infeed_queue: if not None, the infeed queue from which to append a tuple of 199 arguments as inputs to condition. 200 name: (Deprecated) Does nothing. 201 202 Returns: 203 The final values of the loop-carried tensors. 204 Raises: 205 ValueError: if there is a type error. 206 """ 207 def _convert_to_list(xs): 208 if not isinstance(xs, (list, tuple)): 209 return [xs] 210 else: 211 return list(xs) 212 213 def cond(i, *args): 214 del args 215 return i < n 216 217 def body_wrapper(i, *args): 218 return [i + 1] + _convert_to_list(body(*args)) 219 220 inputs = [0] if inputs is None else [0] + _convert_to_list(inputs) 221 outputs = while_loop( 222 cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name) 223 outputs = _convert_to_list(outputs) 224 if len(outputs) == 1: 225 # Returns the Op rather than an empty list. 226 return outputs[0].op 227 else: 228 return outputs[1:] 229