1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converter construction support. 16 17This module contains a base class for all converters, as well as supporting 18structures. These structures are referred to as contexts. 19 20The class hierarchy is as follows: 21 22 <your converter> 23 [extends] converter.Base 24 [extends] transformer.Base 25 [extends] gast.nodeTransformer 26 [uses] transformer.SourceInfo 27 [uses] converter.EntityContext 28 [uses] converter.ProgramContext 29 [uses] transformer.SourceInfo 30 31converter.Base is a specialization of transformer.Base for AutoGraph. It's a 32very lightweight subclass that adds a `ctx` attribute holding the corresponding 33EntityContext object (see below). Note that converters are not reusable, and 34`visit` will raise an error if called more than once. 35 36converter.EntityContext contains mutable state associated with an entity that 37the converter processes. 38 39converter.ProgramContext contains mutable state across related entities. For 40example, when converting several functions that call one another, the 41ProgramContext should be shared across these entities. 42 43Below is the overall flow at conversion: 44 45 program_ctx = ProgramContext(<entities to convert>, <global settings>, ...) 46 while <program_ctx has more entities to convert>: 47 entity, source_info = <get next entity from program_ctx> 48 entity_ctx = EntityContext(program_ctx, source_info) 49 for <each ConverterClass>: 50 converter = ConverterClass(entity_ctx) 51 52 # May update entity_ctx and program_ctx 53 entity = converter.visit(entity) 54 55 <add entity's dependencies to program_ctx> 56 57Note that pyct contains a small number of transformers used for static analysis. 58These implement transformer.Base, rather than converter.Base, to avoid a 59dependency on AutoGraph. 60""" 61 62import enum 63 64from tensorflow.python.autograph.pyct import anno 65from tensorflow.python.autograph.pyct import ast_util 66from tensorflow.python.autograph.pyct import parser 67from tensorflow.python.autograph.pyct import templates 68from tensorflow.python.autograph.pyct import transformer 69from tensorflow.python.util.tf_export import tf_export 70 71# TODO(mdan): These contexts can be refactored into first class objects. 72# For example, we could define Program and Entity abstractions that hold on 73# to the actual entity and have conversion methods. 74 75# TODO(mdan): Add a test specific to this converter. 76 77 78@tf_export('autograph.experimental.Feature') 79class Feature(enum.Enum): 80 """This enumeration represents optional conversion options. 81 82 These conversion options are experimental. They are subject to change without 83 notice and offer no guarantees. 84 85 _Example Usage_ 86 87 ```python 88 optionals= tf.autograph.experimental.Feature.EQUALITY_OPERATORS 89 @tf.function(experimental_autograph_options=optionals) 90 def f(i): 91 if i == 0: # EQUALITY_OPERATORS allows the use of == here. 92 tf.print('i is zero') 93 ``` 94 95 Attributes: 96 ALL: Enable all features. 97 AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code. 98 ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert. 99 BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to 100 their TF counterparts. 101 EQUALITY_OPERATORS: Whether to convert the equality operator ('==') to 102 tf.math.equal. 103 LISTS: Convert list idioms, like initializers, slices, append, etc. 104 NAME_SCOPES: Insert name scopes that name ops according to context, like the 105 function they were defined in. 106 """ 107 108 ALL = 'ALL' 109 110 AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS' 111 ASSERT_STATEMENTS = 'ASSERT_STATEMENTS' 112 BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS' 113 EQUALITY_OPERATORS = 'EQUALITY_OPERATORS' 114 LISTS = 'LISTS' 115 NAME_SCOPES = 'NAME_SCOPES' 116 117 @classmethod 118 def all(cls): 119 """Returns a tuple that enables all options.""" 120 return tuple(cls.__members__.values()) 121 122 @classmethod 123 def all_but(cls, exclude): 124 """Returns a tuple that enables all but the excluded options.""" 125 if not isinstance(exclude, (list, tuple, set)): 126 exclude = (exclude,) 127 return tuple(set(cls.all()) - set(exclude) - {cls.ALL}) 128 129 130STANDARD_OPTIONS = None # Forward definition. 131 132 133class ConversionOptions(object): 134 """Immutable container for global conversion flags. 135 136 Attributes: 137 recursive: bool, whether to recursively convert any user functions or 138 classes that the converted function may use. 139 user_requested: bool, whether the conversion was explicitly requested by 140 the user, as opposed to being performed as a result of other logic. This 141 value always auto-resets to False in child conversions. 142 optional_features: Union[Feature, Set[Feature]], controls the use of 143 optional features in the conversion process. See Feature for available 144 options. 145 """ 146 147 def __init__(self, 148 recursive=False, 149 user_requested=False, 150 internal_convert_user_code=True, 151 optional_features=Feature.ALL): 152 self.recursive = recursive 153 self.user_requested = user_requested 154 # TODO(mdan): Rename to conversion_recursion_depth? 155 self.internal_convert_user_code = internal_convert_user_code 156 157 if optional_features is None: 158 optional_features = () 159 elif isinstance(optional_features, Feature): 160 optional_features = (optional_features,) 161 optional_features = frozenset(optional_features) 162 self.optional_features = optional_features 163 164 def as_tuple(self): 165 return (self.recursive, self.user_requested, 166 self.internal_convert_user_code, self.optional_features) 167 168 def __hash__(self): 169 return hash(self.as_tuple()) 170 171 def __eq__(self, other): 172 assert isinstance(other, ConversionOptions) 173 return self.as_tuple() == other.as_tuple() 174 175 def __str__(self): 176 return 'ConversionOptions[{}]' 177 178 def uses(self, feature): 179 return (Feature.ALL in self.optional_features or 180 feature in self.optional_features) 181 182 def call_options(self): 183 """Returns the corresponding options to be used for recursive conversion.""" 184 return ConversionOptions( 185 recursive=self.recursive, 186 user_requested=False, 187 internal_convert_user_code=self.recursive, 188 optional_features=self.optional_features) 189 190 def to_ast(self): 191 """Returns a representation of this object as an AST node. 192 193 The AST node encodes a constructor that would create an object with the 194 same contents. 195 196 Returns: 197 ast.Node 198 """ 199 if self == STANDARD_OPTIONS: 200 return parser.parse_expression('ag__.STD') 201 202 template = """ 203 ag__.ConversionOptions( 204 recursive=recursive_val, 205 user_requested=user_requested_val, 206 optional_features=optional_features_val, 207 internal_convert_user_code=internal_convert_user_code_val) 208 """ 209 210 def list_of_features(values): 211 return parser.parse_expression('({})'.format(', '.join( 212 'ag__.{}'.format(str(v)) for v in values))) 213 214 expr_ast = templates.replace( 215 template, 216 recursive_val=parser.parse_expression(str(self.recursive)), 217 user_requested_val=parser.parse_expression(str(self.user_requested)), 218 internal_convert_user_code_val=parser.parse_expression( 219 str(self.internal_convert_user_code)), 220 optional_features_val=list_of_features(self.optional_features)) 221 return expr_ast[0].value 222 223 224STANDARD_OPTIONS = ConversionOptions( 225 recursive=True, 226 user_requested=False, 227 internal_convert_user_code=True, 228 optional_features=None) 229 230 231class ProgramContext(object): 232 """ProgramContext keeps track of converting function hierarchies. 233 234 Attributes: 235 options: ConversionOptions 236 autograph_module: Deprecated. Do not use. 237 """ 238 239 def __init__(self, options, autograph_module=None): 240 self.options = options 241 self.autograph_module = autograph_module 242 243 244class Base(transformer.Base): 245 """All converters should inherit from this class. 246 247 Attributes: 248 ctx: EntityContext 249 """ 250 251 def __init__(self, ctx): 252 super(Base, self).__init__(ctx) 253 254 self._used = False 255 self._ast_depth = 0 256 257 def get_definition_directive(self, node, directive, arg, default): 258 """Returns the unique directive argument for a symbol. 259 260 See lang/directives.py for details on directives. 261 262 Example: 263 # Given a directive in the code: 264 ag.foo_directive(bar, baz=1) 265 266 # One can write for an AST node Name(id='bar'): 267 get_definition_directive(node, ag.foo_directive, 'baz') 268 269 Args: 270 node: ast.AST, the node representing the symbol for which the directive 271 argument is needed. 272 directive: Callable[..., Any], the directive to search. 273 arg: str, the directive argument to return. 274 default: Any 275 276 Raises: 277 ValueError: if conflicting annotations have been found 278 """ 279 defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) 280 if not defs: 281 return default 282 283 arg_values_found = [] 284 for def_ in defs: 285 if (directive in def_.directives and arg in def_.directives[directive]): 286 arg_values_found.append(def_.directives[directive][arg]) 287 288 if not arg_values_found: 289 return default 290 291 if len(arg_values_found) == 1: 292 return arg_values_found[0] 293 294 # If multiple annotations reach the symbol, they must all match. If they do, 295 # return any of them. 296 first_value = arg_values_found[0] 297 for other_value in arg_values_found[1:]: 298 if not ast_util.matches(first_value, other_value): 299 qn = anno.getanno(node, anno.Basic.QN) 300 raise ValueError( 301 '%s has ambiguous annotations for %s(%s): %s, %s' % 302 (qn, directive.__name__, arg, parser.unparse(other_value).strip(), 303 parser.unparse(first_value).strip())) 304 return first_value 305 306 def visit(self, node): 307 if not self._ast_depth: 308 if self._used: 309 raise ValueError('converter objects cannot be reused') 310 self._used = True 311 312 self._ast_depth += 1 313 try: 314 return super(Base, self).visit(node) 315 finally: 316 self._ast_depth -= 1 317