1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions called by the generated code to execute an eager-mode op.""" 16 17from google.protobuf import text_format 18from tensorflow.core.framework import tensor_pb2 19from tensorflow.python import pywrap_tfe 20from tensorflow.python.eager import core 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import ops 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.util import compat 25 26 27def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None): 28 """Execute a TensorFlow operation. 29 30 Args: 31 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to 32 execute. 33 num_outputs: The number of outputs of the operation to fetch. (Explicitly 34 provided instead of being inferred for performance reasons). 35 inputs: A list of inputs to the operation. Each entry should be a Tensor, or 36 a value which can be passed to the Tensor constructor to create one. 37 attrs: A tuple with alternating string attr names and attr values for this 38 operation. 39 ctx: The value of context.context(). 40 name: Customized name for the operation. 41 42 Returns: 43 List of output Tensor objects. The list is empty if there are no outputs 44 45 Raises: 46 An exception on error. 47 """ 48 device_name = ctx.device_name 49 # pylint: disable=protected-access 50 try: 51 ctx.ensure_initialized() 52 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 53 inputs, attrs, num_outputs) 54 except core._NotOkStatusException as e: 55 if name is not None: 56 e.message += " name: " + name 57 raise core._status_to_exception(e) from None 58 except TypeError as e: 59 keras_symbolic_tensors = [ 60 x for x in inputs if ops._is_keras_symbolic_tensor(x) 61 ] 62 if keras_symbolic_tensors: 63 raise core._SymbolicException( 64 "Inputs to eager execution function cannot be Keras symbolic " 65 "tensors, but found {}".format(keras_symbolic_tensors)) 66 raise e 67 # pylint: enable=protected-access 68 return tensors 69 70 71def execute_with_cancellation(op_name, 72 num_outputs, 73 inputs, 74 attrs, 75 ctx, 76 cancellation_manager, 77 name=None): 78 """Execute a TensorFlow operation. 79 80 Args: 81 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to 82 execute. 83 num_outputs: The number of outputs of the operation to fetch. (Explicitly 84 provided instead of being inferred for performance reasons). 85 inputs: A list of inputs to the operation. Each entry should be a Tensor, or 86 a value which can be passed to the Tensor constructor to create one. 87 attrs: A tuple with alternating string attr names and attr values for this 88 operation. 89 ctx: The value of context.context(). 90 cancellation_manager: a `CancellationManager` object that can be used to 91 cancel the operation. 92 name: Customized name for the operation. 93 94 Returns: 95 List of output Tensor objects. The list is empty if there are no outputs 96 97 Raises: 98 An exception on error. 99 """ 100 device_name = ctx.device_name 101 # pylint: disable=protected-access 102 try: 103 ctx.ensure_initialized() 104 tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name, 105 op_name, inputs, attrs, 106 cancellation_manager._impl, 107 num_outputs) 108 except core._NotOkStatusException as e: 109 if name is not None: 110 e.message += " name: " + name 111 raise core._status_to_exception(e) from None 112 except TypeError as e: 113 keras_symbolic_tensors = [ 114 x for x in inputs if ops._is_keras_symbolic_tensor(x) 115 ] 116 if keras_symbolic_tensors: 117 raise core._SymbolicException( 118 "Inputs to eager execution function cannot be Keras symbolic " 119 "tensors, but found {}".format(keras_symbolic_tensors)) 120 raise e 121 # pylint: enable=protected-access 122 return tensors 123 124 125def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None): 126 """Monkey-patch to execute to enable execution callbacks.""" 127 tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 128 for callback in ctx.op_callbacks: 129 callback(op_name, tuple(inputs), attrs, tensors, name) 130 131 return tensors 132 133 134execute = quick_execute 135 136 137def must_record_gradient(): 138 """Import backprop if you want gradients recorded.""" 139 return False 140 141 142def record_gradient(unused_op_name, unused_inputs, unused_attrs, 143 unused_outputs): 144 """Import backprop if you want gradients recorded.""" 145 pass 146 147 148def make_float(v, arg_name): 149 if not isinstance(v, compat.real_types): 150 raise TypeError("Expected float for argument '%s' not %s." % 151 (arg_name, repr(v))) 152 return float(v) 153 154 155def make_int(v, arg_name): 156 if isinstance(v, str): 157 raise TypeError("Expected int for argument '%s' not %s." % 158 (arg_name, repr(v))) 159 try: 160 return int(v) 161 except (ValueError, TypeError): 162 raise TypeError("Expected int for argument '%s' not %s." % 163 (arg_name, repr(v))) 164 165 166def make_str(v, arg_name): 167 if not isinstance(v, compat.bytes_or_text_types): 168 raise TypeError("Expected string for argument '%s' not %s." % 169 (arg_name, repr(v))) 170 return compat.as_bytes(v) # Convert unicode strings to bytes. 171 172 173def make_bool(v, arg_name): 174 if not isinstance(v, bool): 175 raise TypeError("Expected bool for argument '%s' not %s." % 176 (arg_name, repr(v))) 177 return v 178 179 180def make_type(v, arg_name): 181 try: 182 v = dtypes.as_dtype(v).base_dtype 183 except TypeError: 184 raise TypeError("Expected DataType for argument '%s' not %s." % 185 (arg_name, repr(v))) 186 i = v.as_datatype_enum 187 return i 188 189 190def make_shape(v, arg_name): 191 """Convert v into a list.""" 192 # Args: 193 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. 194 # arg_name: String, for error messages. 195 196 # Returns: 197 # None if the rank is unknown, otherwise a list of ints (or Nones in the 198 # position where the dimension is unknown). 199 try: 200 shape = tensor_shape.as_shape(v) 201 except TypeError as e: 202 raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e)) 203 except ValueError as e: 204 raise ValueError("Error converting %s to a TensorShape: %s." % 205 (arg_name, e)) 206 if shape.ndims is None: 207 return None 208 else: 209 return shape.as_list() 210 211 212def make_tensor(v, arg_name): 213 """Ensure v is a TensorProto.""" 214 if isinstance(v, tensor_pb2.TensorProto): 215 return v 216 elif isinstance(v, str): 217 pb = tensor_pb2.TensorProto() 218 text_format.Merge(v, pb) 219 return pb 220 raise TypeError( 221 "Don't know how to convert %s to a TensorProto for argument '%s'." % 222 (repr(v), arg_name)) 223 224 225def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None): 226 """Convert sequence `l` to eager same-type Tensors.""" 227 if (not l) and (default_dtype is not None): 228 return default_dtype, [] # List is empty; assume default dtype. 229 EagerTensor = ops.EagerTensor # pylint: disable=invalid-name 230 for x in l: 231 if not isinstance(x, EagerTensor): 232 break 233 else: # note: intentional for-else 234 return l[0]._datatype_enum(), l # pylint: disable=protected-access 235 236 # Is some input already a Tensor with a dtype? 237 dtype = None 238 for t in l: 239 if isinstance(t, EagerTensor): 240 dtype = t.dtype 241 break 242 243 if dtype is None: 244 # Infer a dtype based on the first value, and use that dtype for the 245 # remaining values. 246 247 ret = [] 248 for t in l: 249 tensor = None 250 # First see if we can get a valid dtype with the default conversion 251 # and see if it matches an allowed dtypes. Some ops like ConcatV2 may 252 # not list allowed dtypes, in which case we should skip this. 253 if dtype is None and allowed_dtypes: 254 tensor = ops.convert_to_tensor(t, ctx=ctx) 255 # If we did not match an allowed dtype, try again with the default 256 # dtype. This could be because we have an empty tensor and thus we 257 # picked the wrong type. 258 if tensor.dtype not in allowed_dtypes: 259 tensor = None 260 261 if tensor is None: 262 tensor = ops.convert_to_tensor( 263 t, dtype, preferred_dtype=default_dtype, ctx=ctx) 264 265 ret.append(tensor) 266 if dtype is None: 267 dtype = tensor.dtype 268 else: 269 ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l] 270 271 # TODO(slebedev): consider removing this as it leaks a Keras concept. 272 # pylint: disable=protected-access 273 keras_symbolic_tensors = [x for x in ret if ops._is_keras_symbolic_tensor(x)] 274 if keras_symbolic_tensors: 275 raise core._SymbolicException( 276 "Using symbolic output of a Keras layer during eager execution " 277 "{}".format(keras_symbolic_tensors)) 278 # pylint: enable=protected-access 279 return dtype.as_datatype_enum, ret 280 281 282def convert_to_mixed_eager_tensors(values, ctx): 283 v = [ops.convert_to_tensor(t, ctx=ctx) for t in values] 284 types = [t._datatype_enum() for t in v] # pylint: disable=protected-access 285 return types, v 286 287 288def args_to_mixed_eager_tensors(lists, ctx): 289 """Converts a list of same-length lists of values to eager tensors.""" 290 assert len(lists) > 1 291 292 # Generate an error if len(lists[i]) is not the same for all i. 293 lists_ret = [[]] 294 for l in lists[1:]: 295 if len(l) != len(lists[0]): 296 raise ValueError( 297 "Expected list arguments to be the same length: %d != %d (%r vs. %r)." 298 % (len(lists[0]), len(l), lists[0], l)) 299 lists_ret.append([]) 300 301 # Convert the first element of each list first, then the second element, etc. 302 types = [] 303 for i in range(len(lists[0])): 304 dtype = None 305 # If any list has a Tensor, use that dtype 306 for l in lists: 307 if isinstance(l[i], ops.EagerTensor): 308 dtype = l[i].dtype 309 break 310 if dtype is None: 311 # Convert the first one and use its dtype. 312 lists_ret[0].append(ops.convert_to_tensor(lists[0][i], ctx=ctx)) 313 dtype = lists_ret[0][i].dtype 314 for j in range(1, len(lists)): 315 lists_ret[j].append( 316 ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) 317 else: 318 # Convert everything to the found dtype. 319 for j in range(len(lists)): 320 lists_ret[j].append( 321 ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) 322 types.append(dtype.as_datatype_enum) 323 return types, lists_ret 324