1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-short-docstring-punctuation 16"""Asserts and Boolean Checks.""" 17 18import collections 19 20import numpy as np 21 22from tensorflow.python.eager import context 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import errors 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import sparse_tensor 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.framework import tensor_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.util import compat 33from tensorflow.python.util import deprecation 34from tensorflow.python.util import dispatch 35from tensorflow.python.util.tf_export import tf_export 36 37NUMERIC_TYPES = frozenset([ 38 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16, 39 dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32, 40 dtypes.uint64, dtypes.qint8, dtypes.qint16, dtypes.qint32, dtypes.quint8, 41 dtypes.quint16, dtypes.complex64, dtypes.complex128, dtypes.bfloat16 42]) 43 44__all__ = [ 45 'assert_negative', 46 'assert_positive', 47 'assert_proper_iterable', 48 'assert_non_negative', 49 'assert_non_positive', 50 'assert_equal', 51 'assert_none_equal', 52 'assert_near', 53 'assert_integer', 54 'assert_less', 55 'assert_less_equal', 56 'assert_greater', 57 'assert_greater_equal', 58 'assert_rank', 59 'assert_rank_at_least', 60 'assert_rank_in', 61 'assert_same_float_dtype', 62 'assert_scalar', 63 'assert_type', 64 'assert_shapes', 65 'is_non_decreasing', 66 'is_numeric_tensor', 67 'is_strictly_increasing', 68] 69 70 71def _maybe_constant_value_string(t): 72 if not isinstance(t, ops.Tensor): 73 return str(t) 74 const_t = tensor_util.constant_value(t) 75 if const_t is not None: 76 return str(const_t) 77 return t 78 79 80def _assert_static(condition, data): 81 """Raises a InvalidArgumentError with as much information as possible.""" 82 if not condition: 83 data_static = [_maybe_constant_value_string(x) for x in data] 84 raise errors.InvalidArgumentError(node_def=None, op=None, 85 message='\n'.join(data_static)) 86 87 88def _shape_and_dtype_str(tensor): 89 """Returns a string containing tensor's shape and dtype.""" 90 return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name) 91 92 93def _unary_assert_doc(sym, sym_name): 94 """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor. 95 96 Args: 97 sym: Mathematical symbol for the check performed on each element, i.e. "> 0" 98 sym_name: English-language name for the op described by sym 99 100 Returns: 101 Decorator that adds the appropriate docstring to the function for symbol 102 `sym`. 103 """ 104 105 def _decorator(func): 106 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 107 108 Args: 109 func: Function for a TensorFlow op 110 111 Returns: 112 Version of `func` with documentation attached. 113 """ 114 opname = func.__name__ 115 cap_sym_name = sym_name.capitalize() 116 117 func.__doc__ = """ 118 Assert the condition `x {sym}` holds element-wise. 119 120 When running in graph mode, you should add a dependency on this operation 121 to ensure that it runs. Example of adding a dependency to an operation: 122 123 ```python 124 with tf.control_dependencies([tf.debugging.{opname}(x, y)]): 125 output = tf.reduce_sum(x) 126 ``` 127 128 {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`. 129 If `x` is empty this is trivially satisfied. 130 131 Args: 132 x: Numeric `Tensor`. 133 data: The tensors to print out if the condition is False. Defaults to 134 error message and first few entries of `x`. 135 summarize: Print this many entries of each tensor. 136 message: A string to prefix to the default message. 137 name: A name for this operation (optional). Defaults to "{opname}". 138 139 Returns: 140 Op that raises `InvalidArgumentError` if `x {sym}` is False. 141 @compatibility(eager) 142 returns None 143 @end_compatibility 144 145 Raises: 146 InvalidArgumentError: if the check can be performed immediately and 147 `x {sym}` is False. The check can be performed immediately during 148 eager execution or if `x` is statically known. 149 """.format( 150 sym=sym, sym_name=cap_sym_name, opname=opname) 151 return func 152 153 return _decorator 154 155 156def _binary_assert_doc(sym, test_var): 157 """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise. 158 159 Args: 160 sym: Binary operation symbol, i.e. "==" 161 test_var: a string that represents the variable in the right-hand side of 162 binary operator of the test case 163 164 Returns: 165 Decorator that adds the appropriate docstring to the function for 166 symbol `sym`. 167 """ 168 169 def _decorator(func): 170 """Generated decorator that adds the appropriate docstring to the function for symbol `sym`. 171 172 Args: 173 func: Function for a TensorFlow op 174 175 Returns: 176 A version of `func` with documentation attached. 177 """ 178 opname = func.__name__ 179 180 func.__doc__ = """ 181 Assert the condition `x {sym} y` holds element-wise. 182 183 This condition holds if for every pair of (possibly broadcast) elements 184 `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`. 185 If both `x` and `y` are empty, this is trivially satisfied. 186 187 When running in graph mode, you should add a dependency on this operation 188 to ensure that it runs. Example of adding a dependency to an operation: 189 190 ```python 191 with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]): 192 output = tf.reduce_sum(x) 193 ``` 194 195 Args: 196 x: Numeric `Tensor`. 197 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 198 data: The tensors to print out if the condition is False. Defaults to 199 error message and first few entries of `x`, `y`. 200 summarize: Print this many entries of each tensor. 201 message: A string to prefix to the default message. 202 name: A name for this operation (optional). Defaults to "{opname}". 203 204 Returns: 205 Op that raises `InvalidArgumentError` if `x {sym} y` is False. 206 207 Raises: 208 InvalidArgumentError: if the check can be performed immediately and 209 `x {sym} y` is False. The check can be performed immediately during 210 eager execution or if `x` and `y` are statically known. 211 212 @compatibility(TF2) 213 `tf.compat.v1.{opname}` is compatible with eager execution and 214 `tf.function`. 215 Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart 216 from `data`, all arguments are supported with the same argument name. 217 218 If you want to ensure the assert statements run before the 219 potentially-invalid computation, please use `tf.control_dependencies`, 220 as tf.function auto-control dependencies are insufficient for assert 221 statements. 222 223 #### Structural Mapping to Native TF2 224 225 Before: 226 227 ```python 228 tf.compat.v1.{opname}( 229 x=x, y=y, data=data, summarize=summarize, 230 message=message, name=name) 231 ``` 232 233 After: 234 235 ```python 236 tf.debugging.{opname}( 237 x=x, y=y, message=message, 238 summarize=summarize, name=name) 239 ``` 240 241 #### TF1 & TF2 Usage Example 242 243 TF1: 244 245 >>> g = tf.Graph() 246 >>> with g.as_default(): 247 ... a = tf.compat.v1.placeholder(tf.float32, [2]) 248 ... b = tf.compat.v1.placeholder(tf.float32, [2]) 249 ... result = tf.compat.v1.{opname}(a, b, 250 ... message='"a {sym} b" does not hold for the given inputs') 251 ... with tf.compat.v1.control_dependencies([result]): 252 ... sum_node = a + b 253 >>> sess = tf.compat.v1.Session(graph=g) 254 >>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}}) 255 256 257 TF2: 258 259 >>> a = tf.Variable([1, 2], dtype=tf.float32) 260 >>> b = tf.Variable({test_var}, dtype=tf.float32) 261 >>> assert_op = tf.debugging.{opname}(a, b, message= 262 ... '"a {sym} b" does not hold for the given inputs') 263 >>> # When working with tf.control_dependencies 264 >>> with tf.control_dependencies([assert_op]): 265 ... val = a + b 266 267 @end_compatibility 268 """.format( 269 sym=sym, opname=opname, test_var=test_var) 270 return func 271 272 return _decorator 273 274 275def _make_assert_msg_data(sym, x, y, summarize, test_op): 276 """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode. 277 278 Args: 279 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 280 i.e. "==" 281 x: First input to the assertion after applying `convert_to_tensor()` 282 y: Second input to the assertion 283 summarize: Value of the "summarize" parameter to the original assert_* call; 284 tells how many elements of each tensor to print. 285 test_op: TensorFlow op that returns a Boolean tensor with True in each 286 position where the assertion is satisfied. 287 288 Returns: 289 List of tensors and scalars that, when stringified and concatenated, 290 will produce the error message string. 291 """ 292 # Prepare a message with first elements of x and y. 293 data = [] 294 295 data.append('Condition x %s y did not hold.' % sym) 296 297 if summarize > 0: 298 if x.shape == y.shape and x.shape.as_list(): 299 # If the shapes of x and y are the same (and not scalars), 300 # Get the values that actually differed and their indices. 301 # If shapes are different this information is more confusing 302 # than useful. 303 mask = math_ops.logical_not(test_op) 304 indices = array_ops.where(mask) 305 indices_np = indices.numpy() 306 x_vals = array_ops.boolean_mask(x, mask) 307 y_vals = array_ops.boolean_mask(y, mask) 308 num_vals = min(summarize, indices_np.shape[0]) 309 data.append('Indices of first %d different values:' % num_vals) 310 data.append(indices_np[:num_vals]) 311 data.append('Corresponding x values:') 312 data.append(x_vals.numpy().reshape((-1,))[:num_vals]) 313 data.append('Corresponding y values:') 314 data.append(y_vals.numpy().reshape((-1,))[:num_vals]) 315 316 # reshape((-1,)) is the fastest way to get a flat array view. 317 x_np = x.numpy().reshape((-1,)) 318 y_np = y.numpy().reshape((-1,)) 319 x_sum = min(x_np.size, summarize) 320 y_sum = min(y_np.size, summarize) 321 data.append('First %d elements of x:' % x_sum) 322 data.append(x_np[:x_sum]) 323 data.append('First %d elements of y:' % y_sum) 324 data.append(y_np[:y_sum]) 325 326 return data 327 328 329def _pretty_print(data_item, summarize): 330 """Format a data item for use in an error message in eager mode. 331 332 Args: 333 data_item: One of the items in the "data" argument to an assert_* function. 334 Can be a Tensor or a scalar value. 335 summarize: How many elements to retain of each tensor-valued entry in data. 336 337 Returns: 338 An appropriate string representation of data_item 339 """ 340 if isinstance(data_item, ops.Tensor): 341 arr = data_item.numpy() 342 if np.isscalar(arr): 343 # Tensor.numpy() returns a scalar for zero-dimensional tensors 344 return str(arr) 345 else: 346 flat = arr.reshape((-1,)) 347 lst = [str(x) for x in flat[:summarize]] 348 if len(lst) < flat.size: 349 lst.append('...') 350 return str(lst) 351 else: 352 return str(data_item) 353 354 355def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize, 356 message, name): 357 """Generic binary elementwise assertion. 358 359 Implements the behavior described in _binary_assert_doc() above. 360 Args: 361 sym: Mathematical symbol for the test to apply to pairs of tensor elements, 362 i.e. "==" 363 opname: Name of the assert op in the public API, i.e. "assert_equal" 364 op_func: Function that, if passed the two Tensor inputs to the assertion (x 365 and y), will return the test to be passed to reduce_all() i.e. 366 static_func: Function that, if passed numpy ndarray versions of the two 367 inputs to the assertion, will return a Boolean ndarray with containing 368 True in all positions where the assertion PASSES. 369 i.e. np.equal for assert_equal() 370 x: Numeric `Tensor`. 371 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 372 data: The tensors to print out if the condition is False. Defaults to 373 error message and first few entries of `x`, `y`. 374 summarize: Print this many entries of each tensor. 375 message: A string to prefix to the default message. 376 name: A name for this operation (optional). Defaults to the value of 377 `opname`. 378 379 Returns: 380 See docstring template in _binary_assert_doc(). 381 """ 382 with ops.name_scope(name, opname, [x, y, data]): 383 x = ops.convert_to_tensor(x, name='x') 384 y = ops.convert_to_tensor(y, name='y') 385 386 if context.executing_eagerly(): 387 test_op = op_func(x, y) 388 condition = math_ops.reduce_all(test_op) 389 if condition: 390 return 391 392 # If we get here, the assertion has failed. 393 # Default to printing 3 elements like control_flow_ops.Assert (used 394 # by graph mode) does. Also treat negative values as "print 395 # everything" for consistency with Tensor::SummarizeValue(). 396 if summarize is None: 397 summarize = 3 398 elif summarize < 0: 399 summarize = 1e9 # Code below will find exact size of x and y. 400 401 if data is None: 402 data = _make_assert_msg_data(sym, x, y, summarize, test_op) 403 404 if message is not None: 405 data = [message] + list(data) 406 407 raise errors.InvalidArgumentError( 408 node_def=None, 409 op=None, 410 message=('\n'.join(_pretty_print(d, summarize) for d in data))) 411 412 else: # not context.executing_eagerly() 413 if data is None: 414 data = [ 415 'Condition x %s y did not hold element-wise:' % sym, 416 'x (%s) = ' % x.name, x, 417 'y (%s) = ' % y.name, y 418 ] 419 if message is not None: 420 data = [message] + list(data) 421 condition = math_ops.reduce_all(op_func(x, y)) 422 x_static = tensor_util.constant_value(x) 423 y_static = tensor_util.constant_value(y) 424 if x_static is not None and y_static is not None: 425 condition_static = np.all(static_func(x_static, y_static)) 426 _assert_static(condition_static, data) 427 return control_flow_ops.Assert(condition, data, summarize=summarize) 428 429 430@tf_export( 431 'debugging.assert_proper_iterable', 432 v1=['debugging.assert_proper_iterable', 'assert_proper_iterable']) 433@dispatch.add_dispatch_support 434@deprecation.deprecated_endpoints('assert_proper_iterable') 435def assert_proper_iterable(values): 436 """Static assert that values is a "proper" iterable. 437 438 `Ops` that expect iterables of `Tensor` can call this to validate input. 439 Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves. 440 441 Args: 442 values: Object to be checked. 443 444 Raises: 445 TypeError: If `values` is not iterable or is one of 446 `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`. 447 """ 448 unintentional_iterables = ( 449 (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray) 450 + compat.bytes_or_text_types 451 ) 452 if isinstance(values, unintentional_iterables): 453 raise TypeError( 454 'Expected argument "values" to be a "proper" iterable. Found: %s' % 455 type(values)) 456 457 if not hasattr(values, '__iter__'): 458 raise TypeError( 459 'Expected argument "values" to be iterable. Found: %s' % type(values)) 460 461 462@tf_export('debugging.assert_negative', v1=[]) 463@dispatch.add_dispatch_support 464def assert_negative_v2(x, message=None, summarize=None, name=None): 465 """Assert the condition `x < 0` holds element-wise. 466 467 This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is 468 empty, this is trivially satisfied. 469 470 If `x` is not negative everywhere, `message`, as well as the first `summarize` 471 entries of `x` are printed, and `InvalidArgumentError` is raised. 472 473 Args: 474 x: Numeric `Tensor`. 475 message: A string to prefix to the default message. 476 summarize: Print this many entries of each tensor. 477 name: A name for this operation (optional). Defaults to "assert_negative". 478 479 Returns: 480 Op raising `InvalidArgumentError` unless `x` is all negative. This can be 481 used with `tf.control_dependencies` inside of `tf.function`s to block 482 followup computation until the check has executed. 483 @compatibility(eager) 484 returns None 485 @end_compatibility 486 487 Raises: 488 InvalidArgumentError: if the check can be performed immediately and 489 `x[i] < 0` is False. The check can be performed immediately during eager 490 execution or if `x` is statically known. 491 """ 492 return assert_negative(x=x, message=message, summarize=summarize, name=name) 493 494 495@tf_export(v1=['debugging.assert_negative', 'assert_negative']) 496@dispatch.add_dispatch_support 497@deprecation.deprecated_endpoints('assert_negative') 498@_unary_assert_doc('< 0', 'negative') 499def assert_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 500 message = _message_prefix(message) 501 with ops.name_scope(name, 'assert_negative', [x, data]): 502 x = ops.convert_to_tensor(x, name='x') 503 if data is None: 504 if context.executing_eagerly(): 505 name = _shape_and_dtype_str(x) 506 else: 507 name = x.name 508 data = [ 509 message, 510 'Condition x < 0 did not hold element-wise:', 511 'x (%s) = ' % name, x] 512 zero = ops.convert_to_tensor(0, dtype=x.dtype) 513 return assert_less(x, zero, data=data, summarize=summarize) 514 515 516@tf_export('debugging.assert_positive', v1=[]) 517@dispatch.add_dispatch_support 518def assert_positive_v2(x, message=None, summarize=None, name=None): 519 """Assert the condition `x > 0` holds element-wise. 520 521 This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is 522 empty, this is trivially satisfied. 523 524 If `x` is not positive everywhere, `message`, as well as the first `summarize` 525 entries of `x` are printed, and `InvalidArgumentError` is raised. 526 527 Args: 528 x: Numeric `Tensor`. 529 message: A string to prefix to the default message. 530 summarize: Print this many entries of each tensor. 531 name: A name for this operation (optional). Defaults to "assert_positive". 532 533 Returns: 534 Op raising `InvalidArgumentError` unless `x` is all positive. This can be 535 used with `tf.control_dependencies` inside of `tf.function`s to block 536 followup computation until the check has executed. 537 @compatibility(eager) 538 returns None 539 @end_compatibility 540 541 Raises: 542 InvalidArgumentError: if the check can be performed immediately and 543 `x[i] > 0` is False. The check can be performed immediately during eager 544 execution or if `x` is statically known. 545 """ 546 return assert_positive(x=x, summarize=summarize, message=message, name=name) 547 548 549@tf_export(v1=['debugging.assert_positive', 'assert_positive']) 550@dispatch.add_dispatch_support 551@deprecation.deprecated_endpoints('assert_positive') 552@_unary_assert_doc('> 0', 'positive') 553def assert_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 554 message = _message_prefix(message) 555 with ops.name_scope(name, 'assert_positive', [x, data]): 556 x = ops.convert_to_tensor(x, name='x') 557 if data is None: 558 if context.executing_eagerly(): 559 name = _shape_and_dtype_str(x) 560 else: 561 name = x.name 562 data = [ 563 message, 'Condition x > 0 did not hold element-wise:', 564 'x (%s) = ' % name, x] 565 zero = ops.convert_to_tensor(0, dtype=x.dtype) 566 return assert_less(zero, x, data=data, summarize=summarize) 567 568 569@tf_export('debugging.assert_non_negative', v1=[]) 570@dispatch.add_dispatch_support 571def assert_non_negative_v2(x, message=None, summarize=None, name=None): 572 """Assert the condition `x >= 0` holds element-wise. 573 574 This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is 575 empty, this is trivially satisfied. 576 577 If `x` is not >= 0 everywhere, `message`, as well as the first `summarize` 578 entries of `x` are printed, and `InvalidArgumentError` is raised. 579 580 Args: 581 x: Numeric `Tensor`. 582 message: A string to prefix to the default message. 583 summarize: Print this many entries of each tensor. 584 name: A name for this operation (optional). Defaults to 585 "assert_non_negative". 586 587 Returns: 588 Op raising `InvalidArgumentError` unless `x` is all non-negative. This can 589 be used with `tf.control_dependencies` inside of `tf.function`s to block 590 followup computation until the check has executed. 591 @compatibility(eager) 592 returns None 593 @end_compatibility 594 595 Raises: 596 InvalidArgumentError: if the check can be performed immediately and 597 `x[i] >= 0` is False. The check can be performed immediately during eager 598 execution or if `x` is statically known. 599 """ 600 return assert_non_negative(x=x, summarize=summarize, message=message, 601 name=name) 602 603 604@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative']) 605@dispatch.add_dispatch_support 606@deprecation.deprecated_endpoints('assert_non_negative') 607@_unary_assert_doc('>= 0', 'non-negative') 608def assert_non_negative(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 609 message = _message_prefix(message) 610 with ops.name_scope(name, 'assert_non_negative', [x, data]): 611 x = ops.convert_to_tensor(x, name='x') 612 if data is None: 613 if context.executing_eagerly(): 614 name = _shape_and_dtype_str(x) 615 else: 616 name = x.name 617 data = [ 618 message, 619 'Condition x >= 0 did not hold element-wise:', 620 'x (%s) = ' % name, x] 621 zero = ops.convert_to_tensor(0, dtype=x.dtype) 622 return assert_less_equal(zero, x, data=data, summarize=summarize) 623 624 625@tf_export('debugging.assert_non_positive', v1=[]) 626@dispatch.add_dispatch_support 627def assert_non_positive_v2(x, message=None, summarize=None, name=None): 628 """Assert the condition `x <= 0` holds element-wise. 629 630 This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is 631 empty, this is trivially satisfied. 632 633 If `x` is not <= 0 everywhere, `message`, as well as the first `summarize` 634 entries of `x` are printed, and `InvalidArgumentError` is raised. 635 636 Args: 637 x: Numeric `Tensor`. 638 message: A string to prefix to the default message. 639 summarize: Print this many entries of each tensor. 640 name: A name for this operation (optional). Defaults to 641 "assert_non_positive". 642 643 Returns: 644 Op raising `InvalidArgumentError` unless `x` is all non-positive. This can 645 be used with `tf.control_dependencies` inside of `tf.function`s to block 646 followup computation until the check has executed. 647 @compatibility(eager) 648 returns None 649 @end_compatibility 650 651 Raises: 652 InvalidArgumentError: if the check can be performed immediately and 653 `x[i] <= 0` is False. The check can be performed immediately during eager 654 execution or if `x` is statically known. 655 """ 656 return assert_non_positive(x=x, summarize=summarize, message=message, 657 name=name) 658 659 660@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive']) 661@dispatch.add_dispatch_support 662@deprecation.deprecated_endpoints('assert_non_positive') 663@_unary_assert_doc('<= 0', 'non-positive') 664def assert_non_positive(x, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 665 message = _message_prefix(message) 666 with ops.name_scope(name, 'assert_non_positive', [x, data]): 667 x = ops.convert_to_tensor(x, name='x') 668 if data is None: 669 if context.executing_eagerly(): 670 name = _shape_and_dtype_str(x) 671 else: 672 name = x.name 673 data = [ 674 message, 675 'Condition x <= 0 did not hold element-wise:' 676 'x (%s) = ' % name, x] 677 zero = ops.convert_to_tensor(0, dtype=x.dtype) 678 return assert_less_equal(x, zero, data=data, summarize=summarize) 679 680 681@tf_export('debugging.assert_equal', 'assert_equal', v1=[]) 682@dispatch.register_binary_elementwise_assert_api 683@dispatch.add_dispatch_support 684def assert_equal_v2(x, y, message=None, summarize=None, name=None): 685 """Assert the condition `x == y` holds element-wise. 686 687 This Op checks that `x[i] == y[i]` holds for every pair of (possibly 688 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 689 trivially satisfied. 690 691 If `x` and `y` are not equal, `message`, as well as the first `summarize` 692 entries of `x` and `y` are printed, and `InvalidArgumentError` is raised. 693 694 Args: 695 x: Numeric `Tensor`. 696 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 697 message: A string to prefix to the default message. 698 summarize: Print this many entries of each tensor. 699 name: A name for this operation (optional). Defaults to "assert_equal". 700 701 Returns: 702 Op that raises `InvalidArgumentError` if `x == y` is False. This can be 703 used with `tf.control_dependencies` inside of `tf.function`s to block 704 followup computation until the check has executed. 705 @compatibility(eager) 706 returns None 707 @end_compatibility 708 709 Raises: 710 InvalidArgumentError: if the check can be performed immediately and 711 `x == y` is False. The check can be performed immediately during eager 712 execution or if `x` and `y` are statically known. 713 """ 714 return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name) 715 716 717@tf_export(v1=['debugging.assert_equal', 'assert_equal']) 718@dispatch.register_binary_elementwise_assert_api 719@dispatch.add_dispatch_support 720@_binary_assert_doc('==', '[1, 2]') 721def assert_equal(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 722 with ops.name_scope(name, 'assert_equal', [x, y, data]): 723 # Short-circuit if x and y are the same tensor. 724 if x is y: 725 return None if context.executing_eagerly() else control_flow_ops.no_op() 726 return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y, 727 data, summarize, message, name) 728 729 730@tf_export('debugging.assert_none_equal', v1=[]) 731@dispatch.register_binary_elementwise_assert_api 732@dispatch.add_dispatch_support 733def assert_none_equal_v2(x, y, summarize=None, message=None, name=None): 734 """Assert the condition `x != y` holds for all elements. 735 736 This Op checks that `x[i] != y[i]` holds for every pair of (possibly 737 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 738 trivially satisfied. 739 740 If any elements of `x` and `y` are equal, `message`, as well as the first 741 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 742 is raised. 743 744 Args: 745 x: Numeric `Tensor`. 746 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 747 summarize: Print this many entries of each tensor. 748 message: A string to prefix to the default message. 749 name: A name for this operation (optional). Defaults to 750 "assert_none_equal". 751 752 Returns: 753 Op that raises `InvalidArgumentError` if `x != y` is ever False. This can 754 be used with `tf.control_dependencies` inside of `tf.function`s to block 755 followup computation until the check has executed. 756 @compatibility(eager) 757 returns None 758 @end_compatibility 759 760 Raises: 761 InvalidArgumentError: if the check can be performed immediately and 762 `x != y` is False for any pair of elements in `x` and `y`. The check can 763 be performed immediately during eager execution or if `x` and `y` are 764 statically known. 765 """ 766 return assert_none_equal(x=x, y=y, summarize=summarize, message=message, 767 name=name) 768 769 770@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal']) 771@dispatch.register_binary_elementwise_assert_api 772@dispatch.add_dispatch_support 773@deprecation.deprecated_endpoints('assert_none_equal') 774@_binary_assert_doc('!=', '[2, 1]') 775def assert_none_equal( 776 x, y, data=None, summarize=None, message=None, name=None): 777 return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal, 778 np.not_equal, x, y, data, summarize, message, name) 779 780 781@tf_export('debugging.assert_near', v1=[]) 782@dispatch.register_binary_elementwise_assert_api 783@dispatch.add_dispatch_support 784def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None, 785 name=None): 786 """Assert the condition `x` and `y` are close element-wise. 787 788 This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every 789 pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are 790 empty, this is trivially satisfied. 791 792 If any elements of `x` and `y` are not close, `message`, as well as the first 793 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` 794 is raised. 795 796 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 797 representable positive number such that `1 + eps != 1`. This is about 798 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 799 See `numpy.finfo`. 800 801 Args: 802 x: Float or complex `Tensor`. 803 y: Float or complex `Tensor`, same dtype as and broadcastable to `x`. 804 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 805 The relative tolerance. Default is `10 * eps`. 806 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 807 The absolute tolerance. Default is `10 * eps`. 808 message: A string to prefix to the default message. 809 summarize: Print this many entries of each tensor. 810 name: A name for this operation (optional). Defaults to "assert_near". 811 812 Returns: 813 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 814 This can be used with `tf.control_dependencies` inside of `tf.function`s 815 to block followup computation until the check has executed. 816 @compatibility(eager) 817 returns None 818 @end_compatibility 819 820 Raises: 821 InvalidArgumentError: if the check can be performed immediately and 822 `x != y` is False for any pair of elements in `x` and `y`. The check can 823 be performed immediately during eager execution or if `x` and `y` are 824 statically known. 825 826 @compatibility(numpy) 827 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 828 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 829 `64bit`, and even `16bit` data. 830 @end_compatibility 831 """ 832 return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize, 833 message=message, name=name) 834 835 836@tf_export(v1=['debugging.assert_near', 'assert_near']) 837@dispatch.register_binary_elementwise_assert_api 838@dispatch.add_dispatch_support 839@deprecation.deprecated_endpoints('assert_near') 840def assert_near( 841 x, y, rtol=None, atol=None, data=None, summarize=None, message=None, 842 name=None): 843 """Assert the condition `x` and `y` are close element-wise. 844 845 Example of adding a dependency to an operation: 846 847 ```python 848 with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]): 849 output = tf.reduce_sum(x) 850 ``` 851 852 This condition holds if for every pair of (possibly broadcast) elements 853 `x[i]`, `y[i]`, we have 854 855 ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```. 856 857 If both `x` and `y` are empty, this is trivially satisfied. 858 859 The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest 860 representable positive number such that `1 + eps != 1`. This is about 861 `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`. 862 See `numpy.finfo`. 863 864 Args: 865 x: Float or complex `Tensor`. 866 y: Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`. 867 rtol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 868 The relative tolerance. Default is `10 * eps`. 869 atol: `Tensor`. Same `dtype` as, and broadcastable to, `x`. 870 The absolute tolerance. Default is `10 * eps`. 871 data: The tensors to print out if the condition is False. Defaults to 872 error message and first few entries of `x`, `y`. 873 summarize: Print this many entries of each tensor. 874 message: A string to prefix to the default message. 875 name: A name for this operation (optional). Defaults to "assert_near". 876 877 Returns: 878 Op that raises `InvalidArgumentError` if `x` and `y` are not close enough. 879 880 @compatibility(numpy) 881 Similar to `numpy.testing.assert_allclose`, except tolerance depends on data 882 type. This is due to the fact that `TensorFlow` is often used with `32bit`, 883 `64bit`, and even `16bit` data. 884 @end_compatibility 885 """ 886 message = _message_prefix(message) 887 with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]): 888 x = ops.convert_to_tensor(x, name='x') 889 y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) 890 891 dtype = x.dtype 892 if dtype.is_complex: 893 dtype = dtype.real_dtype 894 eps = np.finfo(dtype.as_numpy_dtype).eps 895 rtol = 10 * eps if rtol is None else rtol 896 atol = 10 * eps if atol is None else atol 897 898 rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype) 899 atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype) 900 901 if context.executing_eagerly(): 902 x_name = _shape_and_dtype_str(x) 903 y_name = _shape_and_dtype_str(y) 904 else: 905 x_name = x.name 906 y_name = y.name 907 908 if data is None: 909 data = [ 910 message, 911 'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol), 912 'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y 913 ] 914 tol = atol + rtol * math_ops.abs(y) 915 diff = math_ops.abs(x - y) 916 condition = math_ops.reduce_all(math_ops.less(diff, tol)) 917 return control_flow_ops.Assert(condition, data, summarize=summarize) 918 919 920@tf_export('debugging.assert_less', 'assert_less', v1=[]) 921@dispatch.register_binary_elementwise_assert_api 922@dispatch.add_dispatch_support 923def assert_less_v2(x, y, message=None, summarize=None, name=None): 924 """Assert the condition `x < y` holds element-wise. 925 926 This Op checks that `x[i] < y[i]` holds for every pair of (possibly 927 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 928 trivially satisfied. 929 930 If `x` is not less than `y` element-wise, `message`, as well as the first 931 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 932 raised. 933 934 Args: 935 x: Numeric `Tensor`. 936 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 937 message: A string to prefix to the default message. 938 summarize: Print this many entries of each tensor. 939 name: A name for this operation (optional). Defaults to "assert_less". 940 941 Returns: 942 Op that raises `InvalidArgumentError` if `x < y` is False. 943 This can be used with `tf.control_dependencies` inside of `tf.function`s 944 to block followup computation until the check has executed. 945 @compatibility(eager) 946 returns None 947 @end_compatibility 948 949 Raises: 950 InvalidArgumentError: if the check can be performed immediately and 951 `x < y` is False. The check can be performed immediately during eager 952 execution or if `x` and `y` are statically known. 953 """ 954 return assert_less(x=x, y=y, summarize=summarize, message=message, name=name) 955 956 957@tf_export(v1=['debugging.assert_less', 'assert_less']) 958@dispatch.register_binary_elementwise_assert_api 959@dispatch.add_dispatch_support 960@_binary_assert_doc('<', '[2, 3]') 961def assert_less(x, y, data=None, summarize=None, message=None, name=None): 962 return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data, 963 summarize, message, name) 964 965 966@tf_export('debugging.assert_less_equal', v1=[]) 967@dispatch.register_binary_elementwise_assert_api 968@dispatch.add_dispatch_support 969def assert_less_equal_v2(x, y, message=None, summarize=None, name=None): 970 """Assert the condition `x <= y` holds element-wise. 971 972 This Op checks that `x[i] <= y[i]` holds for every pair of (possibly 973 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 974 trivially satisfied. 975 976 If `x` is not less or equal than `y` element-wise, `message`, as well as the 977 first `summarize` entries of `x` and `y` are printed, and 978 `InvalidArgumentError` is raised. 979 980 Args: 981 x: Numeric `Tensor`. 982 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 983 message: A string to prefix to the default message. 984 summarize: Print this many entries of each tensor. 985 name: A name for this operation (optional). Defaults to "assert_less_equal". 986 987 Returns: 988 Op that raises `InvalidArgumentError` if `x <= y` is False. This can be 989 used with `tf.control_dependencies` inside of `tf.function`s to block 990 followup computation until the check has executed. 991 @compatibility(eager) 992 returns None 993 @end_compatibility 994 995 Raises: 996 InvalidArgumentError: if the check can be performed immediately and 997 `x <= y` is False. The check can be performed immediately during eager 998 execution or if `x` and `y` are statically known. 999 """ 1000 return assert_less_equal(x=x, y=y, 1001 summarize=summarize, message=message, name=name) 1002 1003 1004@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal']) 1005@dispatch.register_binary_elementwise_assert_api 1006@dispatch.add_dispatch_support 1007@deprecation.deprecated_endpoints('assert_less_equal') 1008@_binary_assert_doc('<=', '[1, 3]') 1009def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None): 1010 return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal, 1011 np.less_equal, x, y, data, summarize, message, name) 1012 1013 1014@tf_export('debugging.assert_greater', 'assert_greater', v1=[]) 1015@dispatch.register_binary_elementwise_assert_api 1016@dispatch.add_dispatch_support 1017def assert_greater_v2(x, y, message=None, summarize=None, name=None): 1018 """Assert the condition `x > y` holds element-wise. 1019 1020 This Op checks that `x[i] > y[i]` holds for every pair of (possibly 1021 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 1022 trivially satisfied. 1023 1024 If `x` is not greater than `y` element-wise, `message`, as well as the first 1025 `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is 1026 raised. 1027 1028 Args: 1029 x: Numeric `Tensor`. 1030 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 1031 message: A string to prefix to the default message. 1032 summarize: Print this many entries of each tensor. 1033 name: A name for this operation (optional). Defaults to "assert_greater". 1034 1035 Returns: 1036 Op that raises `InvalidArgumentError` if `x > y` is False. This can be 1037 used with `tf.control_dependencies` inside of `tf.function`s to block 1038 followup computation until the check has executed. 1039 @compatibility(eager) 1040 returns None 1041 @end_compatibility 1042 1043 Raises: 1044 InvalidArgumentError: if the check can be performed immediately and 1045 `x > y` is False. The check can be performed immediately during eager 1046 execution or if `x` and `y` are statically known. 1047 """ 1048 return assert_greater(x=x, y=y, summarize=summarize, message=message, 1049 name=name) 1050 1051 1052@tf_export(v1=['debugging.assert_greater', 'assert_greater']) 1053@dispatch.register_binary_elementwise_assert_api 1054@dispatch.add_dispatch_support 1055@_binary_assert_doc('>', '[0, 1]') 1056def assert_greater(x, y, data=None, summarize=None, message=None, name=None): # pylint: disable=missing-docstring 1057 return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x, 1058 y, data, summarize, message, name) 1059 1060 1061@tf_export('debugging.assert_greater_equal', v1=[]) 1062@dispatch.register_binary_elementwise_assert_api 1063@dispatch.add_dispatch_support 1064def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None): 1065 """Assert the condition `x >= y` holds element-wise. 1066 1067 This Op checks that `x[i] >= y[i]` holds for every pair of (possibly 1068 broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is 1069 trivially satisfied. 1070 1071 If `x` is not greater or equal to `y` element-wise, `message`, as well as the 1072 first `summarize` entries of `x` and `y` are printed, and 1073 `InvalidArgumentError` is raised. 1074 1075 Args: 1076 x: Numeric `Tensor`. 1077 y: Numeric `Tensor`, same dtype as and broadcastable to `x`. 1078 message: A string to prefix to the default message. 1079 summarize: Print this many entries of each tensor. 1080 name: A name for this operation (optional). Defaults to 1081 "assert_greater_equal". 1082 1083 Returns: 1084 Op that raises `InvalidArgumentError` if `x >= y` is False. This can be 1085 used with `tf.control_dependencies` inside of `tf.function`s to block 1086 followup computation until the check has executed. 1087 @compatibility(eager) 1088 returns None 1089 @end_compatibility 1090 1091 Raises: 1092 InvalidArgumentError: if the check can be performed immediately and 1093 `x >= y` is False. The check can be performed immediately during eager 1094 execution or if `x` and `y` are statically known. 1095 """ 1096 return assert_greater_equal(x=x, y=y, summarize=summarize, message=message, 1097 name=name) 1098 1099 1100@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal']) 1101@dispatch.register_binary_elementwise_assert_api 1102@dispatch.add_dispatch_support 1103@deprecation.deprecated_endpoints('assert_greater_equal') 1104@_binary_assert_doc('>=', '[1, 0]') 1105def assert_greater_equal(x, y, data=None, summarize=None, message=None, 1106 name=None): 1107 return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal, 1108 np.greater_equal, x, y, data, summarize, message, name) 1109 1110 1111def _assert_rank_condition( 1112 x, rank, static_condition, dynamic_condition, data, summarize): 1113 """Assert `x` has a rank that satisfies a given condition. 1114 1115 Args: 1116 x: Numeric `Tensor`. 1117 rank: Scalar `Tensor`. 1118 static_condition: A python function that takes `[actual_rank, given_rank]` 1119 and returns `True` if the condition is satisfied, `False` otherwise. 1120 dynamic_condition: An `op` that takes [actual_rank, given_rank] and return 1121 `True` if the condition is satisfied, `False` otherwise. 1122 data: The tensors to print out if the condition is false. Defaults to 1123 error message and first few entries of `x`. 1124 summarize: Print this many entries of each tensor. 1125 1126 Returns: 1127 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1128 1129 Raises: 1130 ValueError: If static checks determine `x` fails static_condition. 1131 """ 1132 assert_type(rank, dtypes.int32) 1133 1134 # Attempt to statically defined rank. 1135 rank_static = tensor_util.constant_value(rank) 1136 if rank_static is not None: 1137 if rank_static.ndim != 0: 1138 raise ValueError('Rank must be a scalar.') 1139 1140 x_rank_static = x.get_shape().ndims 1141 if x_rank_static is not None: 1142 if not static_condition(x_rank_static, rank_static): 1143 raise ValueError( 1144 'Static rank condition failed', x_rank_static, rank_static) 1145 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1146 1147 condition = dynamic_condition(array_ops.rank(x), rank) 1148 1149 # Add the condition that `rank` must have rank zero. Prevents the bug where 1150 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1151 if rank_static is None: 1152 this_data = ['Rank must be a scalar. Received rank: ', rank] 1153 rank_check = assert_rank(rank, 0, data=this_data) 1154 condition = control_flow_ops.with_dependencies([rank_check], condition) 1155 1156 return control_flow_ops.Assert(condition, data, summarize=summarize) 1157 1158 1159@tf_export('debugging.assert_rank', 'assert_rank', v1=[]) 1160@dispatch.add_dispatch_support 1161def assert_rank_v2(x, rank, message=None, name=None): 1162 """Assert that `x` has rank equal to `rank`. 1163 1164 This Op checks that the rank of `x` is equal to `rank`. 1165 1166 If `x` has a different rank, `message`, as well as the shape of `x` are 1167 printed, and `InvalidArgumentError` is raised. 1168 1169 Args: 1170 x: `Tensor`. 1171 rank: Scalar integer `Tensor`. 1172 message: A string to prefix to the default message. 1173 name: A name for this operation (optional). Defaults to 1174 "assert_rank". 1175 1176 Returns: 1177 Op raising `InvalidArgumentError` unless `x` has specified rank. 1178 If static checks determine `x` has correct rank, a `no_op` is returned. 1179 This can be used with `tf.control_dependencies` inside of `tf.function`s 1180 to block followup computation until the check has executed. 1181 @compatibility(eager) 1182 returns None 1183 @end_compatibility 1184 1185 Raises: 1186 InvalidArgumentError: if the check can be performed immediately and 1187 `x` does not have rank `rank`. The check can be performed immediately 1188 during eager execution or if the shape of `x` is statically known. 1189 """ 1190 return assert_rank(x=x, rank=rank, message=message, name=name) 1191 1192 1193@tf_export(v1=['debugging.assert_rank', 'assert_rank']) 1194@dispatch.add_dispatch_support 1195def assert_rank(x, rank, data=None, summarize=None, message=None, name=None): 1196 """Assert `x` has rank equal to `rank`. 1197 1198 Example of adding a dependency to an operation: 1199 1200 ```python 1201 with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]): 1202 output = tf.reduce_sum(x) 1203 ``` 1204 1205 Args: 1206 x: Numeric `Tensor`. 1207 rank: Scalar integer `Tensor`. 1208 data: The tensors to print out if the condition is False. Defaults to 1209 error message and the shape of `x`. 1210 summarize: Print this many entries of each tensor. 1211 message: A string to prefix to the default message. 1212 name: A name for this operation (optional). Defaults to "assert_rank". 1213 1214 Returns: 1215 Op raising `InvalidArgumentError` unless `x` has specified rank. 1216 If static checks determine `x` has correct rank, a `no_op` is returned. 1217 1218 Raises: 1219 ValueError: If static checks determine `x` has wrong rank. 1220 """ 1221 with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])): 1222 if not isinstance(x, sparse_tensor.SparseTensor): 1223 x = ops.convert_to_tensor(x, name='x') 1224 rank = ops.convert_to_tensor(rank, name='rank') 1225 message = _message_prefix(message) 1226 1227 static_condition = lambda actual_rank, given_rank: actual_rank == given_rank 1228 dynamic_condition = math_ops.equal 1229 1230 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 1231 name = '' 1232 else: 1233 name = x.name 1234 1235 if data is None: 1236 data = [ 1237 message, 1238 'Tensor %s must have rank' % name, rank, 'Received shape: ', 1239 array_ops.shape(x) 1240 ] 1241 1242 try: 1243 assert_op = _assert_rank_condition(x, rank, static_condition, 1244 dynamic_condition, data, summarize) 1245 1246 except ValueError as e: 1247 if e.args[0] == 'Static rank condition failed': 1248 raise ValueError( 1249 '%sTensor %s must have rank %d. Received rank %d, shape %s' % 1250 (message, name, e.args[2], e.args[1], x.get_shape())) 1251 else: 1252 raise ValueError(e.args[0]) 1253 1254 return assert_op 1255 1256 1257@tf_export('debugging.assert_rank_at_least', v1=[]) 1258@dispatch.add_dispatch_support 1259def assert_rank_at_least_v2(x, rank, message=None, name=None): 1260 """Assert that `x` has rank of at least `rank`. 1261 1262 This Op checks that the rank of `x` is greater or equal to `rank`. 1263 1264 If `x` has a rank lower than `rank`, `message`, as well as the shape of `x` 1265 are printed, and `InvalidArgumentError` is raised. 1266 1267 Args: 1268 x: `Tensor`. 1269 rank: Scalar integer `Tensor`. 1270 message: A string to prefix to the default message. 1271 name: A name for this operation (optional). Defaults to 1272 "assert_rank_at_least". 1273 1274 Returns: 1275 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 1276 If static checks determine `x` has correct rank, a `no_op` is returned. 1277 This can be used with `tf.control_dependencies` inside of `tf.function`s 1278 to block followup computation until the check has executed. 1279 @compatibility(eager) 1280 returns None 1281 @end_compatibility 1282 1283 Raises: 1284 InvalidArgumentError: `x` does not have rank at least `rank`, but the rank 1285 cannot be statically determined. 1286 ValueError: If static checks determine `x` has mismatched rank. 1287 """ 1288 return assert_rank_at_least(x=x, rank=rank, message=message, name=name) 1289 1290 1291@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least']) 1292@dispatch.add_dispatch_support 1293@deprecation.deprecated_endpoints('assert_rank_at_least') 1294def assert_rank_at_least( 1295 x, rank, data=None, summarize=None, message=None, name=None): 1296 """Assert `x` has rank equal to `rank` or higher. 1297 1298 Example of adding a dependency to an operation: 1299 1300 ```python 1301 with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]): 1302 output = tf.reduce_sum(x) 1303 ``` 1304 1305 Args: 1306 x: Numeric `Tensor`. 1307 rank: Scalar `Tensor`. 1308 data: The tensors to print out if the condition is False. Defaults to 1309 error message and first few entries of `x`. 1310 summarize: Print this many entries of each tensor. 1311 message: A string to prefix to the default message. 1312 name: A name for this operation (optional). 1313 Defaults to "assert_rank_at_least". 1314 1315 Returns: 1316 Op raising `InvalidArgumentError` unless `x` has specified rank or higher. 1317 If static checks determine `x` has correct rank, a `no_op` is returned. 1318 1319 Raises: 1320 ValueError: If static checks determine `x` has wrong rank. 1321 """ 1322 with ops.name_scope( 1323 name, 'assert_rank_at_least', (x, rank) + tuple(data or [])): 1324 x = ops.convert_to_tensor(x, name='x') 1325 rank = ops.convert_to_tensor(rank, name='rank') 1326 message = _message_prefix(message) 1327 1328 static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank 1329 dynamic_condition = math_ops.greater_equal 1330 1331 if context.executing_eagerly(): 1332 name = '' 1333 else: 1334 name = x.name 1335 1336 if data is None: 1337 data = [ 1338 message, 1339 'Tensor %s must have rank at least' % name, rank, 1340 'Received shape: ', array_ops.shape(x) 1341 ] 1342 1343 try: 1344 assert_op = _assert_rank_condition(x, rank, static_condition, 1345 dynamic_condition, data, summarize) 1346 1347 except ValueError as e: 1348 if e.args[0] == 'Static rank condition failed': 1349 raise ValueError( 1350 '%sTensor %s must have rank at least %d. Received rank %d, ' 1351 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1352 else: 1353 raise 1354 1355 return assert_op 1356 1357 1358def _static_rank_in(actual_rank, given_ranks): 1359 return actual_rank in given_ranks 1360 1361 1362def _dynamic_rank_in(actual_rank, given_ranks): 1363 if len(given_ranks) < 1: 1364 return ops.convert_to_tensor(False) 1365 result = math_ops.equal(given_ranks[0], actual_rank) 1366 for given_rank in given_ranks[1:]: 1367 result = math_ops.logical_or( 1368 result, math_ops.equal(given_rank, actual_rank)) 1369 return result 1370 1371 1372def _assert_ranks_condition( 1373 x, ranks, static_condition, dynamic_condition, data, summarize): 1374 """Assert `x` has a rank that satisfies a given condition. 1375 1376 Args: 1377 x: Numeric `Tensor`. 1378 ranks: Scalar `Tensor`. 1379 static_condition: A python function that takes 1380 `[actual_rank, given_ranks]` and returns `True` if the condition is 1381 satisfied, `False` otherwise. 1382 dynamic_condition: An `op` that takes [actual_rank, given_ranks] 1383 and return `True` if the condition is satisfied, `False` otherwise. 1384 data: The tensors to print out if the condition is false. Defaults to 1385 error message and first few entries of `x`. 1386 summarize: Print this many entries of each tensor. 1387 1388 Returns: 1389 Op raising `InvalidArgumentError` if `x` fails dynamic_condition. 1390 1391 Raises: 1392 ValueError: If static checks determine `x` fails static_condition. 1393 """ 1394 for rank in ranks: 1395 assert_type(rank, dtypes.int32) 1396 1397 # Attempt to statically defined rank. 1398 ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks]) 1399 if not any(r is None for r in ranks_static): 1400 for rank_static in ranks_static: 1401 if rank_static.ndim != 0: 1402 raise ValueError('Rank must be a scalar.') 1403 1404 x_rank_static = x.get_shape().ndims 1405 if x_rank_static is not None: 1406 if not static_condition(x_rank_static, ranks_static): 1407 raise ValueError( 1408 'Static rank condition failed', x_rank_static, ranks_static) 1409 return control_flow_ops.no_op(name='static_checks_determined_all_ok') 1410 1411 condition = dynamic_condition(array_ops.rank(x), ranks) 1412 1413 # Add the condition that `rank` must have rank zero. Prevents the bug where 1414 # someone does assert_rank(x, [n]), rather than assert_rank(x, n). 1415 for rank, rank_static in zip(ranks, ranks_static): 1416 if rank_static is None: 1417 this_data = ['Rank must be a scalar. Received rank: ', rank] 1418 rank_check = assert_rank(rank, 0, data=this_data) 1419 condition = control_flow_ops.with_dependencies([rank_check], condition) 1420 1421 return control_flow_ops.Assert(condition, data, summarize=summarize) 1422 1423 1424@tf_export('debugging.assert_rank_in', v1=[]) 1425@dispatch.add_dispatch_support 1426def assert_rank_in_v2(x, ranks, message=None, name=None): 1427 """Assert that `x` has a rank in `ranks`. 1428 1429 This Op checks that the rank of `x` is in `ranks`. 1430 1431 If `x` has a different rank, `message`, as well as the shape of `x` are 1432 printed, and `InvalidArgumentError` is raised. 1433 1434 Args: 1435 x: `Tensor`. 1436 ranks: `Iterable` of scalar `Tensor` objects. 1437 message: A string to prefix to the default message. 1438 name: A name for this operation (optional). Defaults to "assert_rank_in". 1439 1440 Returns: 1441 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 1442 If static checks determine `x` has matching rank, a `no_op` is returned. 1443 This can be used with `tf.control_dependencies` inside of `tf.function`s 1444 to block followup computation until the check has executed. 1445 @compatibility(eager) 1446 returns None 1447 @end_compatibility 1448 1449 Raises: 1450 InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot 1451 be statically determined. 1452 ValueError: If static checks determine `x` has mismatched rank. 1453 """ 1454 return assert_rank_in(x=x, ranks=ranks, message=message, name=name) 1455 1456 1457@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in']) 1458@dispatch.add_dispatch_support 1459@deprecation.deprecated_endpoints('assert_rank_in') 1460def assert_rank_in( 1461 x, ranks, data=None, summarize=None, message=None, name=None): 1462 """Assert `x` has rank in `ranks`. 1463 1464 Example of adding a dependency to an operation: 1465 1466 ```python 1467 with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]): 1468 output = tf.reduce_sum(x) 1469 ``` 1470 1471 Args: 1472 x: Numeric `Tensor`. 1473 ranks: Iterable of scalar `Tensor` objects. 1474 data: The tensors to print out if the condition is False. Defaults to 1475 error message and first few entries of `x`. 1476 summarize: Print this many entries of each tensor. 1477 message: A string to prefix to the default message. 1478 name: A name for this operation (optional). 1479 Defaults to "assert_rank_in". 1480 1481 Returns: 1482 Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`. 1483 If static checks determine `x` has matching rank, a `no_op` is returned. 1484 1485 Raises: 1486 ValueError: If static checks determine `x` has mismatched rank. 1487 """ 1488 with ops.name_scope( 1489 name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])): 1490 if not isinstance(x, sparse_tensor.SparseTensor): 1491 x = ops.convert_to_tensor(x, name='x') 1492 ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks]) 1493 message = _message_prefix(message) 1494 1495 if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor): 1496 name = '' 1497 else: 1498 name = x.name 1499 1500 if data is None: 1501 data = [ 1502 message, 'Tensor %s must have rank in' % name 1503 ] + list(ranks) + [ 1504 'Received shape: ', array_ops.shape(x) 1505 ] 1506 1507 try: 1508 assert_op = _assert_ranks_condition(x, ranks, _static_rank_in, 1509 _dynamic_rank_in, data, summarize) 1510 1511 except ValueError as e: 1512 if e.args[0] == 'Static rank condition failed': 1513 raise ValueError( 1514 '%sTensor %s must have rank in %s. Received rank %d, ' 1515 'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape())) 1516 else: 1517 raise 1518 1519 return assert_op 1520 1521 1522@tf_export('debugging.assert_integer', v1=[]) 1523@dispatch.add_dispatch_support 1524def assert_integer_v2(x, message=None, name=None): 1525 """Assert that `x` is of integer dtype. 1526 1527 If `x` has a non-integer type, `message`, as well as the dtype of `x` are 1528 printed, and `InvalidArgumentError` is raised. 1529 1530 This can always be checked statically, so this method returns nothing. 1531 1532 Args: 1533 x: A `Tensor`. 1534 message: A string to prefix to the default message. 1535 name: A name for this operation (optional). Defaults to "assert_integer". 1536 1537 Raises: 1538 TypeError: If `x.dtype` is not a non-quantized integer type. 1539 """ 1540 assert_integer(x=x, message=message, name=name) 1541 1542 1543@tf_export(v1=['debugging.assert_integer', 'assert_integer']) 1544@dispatch.add_dispatch_support 1545@deprecation.deprecated_endpoints('assert_integer') 1546def assert_integer(x, message=None, name=None): 1547 """Assert that `x` is of integer dtype. 1548 1549 Example of adding a dependency to an operation: 1550 1551 ```python 1552 with tf.control_dependencies([tf.compat.v1.assert_integer(x)]): 1553 output = tf.reduce_sum(x) 1554 ``` 1555 1556 Args: 1557 x: `Tensor` whose basetype is integer and is not quantized. 1558 message: A string to prefix to the default message. 1559 name: A name for this operation (optional). Defaults to "assert_integer". 1560 1561 Raises: 1562 TypeError: If `x.dtype` is anything other than non-quantized integer. 1563 1564 Returns: 1565 A `no_op` that does nothing. Type can be determined statically. 1566 """ 1567 with ops.name_scope(name, 'assert_integer', [x]): 1568 x = ops.convert_to_tensor(x, name='x') 1569 if not x.dtype.is_integer: 1570 if context.executing_eagerly(): 1571 name = 'tensor' 1572 else: 1573 name = x.name 1574 err_msg = ( 1575 '%sExpected "x" to be integer type. Found: %s of dtype %s' 1576 % (_message_prefix(message), name, x.dtype)) 1577 raise TypeError(err_msg) 1578 1579 return control_flow_ops.no_op('statically_determined_was_integer') 1580 1581 1582@tf_export('debugging.assert_type', v1=[]) 1583@dispatch.add_dispatch_support 1584def assert_type_v2(tensor, tf_type, message=None, name=None): 1585 """Asserts that the given `Tensor` is of the specified type. 1586 1587 This can always be checked statically, so this method returns nothing. 1588 1589 Example: 1590 1591 >>> a = tf.Variable(1.0) 1592 >>> tf.debugging.assert_type(a, tf_type= tf.float32) 1593 1594 >>> b = tf.constant(21) 1595 >>> tf.debugging.assert_type(b, tf_type=tf.bool) 1596 Traceback (most recent call last): 1597 ... 1598 TypeError: ... 1599 1600 >>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2], 1601 ... dense_shape=[3, 4]) 1602 >>> tf.debugging.assert_type(c, tf_type= tf.int32) 1603 1604 Args: 1605 tensor: A `Tensor`, `SparseTensor` or `tf.Variable` . 1606 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1607 etc). 1608 message: A string to prefix to the default message. 1609 name: A name for this operation. Defaults to "assert_type" 1610 1611 Raises: 1612 TypeError: If the tensor's data type doesn't match `tf_type`. 1613 """ 1614 assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name) 1615 1616 1617@tf_export(v1=['debugging.assert_type', 'assert_type']) 1618@dispatch.add_dispatch_support 1619@deprecation.deprecated_endpoints('assert_type') 1620def assert_type(tensor, tf_type, message=None, name=None): 1621 """Statically asserts that the given `Tensor` is of the specified type. 1622 1623 Args: 1624 tensor: A `Tensor` or `SparseTensor`. 1625 tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, 1626 etc). 1627 message: A string to prefix to the default message. 1628 name: A name to give this `Op`. Defaults to "assert_type" 1629 1630 Raises: 1631 TypeError: If the tensors data type doesn't match `tf_type`. 1632 1633 Returns: 1634 A `no_op` that does nothing. Type can be determined statically. 1635 """ 1636 tf_type = dtypes.as_dtype(tf_type) 1637 with ops.name_scope(name, 'assert_type', [tensor]): 1638 if not isinstance(tensor, sparse_tensor.SparseTensor): 1639 tensor = ops.convert_to_tensor(tensor, name='tensor') 1640 if tensor.dtype != tf_type: 1641 raise TypeError( 1642 f'{_message_prefix(message)}{getattr(tensor, "name", "tensor")}' 1643 f' must be of type {tf_type!r}; got {tensor.dtype!r}') 1644 1645 return control_flow_ops.no_op('statically_determined_correct_type') 1646 1647 1648def _dimension_sizes(x): 1649 """Gets the dimension sizes of a tensor `x`. 1650 1651 If a size can be determined statically it is returned as an integer, 1652 otherwise as a tensor. 1653 1654 If `x` is a scalar it is treated as rank 1 size 1. 1655 1656 Args: 1657 x: A `Tensor`. 1658 1659 Returns: 1660 Dimension sizes. 1661 """ 1662 dynamic_shape = array_ops.shape(x) 1663 rank = x.get_shape().rank 1664 rank_is_known = rank is not None 1665 if rank_is_known and rank == 0: 1666 return (1,) 1667 if rank_is_known and rank > 0: 1668 static_shape = x.get_shape().as_list() 1669 sizes = [ 1670 int(size) if size is not None else dynamic_shape[i] 1671 for i, size in enumerate(static_shape) 1672 ] 1673 return sizes 1674 has_rank_zero = math_ops.equal(array_ops.rank(x), 0) 1675 return control_flow_ops.cond( 1676 has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape) 1677 1678 1679def _symbolic_dimension_sizes(symbolic_shape): 1680 # If len(symbolic_shape) == 0 construct a tuple 1681 if not symbolic_shape: 1682 return tuple([1]) 1683 1684 return symbolic_shape 1685 1686 1687def _has_known_value(dimension_size): 1688 not_none = dimension_size is not None 1689 try: 1690 int(dimension_size) 1691 can_be_parsed_as_int = True 1692 except (ValueError, TypeError): 1693 can_be_parsed_as_int = False 1694 return not_none and can_be_parsed_as_int 1695 1696 1697def _is_symbol_for_any_size(symbol): 1698 return symbol in [None, '.'] 1699 1700 1701_TensorDimSizes = collections.namedtuple( 1702 '_TensorDimSizes', 1703 ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes']) 1704 1705 1706@tf_export('debugging.assert_shapes', v1=[]) 1707@dispatch.add_dispatch_support 1708def assert_shapes_v2(shapes, data=None, summarize=None, message=None, 1709 name=None): 1710 """Assert tensor shapes and dimension size relationships between tensors. 1711 1712 This Op checks that a collection of tensors shape relationships 1713 satisfies given constraints. 1714 1715 Example: 1716 1717 >>> n = 10 1718 >>> q = 3 1719 >>> d = 7 1720 >>> x = tf.zeros([n,q]) 1721 >>> y = tf.ones([n,d]) 1722 >>> param = tf.Variable([1.0, 2.0, 3.0]) 1723 >>> scalar = 1.0 1724 >>> tf.debugging.assert_shapes([ 1725 ... (x, ('N', 'Q')), 1726 ... (y, ('N', 'D')), 1727 ... (param, ('Q',)), 1728 ... (scalar, ()), 1729 ... ]) 1730 1731 >>> tf.debugging.assert_shapes([ 1732 ... (x, ('N', 'D')), 1733 ... (y, ('N', 'D')) 1734 ... ]) 1735 Traceback (most recent call last): 1736 ... 1737 ValueError: ... 1738 1739 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 1740 all specified constraints, `message`, as well as the first `summarize` entries 1741 of the first encountered violating tensor are printed, and 1742 `InvalidArgumentError` is raised. 1743 1744 Size entries in the specified shapes are checked against other entries by 1745 their __hash__, except: 1746 - a size entry is interpreted as an explicit size if it can be parsed as an 1747 integer primitive. 1748 - a size entry is interpreted as *any* size if it is None or '.'. 1749 1750 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 1751 a variable number of outer dimensions of unspecified size, i.e. the constraint 1752 applies to the inner-most dimensions only. 1753 1754 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 1755 prefix) are both treated as having a single dimension of size one. 1756 1757 Args: 1758 shapes: dictionary with (`Tensor` to shape) items, or a list of 1759 (`Tensor`, shape) tuples. A shape must be an iterable. 1760 data: The tensors to print out if the condition is False. Defaults to error 1761 message and first few entries of the violating tensor. 1762 summarize: Print this many entries of the tensor. 1763 message: A string to prefix to the default message. 1764 name: A name for this operation (optional). Defaults to "assert_shapes". 1765 1766 Raises: 1767 ValueError: If static checks determine any shape constraint is violated. 1768 """ 1769 assert_shapes( 1770 shapes, data=data, summarize=summarize, message=message, name=name) 1771 1772 1773@tf_export(v1=['debugging.assert_shapes']) 1774@dispatch.add_dispatch_support 1775def assert_shapes(shapes, data=None, summarize=None, message=None, name=None): 1776 """Assert tensor shapes and dimension size relationships between tensors. 1777 1778 This Op checks that a collection of tensors shape relationships 1779 satisfies given constraints. 1780 1781 Example: 1782 1783 >>> n = 10 1784 >>> q = 3 1785 >>> d = 7 1786 >>> x = tf.zeros([n,q]) 1787 >>> y = tf.ones([n,d]) 1788 >>> param = tf.Variable([1.0, 2.0, 3.0]) 1789 >>> scalar = 1.0 1790 >>> tf.debugging.assert_shapes([ 1791 ... (x, ('N', 'Q')), 1792 ... (y, ('N', 'D')), 1793 ... (param, ('Q',)), 1794 ... (scalar, ()), 1795 ... ]) 1796 1797 >>> tf.debugging.assert_shapes([ 1798 ... (x, ('N', 'D')), 1799 ... (y, ('N', 'D')) 1800 ... ]) 1801 Traceback (most recent call last): 1802 ... 1803 ValueError: ... 1804 1805 Example of adding a dependency to an operation: 1806 1807 ```python 1808 with tf.control_dependencies([tf.assert_shapes(shapes)]): 1809 output = tf.matmul(x, y, transpose_a=True) 1810 ``` 1811 1812 If `x`, `y`, `param` or `scalar` does not have a shape that satisfies 1813 all specified constraints, `message`, as well as the first `summarize` entries 1814 of the first encountered violating tensor are printed, and 1815 `InvalidArgumentError` is raised. 1816 1817 Size entries in the specified shapes are checked against other entries by 1818 their __hash__, except: 1819 - a size entry is interpreted as an explicit size if it can be parsed as an 1820 integer primitive. 1821 - a size entry is interpreted as *any* size if it is None or '.'. 1822 1823 If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates 1824 a variable number of outer dimensions of unspecified size, i.e. the constraint 1825 applies to the inner-most dimensions only. 1826 1827 Scalar tensors and specified shapes of length zero (excluding the 'inner-most' 1828 prefix) are both treated as having a single dimension of size one. 1829 1830 Args: 1831 shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the 1832 expected shape of `Tensor`. See the example code above. The `shape` must 1833 be an iterable. Each element of the iterable can be either a concrete 1834 integer value or a string that abstractly represents the dimension. 1835 For example, 1836 - `('N', 'Q')` specifies a 2D shape wherein the first and second 1837 dimensions of shape may or may not be equal. 1838 - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second 1839 dimensions are equal. 1840 - `(1, 'N')` specifies a 2D shape wherein the first dimension is 1841 exactly 1 and the second dimension can be any value. 1842 Note that the abstract dimension letters take effect across different 1843 tuple elements of the list. For example, 1844 `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts 1845 that both `x` and `y` are rank-2 tensors and their first dimensions are 1846 equal (`N`). 1847 `shape` can also be a `tf.TensorShape`. 1848 data: The tensors to print out if the condition is False. Defaults to error 1849 message and first few entries of the violating tensor. 1850 summarize: Print this many entries of the tensor. 1851 message: A string to prefix to the default message. 1852 name: A name for this operation (optional). Defaults to "assert_shapes". 1853 1854 Returns: 1855 Op raising `InvalidArgumentError` unless all shape constraints are 1856 satisfied. 1857 If static checks determine all constraints are satisfied, a `no_op` is 1858 returned. 1859 1860 Raises: 1861 ValueError: If static checks determine any shape constraint is violated. 1862 """ 1863 # If the user manages to assemble a dict containing tensors (possible in 1864 # Graph mode only), make sure we still accept that. 1865 if isinstance(shapes, dict): 1866 shapes = shapes.items() 1867 1868 message_prefix = _message_prefix(message) 1869 with ops.name_scope(name, 'assert_shapes', [shapes, data]): 1870 # Shape specified as None implies no constraint 1871 shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else 1872 ops.convert_to_tensor(x), s) 1873 for x, s in shapes if s is not None] 1874 1875 executing_eagerly = context.executing_eagerly() 1876 1877 def tensor_name(x): 1878 if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor): 1879 return _shape_and_dtype_str(x) 1880 return x.name 1881 1882 tensor_dim_sizes = [] 1883 for tensor, symbolic_shape in shape_constraints: 1884 is_iterable = ( 1885 hasattr(symbolic_shape, '__iter__') or 1886 hasattr(symbolic_shape, '__getitem__') # For Python 2 compat. 1887 ) 1888 if not is_iterable: 1889 raise ValueError( 1890 '%s' 1891 'Tensor %s. Specified shape must be an iterable. ' 1892 'An iterable has the attribute `__iter__` or `__getitem__`. ' 1893 'Received specified shape: %s' % 1894 (message_prefix, tensor_name(tensor), symbolic_shape)) 1895 1896 # We convert this into a tuple to handle strings, lists and numpy arrays 1897 symbolic_shape_tuple = tuple(symbolic_shape) 1898 1899 tensors_specified_innermost = False 1900 for i, symbol in enumerate(symbolic_shape_tuple): 1901 if symbol not in [Ellipsis, '*']: 1902 continue 1903 1904 if i != 0: 1905 raise ValueError( 1906 '%s' 1907 'Tensor %s specified shape index %d. ' 1908 'Symbol `...` or `*` for a variable number of ' 1909 'unspecified dimensions is only allowed as the first entry' % 1910 (message_prefix, tensor_name(tensor), i)) 1911 1912 tensors_specified_innermost = True 1913 1914 # Only include the size of the specified dimensions since the 0th symbol 1915 # is either ellipsis or * 1916 tensor_dim_sizes.append( 1917 _TensorDimSizes( 1918 tensor, tensors_specified_innermost, _dimension_sizes(tensor), 1919 _symbolic_dimension_sizes( 1920 symbolic_shape_tuple[1:] 1921 if tensors_specified_innermost else symbolic_shape_tuple))) 1922 1923 rank_assertions = [] 1924 for sizes in tensor_dim_sizes: 1925 rank = len(sizes.symbolic_sizes) 1926 rank_zero_or_one = rank in [0, 1] 1927 if sizes.unspecified_dim: 1928 if rank_zero_or_one: 1929 # No assertion of rank needed as `x` only need to have rank at least 1930 # 0. See elif rank_zero_or_one case comment. 1931 continue 1932 assertion = assert_rank_at_least( 1933 x=sizes.x, 1934 rank=rank, 1935 data=data, 1936 summarize=summarize, 1937 message=message, 1938 name=name) 1939 elif rank_zero_or_one: 1940 # Rank 0 is treated as rank 1 size 1, i.e. there is 1941 # no distinction between the two in terms of rank. 1942 # See _dimension_sizes. 1943 assertion = assert_rank_in( 1944 x=sizes.x, 1945 ranks=[0, 1], 1946 data=data, 1947 summarize=summarize, 1948 message=message, 1949 name=name) 1950 else: 1951 assertion = assert_rank( 1952 x=sizes.x, 1953 rank=rank, 1954 data=data, 1955 summarize=summarize, 1956 message=message, 1957 name=name) 1958 rank_assertions.append(assertion) 1959 1960 size_assertions = [] 1961 size_specifications = {} 1962 for sizes in tensor_dim_sizes: 1963 for i, size_symbol in enumerate(sizes.symbolic_sizes): 1964 1965 if _is_symbol_for_any_size(size_symbol): 1966 # Size specified as any implies no constraint 1967 continue 1968 1969 if sizes.unspecified_dim: 1970 tensor_dim = i - len(sizes.symbolic_sizes) 1971 else: 1972 tensor_dim = i 1973 1974 if size_symbol in size_specifications or _has_known_value(size_symbol): 1975 if _has_known_value(size_symbol): 1976 specified_size = int(size_symbol) 1977 size_check_message = 'Specified explicitly' 1978 else: 1979 specified_size, specified_by_y, specified_at_dim = ( 1980 size_specifications[size_symbol]) 1981 size_check_message = ( 1982 'Specified by tensor %s dimension %d' % 1983 (tensor_name(specified_by_y), specified_at_dim)) 1984 1985 # This is extremely subtle. If actual_sizes is dynamic, we must 1986 # make sure a control dependency is inserted here so that this slice 1987 # can not execute until the rank is asserted to be enough for the 1988 # slice to not fail. 1989 with ops.control_dependencies(rank_assertions): 1990 actual_size = sizes.actual_sizes[tensor_dim] 1991 if _has_known_value(actual_size) and _has_known_value(specified_size): 1992 if int(actual_size) != int(specified_size): 1993 raise ValueError( 1994 '%s%s. Tensor %s dimension %s must have size %d. ' 1995 'Received size %d, shape %s' % 1996 (message_prefix, size_check_message, tensor_name(sizes.x), 1997 tensor_dim, specified_size, actual_size, 1998 sizes.x.get_shape())) 1999 # No dynamic assertion needed 2000 continue 2001 2002 condition = math_ops.equal( 2003 ops.convert_to_tensor(actual_size), 2004 ops.convert_to_tensor(specified_size)) 2005 data_ = data 2006 if data is None: 2007 data_ = [ 2008 message_prefix, size_check_message, 2009 'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim, 2010 'must have size', specified_size, 'Received shape: ', 2011 array_ops.shape(sizes.x) 2012 ] 2013 size_assertions.append( 2014 control_flow_ops.Assert(condition, data_, summarize=summarize)) 2015 else: 2016 # Not sure if actual_sizes is a constant, but for safety, guard 2017 # on rank. See explanation above about actual_sizes need for safety. 2018 with ops.control_dependencies(rank_assertions): 2019 size = sizes.actual_sizes[tensor_dim] 2020 size_specifications[size_symbol] = (size, sizes.x, tensor_dim) 2021 2022 # Ensure both assertions actually occur. 2023 with ops.control_dependencies(rank_assertions): 2024 shapes_assertion = control_flow_ops.group(size_assertions) 2025 2026 return shapes_assertion 2027 2028 2029# pylint: disable=line-too-long 2030def _get_diff_for_monotonic_comparison(x): 2031 """Gets the difference x[1:] - x[:-1].""" 2032 x = array_ops.reshape(x, [-1]) 2033 if not is_numeric_tensor(x): 2034 raise TypeError('Expected x to be numeric, instead found: %s' % x) 2035 2036 # If x has less than 2 elements, there is nothing to compare. So return []. 2037 is_shorter_than_two = math_ops.less(array_ops.size(x), 2) 2038 short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype) 2039 2040 # With 2 or more elements, return x[1:] - x[:-1] 2041 s_len = array_ops.shape(x) - 1 2042 diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len) 2043 return control_flow_ops.cond(is_shorter_than_two, short_result, diff) 2044 2045 2046@tf_export( 2047 'debugging.is_numeric_tensor', 2048 v1=['debugging.is_numeric_tensor', 'is_numeric_tensor']) 2049@deprecation.deprecated_endpoints('is_numeric_tensor') 2050def is_numeric_tensor(tensor): 2051 """Returns `True` if the elements of `tensor` are numbers. 2052 2053 Specifically, returns `True` if the dtype of `tensor` is one of the following: 2054 2055 * `tf.float16` 2056 * `tf.float32` 2057 * `tf.float64` 2058 * `tf.int8` 2059 * `tf.int16` 2060 * `tf.int32` 2061 * `tf.int64` 2062 * `tf.uint8` 2063 * `tf.uint16` 2064 * `tf.uint32` 2065 * `tf.uint64` 2066 * `tf.qint8` 2067 * `tf.qint16` 2068 * `tf.qint32` 2069 * `tf.quint8` 2070 * `tf.quint16` 2071 * `tf.complex64` 2072 * `tf.complex128` 2073 * `tf.bfloat16` 2074 2075 Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not 2076 a `tf.Tensor` object. 2077 """ 2078 return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES 2079 2080 2081@tf_export( 2082 'math.is_non_decreasing', 2083 v1=[ 2084 'math.is_non_decreasing', 'debugging.is_non_decreasing', 2085 'is_non_decreasing' 2086 ]) 2087@dispatch.add_dispatch_support 2088@deprecation.deprecated_endpoints('debugging.is_non_decreasing', 2089 'is_non_decreasing') 2090def is_non_decreasing(x, name=None): 2091 """Returns `True` if `x` is non-decreasing. 2092 2093 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 2094 is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`. 2095 If `x` has less than two elements, it is trivially non-decreasing. 2096 2097 See also: `is_strictly_increasing` 2098 2099 >>> x1 = tf.constant([1.0, 1.0, 3.0]) 2100 >>> tf.math.is_non_decreasing(x1) 2101 <tf.Tensor: shape=(), dtype=bool, numpy=True> 2102 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 2103 >>> tf.math.is_non_decreasing(x2) 2104 <tf.Tensor: shape=(), dtype=bool, numpy=False> 2105 2106 Args: 2107 x: Numeric `Tensor`. 2108 name: A name for this operation (optional). Defaults to "is_non_decreasing" 2109 2110 Returns: 2111 Boolean `Tensor`, equal to `True` iff `x` is non-decreasing. 2112 2113 Raises: 2114 TypeError: if `x` is not a numeric tensor. 2115 """ 2116 with ops.name_scope(name, 'is_non_decreasing', [x]): 2117 diff = _get_diff_for_monotonic_comparison(x) 2118 # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True. 2119 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 2120 return math_ops.reduce_all(math_ops.less_equal(zero, diff)) 2121 2122 2123@tf_export( 2124 'math.is_strictly_increasing', 2125 v1=[ 2126 'math.is_strictly_increasing', 'debugging.is_strictly_increasing', 2127 'is_strictly_increasing' 2128 ]) 2129@dispatch.add_dispatch_support 2130@deprecation.deprecated_endpoints('debugging.is_strictly_increasing', 2131 'is_strictly_increasing') 2132def is_strictly_increasing(x, name=None): 2133 """Returns `True` if `x` is strictly increasing. 2134 2135 Elements of `x` are compared in row-major order. The tensor `[x[0],...]` 2136 is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`. 2137 If `x` has less than two elements, it is trivially strictly increasing. 2138 2139 See also: `is_non_decreasing` 2140 2141 >>> x1 = tf.constant([1.0, 2.0, 3.0]) 2142 >>> tf.math.is_strictly_increasing(x1) 2143 <tf.Tensor: shape=(), dtype=bool, numpy=True> 2144 >>> x2 = tf.constant([3.0, 1.0, 2.0]) 2145 >>> tf.math.is_strictly_increasing(x2) 2146 <tf.Tensor: shape=(), dtype=bool, numpy=False> 2147 2148 Args: 2149 x: Numeric `Tensor`. 2150 name: A name for this operation (optional). 2151 Defaults to "is_strictly_increasing" 2152 2153 Returns: 2154 Boolean `Tensor`, equal to `True` iff `x` is strictly increasing. 2155 2156 Raises: 2157 TypeError: if `x` is not a numeric tensor. 2158 """ 2159 with ops.name_scope(name, 'is_strictly_increasing', [x]): 2160 diff = _get_diff_for_monotonic_comparison(x) 2161 # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True. 2162 zero = ops.convert_to_tensor(0, dtype=diff.dtype) 2163 return math_ops.reduce_all(math_ops.less(zero, diff)) 2164 2165 2166def _assert_same_base_type(items, expected_type=None): 2167 r"""Asserts all items are of the same base type. 2168 2169 Args: 2170 items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`, 2171 `Operation`, or `IndexedSlices`). Can include `None` elements, which 2172 will be ignored. 2173 expected_type: Expected type. If not specified, assert all items are 2174 of the same base type. 2175 2176 Returns: 2177 Validated type, or none if neither expected_type nor items provided. 2178 2179 Raises: 2180 ValueError: If any types do not match. 2181 """ 2182 original_expected_type = expected_type 2183 mismatch = False 2184 for item in items: 2185 if item is not None: 2186 item_type = item.dtype.base_dtype 2187 if not expected_type: 2188 expected_type = item_type 2189 elif expected_type != item_type: 2190 mismatch = True 2191 break 2192 if mismatch: 2193 # Loop back through and build up an informative error message (this is very 2194 # slow, so we don't do it unless we found an error above). 2195 expected_type = original_expected_type 2196 original_item_str = None 2197 for item in items: 2198 if item is not None: 2199 item_type = item.dtype.base_dtype 2200 if not expected_type: 2201 expected_type = item_type 2202 original_item_str = item.name if hasattr(item, 'name') else str(item) 2203 elif expected_type != item_type: 2204 raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % ( 2205 item.name if hasattr(item, 'name') else str(item), 2206 item_type, expected_type, 2207 (' as %s' % original_item_str) if original_item_str else '')) 2208 return expected_type # Should be unreachable 2209 else: 2210 return expected_type 2211 2212 2213@tf_export( 2214 'debugging.assert_same_float_dtype', 2215 v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype']) 2216@dispatch.add_dispatch_support 2217@deprecation.deprecated_endpoints('assert_same_float_dtype') 2218def assert_same_float_dtype(tensors=None, dtype=None): 2219 """Validate and return float type based on `tensors` and `dtype`. 2220 2221 For ops such as matrix multiplication, inputs and weights must be of the 2222 same float type. This function validates that all `tensors` are the same type, 2223 validates that type is `dtype` (if supplied), and returns the type. Type must 2224 be a floating point type. If neither `tensors` nor `dtype` is supplied, 2225 the function will return `dtypes.float32`. 2226 2227 Args: 2228 tensors: Tensors of input values. Can include `None` elements, which will be 2229 ignored. 2230 dtype: Expected type. 2231 2232 Returns: 2233 Validated type. 2234 2235 Raises: 2236 ValueError: if neither `tensors` nor `dtype` is supplied, or result is not 2237 float, or the common type of the inputs is not a floating point type. 2238 """ 2239 if tensors: 2240 dtype = _assert_same_base_type(tensors, dtype) 2241 if not dtype: 2242 dtype = dtypes.float32 2243 elif not dtype.is_floating: 2244 raise ValueError('Expected floating point type, got %s.' % dtype) 2245 return dtype 2246 2247 2248@tf_export('debugging.assert_scalar', v1=[]) 2249@dispatch.add_dispatch_support 2250def assert_scalar_v2(tensor, message=None, name=None): 2251 """Asserts that the given `tensor` is a scalar. 2252 2253 This function raises `ValueError` unless it can be certain that the given 2254 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 2255 unknown. 2256 2257 This is always checked statically, so this method returns nothing. 2258 2259 Args: 2260 tensor: A `Tensor`. 2261 message: A string to prefix to the default message. 2262 name: A name for this operation. Defaults to "assert_scalar" 2263 2264 Raises: 2265 ValueError: If the tensor is not scalar (rank 0), or if its shape is 2266 unknown. 2267 """ 2268 assert_scalar(tensor=tensor, message=message, name=name) 2269 2270 2271@tf_export(v1=['debugging.assert_scalar', 'assert_scalar']) 2272@dispatch.add_dispatch_support 2273@deprecation.deprecated_endpoints('assert_scalar') 2274def assert_scalar(tensor, name=None, message=None): 2275 """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional). 2276 2277 This function raises `ValueError` unless it can be certain that the given 2278 `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is 2279 unknown. 2280 2281 Args: 2282 tensor: A `Tensor`. 2283 name: A name for this operation. Defaults to "assert_scalar" 2284 message: A string to prefix to the default message. 2285 2286 Returns: 2287 The input tensor (potentially converted to a `Tensor`). 2288 2289 Raises: 2290 ValueError: If the tensor is not scalar (rank 0), or if its shape is 2291 unknown. 2292 """ 2293 with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope: 2294 tensor = ops.convert_to_tensor(tensor, name=name_scope) 2295 shape = tensor.get_shape() 2296 message = _message_prefix(message) 2297 if shape.ndims != 0: 2298 if context.executing_eagerly(): 2299 raise ValueError('%sExpected scalar shape, saw shape: %s.' 2300 % (message, shape,)) 2301 else: 2302 raise ValueError('%sExpected scalar shape for %s, saw shape: %s.' 2303 % (message, tensor.name, shape)) 2304 return tensor 2305 2306 2307def _message_prefix(message): 2308 if message: 2309 return '%s. ' % message 2310 return '' 2311 2312 2313@tf_export('ensure_shape') 2314@dispatch.add_dispatch_support 2315def ensure_shape(x, shape, name=None): 2316 """Updates the shape of a tensor and checks at runtime that the shape holds. 2317 2318 When executed, this operation asserts that the input tensor `x`'s shape 2319 is compatible with the `shape` argument. 2320 See `tf.TensorShape.is_compatible_with` for details. 2321 2322 >>> x = tf.constant([[1, 2, 3], 2323 ... [4, 5, 6]]) 2324 >>> x = tf.ensure_shape(x, [2, 3]) 2325 2326 Use `None` for unknown dimensions: 2327 2328 >>> x = tf.ensure_shape(x, [None, 3]) 2329 >>> x = tf.ensure_shape(x, [2, None]) 2330 2331 If the tensor's shape is not compatible with the `shape` argument, an error 2332 is raised: 2333 2334 >>> x = tf.ensure_shape(x, [5]) 2335 Traceback (most recent call last): 2336 ... 2337 tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not 2338 compatible with expected shape [5]. [Op:EnsureShape] 2339 2340 During graph construction (typically tracing a `tf.function`), 2341 `tf.ensure_shape` updates the static-shape of the **result** tensor by 2342 merging the two shapes. See `tf.TensorShape.merge_with` for details. 2343 2344 This is most useful when **you** know a shape that can't be determined 2345 statically by TensorFlow. 2346 2347 The following trivial `tf.function` prints the input tensor's 2348 static-shape before and after `ensure_shape` is applied. 2349 2350 >>> @tf.function 2351 ... def f(tensor): 2352 ... print("Static-shape before:", tensor.shape) 2353 ... tensor = tf.ensure_shape(tensor, [None, 3]) 2354 ... print("Static-shape after:", tensor.shape) 2355 ... return tensor 2356 2357 This lets you see the effect of `tf.ensure_shape` when the function is traced: 2358 >>> cf = f.get_concrete_function(tf.TensorSpec([None, None])) 2359 Static-shape before: (None, None) 2360 Static-shape after: (None, 3) 2361 2362 >>> cf(tf.zeros([3, 3])) # Passes 2363 >>> cf(tf.constant([1, 2, 3])) # fails 2364 Traceback (most recent call last): 2365 ... 2366 InvalidArgumentError: Shape of tensor x [3] is not compatible with expected shape [3,3]. 2367 2368 The above example raises `tf.errors.InvalidArgumentError`, because `x`'s 2369 shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)` 2370 2371 Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and 2372 runtime shapes. This is stricter than `tf.Tensor.set_shape` which only 2373 checks the buildtime shape. 2374 2375 Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape 2376 of the resulting tensor and enforces it at runtime, raising an error if the 2377 tensor's runtime shape is incompatible with the specified shape. 2378 `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it 2379 at runtime, which may result in inconsistencies between the statically-known 2380 shape of tensors and the runtime value of tensors. 2381 2382 For example, of loading images of a known size: 2383 2384 >>> @tf.function 2385 ... def decode_image(png): 2386 ... image = tf.image.decode_png(png, channels=3) 2387 ... # the `print` executes during tracing. 2388 ... print("Initial shape: ", image.shape) 2389 ... image = tf.ensure_shape(image,[28, 28, 3]) 2390 ... print("Final shape: ", image.shape) 2391 ... return image 2392 2393 When tracing a function, no ops are being executed, shapes may be unknown. 2394 See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function) 2395 for details. 2396 2397 >>> concrete_decode = decode_image.get_concrete_function( 2398 ... tf.TensorSpec([], dtype=tf.string)) 2399 Initial shape: (None, None, 3) 2400 Final shape: (28, 28, 3) 2401 2402 >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32) 2403 >>> image = tf.cast(image,tf.uint8) 2404 >>> png = tf.image.encode_png(image) 2405 >>> image2 = concrete_decode(png) 2406 >>> print(image2.shape) 2407 (28, 28, 3) 2408 2409 >>> image = tf.concat([image,image], axis=0) 2410 >>> print(image.shape) 2411 (56, 28, 3) 2412 >>> png = tf.image.encode_png(image) 2413 >>> image2 = concrete_decode(png) 2414 Traceback (most recent call last): 2415 ... 2416 tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is not 2417 compatible with expected shape [28,28,3]. 2418 2419 Caution: if you don't use the result of `tf.ensure_shape` the check may not 2420 run. 2421 2422 >>> @tf.function 2423 ... def bad_decode_image(png): 2424 ... image = tf.image.decode_png(png, channels=3) 2425 ... # the `print` executes during tracing. 2426 ... print("Initial shape: ", image.shape) 2427 ... # BAD: forgot to use the returned tensor. 2428 ... tf.ensure_shape(image,[28, 28, 3]) 2429 ... print("Final shape: ", image.shape) 2430 ... return image 2431 2432 >>> image = bad_decode_image(png) 2433 Initial shape: (None, None, 3) 2434 Final shape: (None, None, 3) 2435 >>> print(image.shape) 2436 (56, 28, 3) 2437 2438 Args: 2439 x: A `Tensor`. 2440 shape: A `TensorShape` representing the shape of this tensor, a 2441 `TensorShapeProto`, a list, a tuple, or None. 2442 name: A name for this operation (optional). Defaults to "EnsureShape". 2443 2444 Returns: 2445 A `Tensor`. Has the same type and contents as `x`. 2446 2447 Raises: 2448 tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape 2449 of `x`. 2450 """ 2451 if not isinstance(shape, tensor_shape.TensorShape): 2452 shape = tensor_shape.TensorShape(shape) 2453 2454 return array_ops.ensure_shape(x, shape, name=name) 2455 2456 2457@ops.RegisterGradient('EnsureShape') 2458def _ensure_shape_grad(op, grad): 2459 del op # Unused. 2460 return grad 2461