1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""smart_cond and related utilities.""" 16 17from tensorflow.python.framework import ops 18from tensorflow.python.framework import tensor_util 19from tensorflow.python.ops import control_flow_ops 20from tensorflow.python.util.tf_export import tf_export 21 22 23@tf_export("__internal__.smart_cond.smart_cond", v1=[]) 24def smart_cond(pred, true_fn=None, false_fn=None, name=None): 25 """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. 26 27 If `pred` is a bool or has a constant value, we return either `true_fn()` 28 or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. 29 30 Args: 31 pred: A scalar determining whether to return the result of `true_fn` or 32 `false_fn`. 33 true_fn: The callable to be performed if pred is true. 34 false_fn: The callable to be performed if pred is false. 35 name: Optional name prefix when using `tf.cond`. 36 37 Returns: 38 Tensors returned by the call to either `true_fn` or `false_fn`. 39 40 Raises: 41 TypeError: If `true_fn` or `false_fn` is not callable. 42 """ 43 if not callable(true_fn): 44 raise TypeError(f"Argument `true_fn` must be callable. Received {true_fn}") 45 if not callable(false_fn): 46 raise TypeError( 47 f"Argument `false_fn` must be callable. Received {false_fn}") 48 49 pred_value = smart_constant_value(pred) 50 if pred_value is not None: 51 if pred_value: 52 return true_fn() 53 else: 54 return false_fn() 55 else: 56 return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, 57 name=name) 58 59 60def smart_constant_value(pred): 61 """Return the bool value for `pred`, or None if `pred` had a dynamic value. 62 63 Args: 64 pred: A scalar, either a Python bool or tensor. 65 66 Returns: 67 True or False if `pred` has a constant boolean value, None otherwise. 68 69 Raises: 70 TypeError: If `pred` is not a Tensor or bool. 71 """ 72 if isinstance(pred, ops.Tensor): 73 pred_value = tensor_util.constant_value(pred) 74 # TODO(skyewm): consider folding this into tensor_util.constant_value. 75 # pylint: disable=protected-access 76 if pred_value is None: 77 pred_value = tensor_util.try_evaluate_constant(pred) 78 # pylint: enable=protected-access 79 elif pred in {0, 1}: # Accept 1/0 as valid boolean values 80 pred_value = bool(pred) 81 elif isinstance(pred, bool): 82 pred_value = pred 83 else: 84 raise TypeError("Argument `pred` must be a Tensor, or a Python bool, or 1 " 85 f"or 0. Received: pred={pred} of type " 86 f"{type(pred).__name__}") 87 88 return pred_value 89 90 91def smart_case(pred_fn_pairs, default=None, exclusive=False, name="smart_case"): 92 """Like tf.case, except attempts to statically evaluate predicates. 93 94 If any predicate in `pred_fn_pairs` is a bool or has a constant value, the 95 associated callable will be called or omitted depending on its value. 96 Otherwise this functions like tf.case. 97 98 Args: 99 pred_fn_pairs: Dict or list of pairs of a boolean scalar tensor and a 100 callable which returns a list of tensors. 101 default: Optional callable that returns a list of tensors. 102 exclusive: True iff at most one predicate is allowed to evaluate to `True`. 103 name: A name for this operation (optional). 104 105 Returns: 106 The tensors returned by the first pair whose predicate evaluated to True, or 107 those returned by `default` if none does. 108 109 Raises: 110 TypeError: If `pred_fn_pairs` is not a list/dictionary. 111 TypeError: If `pred_fn_pairs` is a list but does not contain 2-tuples. 112 TypeError: If `fns[i]` is not callable for any i, or `default` is not 113 callable. 114 """ 115 return control_flow_ops._case_helper( # pylint: disable=protected-access 116 smart_cond, pred_fn_pairs, default, exclusive, name, 117 allow_python_preds=True) 118