1# mypy: allow-untyped-defs 2import functools 3import inspect 4import itertools 5import logging 6from dataclasses import dataclass 7from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 9import torch 10import torch.utils._pytree as pytree 11 12 13log = logging.getLogger(__name__) 14trace_shape_events_log = torch._logging.getArtifactLogger( 15 __name__, "trace_shape_events" 16) 17 18 19__all__ = [ 20 "ShapeEnvEvent", 21 "record_shapeenv_event", 22 "replay_shape_env_events", 23 "FakeTensorMeta", 24 "shape_env_check_state_equal", 25 "NotEqualError", 26] 27 28# [Note: Recording ShapeEnv Events] 29# ================================= 30# 31# What is a ShapeEnv event? 32# ------------------------- 33# We consider a ShapeEnv event every function call (ShapeEnv method or 34# independent function) that modifies the state of the ShapeEnv instance. 35# Such calls are recorded alongside their positional and keyword arguments, 36# so that it may be replayed over a different ShapeEnv instance. 37# 38# See [Note: ShapeEnv State Equality] for what is considered the state 39# of a ShapeEnv instance. 40# 41# What is it for? 42# --------------- 43# ShapeEnv events recording is used for reconstructing the ShapeEnv in an 44# arbitrary state in time. 45# 46# Being able to arbitrarily replay events like so is useful, mainly for 47# translation validation bisection. i.e. if a ValidationException has been 48# raised, find the earliest point in time where the translation validation 49# fails. 50# 51# Besides that, it also allows us to inspect the given instance and, 52# for example, check the guards that would actually be issued at that point. 53# 54# What kind of arguments can be stored in an event? 55# ------------------------------------------------- 56# There's no specific rule for what cannot be used as an argument. 57# That said, pay special attention to the following cases: 58# 59# 1. Tensor inputs: there are some tests that check whether the inputs 60# were garbage collected after execution. These will fail if there's 61# an event that is holding a reference to those inputs. 62# 63# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that 64# will be automatically replaced by the new given ShapeEnv instance. 65# 66# 3. SymTypes arguments: they also hold references to ShapeEnv. So, 67# whenever we see them, we create a new instance, replacing the 68# ShapeEnv reference. 69# 70# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic 71# shapes. That argument must be replaced when replaying the event at 72# ShapeEnvEvent.run, since it has to reference a node from the given 73# instance, and not from the recorded instance. 74 75 76# Event class for reconstructing ShapeEnv at arbitrary time. 77# 78# Represents a method call that mutates ShapeEnv in a way that affects the 79# issued guards, when ShapeEnv.produce_guards is called. 80@dataclass 81class ShapeEnvEvent: 82 # ShapeEnv method. 83 f: Callable 84 85 # Arguments and keyword arguments called with. 86 args: Optional[List[Any]] = None 87 kwargs: Optional[Dict[str, Any]] = None 88 89 # List of tracked_fakes at the time the method was called. 90 tracked_fakes: Optional[List[Any]] = None 91 92 # Name of the captured event. 93 # Used for special handling of particular methods. 94 name: Optional[str] = None 95 96 # Replay itself, but using shape_env as self. 97 def run(self, shape_env=None) -> Any: 98 from torch.fx.experimental.symbolic_shapes import ( 99 is_symbolic, 100 ShapeEnv, 101 SymTypes, 102 ) 103 104 # Special handling for the constructor event. 105 if self.f is ShapeEnv: 106 assert shape_env is None and self.args is None and self.kwargs is not None 107 return ShapeEnv(**self.kwargs) 108 109 assert shape_env is not None 110 args = list(self.args or []) 111 kwargs = dict(self.kwargs or {}) 112 113 # Replace any argument of type ShapeEnv by the given one. 114 args, kwargs = pytree.tree_map_only( 115 ShapeEnv, lambda _: shape_env, (args, kwargs) 116 ) 117 118 # Replace any argument of type SymTypes by a new instance, 119 # replacing its ShapeEnv reference. 120 args, kwargs = pytree.tree_map_only( 121 lambda x: isinstance(x, SymTypes) and is_symbolic(x), 122 lambda a: type(a)(a.node.with_shape_env(shape_env)), 123 (args, kwargs), 124 ) 125 126 # Converts FX nodes using the mapping argument. 127 def maybe_convert_node(x: Any) -> Any: 128 if not isinstance(x, torch.fx.Node): 129 # Don't do anything to x if it's not an FX node. 130 return x 131 132 # If, at some point, we created an FX node, it means that translation validation is on. 133 # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and 134 # we are tracking node names at shape_env.name_to_node. 135 assert hasattr(shape_env, "name_to_node") 136 name_to_node = shape_env.name_to_node # type: ignore[attr-defined] 137 assert x.name in name_to_node 138 return name_to_node[x.name] 139 140 # Replaces the value of an specific argument by the result of fn. 141 def replacearg(index: int, key: str, fn: Callable): 142 if index < len(args): 143 args[index] = fn(args[index]) 144 if key in kwargs: 145 kwargs[key] = fn(kwargs[key]) 146 147 if self.is_create_fx_call_function(): 148 # ShapeEnv.create_fx_call_function: 149 # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv. 150 # They must be replaced, since a "call_function" FX node with this tuple as argument 151 # will be added to the FX graph of the new shape_env. 152 replacearg( 153 index=2, 154 key="args", 155 fn=lambda args: tuple(maybe_convert_node(a) for a in args), 156 ) 157 if self.is_evaluate_expr() or self.is_defer_runtime_assert(): 158 # ShapeEnv.evaluate_expr and ShapeEnv.defer_runtime_assert: 159 # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. 160 # They must be replaced, since it will be part of a "call_function" FX node for 161 # torch._assert, which will be added to the FX graph of the new shape_env. 162 replacearg(index=3, key="fx_node", fn=maybe_convert_node) 163 164 # Actually call the method with the converted arguments. 165 return self.f(*args, **kwargs) 166 167 def __str__(self) -> str: 168 name = self.name if self.name is not None else self.f.__name__ 169 return f"event: {name} ({self.args}, {self.kwargs})" 170 171 def is_create_fx_call_function(self) -> bool: 172 return self.name == "_create_fx_call_function" 173 174 def is_evaluate_expr(self) -> bool: 175 return self.name == "evaluate_expr" 176 177 def is_defer_runtime_assert(self) -> bool: 178 return self.name == "defer_runtime_assert" 179 180 181NEST = 0 182 183 184# Extracts a ShapeEnv instance inside args and kwargs. 185# Specifically, it looks for: 186# 1. ShapeEnv arguments 187# 2. SymInt, SymFloat, or SymBool arguments 188# If we find more than one object of any of the above types, we 189# also check that the ShapeEnv instance is the same for all of them. 190def _extract_shape_env_and_assert_equal(args, kwargs): 191 from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes 192 193 def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: 194 if old is not None: 195 assert old is new, "call with different ShapeEnv" 196 return new 197 198 shape_env = None 199 for val in itertools.chain(args, kwargs.values()): 200 if isinstance(val, ShapeEnv): 201 shape_env = assert_equal(shape_env, val) 202 if isinstance(val, SymTypes) and is_symbolic(val): 203 shape_env = assert_equal(shape_env, val.node.shape_env) 204 205 return shape_env 206 207 208# Decorator for recording the given function as a replayable event. 209# 210# This decorator should be used at every function that mutates the state of 211# ShapeEnv in some way that affects the resulting issued guards (i.e. when 212# ShapeEnv.produce_guards is called). 213# 214# save_tracked_fakes: saves a snapshot of the TrackedFake list. 215# This is used when calling ShapeEnv.produce_guards at arbitrary points in time. 216# 217# When to save the list of TrackedFake? 218# ===================================== 219# We should save the list of TrackedFake whenever the translation validation 220# bisection may actually stop and call the produce_guards method at the moment 221# right after the recorded function was played. In other words, since the 222# bisection bisects through torch._assert calls, we should save in all methods 223# that adds a torch._assert call to the symbolic shapes FX graph. 224# 225# At the moment, there are 2 methods that save the list: 226# - ShapeEnv.evaluate_expr 227# - ShapeEnv.defer_runtime_assert 228def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable: 229 def decorator(fn: Callable) -> Callable: 230 assert callable(fn) 231 args = inspect.getfullargspec(fn).args 232 assert args and args[0] == "self", ( 233 "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " 234 "code so that it calls into a method on ShapeEnv" 235 ) 236 name = fn.__name__ 237 238 @functools.wraps(fn) 239 def wrapper(*args, **kwargs): 240 from torch.fx.experimental.symbolic_shapes import ShapeEnv 241 242 assert isinstance(args[0], ShapeEnv) 243 244 global NEST 245 246 trace_shape_events_log.debug( 247 "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs 248 ) 249 NEST += 1 250 251 def retlog(r): 252 trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r) 253 return r 254 255 try: 256 if args[0].is_recording: # type: ignore[has-type] 257 # If ShapeEnv is already recording an event, call the wrapped 258 # function directly. 259 # 260 # NB: here, we skip the check of whether all ShapeEnv instances 261 # are equal, in favor of a faster dispatch. 262 return retlog(fn(*args, **kwargs)) 263 264 # Retrieve an instance of ShapeEnv. 265 # Assumption: the collection of args and kwargs may not reference 266 # different ShapeEnv instances. 267 self = _extract_shape_env_and_assert_equal(args, kwargs) 268 269 # If we are calling this function without any ShapeEnv instance 270 # alive in its arguments, we don't record and call the original. 271 if self is None: 272 return retlog(fn(*args, **kwargs)) 273 274 # Otherwise, start recording and call the function. 275 with self._recording(): 276 # Take a snapshot of the current tracked_fakes. 277 tracked_fakes = ( 278 self._snapshot_tracked_fakes() if save_tracked_fakes else None 279 ) 280 # Record the event for 'fn'. 281 event = ShapeEnvEvent( 282 fn, list(args), kwargs, tracked_fakes, name=fn.__name__ 283 ) 284 # Play the event on this ShapeEnv. 285 # NB: It's important to put the event first, because running 286 # the event can trigger internal events that must be ordered 287 # after this event. However, if an exception happens, we do 288 # NOT want to have the event in the list, so pop it off from 289 # the record if an error happened 290 self.events.append(event) 291 try: 292 return retlog(event.run(self)) 293 except Exception: 294 self.events.pop() 295 raise 296 297 except Exception: 298 log.error( # noqa: G201 299 "failed while running %s(*%s, **%s)", 300 name, 301 args[1:], 302 kwargs, 303 exc_info=log.isEnabledFor(logging.INFO), 304 ) 305 raise 306 307 finally: 308 NEST -= 1 309 310 return wrapper 311 312 return decorator 313 314 315# Replays the ShapeEnvEvents list. 316# It assumes the first event is the constructor call. 317# 318# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. 319def replay_shape_env_events(events): 320 from torch.fx.experimental.symbolic_shapes import ShapeEnv 321 322 constructor_event = events[0] 323 assert constructor_event.f == ShapeEnv 324 325 # Constructs the new ShapeEnv. 326 shape_env = constructor_event.run() 327 328 for event in events[1:]: 329 try: 330 # Actually replays each event. 331 # We need to call create_mapping_fn every time, since the node list might 332 # change after each event is replayed. 333 event.run(shape_env) 334 except Exception as e: 335 log.error("failed when running event: %s", event) 336 raise 337 338 return shape_env 339 340 341# FakeTensor metadata. 342# This is to be used in place of FakeTensor placeholders when calling 343# ShapeEnv.produce_guards. 344@dataclass 345class FakeTensorMeta: 346 tensor_size: Tuple[Union[int, torch.SymInt], ...] 347 tensor_stride: Tuple[Union[int, torch.SymInt], ...] 348 tensor_storage_offset: Union[int, torch.SymInt] 349 is_nested: bool 350 351 def size(self) -> Tuple[Union[int, torch.SymInt], ...]: 352 return self.tensor_size 353 354 def stride(self) -> Tuple[Union[int, torch.SymInt], ...]: 355 return self.tensor_stride 356 357 def storage_offset(self) -> Union[int, torch.SymInt]: 358 return self.tensor_storage_offset 359 360 def dim(self) -> int: 361 return len(self.tensor_size) 362 363 @staticmethod 364 def from_fake(fake) -> "FakeTensorMeta": 365 return FakeTensorMeta( 366 fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested 367 ) 368 369 370# [Note: ShapeEnv State Equality] 371# =============================== 372# 373# What is considered ShapeEnv state? 374# ---------------------------------- 375# We consider to be the state of a ShapeEnv instance everything that 376# is not in the inline tuple inside remove_nonstate_variables function. 377# That is: the fields within ShapeEnv that modify the flow of execution 378# of the program. 379# 380# So, for example: the replacements field might influence on how an 381# expression is simplified. That, in turn, may result in a guard being 382# statically known (i.e. not added). 383# 384# On the other hand, var_to_stack serves only changes what is printed 385# in the screen, i.e. used only for debugging purposes. Therefore, we 386# should not consider it when comparing states. 387# 388# What to do on NotEqualError? 389# ---------------------------- 390# Here are a few possible causes for getting a NotEqualError raised: 391# 392# 1. New field that does not belong in the ShapeEnv state. 393# For example: log field of type ShapeEnvLoggerAdapter. Different 394# ShapeEnv instances will always have different ShapeEnvLoggerAdapter 395# instances, i.e. equality comparison would fail. 396# Solution: add it to the inlined tuple inside remove_nonstate_variables 397# function inside check_equal method. 398# 399# 2. New field that is not directly comparable across instances. 400# For example: guards field of type List[ShapeGuard]. More specifically, 401# the ShapeGuard type holds an expression and a stack information 402# for debugging purposes. When replaying the even on a new ShapeEnv 403# instance, the stack would be different, which would trigger this error. 404# Solution: add a special case to the map_value function inside 405# check_equal function. 406# 407# 3. Mutation of ShapeEnv on some not recorded function. 408# If a mutation of the state of ShapeEnv happens inside a function 409# that is not recorded (or that no caller in the stack is recorded), 410# then, the replayed ShapeEnv won't catch that. 411# Solution: decorate the function with record_shape_env_event. 412 413 414# Checks whether the state of two ShapeEnv are equal w.r.t. the guards 415# returned by ShapeEnv.produce_guards. 416def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): 417 # Collect and remove variables that don't necessarily represent the state 418 # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the 419 # instance itself. 420 env1_vars = vars(env1).copy() 421 env2_vars = vars(env2).copy() 422 423 for v in non_state_variable_names: 424 if v in env1_vars: 425 env1_vars.pop(v) 426 if v in env2_vars: 427 env2_vars.pop(v) 428 429 # Function for transforming the mismatched values into string. 430 # Needed, since dict and set entries order might not be the same every time. 431 def value_to_str(value: Any) -> str: 432 if isinstance(value, dict): 433 return ( 434 "{" 435 + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str)) 436 + "}" 437 ) 438 if isinstance(value, set): 439 return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}" 440 return str(value) 441 442 # Compares env1_vars with env2_vars. 443 # Here, we allow the value of each field to be mapped, so that we appropriately 444 # compare the two values. 445 def compare_vars( 446 map_value: Callable[[str, Any], Any] 447 ) -> List[Tuple[str, str, str]]: 448 env1_set, env2_set = set(env1_vars), set(env2_vars) 449 450 # First, compare the set of keys in each vars dictionary. 451 if env1_set != env2_set: 452 raise NotEqualError( 453 "field set mismatch:", 454 [ 455 ( 456 "found unique fields:", 457 str(sorted(env1_set - env2_set)), 458 str(sorted(env2_set - env1_set)), 459 ), 460 ], 461 ) 462 463 # Then, sort the keys, and compare the mapped values of each key. 464 sorted_keys = list(env1_set) 465 sorted_keys.sort() 466 467 mapped_dict = [ 468 (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k])) 469 for k in sorted_keys 470 ] 471 472 # Return a list of tuples representing the fields that did not match 473 # alongside their respective mapped values. 474 return [ 475 (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2)) 476 for k, val1, val2 in mapped_dict 477 if val1 != val2 478 ] 479 480 # Accumulate the mismatching fields. 481 errors = compare_vars(map_value) 482 483 if len(errors) > 0: 484 raise NotEqualError("field values don't match:", errors) 485 486 487class NotEqualError(Exception): 488 def __init__( 489 self, 490 msg: str, 491 mismatched: List[Tuple[str, str, str]], 492 ) -> None: 493 details = "\n".join( 494 [ 495 "\n".join( 496 [ 497 f"==> {inner_msg}", 498 f" > Left: {str1}", 499 f" > Right: {str2}", 500 ] 501 ) 502 for inner_msg, str1, str2 in mismatched 503 ] 504 ) 505 506 super().__init__( 507 f"""\ 508ShapeEnv not equal: {msg} 509 510{details} 511""" 512 ) 513