1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converter for logical expressions, e.g. `a and b -> tf.logical_and(a, b)`.""" 16 17import gast 18 19from tensorflow.python.autograph.core import converter 20from tensorflow.python.autograph.pyct import parser 21from tensorflow.python.autograph.pyct import templates 22 23# TODO(mdan): Properly extract boolean ops according to lazy eval rules. 24# Note that this isn't completely safe either, because tensors may have control 25# dependencies. 26# Note that for loops that should be done after the loop was converted to 27# tf.while_loop so that the expanded conditionals are properly scoped. 28 29# Used to signal that an operand is safe for non-lazy evaluation. 30SAFE_BOOLEAN_OPERAND = 'SAFE_BOOLEAN_OPERAND' 31 32 33LOGICAL_OPERATORS = { 34 gast.And: 'ag__.and_', 35 gast.Not: 'ag__.not_', 36 gast.Or: 'ag__.or_', 37} 38 39EQUALITY_OPERATORS = { 40 gast.Eq: 'ag__.eq', 41 gast.NotEq: 'ag__.not_eq', 42} 43 44 45class LogicalExpressionTransformer(converter.Base): 46 """Converts logical expressions to corresponding TF calls.""" 47 48 def _overload_of(self, operator): 49 op_type = type(operator) 50 if op_type in LOGICAL_OPERATORS: 51 return LOGICAL_OPERATORS[op_type] 52 if self.ctx.user.options.uses(converter.Feature.EQUALITY_OPERATORS): 53 if op_type in EQUALITY_OPERATORS: 54 return EQUALITY_OPERATORS[op_type] 55 return None 56 57 def _as_lambda(self, expr): 58 return templates.replace_as_expression('lambda: expr', expr=expr) 59 60 def _as_binary_function(self, func_name, arg1, arg2): 61 return templates.replace_as_expression( 62 'func_name(arg1, arg2)', 63 func_name=parser.parse_expression(func_name), 64 arg1=arg1, 65 arg2=arg2) 66 67 def _as_binary_operation(self, op, arg1, arg2): 68 template = templates.replace_as_expression( 69 'arg1 is arg2', # Note: `is` will be replaced with `op` below. 70 arg1=arg1, 71 arg2=arg2) 72 template.ops[0] = op 73 return template 74 75 def _as_unary_function(self, func_name, arg): 76 return templates.replace_as_expression( 77 'func_name(arg)', func_name=parser.parse_expression(func_name), arg=arg) 78 79 def _process_binop(self, op, left, right): 80 overload = self._overload_of(op) 81 if overload is None: 82 return self._as_binary_operation(op, left, right) 83 return self._as_binary_function(overload, left, right) 84 85 def visit_Compare(self, node): 86 node = self.generic_visit(node) 87 88 ops_and_comps = list(zip(node.ops, node.comparators)) 89 left = node.left 90 91 # Repeated comparisons are converted to conjunctions: 92 # a < b < c -> a < b and b < c 93 op_tree = None 94 while ops_and_comps: 95 op, right = ops_and_comps.pop(0) 96 binary_comparison = self._process_binop(op, left, right) 97 if op_tree is not None: 98 op_tree = self._as_binary_function('ag__.and_', 99 self._as_lambda(op_tree), 100 self._as_lambda(binary_comparison)) 101 else: 102 op_tree = binary_comparison 103 left = right 104 105 assert op_tree is not None 106 return op_tree 107 108 def visit_UnaryOp(self, node): 109 node = self.generic_visit(node) 110 111 overload = self._overload_of(node.op) 112 if overload is None: 113 return node 114 115 return self._as_unary_function(overload, node.operand) 116 117 def visit_BoolOp(self, node): 118 node = self.generic_visit(node) 119 node_values = node.values 120 right = node.values.pop() 121 while node_values: 122 left = node_values.pop() 123 right = self._as_binary_function( 124 self._overload_of(node.op), self._as_lambda(left), 125 self._as_lambda(right)) 126 return right 127 128 129def transform(node, ctx): 130 transformer = LogicalExpressionTransformer(ctx) 131 return transformer.visit(node) 132