xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/smart_cond.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""smart_cond and related utilities."""
16
17from tensorflow.python.framework import ops
18from tensorflow.python.framework import tensor_util
19from tensorflow.python.ops import control_flow_ops
20from tensorflow.python.util.tf_export import tf_export
21
22
23@tf_export("__internal__.smart_cond.smart_cond", v1=[])
24def smart_cond(pred, true_fn=None, false_fn=None, name=None):
25  """Return either `true_fn()` if predicate `pred` is true else `false_fn()`.
26
27  If `pred` is a bool or has a constant value, we return either `true_fn()`
28  or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both.
29
30  Args:
31    pred: A scalar determining whether to return the result of `true_fn` or
32      `false_fn`.
33    true_fn: The callable to be performed if pred is true.
34    false_fn: The callable to be performed if pred is false.
35    name: Optional name prefix when using `tf.cond`.
36
37  Returns:
38    Tensors returned by the call to either `true_fn` or `false_fn`.
39
40  Raises:
41    TypeError: If `true_fn` or `false_fn` is not callable.
42  """
43  if not callable(true_fn):
44    raise TypeError(f"Argument `true_fn` must be callable. Received {true_fn}")
45  if not callable(false_fn):
46    raise TypeError(
47        f"Argument `false_fn` must be callable. Received {false_fn}")
48
49  pred_value = smart_constant_value(pred)
50  if pred_value is not None:
51    if pred_value:
52      return true_fn()
53    else:
54      return false_fn()
55  else:
56    return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn,
57                                 name=name)
58
59
60def smart_constant_value(pred):
61  """Return the bool value for `pred`, or None if `pred` had a dynamic value.
62
63  Args:
64    pred: A scalar, either a Python bool or tensor.
65
66  Returns:
67    True or False if `pred` has a constant boolean value, None otherwise.
68
69  Raises:
70    TypeError: If `pred` is not a Tensor or bool.
71  """
72  if isinstance(pred, ops.Tensor):
73    pred_value = tensor_util.constant_value(pred)
74    # TODO(skyewm): consider folding this into tensor_util.constant_value.
75    # pylint: disable=protected-access
76    if pred_value is None:
77      pred_value = tensor_util.try_evaluate_constant(pred)
78    # pylint: enable=protected-access
79  elif pred in {0, 1}:  # Accept 1/0 as valid boolean values
80    pred_value = bool(pred)
81  elif isinstance(pred, bool):
82    pred_value = pred
83  else:
84    raise TypeError("Argument `pred` must be a Tensor, or a Python bool, or 1 "
85                    f"or 0. Received: pred={pred} of type "
86                    f"{type(pred).__name__}")
87
88  return pred_value
89
90
91def smart_case(pred_fn_pairs, default=None, exclusive=False, name="smart_case"):
92  """Like tf.case, except attempts to statically evaluate predicates.
93
94  If any predicate in `pred_fn_pairs` is a bool or has a constant value, the
95  associated callable will be called or omitted depending on its value.
96  Otherwise this functions like tf.case.
97
98  Args:
99    pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a
100                   callable which returns a list of tensors.
101    default: Optional callable that returns a list of tensors.
102    exclusive: True iff at most one predicate is allowed to evaluate to `True`.
103    name: A name for this operation (optional).
104
105  Returns:
106    The tensors returned by the first pair whose predicate evaluated to True, or
107    those returned by `default` if none does.
108
109  Raises:
110    TypeError: If `pred_fn_pairs` is not a list/dictionary.
111    TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples.
112    TypeError: If `fns[i]` is not callable for any i, or `default` is not
113               callable.
114  """
115  return control_flow_ops._case_helper(  # pylint: disable=protected-access
116      smart_cond, pred_fn_pairs, default, exclusive, name,
117      allow_python_preds=True)
118