xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/check_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=g-short-docstring-punctuation
16"""Asserts and Boolean Checks."""
17
18import collections
19
20import numpy as np
21
22from tensorflow.python.eager import context
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import errors
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.util import compat
33from tensorflow.python.util import deprecation
34from tensorflow.python.util import dispatch
35from tensorflow.python.util.tf_export import tf_export
36
37NUMERIC_TYPES = frozenset([
38    dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int8, dtypes.int16,
39    dtypes.int32, dtypes.int64, dtypes.uint8, dtypes.uint16, dtypes.uint32,
40    dtypes.uint64, dtypes.qint8, dtypes.qint16, dtypes.qint32, dtypes.quint8,
41    dtypes.quint16, dtypes.complex64, dtypes.complex128, dtypes.bfloat16
42])
43
44__all__ = [
45    'assert_negative',
46    'assert_positive',
47    'assert_proper_iterable',
48    'assert_non_negative',
49    'assert_non_positive',
50    'assert_equal',
51    'assert_none_equal',
52    'assert_near',
53    'assert_integer',
54    'assert_less',
55    'assert_less_equal',
56    'assert_greater',
57    'assert_greater_equal',
58    'assert_rank',
59    'assert_rank_at_least',
60    'assert_rank_in',
61    'assert_same_float_dtype',
62    'assert_scalar',
63    'assert_type',
64    'assert_shapes',
65    'is_non_decreasing',
66    'is_numeric_tensor',
67    'is_strictly_increasing',
68]
69
70
71def _maybe_constant_value_string(t):
72  if not isinstance(t, ops.Tensor):
73    return str(t)
74  const_t = tensor_util.constant_value(t)
75  if const_t is not None:
76    return str(const_t)
77  return t
78
79
80def _assert_static(condition, data):
81  """Raises a InvalidArgumentError with as much information as possible."""
82  if not condition:
83    data_static = [_maybe_constant_value_string(x) for x in data]
84    raise errors.InvalidArgumentError(node_def=None, op=None,
85                                      message='\n'.join(data_static))
86
87
88def _shape_and_dtype_str(tensor):
89  """Returns a string containing tensor's shape and dtype."""
90  return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
91
92
93def _unary_assert_doc(sym, sym_name):
94  """Common docstring for assert_* ops that evaluate a unary predicate over every element of a tensor.
95
96  Args:
97    sym: Mathematical symbol for the check performed on each element, i.e. "> 0"
98    sym_name: English-language name for the op described by sym
99
100  Returns:
101    Decorator that adds the appropriate docstring to the function for symbol
102    `sym`.
103  """
104
105  def _decorator(func):
106    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
107
108    Args:
109      func: Function for a TensorFlow op
110
111    Returns:
112      Version of `func` with documentation attached.
113    """
114    opname = func.__name__
115    cap_sym_name = sym_name.capitalize()
116
117    func.__doc__ = """
118    Assert the condition `x {sym}` holds element-wise.
119
120    When running in graph mode, you should add a dependency on this operation
121    to ensure that it runs. Example of adding a dependency to an operation:
122
123    ```python
124    with tf.control_dependencies([tf.debugging.{opname}(x, y)]):
125      output = tf.reduce_sum(x)
126    ```
127
128    {sym_name} means, for every element `x[i]` of `x`, we have `x[i] {sym}`.
129    If `x` is empty this is trivially satisfied.
130
131    Args:
132      x:  Numeric `Tensor`.
133      data:  The tensors to print out if the condition is False.  Defaults to
134        error message and first few entries of `x`.
135      summarize: Print this many entries of each tensor.
136      message: A string to prefix to the default message.
137      name: A name for this operation (optional).  Defaults to "{opname}".
138
139    Returns:
140      Op that raises `InvalidArgumentError` if `x {sym}` is False.
141      @compatibility(eager)
142        returns None
143      @end_compatibility
144
145    Raises:
146      InvalidArgumentError: if the check can be performed immediately and
147        `x {sym}` is False. The check can be performed immediately during
148        eager execution or if `x` is statically known.
149    """.format(
150        sym=sym, sym_name=cap_sym_name, opname=opname)
151    return func
152
153  return _decorator
154
155
156def _binary_assert_doc(sym, test_var):
157  """Common docstring for most of the v1 assert_* ops that compare two tensors element-wise.
158
159  Args:
160    sym: Binary operation symbol, i.e. "=="
161    test_var: a string that represents the variable in the right-hand side of
162      binary operator of the test case
163
164  Returns:
165    Decorator that adds the appropriate docstring to the function for
166  symbol `sym`.
167  """
168
169  def _decorator(func):
170    """Generated decorator that adds the appropriate docstring to the function for symbol `sym`.
171
172    Args:
173      func: Function for a TensorFlow op
174
175    Returns:
176      A version of `func` with documentation attached.
177    """
178    opname = func.__name__
179
180    func.__doc__ = """
181    Assert the condition `x {sym} y` holds element-wise.
182
183    This condition holds if for every pair of (possibly broadcast) elements
184    `x[i]`, `y[i]`, we have `x[i] {sym} y[i]`.
185    If both `x` and `y` are empty, this is trivially satisfied.
186
187    When running in graph mode, you should add a dependency on this operation
188    to ensure that it runs. Example of adding a dependency to an operation:
189
190    ```python
191    with tf.control_dependencies([tf.compat.v1.{opname}(x, y)]):
192      output = tf.reduce_sum(x)
193    ```
194
195    Args:
196      x:  Numeric `Tensor`.
197      y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
198      data:  The tensors to print out if the condition is False.  Defaults to
199        error message and first few entries of `x`, `y`.
200      summarize: Print this many entries of each tensor.
201      message: A string to prefix to the default message.
202      name: A name for this operation (optional).  Defaults to "{opname}".
203
204    Returns:
205      Op that raises `InvalidArgumentError` if `x {sym} y` is False.
206
207    Raises:
208      InvalidArgumentError: if the check can be performed immediately and
209        `x {sym} y` is False. The check can be performed immediately during
210        eager execution or if `x` and `y` are statically known.
211
212    @compatibility(TF2)
213    `tf.compat.v1.{opname}` is compatible with eager execution and
214    `tf.function`.
215    Please use `tf.debugging.{opname}` instead when migrating to TF2. Apart
216    from `data`, all arguments are supported with the same argument name.
217
218    If you want to ensure the assert statements run before the
219    potentially-invalid computation, please use `tf.control_dependencies`,
220    as tf.function auto-control dependencies are insufficient for assert
221    statements.
222
223    #### Structural Mapping to Native TF2
224
225    Before:
226
227    ```python
228    tf.compat.v1.{opname}(
229      x=x, y=y, data=data, summarize=summarize,
230      message=message, name=name)
231    ```
232
233    After:
234
235    ```python
236    tf.debugging.{opname}(
237      x=x, y=y, message=message,
238      summarize=summarize, name=name)
239    ```
240
241    #### TF1 & TF2 Usage Example
242
243    TF1:
244
245    >>> g = tf.Graph()
246    >>> with g.as_default():
247    ...   a = tf.compat.v1.placeholder(tf.float32, [2])
248    ...   b = tf.compat.v1.placeholder(tf.float32, [2])
249    ...   result = tf.compat.v1.{opname}(a, b,
250    ...     message='"a {sym} b" does not hold for the given inputs')
251    ...   with tf.compat.v1.control_dependencies([result]):
252    ...     sum_node = a + b
253    >>> sess = tf.compat.v1.Session(graph=g)
254    >>> val = sess.run(sum_node, feed_dict={{a: [1, 2], b:{test_var}}})
255
256
257    TF2:
258
259    >>> a = tf.Variable([1, 2], dtype=tf.float32)
260    >>> b = tf.Variable({test_var}, dtype=tf.float32)
261    >>> assert_op = tf.debugging.{opname}(a, b, message=
262    ...   '"a {sym} b" does not hold for the given inputs')
263    >>> # When working with tf.control_dependencies
264    >>> with tf.control_dependencies([assert_op]):
265    ...   val = a + b
266
267    @end_compatibility
268    """.format(
269        sym=sym, opname=opname, test_var=test_var)
270    return func
271
272  return _decorator
273
274
275def _make_assert_msg_data(sym, x, y, summarize, test_op):
276  """Subroutine of _binary_assert that generates the components of the default error message when running in eager mode.
277
278  Args:
279    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
280      i.e. "=="
281    x: First input to the assertion after applying `convert_to_tensor()`
282    y: Second input to the assertion
283    summarize: Value of the "summarize" parameter to the original assert_* call;
284      tells how many elements of each tensor to print.
285    test_op: TensorFlow op that returns a Boolean tensor with True in each
286      position where the assertion is satisfied.
287
288  Returns:
289    List of tensors and scalars that, when stringified and concatenated,
290    will produce the error message string.
291  """
292  # Prepare a message with first elements of x and y.
293  data = []
294
295  data.append('Condition x %s y did not hold.' % sym)
296
297  if summarize > 0:
298    if x.shape == y.shape and x.shape.as_list():
299      # If the shapes of x and y are the same (and not scalars),
300      # Get the values that actually differed and their indices.
301      # If shapes are different this information is more confusing
302      # than useful.
303      mask = math_ops.logical_not(test_op)
304      indices = array_ops.where(mask)
305      indices_np = indices.numpy()
306      x_vals = array_ops.boolean_mask(x, mask)
307      y_vals = array_ops.boolean_mask(y, mask)
308      num_vals = min(summarize, indices_np.shape[0])
309      data.append('Indices of first %d different values:' % num_vals)
310      data.append(indices_np[:num_vals])
311      data.append('Corresponding x values:')
312      data.append(x_vals.numpy().reshape((-1,))[:num_vals])
313      data.append('Corresponding y values:')
314      data.append(y_vals.numpy().reshape((-1,))[:num_vals])
315
316    # reshape((-1,)) is the fastest way to get a flat array view.
317    x_np = x.numpy().reshape((-1,))
318    y_np = y.numpy().reshape((-1,))
319    x_sum = min(x_np.size, summarize)
320    y_sum = min(y_np.size, summarize)
321    data.append('First %d elements of x:' % x_sum)
322    data.append(x_np[:x_sum])
323    data.append('First %d elements of y:' % y_sum)
324    data.append(y_np[:y_sum])
325
326  return data
327
328
329def _pretty_print(data_item, summarize):
330  """Format a data item for use in an error message in eager mode.
331
332  Args:
333    data_item: One of the items in the "data" argument to an assert_* function.
334      Can be a Tensor or a scalar value.
335    summarize: How many elements to retain of each tensor-valued entry in data.
336
337  Returns:
338    An appropriate string representation of data_item
339  """
340  if isinstance(data_item, ops.Tensor):
341    arr = data_item.numpy()
342    if np.isscalar(arr):
343      # Tensor.numpy() returns a scalar for zero-dimensional tensors
344      return str(arr)
345    else:
346      flat = arr.reshape((-1,))
347      lst = [str(x) for x in flat[:summarize]]
348      if len(lst) < flat.size:
349        lst.append('...')
350      return str(lst)
351  else:
352    return str(data_item)
353
354
355def _binary_assert(sym, opname, op_func, static_func, x, y, data, summarize,
356                   message, name):
357  """Generic binary elementwise assertion.
358
359  Implements the behavior described in _binary_assert_doc() above.
360  Args:
361    sym: Mathematical symbol for the test to apply to pairs of tensor elements,
362      i.e. "=="
363    opname: Name of the assert op in the public API, i.e. "assert_equal"
364    op_func: Function that, if passed the two Tensor inputs to the assertion (x
365      and y), will return the test to be passed to reduce_all() i.e.
366    static_func: Function that, if passed numpy ndarray versions of the two
367      inputs to the assertion, will return a Boolean ndarray with containing
368      True in all positions where the assertion PASSES.
369      i.e. np.equal for assert_equal()
370    x:  Numeric `Tensor`.
371    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
372    data:  The tensors to print out if the condition is False.  Defaults to
373      error message and first few entries of `x`, `y`.
374    summarize: Print this many entries of each tensor.
375    message: A string to prefix to the default message.
376    name: A name for this operation (optional).  Defaults to the value of
377      `opname`.
378
379  Returns:
380    See docstring template in _binary_assert_doc().
381  """
382  with ops.name_scope(name, opname, [x, y, data]):
383    x = ops.convert_to_tensor(x, name='x')
384    y = ops.convert_to_tensor(y, name='y')
385
386    if context.executing_eagerly():
387      test_op = op_func(x, y)
388      condition = math_ops.reduce_all(test_op)
389      if condition:
390        return
391
392      # If we get here, the assertion has failed.
393      # Default to printing 3 elements like control_flow_ops.Assert (used
394      # by graph mode) does. Also treat negative values as "print
395      # everything" for consistency with Tensor::SummarizeValue().
396      if summarize is None:
397        summarize = 3
398      elif summarize < 0:
399        summarize = 1e9  # Code below will find exact size of x and y.
400
401      if data is None:
402        data = _make_assert_msg_data(sym, x, y, summarize, test_op)
403
404      if message is not None:
405        data = [message] + list(data)
406
407      raise errors.InvalidArgumentError(
408          node_def=None,
409          op=None,
410          message=('\n'.join(_pretty_print(d, summarize) for d in data)))
411
412    else:  # not context.executing_eagerly()
413      if data is None:
414        data = [
415            'Condition x %s y did not hold element-wise:' % sym,
416            'x (%s) = ' % x.name, x,
417            'y (%s) = ' % y.name, y
418        ]
419      if message is not None:
420        data = [message] + list(data)
421      condition = math_ops.reduce_all(op_func(x, y))
422      x_static = tensor_util.constant_value(x)
423      y_static = tensor_util.constant_value(y)
424      if x_static is not None and y_static is not None:
425        condition_static = np.all(static_func(x_static, y_static))
426        _assert_static(condition_static, data)
427      return control_flow_ops.Assert(condition, data, summarize=summarize)
428
429
430@tf_export(
431    'debugging.assert_proper_iterable',
432    v1=['debugging.assert_proper_iterable', 'assert_proper_iterable'])
433@dispatch.add_dispatch_support
434@deprecation.deprecated_endpoints('assert_proper_iterable')
435def assert_proper_iterable(values):
436  """Static assert that values is a "proper" iterable.
437
438  `Ops` that expect iterables of `Tensor` can call this to validate input.
439  Useful since `Tensor`, `ndarray`, byte/text type are all iterables themselves.
440
441  Args:
442    values:  Object to be checked.
443
444  Raises:
445    TypeError:  If `values` is not iterable or is one of
446      `Tensor`, `SparseTensor`, `np.array`, `tf.compat.bytes_or_text_types`.
447  """
448  unintentional_iterables = (
449      (ops.Tensor, sparse_tensor.SparseTensor, np.ndarray)
450      + compat.bytes_or_text_types
451  )
452  if isinstance(values, unintentional_iterables):
453    raise TypeError(
454        'Expected argument "values" to be a "proper" iterable.  Found: %s' %
455        type(values))
456
457  if not hasattr(values, '__iter__'):
458    raise TypeError(
459        'Expected argument "values" to be iterable.  Found: %s' % type(values))
460
461
462@tf_export('debugging.assert_negative', v1=[])
463@dispatch.add_dispatch_support
464def assert_negative_v2(x, message=None, summarize=None, name=None):
465  """Assert the condition `x < 0` holds element-wise.
466
467  This Op checks that `x[i] < 0` holds for every element of `x`. If `x` is
468  empty, this is trivially satisfied.
469
470  If `x` is not negative everywhere, `message`, as well as the first `summarize`
471  entries of `x` are printed, and `InvalidArgumentError` is raised.
472
473  Args:
474    x:  Numeric `Tensor`.
475    message: A string to prefix to the default message.
476    summarize: Print this many entries of each tensor.
477    name: A name for this operation (optional).  Defaults to "assert_negative".
478
479  Returns:
480    Op raising `InvalidArgumentError` unless `x` is all negative. This can be
481      used with `tf.control_dependencies` inside of `tf.function`s to block
482      followup computation until the check has executed.
483    @compatibility(eager)
484    returns None
485    @end_compatibility
486
487  Raises:
488    InvalidArgumentError: if the check can be performed immediately and
489      `x[i] < 0` is False. The check can be performed immediately during eager
490      execution or if `x` is statically known.
491  """
492  return assert_negative(x=x, message=message, summarize=summarize, name=name)
493
494
495@tf_export(v1=['debugging.assert_negative', 'assert_negative'])
496@dispatch.add_dispatch_support
497@deprecation.deprecated_endpoints('assert_negative')
498@_unary_assert_doc('< 0', 'negative')
499def assert_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
500  message = _message_prefix(message)
501  with ops.name_scope(name, 'assert_negative', [x, data]):
502    x = ops.convert_to_tensor(x, name='x')
503    if data is None:
504      if context.executing_eagerly():
505        name = _shape_and_dtype_str(x)
506      else:
507        name = x.name
508      data = [
509          message,
510          'Condition x < 0 did not hold element-wise:',
511          'x (%s) = ' % name, x]
512    zero = ops.convert_to_tensor(0, dtype=x.dtype)
513    return assert_less(x, zero, data=data, summarize=summarize)
514
515
516@tf_export('debugging.assert_positive', v1=[])
517@dispatch.add_dispatch_support
518def assert_positive_v2(x, message=None, summarize=None, name=None):
519  """Assert the condition `x > 0` holds element-wise.
520
521  This Op checks that `x[i] > 0` holds for every element of `x`. If `x` is
522  empty, this is trivially satisfied.
523
524  If `x` is not positive everywhere, `message`, as well as the first `summarize`
525  entries of `x` are printed, and `InvalidArgumentError` is raised.
526
527  Args:
528    x:  Numeric `Tensor`.
529    message: A string to prefix to the default message.
530    summarize: Print this many entries of each tensor.
531    name: A name for this operation (optional). Defaults to "assert_positive".
532
533  Returns:
534    Op raising `InvalidArgumentError` unless `x` is all positive. This can be
535      used with `tf.control_dependencies` inside of `tf.function`s to block
536      followup computation until the check has executed.
537    @compatibility(eager)
538    returns None
539    @end_compatibility
540
541  Raises:
542    InvalidArgumentError: if the check can be performed immediately and
543      `x[i] > 0` is False. The check can be performed immediately during eager
544      execution or if `x` is statically known.
545  """
546  return assert_positive(x=x, summarize=summarize, message=message, name=name)
547
548
549@tf_export(v1=['debugging.assert_positive', 'assert_positive'])
550@dispatch.add_dispatch_support
551@deprecation.deprecated_endpoints('assert_positive')
552@_unary_assert_doc('> 0', 'positive')
553def assert_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
554  message = _message_prefix(message)
555  with ops.name_scope(name, 'assert_positive', [x, data]):
556    x = ops.convert_to_tensor(x, name='x')
557    if data is None:
558      if context.executing_eagerly():
559        name = _shape_and_dtype_str(x)
560      else:
561        name = x.name
562      data = [
563          message, 'Condition x > 0 did not hold element-wise:',
564          'x (%s) = ' % name, x]
565    zero = ops.convert_to_tensor(0, dtype=x.dtype)
566    return assert_less(zero, x, data=data, summarize=summarize)
567
568
569@tf_export('debugging.assert_non_negative', v1=[])
570@dispatch.add_dispatch_support
571def assert_non_negative_v2(x, message=None, summarize=None, name=None):
572  """Assert the condition `x >= 0` holds element-wise.
573
574  This Op checks that `x[i] >= 0` holds for every element of `x`. If `x` is
575  empty, this is trivially satisfied.
576
577  If `x` is not >= 0 everywhere, `message`, as well as the first `summarize`
578  entries of `x` are printed, and `InvalidArgumentError` is raised.
579
580  Args:
581    x:  Numeric `Tensor`.
582    message: A string to prefix to the default message.
583    summarize: Print this many entries of each tensor.
584    name: A name for this operation (optional).  Defaults to
585      "assert_non_negative".
586
587  Returns:
588    Op raising `InvalidArgumentError` unless `x` is all non-negative. This can
589      be used with `tf.control_dependencies` inside of `tf.function`s to block
590      followup computation until the check has executed.
591    @compatibility(eager)
592    returns None
593    @end_compatibility
594
595  Raises:
596    InvalidArgumentError: if the check can be performed immediately and
597      `x[i] >= 0` is False. The check can be performed immediately during eager
598      execution or if `x` is statically known.
599  """
600  return assert_non_negative(x=x, summarize=summarize, message=message,
601                             name=name)
602
603
604@tf_export(v1=['debugging.assert_non_negative', 'assert_non_negative'])
605@dispatch.add_dispatch_support
606@deprecation.deprecated_endpoints('assert_non_negative')
607@_unary_assert_doc('>= 0', 'non-negative')
608def assert_non_negative(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
609  message = _message_prefix(message)
610  with ops.name_scope(name, 'assert_non_negative', [x, data]):
611    x = ops.convert_to_tensor(x, name='x')
612    if data is None:
613      if context.executing_eagerly():
614        name = _shape_and_dtype_str(x)
615      else:
616        name = x.name
617      data = [
618          message,
619          'Condition x >= 0 did not hold element-wise:',
620          'x (%s) = ' % name, x]
621    zero = ops.convert_to_tensor(0, dtype=x.dtype)
622    return assert_less_equal(zero, x, data=data, summarize=summarize)
623
624
625@tf_export('debugging.assert_non_positive', v1=[])
626@dispatch.add_dispatch_support
627def assert_non_positive_v2(x, message=None, summarize=None, name=None):
628  """Assert the condition `x <= 0` holds element-wise.
629
630  This Op checks that `x[i] <= 0` holds for every element of `x`. If `x` is
631  empty, this is trivially satisfied.
632
633  If `x` is not <= 0 everywhere, `message`, as well as the first `summarize`
634  entries of `x` are printed, and `InvalidArgumentError` is raised.
635
636  Args:
637    x:  Numeric `Tensor`.
638    message: A string to prefix to the default message.
639    summarize: Print this many entries of each tensor.
640    name: A name for this operation (optional).  Defaults to
641      "assert_non_positive".
642
643  Returns:
644    Op raising `InvalidArgumentError` unless `x` is all non-positive. This can
645      be used with `tf.control_dependencies` inside of `tf.function`s to block
646      followup computation until the check has executed.
647    @compatibility(eager)
648    returns None
649    @end_compatibility
650
651  Raises:
652    InvalidArgumentError: if the check can be performed immediately and
653      `x[i] <= 0` is False. The check can be performed immediately during eager
654      execution or if `x` is statically known.
655  """
656  return assert_non_positive(x=x, summarize=summarize, message=message,
657                             name=name)
658
659
660@tf_export(v1=['debugging.assert_non_positive', 'assert_non_positive'])
661@dispatch.add_dispatch_support
662@deprecation.deprecated_endpoints('assert_non_positive')
663@_unary_assert_doc('<= 0', 'non-positive')
664def assert_non_positive(x, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
665  message = _message_prefix(message)
666  with ops.name_scope(name, 'assert_non_positive', [x, data]):
667    x = ops.convert_to_tensor(x, name='x')
668    if data is None:
669      if context.executing_eagerly():
670        name = _shape_and_dtype_str(x)
671      else:
672        name = x.name
673      data = [
674          message,
675          'Condition x <= 0 did not hold element-wise:'
676          'x (%s) = ' % name, x]
677    zero = ops.convert_to_tensor(0, dtype=x.dtype)
678    return assert_less_equal(x, zero, data=data, summarize=summarize)
679
680
681@tf_export('debugging.assert_equal', 'assert_equal', v1=[])
682@dispatch.register_binary_elementwise_assert_api
683@dispatch.add_dispatch_support
684def assert_equal_v2(x, y, message=None, summarize=None, name=None):
685  """Assert the condition `x == y` holds element-wise.
686
687  This Op checks that `x[i] == y[i]` holds for every pair of (possibly
688  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
689  trivially satisfied.
690
691  If `x` and `y` are not equal, `message`, as well as the first `summarize`
692  entries of `x` and `y` are printed, and `InvalidArgumentError` is raised.
693
694  Args:
695    x:  Numeric `Tensor`.
696    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
697    message: A string to prefix to the default message.
698    summarize: Print this many entries of each tensor.
699    name: A name for this operation (optional).  Defaults to "assert_equal".
700
701  Returns:
702    Op that raises `InvalidArgumentError` if `x == y` is False. This can be
703      used with `tf.control_dependencies` inside of `tf.function`s to block
704      followup computation until the check has executed.
705    @compatibility(eager)
706    returns None
707    @end_compatibility
708
709  Raises:
710    InvalidArgumentError: if the check can be performed immediately and
711      `x == y` is False. The check can be performed immediately during eager
712      execution or if `x` and `y` are statically known.
713  """
714  return assert_equal(x=x, y=y, summarize=summarize, message=message, name=name)
715
716
717@tf_export(v1=['debugging.assert_equal', 'assert_equal'])
718@dispatch.register_binary_elementwise_assert_api
719@dispatch.add_dispatch_support
720@_binary_assert_doc('==', '[1, 2]')
721def assert_equal(x, y, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
722  with ops.name_scope(name, 'assert_equal', [x, y, data]):
723    # Short-circuit if x and y are the same tensor.
724    if x is y:
725      return None if context.executing_eagerly() else control_flow_ops.no_op()
726  return _binary_assert('==', 'assert_equal', math_ops.equal, np.equal, x, y,
727                        data, summarize, message, name)
728
729
730@tf_export('debugging.assert_none_equal', v1=[])
731@dispatch.register_binary_elementwise_assert_api
732@dispatch.add_dispatch_support
733def assert_none_equal_v2(x, y, summarize=None, message=None, name=None):
734  """Assert the condition `x != y` holds for all elements.
735
736  This Op checks that `x[i] != y[i]` holds for every pair of (possibly
737  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
738  trivially satisfied.
739
740  If any elements of `x` and `y` are equal, `message`, as well as the first
741  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
742  is raised.
743
744  Args:
745    x:  Numeric `Tensor`.
746    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
747    summarize: Print this many entries of each tensor.
748    message: A string to prefix to the default message.
749    name: A name for this operation (optional).  Defaults to
750    "assert_none_equal".
751
752  Returns:
753    Op that raises `InvalidArgumentError` if `x != y` is ever False. This can
754      be used with `tf.control_dependencies` inside of `tf.function`s to block
755      followup computation until the check has executed.
756    @compatibility(eager)
757    returns None
758    @end_compatibility
759
760  Raises:
761    InvalidArgumentError: if the check can be performed immediately and
762      `x != y` is False for any pair of elements in `x` and `y`. The check can
763      be performed immediately during eager execution or if `x` and `y` are
764      statically known.
765  """
766  return assert_none_equal(x=x, y=y, summarize=summarize, message=message,
767                           name=name)
768
769
770@tf_export(v1=['debugging.assert_none_equal', 'assert_none_equal'])
771@dispatch.register_binary_elementwise_assert_api
772@dispatch.add_dispatch_support
773@deprecation.deprecated_endpoints('assert_none_equal')
774@_binary_assert_doc('!=', '[2, 1]')
775def assert_none_equal(
776    x, y, data=None, summarize=None, message=None, name=None):
777  return _binary_assert('!=', 'assert_none_equal', math_ops.not_equal,
778                        np.not_equal, x, y, data, summarize, message, name)
779
780
781@tf_export('debugging.assert_near', v1=[])
782@dispatch.register_binary_elementwise_assert_api
783@dispatch.add_dispatch_support
784def assert_near_v2(x, y, rtol=None, atol=None, message=None, summarize=None,
785                   name=None):
786  """Assert the condition `x` and `y` are close element-wise.
787
788  This Op checks that `x[i] - y[i] < atol + rtol * tf.abs(y[i])` holds for every
789  pair of (possibly broadcast) elements of `x` and `y`. If both `x` and `y` are
790  empty, this is trivially satisfied.
791
792  If any elements of `x` and `y` are not close, `message`, as well as the first
793  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError`
794  is raised.
795
796  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
797  representable positive number such that `1 + eps != 1`.  This is about
798  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
799  See `numpy.finfo`.
800
801  Args:
802    x: Float or complex `Tensor`.
803    y: Float or complex `Tensor`, same dtype as and broadcastable to `x`.
804    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
805      The relative tolerance.  Default is `10 * eps`.
806    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
807      The absolute tolerance.  Default is `10 * eps`.
808    message: A string to prefix to the default message.
809    summarize: Print this many entries of each tensor.
810    name: A name for this operation (optional).  Defaults to "assert_near".
811
812  Returns:
813    Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
814      This can be used with `tf.control_dependencies` inside of `tf.function`s
815      to block followup computation until the check has executed.
816    @compatibility(eager)
817    returns None
818    @end_compatibility
819
820  Raises:
821    InvalidArgumentError: if the check can be performed immediately and
822      `x != y` is False for any pair of elements in `x` and `y`. The check can
823      be performed immediately during eager execution or if `x` and `y` are
824      statically known.
825
826  @compatibility(numpy)
827  Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
828  type. This is due to the fact that `TensorFlow` is often used with `32bit`,
829  `64bit`, and even `16bit` data.
830  @end_compatibility
831  """
832  return assert_near(x=x, y=y, rtol=rtol, atol=atol, summarize=summarize,
833                     message=message, name=name)
834
835
836@tf_export(v1=['debugging.assert_near', 'assert_near'])
837@dispatch.register_binary_elementwise_assert_api
838@dispatch.add_dispatch_support
839@deprecation.deprecated_endpoints('assert_near')
840def assert_near(
841    x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
842    name=None):
843  """Assert the condition `x` and `y` are close element-wise.
844
845  Example of adding a dependency to an operation:
846
847  ```python
848  with tf.control_dependencies([tf.compat.v1.assert_near(x, y)]):
849    output = tf.reduce_sum(x)
850  ```
851
852  This condition holds if for every pair of (possibly broadcast) elements
853  `x[i]`, `y[i]`, we have
854
855  ```tf.abs(x[i] - y[i]) <= atol + rtol * tf.abs(y[i])```.
856
857  If both `x` and `y` are empty, this is trivially satisfied.
858
859  The default `atol` and `rtol` is `10 * eps`, where `eps` is the smallest
860  representable positive number such that `1 + eps != 1`.  This is about
861  `1.2e-6` in `32bit`, `2.22e-15` in `64bit`, and `0.00977` in `16bit`.
862  See `numpy.finfo`.
863
864  Args:
865    x:  Float or complex `Tensor`.
866    y:  Float or complex `Tensor`, same `dtype` as, and broadcastable to, `x`.
867    rtol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
868      The relative tolerance.  Default is `10 * eps`.
869    atol:  `Tensor`.  Same `dtype` as, and broadcastable to, `x`.
870      The absolute tolerance.  Default is `10 * eps`.
871    data:  The tensors to print out if the condition is False.  Defaults to
872      error message and first few entries of `x`, `y`.
873    summarize: Print this many entries of each tensor.
874    message: A string to prefix to the default message.
875    name: A name for this operation (optional).  Defaults to "assert_near".
876
877  Returns:
878    Op that raises `InvalidArgumentError` if `x` and `y` are not close enough.
879
880  @compatibility(numpy)
881  Similar to `numpy.testing.assert_allclose`, except tolerance depends on data
882  type. This is due to the fact that `TensorFlow` is often used with `32bit`,
883  `64bit`, and even `16bit` data.
884  @end_compatibility
885  """
886  message = _message_prefix(message)
887  with ops.name_scope(name, 'assert_near', [x, y, rtol, atol, data]):
888    x = ops.convert_to_tensor(x, name='x')
889    y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
890
891    dtype = x.dtype
892    if dtype.is_complex:
893      dtype = dtype.real_dtype
894    eps = np.finfo(dtype.as_numpy_dtype).eps
895    rtol = 10 * eps if rtol is None else rtol
896    atol = 10 * eps if atol is None else atol
897
898    rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
899    atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
900
901    if context.executing_eagerly():
902      x_name = _shape_and_dtype_str(x)
903      y_name = _shape_and_dtype_str(y)
904    else:
905      x_name = x.name
906      y_name = y.name
907
908    if data is None:
909      data = [
910          message,
911          'x and y not equal to tolerance rtol = %s, atol = %s' % (rtol, atol),
912          'x (%s) = ' % x_name, x, 'y (%s) = ' % y_name, y
913      ]
914    tol = atol + rtol * math_ops.abs(y)
915    diff = math_ops.abs(x - y)
916    condition = math_ops.reduce_all(math_ops.less(diff, tol))
917    return control_flow_ops.Assert(condition, data, summarize=summarize)
918
919
920@tf_export('debugging.assert_less', 'assert_less', v1=[])
921@dispatch.register_binary_elementwise_assert_api
922@dispatch.add_dispatch_support
923def assert_less_v2(x, y, message=None, summarize=None, name=None):
924  """Assert the condition `x < y` holds element-wise.
925
926  This Op checks that `x[i] < y[i]` holds for every pair of (possibly
927  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
928  trivially satisfied.
929
930  If `x` is not less than `y` element-wise, `message`, as well as the first
931  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
932  raised.
933
934  Args:
935    x:  Numeric `Tensor`.
936    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
937    message: A string to prefix to the default message.
938    summarize: Print this many entries of each tensor.
939    name: A name for this operation (optional).  Defaults to "assert_less".
940
941  Returns:
942    Op that raises `InvalidArgumentError` if `x < y` is False.
943    This can be used with `tf.control_dependencies` inside of `tf.function`s
944    to block followup computation until the check has executed.
945    @compatibility(eager)
946    returns None
947    @end_compatibility
948
949  Raises:
950    InvalidArgumentError: if the check can be performed immediately and
951      `x < y` is False. The check can be performed immediately during eager
952      execution or if `x` and `y` are statically known.
953  """
954  return assert_less(x=x, y=y, summarize=summarize, message=message, name=name)
955
956
957@tf_export(v1=['debugging.assert_less', 'assert_less'])
958@dispatch.register_binary_elementwise_assert_api
959@dispatch.add_dispatch_support
960@_binary_assert_doc('<', '[2, 3]')
961def assert_less(x, y, data=None, summarize=None, message=None, name=None):
962  return _binary_assert('<', 'assert_less', math_ops.less, np.less, x, y, data,
963                        summarize, message, name)
964
965
966@tf_export('debugging.assert_less_equal', v1=[])
967@dispatch.register_binary_elementwise_assert_api
968@dispatch.add_dispatch_support
969def assert_less_equal_v2(x, y, message=None, summarize=None, name=None):
970  """Assert the condition `x <= y` holds element-wise.
971
972  This Op checks that `x[i] <= y[i]` holds for every pair of (possibly
973  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
974  trivially satisfied.
975
976  If `x` is not less or equal than `y` element-wise, `message`, as well as the
977  first `summarize` entries of `x` and `y` are printed, and
978  `InvalidArgumentError` is raised.
979
980  Args:
981    x:  Numeric `Tensor`.
982    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
983    message: A string to prefix to the default message.
984    summarize: Print this many entries of each tensor.
985    name: A name for this operation (optional). Defaults to "assert_less_equal".
986
987  Returns:
988    Op that raises `InvalidArgumentError` if `x <= y` is False. This can be
989      used with `tf.control_dependencies` inside of `tf.function`s to block
990      followup computation until the check has executed.
991    @compatibility(eager)
992    returns None
993    @end_compatibility
994
995  Raises:
996    InvalidArgumentError: if the check can be performed immediately and
997      `x <= y` is False. The check can be performed immediately during eager
998      execution or if `x` and `y` are statically known.
999  """
1000  return assert_less_equal(x=x, y=y,
1001                           summarize=summarize, message=message, name=name)
1002
1003
1004@tf_export(v1=['debugging.assert_less_equal', 'assert_less_equal'])
1005@dispatch.register_binary_elementwise_assert_api
1006@dispatch.add_dispatch_support
1007@deprecation.deprecated_endpoints('assert_less_equal')
1008@_binary_assert_doc('<=', '[1, 3]')
1009def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
1010  return _binary_assert('<=', 'assert_less_equal', math_ops.less_equal,
1011                        np.less_equal, x, y, data, summarize, message, name)
1012
1013
1014@tf_export('debugging.assert_greater', 'assert_greater', v1=[])
1015@dispatch.register_binary_elementwise_assert_api
1016@dispatch.add_dispatch_support
1017def assert_greater_v2(x, y, message=None, summarize=None, name=None):
1018  """Assert the condition `x > y` holds element-wise.
1019
1020  This Op checks that `x[i] > y[i]` holds for every pair of (possibly
1021  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
1022  trivially satisfied.
1023
1024  If `x` is not greater than `y` element-wise, `message`, as well as the first
1025  `summarize` entries of `x` and `y` are printed, and `InvalidArgumentError` is
1026  raised.
1027
1028  Args:
1029    x:  Numeric `Tensor`.
1030    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
1031    message: A string to prefix to the default message.
1032    summarize: Print this many entries of each tensor.
1033    name: A name for this operation (optional).  Defaults to "assert_greater".
1034
1035  Returns:
1036    Op that raises `InvalidArgumentError` if `x > y` is False. This can be
1037      used with `tf.control_dependencies` inside of `tf.function`s to block
1038      followup computation until the check has executed.
1039    @compatibility(eager)
1040    returns None
1041    @end_compatibility
1042
1043  Raises:
1044    InvalidArgumentError: if the check can be performed immediately and
1045      `x > y` is False. The check can be performed immediately during eager
1046      execution or if `x` and `y` are statically known.
1047  """
1048  return assert_greater(x=x, y=y, summarize=summarize, message=message,
1049                        name=name)
1050
1051
1052@tf_export(v1=['debugging.assert_greater', 'assert_greater'])
1053@dispatch.register_binary_elementwise_assert_api
1054@dispatch.add_dispatch_support
1055@_binary_assert_doc('>', '[0, 1]')
1056def assert_greater(x, y, data=None, summarize=None, message=None, name=None):  # pylint: disable=missing-docstring
1057  return _binary_assert('>', 'assert_greater', math_ops.greater, np.greater, x,
1058                        y, data, summarize, message, name)
1059
1060
1061@tf_export('debugging.assert_greater_equal', v1=[])
1062@dispatch.register_binary_elementwise_assert_api
1063@dispatch.add_dispatch_support
1064def assert_greater_equal_v2(x, y, message=None, summarize=None, name=None):
1065  """Assert the condition `x >= y` holds element-wise.
1066
1067  This Op checks that `x[i] >= y[i]` holds for every pair of (possibly
1068  broadcast) elements of `x` and `y`. If both `x` and `y` are empty, this is
1069  trivially satisfied.
1070
1071  If `x` is not greater or equal to `y` element-wise, `message`, as well as the
1072  first `summarize` entries of `x` and `y` are printed, and
1073  `InvalidArgumentError` is raised.
1074
1075  Args:
1076    x:  Numeric `Tensor`.
1077    y:  Numeric `Tensor`, same dtype as and broadcastable to `x`.
1078    message: A string to prefix to the default message.
1079    summarize: Print this many entries of each tensor.
1080    name: A name for this operation (optional).  Defaults to
1081    "assert_greater_equal".
1082
1083  Returns:
1084    Op that raises `InvalidArgumentError` if `x >= y` is False. This can be
1085      used with `tf.control_dependencies` inside of `tf.function`s to block
1086      followup computation until the check has executed.
1087    @compatibility(eager)
1088    returns None
1089    @end_compatibility
1090
1091  Raises:
1092    InvalidArgumentError: if the check can be performed immediately and
1093      `x >= y` is False. The check can be performed immediately during eager
1094      execution or if `x` and `y` are statically known.
1095  """
1096  return assert_greater_equal(x=x, y=y, summarize=summarize, message=message,
1097                              name=name)
1098
1099
1100@tf_export(v1=['debugging.assert_greater_equal', 'assert_greater_equal'])
1101@dispatch.register_binary_elementwise_assert_api
1102@dispatch.add_dispatch_support
1103@deprecation.deprecated_endpoints('assert_greater_equal')
1104@_binary_assert_doc('>=', '[1, 0]')
1105def assert_greater_equal(x, y, data=None, summarize=None, message=None,
1106                         name=None):
1107  return _binary_assert('>=', 'assert_greater_equal', math_ops.greater_equal,
1108                        np.greater_equal, x, y, data, summarize, message, name)
1109
1110
1111def _assert_rank_condition(
1112    x, rank, static_condition, dynamic_condition, data, summarize):
1113  """Assert `x` has a rank that satisfies a given condition.
1114
1115  Args:
1116    x:  Numeric `Tensor`.
1117    rank:  Scalar `Tensor`.
1118    static_condition:   A python function that takes `[actual_rank, given_rank]`
1119      and returns `True` if the condition is satisfied, `False` otherwise.
1120    dynamic_condition:  An `op` that takes [actual_rank, given_rank] and return
1121      `True` if the condition is satisfied, `False` otherwise.
1122    data:  The tensors to print out if the condition is false.  Defaults to
1123      error message and first few entries of `x`.
1124    summarize: Print this many entries of each tensor.
1125
1126  Returns:
1127    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1128
1129  Raises:
1130    ValueError:  If static checks determine `x` fails static_condition.
1131  """
1132  assert_type(rank, dtypes.int32)
1133
1134  # Attempt to statically defined rank.
1135  rank_static = tensor_util.constant_value(rank)
1136  if rank_static is not None:
1137    if rank_static.ndim != 0:
1138      raise ValueError('Rank must be a scalar.')
1139
1140    x_rank_static = x.get_shape().ndims
1141    if x_rank_static is not None:
1142      if not static_condition(x_rank_static, rank_static):
1143        raise ValueError(
1144            'Static rank condition failed', x_rank_static, rank_static)
1145      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1146
1147  condition = dynamic_condition(array_ops.rank(x), rank)
1148
1149  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1150  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1151  if rank_static is None:
1152    this_data = ['Rank must be a scalar. Received rank: ', rank]
1153    rank_check = assert_rank(rank, 0, data=this_data)
1154    condition = control_flow_ops.with_dependencies([rank_check], condition)
1155
1156  return control_flow_ops.Assert(condition, data, summarize=summarize)
1157
1158
1159@tf_export('debugging.assert_rank', 'assert_rank', v1=[])
1160@dispatch.add_dispatch_support
1161def assert_rank_v2(x, rank, message=None, name=None):
1162  """Assert that `x` has rank equal to `rank`.
1163
1164  This Op checks that the rank of `x` is equal to `rank`.
1165
1166  If `x` has a different rank, `message`, as well as the shape of `x` are
1167  printed, and `InvalidArgumentError` is raised.
1168
1169  Args:
1170    x: `Tensor`.
1171    rank: Scalar integer `Tensor`.
1172    message: A string to prefix to the default message.
1173    name: A name for this operation (optional). Defaults to
1174      "assert_rank".
1175
1176  Returns:
1177    Op raising `InvalidArgumentError` unless `x` has specified rank.
1178    If static checks determine `x` has correct rank, a `no_op` is returned.
1179    This can be used with `tf.control_dependencies` inside of `tf.function`s
1180    to block followup computation until the check has executed.
1181    @compatibility(eager)
1182    returns None
1183    @end_compatibility
1184
1185  Raises:
1186    InvalidArgumentError: if the check can be performed immediately and
1187      `x` does not have rank `rank`. The check can be performed immediately
1188      during eager execution or if the shape of `x` is statically known.
1189  """
1190  return assert_rank(x=x, rank=rank, message=message, name=name)
1191
1192
1193@tf_export(v1=['debugging.assert_rank', 'assert_rank'])
1194@dispatch.add_dispatch_support
1195def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
1196  """Assert `x` has rank equal to `rank`.
1197
1198  Example of adding a dependency to an operation:
1199
1200  ```python
1201  with tf.control_dependencies([tf.compat.v1.assert_rank(x, 2)]):
1202    output = tf.reduce_sum(x)
1203  ```
1204
1205  Args:
1206    x:  Numeric `Tensor`.
1207    rank:  Scalar integer `Tensor`.
1208    data:  The tensors to print out if the condition is False.  Defaults to
1209      error message and the shape of `x`.
1210    summarize: Print this many entries of each tensor.
1211    message: A string to prefix to the default message.
1212    name: A name for this operation (optional).  Defaults to "assert_rank".
1213
1214  Returns:
1215    Op raising `InvalidArgumentError` unless `x` has specified rank.
1216    If static checks determine `x` has correct rank, a `no_op` is returned.
1217
1218  Raises:
1219    ValueError:  If static checks determine `x` has wrong rank.
1220  """
1221  with ops.name_scope(name, 'assert_rank', (x, rank) + tuple(data or [])):
1222    if not isinstance(x, sparse_tensor.SparseTensor):
1223      x = ops.convert_to_tensor(x, name='x')
1224    rank = ops.convert_to_tensor(rank, name='rank')
1225    message = _message_prefix(message)
1226
1227    static_condition = lambda actual_rank, given_rank: actual_rank == given_rank
1228    dynamic_condition = math_ops.equal
1229
1230    if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1231      name = ''
1232    else:
1233      name = x.name
1234
1235    if data is None:
1236      data = [
1237          message,
1238          'Tensor %s must have rank' % name, rank, 'Received shape: ',
1239          array_ops.shape(x)
1240      ]
1241
1242    try:
1243      assert_op = _assert_rank_condition(x, rank, static_condition,
1244                                         dynamic_condition, data, summarize)
1245
1246    except ValueError as e:
1247      if e.args[0] == 'Static rank condition failed':
1248        raise ValueError(
1249            '%sTensor %s must have rank %d.  Received rank %d, shape %s' %
1250            (message, name, e.args[2], e.args[1], x.get_shape()))
1251      else:
1252        raise ValueError(e.args[0])
1253
1254  return assert_op
1255
1256
1257@tf_export('debugging.assert_rank_at_least', v1=[])
1258@dispatch.add_dispatch_support
1259def assert_rank_at_least_v2(x, rank, message=None, name=None):
1260  """Assert that `x` has rank of at least `rank`.
1261
1262  This Op checks that the rank of `x` is greater or equal to `rank`.
1263
1264  If `x` has a rank lower than `rank`, `message`, as well as the shape of `x`
1265  are printed, and `InvalidArgumentError` is raised.
1266
1267  Args:
1268    x: `Tensor`.
1269    rank: Scalar integer `Tensor`.
1270    message: A string to prefix to the default message.
1271    name: A name for this operation (optional).  Defaults to
1272      "assert_rank_at_least".
1273
1274  Returns:
1275    Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1276    If static checks determine `x` has correct rank, a `no_op` is returned.
1277    This can be used with `tf.control_dependencies` inside of `tf.function`s
1278    to block followup computation until the check has executed.
1279    @compatibility(eager)
1280    returns None
1281    @end_compatibility
1282
1283  Raises:
1284    InvalidArgumentError: `x` does not have rank at least `rank`, but the rank
1285      cannot be statically determined.
1286    ValueError: If static checks determine `x` has mismatched rank.
1287  """
1288  return assert_rank_at_least(x=x, rank=rank, message=message, name=name)
1289
1290
1291@tf_export(v1=['debugging.assert_rank_at_least', 'assert_rank_at_least'])
1292@dispatch.add_dispatch_support
1293@deprecation.deprecated_endpoints('assert_rank_at_least')
1294def assert_rank_at_least(
1295    x, rank, data=None, summarize=None, message=None, name=None):
1296  """Assert `x` has rank equal to `rank` or higher.
1297
1298  Example of adding a dependency to an operation:
1299
1300  ```python
1301  with tf.control_dependencies([tf.compat.v1.assert_rank_at_least(x, 2)]):
1302    output = tf.reduce_sum(x)
1303  ```
1304
1305  Args:
1306    x:  Numeric `Tensor`.
1307    rank:  Scalar `Tensor`.
1308    data:  The tensors to print out if the condition is False.  Defaults to
1309      error message and first few entries of `x`.
1310    summarize: Print this many entries of each tensor.
1311    message: A string to prefix to the default message.
1312    name: A name for this operation (optional).
1313      Defaults to "assert_rank_at_least".
1314
1315  Returns:
1316    Op raising `InvalidArgumentError` unless `x` has specified rank or higher.
1317    If static checks determine `x` has correct rank, a `no_op` is returned.
1318
1319  Raises:
1320    ValueError:  If static checks determine `x` has wrong rank.
1321  """
1322  with ops.name_scope(
1323      name, 'assert_rank_at_least', (x, rank) + tuple(data or [])):
1324    x = ops.convert_to_tensor(x, name='x')
1325    rank = ops.convert_to_tensor(rank, name='rank')
1326    message = _message_prefix(message)
1327
1328    static_condition = lambda actual_rank, given_rank: actual_rank >= given_rank
1329    dynamic_condition = math_ops.greater_equal
1330
1331    if context.executing_eagerly():
1332      name = ''
1333    else:
1334      name = x.name
1335
1336    if data is None:
1337      data = [
1338          message,
1339          'Tensor %s must have rank at least' % name, rank,
1340          'Received shape: ', array_ops.shape(x)
1341      ]
1342
1343    try:
1344      assert_op = _assert_rank_condition(x, rank, static_condition,
1345                                         dynamic_condition, data, summarize)
1346
1347    except ValueError as e:
1348      if e.args[0] == 'Static rank condition failed':
1349        raise ValueError(
1350            '%sTensor %s must have rank at least %d.  Received rank %d, '
1351            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1352      else:
1353        raise
1354
1355  return assert_op
1356
1357
1358def _static_rank_in(actual_rank, given_ranks):
1359  return actual_rank in given_ranks
1360
1361
1362def _dynamic_rank_in(actual_rank, given_ranks):
1363  if len(given_ranks) < 1:
1364    return ops.convert_to_tensor(False)
1365  result = math_ops.equal(given_ranks[0], actual_rank)
1366  for given_rank in given_ranks[1:]:
1367    result = math_ops.logical_or(
1368        result, math_ops.equal(given_rank, actual_rank))
1369  return result
1370
1371
1372def _assert_ranks_condition(
1373    x, ranks, static_condition, dynamic_condition, data, summarize):
1374  """Assert `x` has a rank that satisfies a given condition.
1375
1376  Args:
1377    x:  Numeric `Tensor`.
1378    ranks:  Scalar `Tensor`.
1379    static_condition:   A python function that takes
1380      `[actual_rank, given_ranks]` and returns `True` if the condition is
1381      satisfied, `False` otherwise.
1382    dynamic_condition:  An `op` that takes [actual_rank, given_ranks]
1383      and return `True` if the condition is satisfied, `False` otherwise.
1384    data:  The tensors to print out if the condition is false.  Defaults to
1385      error message and first few entries of `x`.
1386    summarize: Print this many entries of each tensor.
1387
1388  Returns:
1389    Op raising `InvalidArgumentError` if `x` fails dynamic_condition.
1390
1391  Raises:
1392    ValueError:  If static checks determine `x` fails static_condition.
1393  """
1394  for rank in ranks:
1395    assert_type(rank, dtypes.int32)
1396
1397  # Attempt to statically defined rank.
1398  ranks_static = tuple([tensor_util.constant_value(rank) for rank in ranks])
1399  if not any(r is None for r in ranks_static):
1400    for rank_static in ranks_static:
1401      if rank_static.ndim != 0:
1402        raise ValueError('Rank must be a scalar.')
1403
1404    x_rank_static = x.get_shape().ndims
1405    if x_rank_static is not None:
1406      if not static_condition(x_rank_static, ranks_static):
1407        raise ValueError(
1408            'Static rank condition failed', x_rank_static, ranks_static)
1409      return control_flow_ops.no_op(name='static_checks_determined_all_ok')
1410
1411  condition = dynamic_condition(array_ops.rank(x), ranks)
1412
1413  # Add the condition that `rank` must have rank zero.  Prevents the bug where
1414  # someone does assert_rank(x, [n]), rather than assert_rank(x, n).
1415  for rank, rank_static in zip(ranks, ranks_static):
1416    if rank_static is None:
1417      this_data = ['Rank must be a scalar. Received rank: ', rank]
1418      rank_check = assert_rank(rank, 0, data=this_data)
1419      condition = control_flow_ops.with_dependencies([rank_check], condition)
1420
1421  return control_flow_ops.Assert(condition, data, summarize=summarize)
1422
1423
1424@tf_export('debugging.assert_rank_in', v1=[])
1425@dispatch.add_dispatch_support
1426def assert_rank_in_v2(x, ranks, message=None, name=None):
1427  """Assert that `x` has a rank in `ranks`.
1428
1429  This Op checks that the rank of `x` is in `ranks`.
1430
1431  If `x` has a different rank, `message`, as well as the shape of `x` are
1432  printed, and `InvalidArgumentError` is raised.
1433
1434  Args:
1435    x: `Tensor`.
1436    ranks: `Iterable` of scalar `Tensor` objects.
1437    message: A string to prefix to the default message.
1438    name: A name for this operation (optional). Defaults to "assert_rank_in".
1439
1440  Returns:
1441    Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1442    If static checks determine `x` has matching rank, a `no_op` is returned.
1443    This can be used with `tf.control_dependencies` inside of `tf.function`s
1444    to block followup computation until the check has executed.
1445    @compatibility(eager)
1446    returns None
1447    @end_compatibility
1448
1449  Raises:
1450    InvalidArgumentError: `x` does not have rank in `ranks`, but the rank cannot
1451      be statically determined.
1452    ValueError: If static checks determine `x` has mismatched rank.
1453  """
1454  return assert_rank_in(x=x, ranks=ranks, message=message, name=name)
1455
1456
1457@tf_export(v1=['debugging.assert_rank_in', 'assert_rank_in'])
1458@dispatch.add_dispatch_support
1459@deprecation.deprecated_endpoints('assert_rank_in')
1460def assert_rank_in(
1461    x, ranks, data=None, summarize=None, message=None, name=None):
1462  """Assert `x` has rank in `ranks`.
1463
1464  Example of adding a dependency to an operation:
1465
1466  ```python
1467  with tf.control_dependencies([tf.compat.v1.assert_rank_in(x, (2, 4))]):
1468    output = tf.reduce_sum(x)
1469  ```
1470
1471  Args:
1472    x:  Numeric `Tensor`.
1473    ranks:  Iterable of scalar `Tensor` objects.
1474    data:  The tensors to print out if the condition is False.  Defaults to
1475      error message and first few entries of `x`.
1476    summarize: Print this many entries of each tensor.
1477    message: A string to prefix to the default message.
1478    name: A name for this operation (optional).
1479      Defaults to "assert_rank_in".
1480
1481  Returns:
1482    Op raising `InvalidArgumentError` unless rank of `x` is in `ranks`.
1483    If static checks determine `x` has matching rank, a `no_op` is returned.
1484
1485  Raises:
1486    ValueError:  If static checks determine `x` has mismatched rank.
1487  """
1488  with ops.name_scope(
1489      name, 'assert_rank_in', (x,) + tuple(ranks) + tuple(data or [])):
1490    if not isinstance(x, sparse_tensor.SparseTensor):
1491      x = ops.convert_to_tensor(x, name='x')
1492    ranks = tuple([ops.convert_to_tensor(rank, name='rank') for rank in ranks])
1493    message = _message_prefix(message)
1494
1495    if context.executing_eagerly() or isinstance(x, sparse_tensor.SparseTensor):
1496      name = ''
1497    else:
1498      name = x.name
1499
1500    if data is None:
1501      data = [
1502          message, 'Tensor %s must have rank in' % name
1503      ] + list(ranks) + [
1504          'Received shape: ', array_ops.shape(x)
1505      ]
1506
1507    try:
1508      assert_op = _assert_ranks_condition(x, ranks, _static_rank_in,
1509                                          _dynamic_rank_in, data, summarize)
1510
1511    except ValueError as e:
1512      if e.args[0] == 'Static rank condition failed':
1513        raise ValueError(
1514            '%sTensor %s must have rank in %s.  Received rank %d, '
1515            'shape %s' % (message, name, e.args[2], e.args[1], x.get_shape()))
1516      else:
1517        raise
1518
1519  return assert_op
1520
1521
1522@tf_export('debugging.assert_integer', v1=[])
1523@dispatch.add_dispatch_support
1524def assert_integer_v2(x, message=None, name=None):
1525  """Assert that `x` is of integer dtype.
1526
1527  If `x` has a non-integer type, `message`, as well as the dtype of `x` are
1528  printed, and `InvalidArgumentError` is raised.
1529
1530  This can always be checked statically, so this method returns nothing.
1531
1532  Args:
1533    x: A `Tensor`.
1534    message: A string to prefix to the default message.
1535    name: A name for this operation (optional). Defaults to "assert_integer".
1536
1537  Raises:
1538    TypeError:  If `x.dtype` is not a non-quantized integer type.
1539  """
1540  assert_integer(x=x, message=message, name=name)
1541
1542
1543@tf_export(v1=['debugging.assert_integer', 'assert_integer'])
1544@dispatch.add_dispatch_support
1545@deprecation.deprecated_endpoints('assert_integer')
1546def assert_integer(x, message=None, name=None):
1547  """Assert that `x` is of integer dtype.
1548
1549  Example of adding a dependency to an operation:
1550
1551  ```python
1552  with tf.control_dependencies([tf.compat.v1.assert_integer(x)]):
1553    output = tf.reduce_sum(x)
1554  ```
1555
1556  Args:
1557    x: `Tensor` whose basetype is integer and is not quantized.
1558    message: A string to prefix to the default message.
1559    name: A name for this operation (optional).  Defaults to "assert_integer".
1560
1561  Raises:
1562    TypeError:  If `x.dtype` is anything other than non-quantized integer.
1563
1564  Returns:
1565    A `no_op` that does nothing.  Type can be determined statically.
1566  """
1567  with ops.name_scope(name, 'assert_integer', [x]):
1568    x = ops.convert_to_tensor(x, name='x')
1569    if not x.dtype.is_integer:
1570      if context.executing_eagerly():
1571        name = 'tensor'
1572      else:
1573        name = x.name
1574      err_msg = (
1575          '%sExpected "x" to be integer type.  Found: %s of dtype %s'
1576          % (_message_prefix(message), name, x.dtype))
1577      raise TypeError(err_msg)
1578
1579    return control_flow_ops.no_op('statically_determined_was_integer')
1580
1581
1582@tf_export('debugging.assert_type', v1=[])
1583@dispatch.add_dispatch_support
1584def assert_type_v2(tensor, tf_type, message=None, name=None):
1585  """Asserts that the given `Tensor` is of the specified type.
1586
1587  This can always be checked statically, so this method returns nothing.
1588
1589  Example:
1590
1591  >>> a = tf.Variable(1.0)
1592  >>> tf.debugging.assert_type(a, tf_type= tf.float32)
1593
1594  >>> b = tf.constant(21)
1595  >>> tf.debugging.assert_type(b, tf_type=tf.bool)
1596  Traceback (most recent call last):
1597  ...
1598  TypeError: ...
1599
1600  >>> c = tf.SparseTensor(indices=[[0, 0], [1, 2]], values=[1, 2],
1601  ...  dense_shape=[3, 4])
1602  >>> tf.debugging.assert_type(c, tf_type= tf.int32)
1603
1604  Args:
1605    tensor: A `Tensor`, `SparseTensor` or `tf.Variable` .
1606    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1607      etc).
1608    message: A string to prefix to the default message.
1609    name:  A name for this operation. Defaults to "assert_type"
1610
1611  Raises:
1612    TypeError: If the tensor's data type doesn't match `tf_type`.
1613  """
1614  assert_type(tensor=tensor, tf_type=tf_type, message=message, name=name)
1615
1616
1617@tf_export(v1=['debugging.assert_type', 'assert_type'])
1618@dispatch.add_dispatch_support
1619@deprecation.deprecated_endpoints('assert_type')
1620def assert_type(tensor, tf_type, message=None, name=None):
1621  """Statically asserts that the given `Tensor` is of the specified type.
1622
1623  Args:
1624    tensor: A `Tensor` or `SparseTensor`.
1625    tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`,
1626      etc).
1627    message: A string to prefix to the default message.
1628    name:  A name to give this `Op`.  Defaults to "assert_type"
1629
1630  Raises:
1631    TypeError: If the tensors data type doesn't match `tf_type`.
1632
1633  Returns:
1634    A `no_op` that does nothing.  Type can be determined statically.
1635  """
1636  tf_type = dtypes.as_dtype(tf_type)
1637  with ops.name_scope(name, 'assert_type', [tensor]):
1638    if not isinstance(tensor, sparse_tensor.SparseTensor):
1639      tensor = ops.convert_to_tensor(tensor, name='tensor')
1640    if tensor.dtype != tf_type:
1641      raise TypeError(
1642          f'{_message_prefix(message)}{getattr(tensor, "name", "tensor")}'
1643          f' must be of type {tf_type!r}; got {tensor.dtype!r}')
1644
1645    return control_flow_ops.no_op('statically_determined_correct_type')
1646
1647
1648def _dimension_sizes(x):
1649  """Gets the dimension sizes of a tensor `x`.
1650
1651  If a size can be determined statically it is returned as an integer,
1652  otherwise as a tensor.
1653
1654  If `x` is a scalar it is treated as rank 1 size 1.
1655
1656  Args:
1657    x: A `Tensor`.
1658
1659  Returns:
1660    Dimension sizes.
1661  """
1662  dynamic_shape = array_ops.shape(x)
1663  rank = x.get_shape().rank
1664  rank_is_known = rank is not None
1665  if rank_is_known and rank == 0:
1666    return (1,)
1667  if rank_is_known and rank > 0:
1668    static_shape = x.get_shape().as_list()
1669    sizes = [
1670        int(size) if size is not None else dynamic_shape[i]
1671        for i, size in enumerate(static_shape)
1672    ]
1673    return sizes
1674  has_rank_zero = math_ops.equal(array_ops.rank(x), 0)
1675  return control_flow_ops.cond(
1676      has_rank_zero, lambda: array_ops.constant([1]), lambda: dynamic_shape)
1677
1678
1679def _symbolic_dimension_sizes(symbolic_shape):
1680  # If len(symbolic_shape) == 0 construct a tuple
1681  if not symbolic_shape:
1682    return tuple([1])
1683
1684  return symbolic_shape
1685
1686
1687def _has_known_value(dimension_size):
1688  not_none = dimension_size is not None
1689  try:
1690    int(dimension_size)
1691    can_be_parsed_as_int = True
1692  except (ValueError, TypeError):
1693    can_be_parsed_as_int = False
1694  return not_none and can_be_parsed_as_int
1695
1696
1697def _is_symbol_for_any_size(symbol):
1698  return symbol in [None, '.']
1699
1700
1701_TensorDimSizes = collections.namedtuple(
1702    '_TensorDimSizes',
1703    ['x', 'unspecified_dim', 'actual_sizes', 'symbolic_sizes'])
1704
1705
1706@tf_export('debugging.assert_shapes', v1=[])
1707@dispatch.add_dispatch_support
1708def assert_shapes_v2(shapes, data=None, summarize=None, message=None,
1709                     name=None):
1710  """Assert tensor shapes and dimension size relationships between tensors.
1711
1712  This Op checks that a collection of tensors shape relationships
1713  satisfies given constraints.
1714
1715  Example:
1716
1717  >>> n = 10
1718  >>> q = 3
1719  >>> d = 7
1720  >>> x = tf.zeros([n,q])
1721  >>> y = tf.ones([n,d])
1722  >>> param = tf.Variable([1.0, 2.0, 3.0])
1723  >>> scalar = 1.0
1724  >>> tf.debugging.assert_shapes([
1725  ...  (x, ('N', 'Q')),
1726  ...  (y, ('N', 'D')),
1727  ...  (param, ('Q',)),
1728  ...  (scalar, ()),
1729  ... ])
1730
1731  >>> tf.debugging.assert_shapes([
1732  ...   (x, ('N', 'D')),
1733  ...   (y, ('N', 'D'))
1734  ... ])
1735  Traceback (most recent call last):
1736  ...
1737  ValueError: ...
1738
1739  If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1740  all specified constraints, `message`, as well as the first `summarize` entries
1741  of the first encountered violating tensor are printed, and
1742  `InvalidArgumentError` is raised.
1743
1744  Size entries in the specified shapes are checked against other entries by
1745  their __hash__, except:
1746    - a size entry is interpreted as an explicit size if it can be parsed as an
1747      integer primitive.
1748    - a size entry is interpreted as *any* size if it is None or '.'.
1749
1750  If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1751  a variable number of outer dimensions of unspecified size, i.e. the constraint
1752  applies to the inner-most dimensions only.
1753
1754  Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1755  prefix) are both treated as having a single dimension of size one.
1756
1757  Args:
1758    shapes: dictionary with (`Tensor` to shape) items, or a list of
1759      (`Tensor`, shape) tuples. A shape must be an iterable.
1760    data: The tensors to print out if the condition is False.  Defaults to error
1761      message and first few entries of the violating tensor.
1762    summarize: Print this many entries of the tensor.
1763    message: A string to prefix to the default message.
1764    name: A name for this operation (optional).  Defaults to "assert_shapes".
1765
1766  Raises:
1767    ValueError:  If static checks determine any shape constraint is violated.
1768  """
1769  assert_shapes(
1770      shapes, data=data, summarize=summarize, message=message, name=name)
1771
1772
1773@tf_export(v1=['debugging.assert_shapes'])
1774@dispatch.add_dispatch_support
1775def assert_shapes(shapes, data=None, summarize=None, message=None, name=None):
1776  """Assert tensor shapes and dimension size relationships between tensors.
1777
1778  This Op checks that a collection of tensors shape relationships
1779  satisfies given constraints.
1780
1781  Example:
1782
1783  >>> n = 10
1784  >>> q = 3
1785  >>> d = 7
1786  >>> x = tf.zeros([n,q])
1787  >>> y = tf.ones([n,d])
1788  >>> param = tf.Variable([1.0, 2.0, 3.0])
1789  >>> scalar = 1.0
1790  >>> tf.debugging.assert_shapes([
1791  ...  (x, ('N', 'Q')),
1792  ...  (y, ('N', 'D')),
1793  ...  (param, ('Q',)),
1794  ...  (scalar, ()),
1795  ... ])
1796
1797  >>> tf.debugging.assert_shapes([
1798  ...   (x, ('N', 'D')),
1799  ...   (y, ('N', 'D'))
1800  ... ])
1801  Traceback (most recent call last):
1802  ...
1803  ValueError: ...
1804
1805  Example of adding a dependency to an operation:
1806
1807  ```python
1808  with tf.control_dependencies([tf.assert_shapes(shapes)]):
1809    output = tf.matmul(x, y, transpose_a=True)
1810  ```
1811
1812  If `x`, `y`, `param` or `scalar` does not have a shape that satisfies
1813  all specified constraints, `message`, as well as the first `summarize` entries
1814  of the first encountered violating tensor are printed, and
1815  `InvalidArgumentError` is raised.
1816
1817  Size entries in the specified shapes are checked against other entries by
1818  their __hash__, except:
1819    - a size entry is interpreted as an explicit size if it can be parsed as an
1820      integer primitive.
1821    - a size entry is interpreted as *any* size if it is None or '.'.
1822
1823  If the first entry of a shape is `...` (type `Ellipsis`) or '*' that indicates
1824  a variable number of outer dimensions of unspecified size, i.e. the constraint
1825  applies to the inner-most dimensions only.
1826
1827  Scalar tensors and specified shapes of length zero (excluding the 'inner-most'
1828  prefix) are both treated as having a single dimension of size one.
1829
1830  Args:
1831    shapes: A list of (`Tensor`, `shape`) tuples, wherein `shape` is the
1832      expected shape of `Tensor`. See the example code above. The `shape` must
1833      be an iterable. Each element of the iterable can be either a concrete
1834      integer value or a string that abstractly represents the dimension.
1835      For example,
1836        - `('N', 'Q')` specifies a 2D shape wherein the first and second
1837          dimensions of shape may or may not be equal.
1838        - `('N', 'N', 'Q')` specifies a 3D shape wherein the first and second
1839          dimensions are equal.
1840        - `(1, 'N')` specifies a 2D shape wherein the first dimension is
1841          exactly 1 and the second dimension can be any value.
1842      Note that the abstract dimension letters take effect across different
1843      tuple elements of the list. For example,
1844      `tf.debugging.assert_shapes([(x, ('N', 'A')), (y, ('N', 'B'))]` asserts
1845      that both `x` and `y` are rank-2 tensors and their first dimensions are
1846      equal (`N`).
1847      `shape` can also be a `tf.TensorShape`.
1848    data: The tensors to print out if the condition is False.  Defaults to error
1849      message and first few entries of the violating tensor.
1850    summarize: Print this many entries of the tensor.
1851    message: A string to prefix to the default message.
1852    name: A name for this operation (optional).  Defaults to "assert_shapes".
1853
1854  Returns:
1855    Op raising `InvalidArgumentError` unless all shape constraints are
1856    satisfied.
1857    If static checks determine all constraints are satisfied, a `no_op` is
1858    returned.
1859
1860  Raises:
1861    ValueError:  If static checks determine any shape constraint is violated.
1862  """
1863  # If the user manages to assemble a dict containing tensors (possible in
1864  # Graph mode only), make sure we still accept that.
1865  if isinstance(shapes, dict):
1866    shapes = shapes.items()
1867
1868  message_prefix = _message_prefix(message)
1869  with ops.name_scope(name, 'assert_shapes', [shapes, data]):
1870    # Shape specified as None implies no constraint
1871    shape_constraints = [(x if isinstance(x, sparse_tensor.SparseTensor) else
1872                          ops.convert_to_tensor(x), s)
1873                         for x, s in shapes if s is not None]
1874
1875    executing_eagerly = context.executing_eagerly()
1876
1877    def tensor_name(x):
1878      if executing_eagerly or isinstance(x, sparse_tensor.SparseTensor):
1879        return _shape_and_dtype_str(x)
1880      return x.name
1881
1882    tensor_dim_sizes = []
1883    for tensor, symbolic_shape in shape_constraints:
1884      is_iterable = (
1885          hasattr(symbolic_shape, '__iter__') or
1886          hasattr(symbolic_shape, '__getitem__')  # For Python 2 compat.
1887      )
1888      if not is_iterable:
1889        raise ValueError(
1890            '%s'
1891            'Tensor %s.  Specified shape must be an iterable.  '
1892            'An iterable has the attribute `__iter__` or `__getitem__`.  '
1893            'Received specified shape: %s' %
1894            (message_prefix, tensor_name(tensor), symbolic_shape))
1895
1896      # We convert this into a tuple to handle strings, lists and numpy arrays
1897      symbolic_shape_tuple = tuple(symbolic_shape)
1898
1899      tensors_specified_innermost = False
1900      for i, symbol in enumerate(symbolic_shape_tuple):
1901        if symbol not in [Ellipsis, '*']:
1902          continue
1903
1904        if i != 0:
1905          raise ValueError(
1906              '%s'
1907              'Tensor %s specified shape index %d.  '
1908              'Symbol `...` or `*` for a variable number of '
1909              'unspecified dimensions is only allowed as the first entry' %
1910              (message_prefix, tensor_name(tensor), i))
1911
1912        tensors_specified_innermost = True
1913
1914      # Only include the size of the specified dimensions since the 0th symbol
1915      # is either ellipsis or *
1916      tensor_dim_sizes.append(
1917          _TensorDimSizes(
1918              tensor, tensors_specified_innermost, _dimension_sizes(tensor),
1919              _symbolic_dimension_sizes(
1920                  symbolic_shape_tuple[1:]
1921                  if tensors_specified_innermost else symbolic_shape_tuple)))
1922
1923    rank_assertions = []
1924    for sizes in tensor_dim_sizes:
1925      rank = len(sizes.symbolic_sizes)
1926      rank_zero_or_one = rank in [0, 1]
1927      if sizes.unspecified_dim:
1928        if rank_zero_or_one:
1929          # No assertion of rank needed as `x` only need to have rank at least
1930          # 0. See elif rank_zero_or_one case comment.
1931          continue
1932        assertion = assert_rank_at_least(
1933            x=sizes.x,
1934            rank=rank,
1935            data=data,
1936            summarize=summarize,
1937            message=message,
1938            name=name)
1939      elif rank_zero_or_one:
1940        # Rank 0 is treated as rank 1 size 1, i.e. there is
1941        # no distinction between the two in terms of rank.
1942        # See _dimension_sizes.
1943        assertion = assert_rank_in(
1944            x=sizes.x,
1945            ranks=[0, 1],
1946            data=data,
1947            summarize=summarize,
1948            message=message,
1949            name=name)
1950      else:
1951        assertion = assert_rank(
1952            x=sizes.x,
1953            rank=rank,
1954            data=data,
1955            summarize=summarize,
1956            message=message,
1957            name=name)
1958      rank_assertions.append(assertion)
1959
1960    size_assertions = []
1961    size_specifications = {}
1962    for sizes in tensor_dim_sizes:
1963      for i, size_symbol in enumerate(sizes.symbolic_sizes):
1964
1965        if _is_symbol_for_any_size(size_symbol):
1966          # Size specified as any implies no constraint
1967          continue
1968
1969        if sizes.unspecified_dim:
1970          tensor_dim = i - len(sizes.symbolic_sizes)
1971        else:
1972          tensor_dim = i
1973
1974        if size_symbol in size_specifications or _has_known_value(size_symbol):
1975          if _has_known_value(size_symbol):
1976            specified_size = int(size_symbol)
1977            size_check_message = 'Specified explicitly'
1978          else:
1979            specified_size, specified_by_y, specified_at_dim = (
1980                size_specifications[size_symbol])
1981            size_check_message = (
1982                'Specified by tensor %s dimension %d' %
1983                (tensor_name(specified_by_y), specified_at_dim))
1984
1985          # This is extremely subtle. If actual_sizes is dynamic, we must
1986          # make sure a control dependency is inserted here so that this slice
1987          # can not execute until the rank is asserted to be enough for the
1988          # slice to not fail.
1989          with ops.control_dependencies(rank_assertions):
1990            actual_size = sizes.actual_sizes[tensor_dim]
1991          if _has_known_value(actual_size) and _has_known_value(specified_size):
1992            if int(actual_size) != int(specified_size):
1993              raise ValueError(
1994                  '%s%s.  Tensor %s dimension %s must have size %d.  '
1995                  'Received size %d, shape %s' %
1996                  (message_prefix, size_check_message, tensor_name(sizes.x),
1997                   tensor_dim, specified_size, actual_size,
1998                   sizes.x.get_shape()))
1999            # No dynamic assertion needed
2000            continue
2001
2002          condition = math_ops.equal(
2003              ops.convert_to_tensor(actual_size),
2004              ops.convert_to_tensor(specified_size))
2005          data_ = data
2006          if data is None:
2007            data_ = [
2008                message_prefix, size_check_message,
2009                'Tensor %s dimension' % tensor_name(sizes.x), tensor_dim,
2010                'must have size', specified_size, 'Received shape: ',
2011                array_ops.shape(sizes.x)
2012            ]
2013          size_assertions.append(
2014              control_flow_ops.Assert(condition, data_, summarize=summarize))
2015        else:
2016          # Not sure if actual_sizes is a constant, but for safety, guard
2017          # on rank. See explanation above about actual_sizes need for safety.
2018          with ops.control_dependencies(rank_assertions):
2019            size = sizes.actual_sizes[tensor_dim]
2020          size_specifications[size_symbol] = (size, sizes.x, tensor_dim)
2021
2022  # Ensure both assertions actually occur.
2023  with ops.control_dependencies(rank_assertions):
2024    shapes_assertion = control_flow_ops.group(size_assertions)
2025
2026  return shapes_assertion
2027
2028
2029# pylint: disable=line-too-long
2030def _get_diff_for_monotonic_comparison(x):
2031  """Gets the difference x[1:] - x[:-1]."""
2032  x = array_ops.reshape(x, [-1])
2033  if not is_numeric_tensor(x):
2034    raise TypeError('Expected x to be numeric, instead found: %s' % x)
2035
2036  # If x has less than 2 elements, there is nothing to compare.  So return [].
2037  is_shorter_than_two = math_ops.less(array_ops.size(x), 2)
2038  short_result = lambda: ops.convert_to_tensor([], dtype=x.dtype)
2039
2040  # With 2 or more elements, return x[1:] - x[:-1]
2041  s_len = array_ops.shape(x) - 1
2042  diff = lambda: array_ops.strided_slice(x, [1], [1] + s_len)- array_ops.strided_slice(x, [0], s_len)
2043  return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
2044
2045
2046@tf_export(
2047    'debugging.is_numeric_tensor',
2048    v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
2049@deprecation.deprecated_endpoints('is_numeric_tensor')
2050def is_numeric_tensor(tensor):
2051  """Returns `True` if the elements of `tensor` are numbers.
2052
2053  Specifically, returns `True` if the dtype of `tensor` is one of the following:
2054
2055  * `tf.float16`
2056  * `tf.float32`
2057  * `tf.float64`
2058  * `tf.int8`
2059  * `tf.int16`
2060  * `tf.int32`
2061  * `tf.int64`
2062  * `tf.uint8`
2063  * `tf.uint16`
2064  * `tf.uint32`
2065  * `tf.uint64`
2066  * `tf.qint8`
2067  * `tf.qint16`
2068  * `tf.qint32`
2069  * `tf.quint8`
2070  * `tf.quint16`
2071  * `tf.complex64`
2072  * `tf.complex128`
2073  * `tf.bfloat16`
2074
2075  Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
2076  a `tf.Tensor` object.
2077  """
2078  return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
2079
2080
2081@tf_export(
2082    'math.is_non_decreasing',
2083    v1=[
2084        'math.is_non_decreasing', 'debugging.is_non_decreasing',
2085        'is_non_decreasing'
2086    ])
2087@dispatch.add_dispatch_support
2088@deprecation.deprecated_endpoints('debugging.is_non_decreasing',
2089                                  'is_non_decreasing')
2090def is_non_decreasing(x, name=None):
2091  """Returns `True` if `x` is non-decreasing.
2092
2093  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
2094  is non-decreasing if for every adjacent pair we have `x[i] <= x[i+1]`.
2095  If `x` has less than two elements, it is trivially non-decreasing.
2096
2097  See also:  `is_strictly_increasing`
2098
2099  >>> x1 = tf.constant([1.0, 1.0, 3.0])
2100  >>> tf.math.is_non_decreasing(x1)
2101  <tf.Tensor: shape=(), dtype=bool, numpy=True>
2102  >>> x2 = tf.constant([3.0, 1.0, 2.0])
2103  >>> tf.math.is_non_decreasing(x2)
2104  <tf.Tensor: shape=(), dtype=bool, numpy=False>
2105
2106  Args:
2107    x: Numeric `Tensor`.
2108    name: A name for this operation (optional).  Defaults to "is_non_decreasing"
2109
2110  Returns:
2111    Boolean `Tensor`, equal to `True` iff `x` is non-decreasing.
2112
2113  Raises:
2114    TypeError: if `x` is not a numeric tensor.
2115  """
2116  with ops.name_scope(name, 'is_non_decreasing', [x]):
2117    diff = _get_diff_for_monotonic_comparison(x)
2118    # When len(x) = 1, diff = [], less_equal = [], and reduce_all([]) = True.
2119    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2120    return math_ops.reduce_all(math_ops.less_equal(zero, diff))
2121
2122
2123@tf_export(
2124    'math.is_strictly_increasing',
2125    v1=[
2126        'math.is_strictly_increasing', 'debugging.is_strictly_increasing',
2127        'is_strictly_increasing'
2128    ])
2129@dispatch.add_dispatch_support
2130@deprecation.deprecated_endpoints('debugging.is_strictly_increasing',
2131                                  'is_strictly_increasing')
2132def is_strictly_increasing(x, name=None):
2133  """Returns `True` if `x` is strictly increasing.
2134
2135  Elements of `x` are compared in row-major order.  The tensor `[x[0],...]`
2136  is strictly increasing if for every adjacent pair we have `x[i] < x[i+1]`.
2137  If `x` has less than two elements, it is trivially strictly increasing.
2138
2139  See also:  `is_non_decreasing`
2140
2141  >>> x1 = tf.constant([1.0, 2.0, 3.0])
2142  >>> tf.math.is_strictly_increasing(x1)
2143  <tf.Tensor: shape=(), dtype=bool, numpy=True>
2144  >>> x2 = tf.constant([3.0, 1.0, 2.0])
2145  >>> tf.math.is_strictly_increasing(x2)
2146  <tf.Tensor: shape=(), dtype=bool, numpy=False>
2147
2148  Args:
2149    x: Numeric `Tensor`.
2150    name: A name for this operation (optional).
2151      Defaults to "is_strictly_increasing"
2152
2153  Returns:
2154    Boolean `Tensor`, equal to `True` iff `x` is strictly increasing.
2155
2156  Raises:
2157    TypeError: if `x` is not a numeric tensor.
2158  """
2159  with ops.name_scope(name, 'is_strictly_increasing', [x]):
2160    diff = _get_diff_for_monotonic_comparison(x)
2161    # When len(x) = 1, diff = [], less = [], and reduce_all([]) = True.
2162    zero = ops.convert_to_tensor(0, dtype=diff.dtype)
2163    return math_ops.reduce_all(math_ops.less(zero, diff))
2164
2165
2166def _assert_same_base_type(items, expected_type=None):
2167  r"""Asserts all items are of the same base type.
2168
2169  Args:
2170    items: List of graph items (e.g., `Variable`, `Tensor`, `SparseTensor`,
2171        `Operation`, or `IndexedSlices`). Can include `None` elements, which
2172        will be ignored.
2173    expected_type: Expected type. If not specified, assert all items are
2174        of the same base type.
2175
2176  Returns:
2177    Validated type, or none if neither expected_type nor items provided.
2178
2179  Raises:
2180    ValueError: If any types do not match.
2181  """
2182  original_expected_type = expected_type
2183  mismatch = False
2184  for item in items:
2185    if item is not None:
2186      item_type = item.dtype.base_dtype
2187      if not expected_type:
2188        expected_type = item_type
2189      elif expected_type != item_type:
2190        mismatch = True
2191        break
2192  if mismatch:
2193    # Loop back through and build up an informative error message (this is very
2194    # slow, so we don't do it unless we found an error above).
2195    expected_type = original_expected_type
2196    original_item_str = None
2197    for item in items:
2198      if item is not None:
2199        item_type = item.dtype.base_dtype
2200        if not expected_type:
2201          expected_type = item_type
2202          original_item_str = item.name if hasattr(item, 'name') else str(item)
2203        elif expected_type != item_type:
2204          raise ValueError('%s, type=%s, must be of the same type (%s)%s.' % (
2205              item.name if hasattr(item, 'name') else str(item),
2206              item_type, expected_type,
2207              (' as %s' % original_item_str) if original_item_str else ''))
2208    return expected_type  # Should be unreachable
2209  else:
2210    return expected_type
2211
2212
2213@tf_export(
2214    'debugging.assert_same_float_dtype',
2215    v1=['debugging.assert_same_float_dtype', 'assert_same_float_dtype'])
2216@dispatch.add_dispatch_support
2217@deprecation.deprecated_endpoints('assert_same_float_dtype')
2218def assert_same_float_dtype(tensors=None, dtype=None):
2219  """Validate and return float type based on `tensors` and `dtype`.
2220
2221  For ops such as matrix multiplication, inputs and weights must be of the
2222  same float type. This function validates that all `tensors` are the same type,
2223  validates that type is `dtype` (if supplied), and returns the type. Type must
2224  be a floating point type. If neither `tensors` nor `dtype` is supplied,
2225  the function will return `dtypes.float32`.
2226
2227  Args:
2228    tensors: Tensors of input values. Can include `None` elements, which will be
2229        ignored.
2230    dtype: Expected type.
2231
2232  Returns:
2233    Validated type.
2234
2235  Raises:
2236    ValueError: if neither `tensors` nor `dtype` is supplied, or result is not
2237        float, or the common type of the inputs is not a floating point type.
2238  """
2239  if tensors:
2240    dtype = _assert_same_base_type(tensors, dtype)
2241  if not dtype:
2242    dtype = dtypes.float32
2243  elif not dtype.is_floating:
2244    raise ValueError('Expected floating point type, got %s.' % dtype)
2245  return dtype
2246
2247
2248@tf_export('debugging.assert_scalar', v1=[])
2249@dispatch.add_dispatch_support
2250def assert_scalar_v2(tensor, message=None, name=None):
2251  """Asserts that the given `tensor` is a scalar.
2252
2253  This function raises `ValueError` unless it can be certain that the given
2254  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2255  unknown.
2256
2257  This is always checked statically, so this method returns nothing.
2258
2259  Args:
2260    tensor: A `Tensor`.
2261    message: A string to prefix to the default message.
2262    name:  A name for this operation. Defaults to "assert_scalar"
2263
2264  Raises:
2265    ValueError: If the tensor is not scalar (rank 0), or if its shape is
2266      unknown.
2267  """
2268  assert_scalar(tensor=tensor, message=message, name=name)
2269
2270
2271@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
2272@dispatch.add_dispatch_support
2273@deprecation.deprecated_endpoints('assert_scalar')
2274def assert_scalar(tensor, name=None, message=None):
2275  """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
2276
2277  This function raises `ValueError` unless it can be certain that the given
2278  `tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
2279  unknown.
2280
2281  Args:
2282    tensor: A `Tensor`.
2283    name:  A name for this operation. Defaults to "assert_scalar"
2284    message: A string to prefix to the default message.
2285
2286  Returns:
2287    The input tensor (potentially converted to a `Tensor`).
2288
2289  Raises:
2290    ValueError: If the tensor is not scalar (rank 0), or if its shape is
2291      unknown.
2292  """
2293  with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
2294    tensor = ops.convert_to_tensor(tensor, name=name_scope)
2295    shape = tensor.get_shape()
2296    message = _message_prefix(message)
2297    if shape.ndims != 0:
2298      if context.executing_eagerly():
2299        raise ValueError('%sExpected scalar shape, saw shape: %s.'
2300                         % (message, shape,))
2301      else:
2302        raise ValueError('%sExpected scalar shape for %s, saw shape: %s.'
2303                         % (message, tensor.name, shape))
2304    return tensor
2305
2306
2307def _message_prefix(message):
2308  if message:
2309    return '%s.  ' % message
2310  return ''
2311
2312
2313@tf_export('ensure_shape')
2314@dispatch.add_dispatch_support
2315def ensure_shape(x, shape, name=None):
2316  """Updates the shape of a tensor and checks at runtime that the shape holds.
2317
2318  When executed, this operation asserts that the input tensor `x`'s shape
2319  is compatible with the `shape` argument.
2320  See `tf.TensorShape.is_compatible_with` for details.
2321
2322  >>> x = tf.constant([[1, 2, 3],
2323  ...                  [4, 5, 6]])
2324  >>> x = tf.ensure_shape(x, [2, 3])
2325
2326  Use `None` for unknown dimensions:
2327
2328  >>> x = tf.ensure_shape(x, [None, 3])
2329  >>> x = tf.ensure_shape(x, [2, None])
2330
2331  If the tensor's shape is not compatible with the `shape` argument, an error
2332  is raised:
2333
2334  >>> x = tf.ensure_shape(x, [5])
2335  Traceback (most recent call last):
2336  ...
2337  tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is not
2338    compatible with expected shape [5]. [Op:EnsureShape]
2339
2340  During graph construction (typically tracing a `tf.function`),
2341  `tf.ensure_shape` updates the static-shape of the **result** tensor by
2342  merging the two shapes. See `tf.TensorShape.merge_with` for details.
2343
2344  This is most useful when **you** know a shape that can't be determined
2345  statically by TensorFlow.
2346
2347  The following trivial `tf.function` prints the input tensor's
2348  static-shape before and after `ensure_shape` is applied.
2349
2350  >>> @tf.function
2351  ... def f(tensor):
2352  ...   print("Static-shape before:", tensor.shape)
2353  ...   tensor = tf.ensure_shape(tensor, [None, 3])
2354  ...   print("Static-shape after:", tensor.shape)
2355  ...   return tensor
2356
2357  This lets you see the effect of `tf.ensure_shape` when the function is traced:
2358  >>> cf = f.get_concrete_function(tf.TensorSpec([None, None]))
2359  Static-shape before: (None, None)
2360  Static-shape after: (None, 3)
2361
2362  >>> cf(tf.zeros([3, 3])) # Passes
2363  >>> cf(tf.constant([1, 2, 3])) # fails
2364  Traceback (most recent call last):
2365  ...
2366  InvalidArgumentError:  Shape of tensor x [3] is not compatible with expected shape [3,3].
2367
2368  The above example raises `tf.errors.InvalidArgumentError`, because `x`'s
2369  shape, `(3,)`, is not compatible with the `shape` argument, `(None, 3)`
2370
2371  Inside a `tf.function` or `v1.Graph` context it checks both the buildtime and
2372  runtime shapes. This is stricter than `tf.Tensor.set_shape` which only
2373  checks the buildtime shape.
2374
2375  Note: This differs from `tf.Tensor.set_shape` in that it sets the static shape
2376  of the resulting tensor and enforces it at runtime, raising an error if the
2377  tensor's runtime shape is incompatible with the specified shape.
2378  `tf.Tensor.set_shape` sets the static shape of the tensor without enforcing it
2379  at runtime, which may result in inconsistencies between the statically-known
2380  shape of tensors and the runtime value of tensors.
2381
2382  For example, of loading images of a known size:
2383
2384  >>> @tf.function
2385  ... def decode_image(png):
2386  ...   image = tf.image.decode_png(png, channels=3)
2387  ...   # the `print` executes during tracing.
2388  ...   print("Initial shape: ", image.shape)
2389  ...   image = tf.ensure_shape(image,[28, 28, 3])
2390  ...   print("Final shape: ", image.shape)
2391  ...   return image
2392
2393  When tracing a function, no ops are being executed, shapes may be unknown.
2394  See the [Concrete Functions Guide](https://www.tensorflow.org/guide/concrete_function)
2395  for details.
2396
2397  >>> concrete_decode = decode_image.get_concrete_function(
2398  ...     tf.TensorSpec([], dtype=tf.string))
2399  Initial shape:  (None, None, 3)
2400  Final shape:  (28, 28, 3)
2401
2402  >>> image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)
2403  >>> image = tf.cast(image,tf.uint8)
2404  >>> png = tf.image.encode_png(image)
2405  >>> image2 = concrete_decode(png)
2406  >>> print(image2.shape)
2407  (28, 28, 3)
2408
2409  >>> image = tf.concat([image,image], axis=0)
2410  >>> print(image.shape)
2411  (56, 28, 3)
2412  >>> png = tf.image.encode_png(image)
2413  >>> image2 = concrete_decode(png)
2414  Traceback (most recent call last):
2415  ...
2416  tf.errors.InvalidArgumentError:  Shape of tensor DecodePng [56,28,3] is not
2417    compatible with expected shape [28,28,3].
2418
2419  Caution: if you don't use the result of `tf.ensure_shape` the check may not
2420  run.
2421
2422  >>> @tf.function
2423  ... def bad_decode_image(png):
2424  ...   image = tf.image.decode_png(png, channels=3)
2425  ...   # the `print` executes during tracing.
2426  ...   print("Initial shape: ", image.shape)
2427  ...   # BAD: forgot to use the returned tensor.
2428  ...   tf.ensure_shape(image,[28, 28, 3])
2429  ...   print("Final shape: ", image.shape)
2430  ...   return image
2431
2432  >>> image = bad_decode_image(png)
2433  Initial shape:  (None, None, 3)
2434  Final shape:  (None, None, 3)
2435  >>> print(image.shape)
2436  (56, 28, 3)
2437
2438  Args:
2439    x: A `Tensor`.
2440    shape: A `TensorShape` representing the shape of this tensor, a
2441      `TensorShapeProto`, a list, a tuple, or None.
2442    name: A name for this operation (optional). Defaults to "EnsureShape".
2443
2444  Returns:
2445    A `Tensor`. Has the same type and contents as `x`.
2446
2447  Raises:
2448    tf.errors.InvalidArgumentError: If `shape` is incompatible with the shape
2449    of `x`.
2450  """
2451  if not isinstance(shape, tensor_shape.TensorShape):
2452    shape = tensor_shape.TensorShape(shape)
2453
2454  return array_ops.ensure_shape(x, shape, name=name)
2455
2456
2457@ops.RegisterGradient('EnsureShape')
2458def _ensure_shape_grad(op, grad):
2459  del op  # Unused.
2460  return grad
2461