xref: /aosp_15_r20/external/tensorflow/tensorflow/python/autograph/converters/logical_expressions.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`."""
16
17import gast
18
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.pyct import parser
21from tensorflow.python.autograph.pyct import templates
22
23# TODO(mdan): Properly extract boolean ops according to lazy eval rules.
24# Note that this isn't completely safe either, because tensors may have control
25# dependencies.
26# Note that for loops that should be done after the loop was converted to
27# tf.while_loop so that the expanded conditionals are properly scoped.
28
29# Used to signal that an operand is safe for non-lazy evaluation.
30SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND'
31
32
33LOGICAL_OPERATORS = {
34    gast.And: 'ag__.and_',
35    gast.Not: 'ag__.not_',
36    gast.Or: 'ag__.or_',
37}
38
39EQUALITY_OPERATORS = {
40    gast.Eq: 'ag__.eq',
41    gast.NotEq: 'ag__.not_eq',
42}
43
44
45class LogicalExpressionTransformer(converter.Base):
46  """Converts logical expressions to corresponding TF calls."""
47
48  def _overload_of(self, operator):
49    op_type = type(operator)
50    if op_type in LOGICAL_OPERATORS:
51      return LOGICAL_OPERATORS[op_type]
52    if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS):
53      if op_type in EQUALITY_OPERATORS:
54        return EQUALITY_OPERATORS[op_type]
55    return None
56
57  def _as_lambda(self, expr):
58    return templates.replace_as_expression('lambda: expr', expr=expr)
59
60  def _as_binary_function(self, func_name, arg1, arg2):
61    return templates.replace_as_expression(
62        'func_name(arg1, arg2)',
63        func_name=parser.parse_expression(func_name),
64        arg1=arg1,
65        arg2=arg2)
66
67  def _as_binary_operation(self, op, arg1, arg2):
68    template = templates.replace_as_expression(
69        'arg1 is arg2',  # Note: `is` will be replaced with `op` below.
70        arg1=arg1,
71        arg2=arg2)
72    template.ops[0] = op
73    return template
74
75  def _as_unary_function(self, func_name, arg):
76    return templates.replace_as_expression(
77        'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg)
78
79  def _process_binop(self, op, left, right):
80    overload = self._overload_of(op)
81    if overload is None:
82      return self._as_binary_operation(op, left, right)
83    return self._as_binary_function(overload, left, right)
84
85  def visit_Compare(self, node):
86    node = self.generic_visit(node)
87
88    ops_and_comps = list(zip(node.ops, node.comparators))
89    left = node.left
90
91    # Repeated comparisons are converted to conjunctions:
92    #   a < b < c   ->   a < b and b < c
93    op_tree = None
94    while ops_and_comps:
95      op, right = ops_and_comps.pop(0)
96      binary_comparison = self._process_binop(op, left, right)
97      if op_tree is not None:
98        op_tree = self._as_binary_function('ag__.and_',
99                                           self._as_lambda(op_tree),
100                                           self._as_lambda(binary_comparison))
101      else:
102        op_tree = binary_comparison
103      left = right
104
105    assert op_tree is not None
106    return op_tree
107
108  def visit_UnaryOp(self, node):
109    node = self.generic_visit(node)
110
111    overload = self._overload_of(node.op)
112    if overload is None:
113      return node
114
115    return self._as_unary_function(overload, node.operand)
116
117  def visit_BoolOp(self, node):
118    node = self.generic_visit(node)
119    node_values = node.values
120    right = node.values.pop()
121    while node_values:
122      left = node_values.pop()
123      right = self._as_binary_function(
124          self._overload_of(node.op), self._as_lambda(left),
125          self._as_lambda(right))
126    return right
127
128
129def transform(node, ctx):
130  transformer = LogicalExpressionTransformer(ctx)
131  return transformer.visit(node)
132