1# mypy: ignore-errors 2 3import copy 4import logging 5import os 6import pickle 7import random 8from contextlib import contextmanager 9from functools import partial 10from typing import Callable, Union 11 12import sympy 13 14import torch 15import torch.fx as fx 16import torch.nn as nn 17import torch.utils._pytree as pytree 18from torch import SymInt 19from torch._decomp import get_decompositions 20from torch.fx.experimental.symbolic_shapes import bind_symbols 21 22from .aot_autograd import aot_function, aot_module, make_boxed_compiler 23from .compile_utils import strip_overloads 24from .partitioners import ( 25 default_partition, 26 draw_graph, 27 min_cut_rematerialization_partition, 28) 29 30 31log = logging.getLogger(__name__) 32 33 34# These canonicalizations are needed here (and not decompositions), as the ops 35# we're trying to canonicalize to CompositeImplicitAutograd. 36def _canonicalize(fx_g): 37 for node in fx_g.graph.find_nodes( 38 op="call_function", target=torch.ops.aten._to_copy 39 ): 40 node.target = torch.ops.aten.to 41 fx_g.recompile() 42 return fx_g 43 44 45@contextmanager 46def _disable_jit_autocast(): 47 old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False) 48 try: 49 yield 50 finally: 51 torch._C._jit_set_autocast_mode(old_jit_autocast_flag) 52 53 54@make_boxed_compiler 55def ts_compile(fx_g: fx.GraphModule, inps) -> Callable: 56 """ 57 Compiles the :attr:`fx_g` with Torchscript compiler. 58 59 .. warning:: 60 This API is experimental and likely to change. 61 62 Args: 63 fx_g(fx.GraphModule): The input Fx graph module to be compiled. 64 65 Returns: 66 Torch scripted model. 67 """ 68 69 with _disable_jit_autocast(): 70 strip_overloads(fx_g) 71 72 for node in fx_g.graph.find_nodes( 73 op="call_function", target=torch.ops.aten._to_copy 74 ): 75 if len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs: 76 node.target = torch.ops.aten.to 77 78 for node in fx_g.graph.nodes: 79 new_kwargs = {} 80 for k, v in node.kwargs.items(): 81 if isinstance(v, torch.device): 82 v = v.type 83 new_kwargs[k] = v 84 node.kwargs = new_kwargs 85 86 fx_g.graph.lint() 87 88 fx_g.recompile() 89 90 f = torch.jit.script(fx_g) 91 92 torch._C._jit_pass_remove_mutation(f.graph) 93 94 f = torch.jit.freeze(f.eval()) 95 f = torch.jit.optimize_for_inference(f) 96 if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps): 97 f(*inps) 98 return f 99 100 101def _draw_graph_compile(fx_g, _, name, clear_meta=True): 102 print(fx_g.code) 103 draw_graph(fx_g, name, clear_meta=clear_meta) 104 return fx_g 105 106 107def draw_graph_compile(name): 108 return make_boxed_compiler(partial(_draw_graph_compile, name=name)) 109 110 111@make_boxed_compiler 112def nop(fx_g: fx.GraphModule, _) -> Callable: 113 """ 114 Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler 115 and can be used to check accuracy. 116 117 .. warning:: 118 This API is experimental and likely to change. 119 120 """ 121 return fx_g 122 123 124class DebugInterpreter(fx.Interpreter): 125 def run(self, *args): 126 self.symbol_mapping = bind_symbols(self.module, *args) 127 super().run(*args) 128 129 def run_node(self, n): 130 def subst_symint(ni): 131 if not isinstance(ni, SymInt): 132 return ni 133 r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping)) 134 assert r.is_number, r 135 return int(r) 136 137 def subst_symint_tuple(nis): 138 return tuple(subst_symint(ni) for ni in nis) 139 140 def check_significant_strides(a, b): 141 if subst_symint(a.numel()) > 0: 142 for idx in range(a.ndim): 143 if ( 144 subst_symint(a.stride(idx)) != b.stride(idx) 145 and subst_symint(a.size(idx)) > 1 146 ): 147 return False 148 return True 149 150 def check(nv, rv, desc): 151 assert callable(desc) 152 assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}" 153 assert ( 154 subst_symint_tuple(nv.size()) == rv.size() 155 ), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}" 156 same_strides = check_significant_strides(nv, rv) 157 assert ( 158 same_strides 159 ), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}" 160 161 r = super().run_node(n) 162 if "val" in n.meta: 163 n_vals, n_spec = pytree.tree_flatten(n.meta["val"]) 164 r_vals, r_spec = pytree.tree_flatten(r) 165 # TODO: There is some sort of problem where we record that an 166 # operator returned a tuple/list, and then later it turns out the 167 # real version of the operator returned a list/tuple. Need to 168 # figure out what's actually going on here, the error itself is 169 # harmless enough as we only getitem out the outputs. 170 # assert n_spec == r_spec, f"{n_spec} != {r_spec}" 171 assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}" 172 for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals): 173 if not isinstance(rv, torch.Tensor): 174 continue 175 check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}") 176 return r 177 178 179@make_boxed_compiler 180def debug_nop(fx_g: fx.GraphModule, _) -> Callable: 181 """ 182 Returns a (slow) interpreter over the FX graph module that also checks 183 various debugging properties (e.g., that tracing strides matched real 184 strides.) 185 """ 186 return DebugInterpreter(fx_g).run 187 188 189@make_boxed_compiler 190def simple_ts_compile(fx_g, _): 191 strip_overloads(fx_g) 192 f = torch.jit.script(fx_g) 193 f = torch.jit.freeze(f.eval()) 194 return f 195 196 197def nnc_jit(f): 198 return aot_function(f, simple_ts_compile) 199 200 201aten = torch.ops.aten 202default_decompositions = { 203 aten.detach, 204 aten.gelu_backward, 205 aten.leaky_relu_backward, 206 aten.sigmoid_backward, 207 aten.threshold_backward, 208 aten.hardtanh_backward, 209 aten.hardsigmoid_backward, 210 aten.hardswish_backward, 211 aten.tanh_backward, 212 aten.silu_backward, 213 aten.elu_backward, 214 aten.cudnn_batch_norm, 215 aten.cudnn_batch_norm_backward, 216 aten.masked_fill.Scalar, 217 aten.masked_fill.Tensor, 218 aten.elu, 219 aten.leaky_relu, 220 aten.hardtanh, 221 aten.hardswish, 222 aten.hardsigmoid, 223 aten.conj_physical, 224 aten.is_same_size, 225} 226 227default_decompositions = get_decompositions(default_decompositions) 228 229 230@make_boxed_compiler 231def print_compile(fx_g, _): 232 print(fx_g.code) 233 return fx_g 234 235 236def memory_efficient_fusion( 237 fn: Union[Callable, nn.Module], 238 **kwargs, 239): 240 """ 241 Wrapper function over :func:`aot_function` and :func:`aot_module` to perform 242 memory efficient fusion. It uses the 243 :func:`min_cut_rematerialization_partition` partitioner to perform efficient 244 recomputation. It uses NVFuser to compile the generated forward and backward 245 graphs. 246 247 .. warning:: 248 This API is experimental and likely to change. 249 250 Args: 251 fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module`` 252 that takes one ore more arguments. Must return one or more Tensors. 253 **kwargs: Any other overrides you want to make to the settings 254 255 Returns: 256 Returns a ``Callable`` or ``nn.Module`` that retains the eager behavior 257 of the original :attr:`fn`, but whose forward and backward graphs have 258 gone through recomputation optimizations, and the graphs have been 259 compiled with nvfuser. 260 261 """ 262 config = { 263 "fw_compiler": ts_compile, 264 "bw_compiler": ts_compile, 265 "partition_fn": min_cut_rematerialization_partition, 266 "decompositions": default_decompositions, 267 } 268 config.update(kwargs) 269 if isinstance(fn, torch.nn.Module): 270 return aot_module(fn, **config) 271 else: 272 return aot_function(fn, **config) 273 274 275def debug_compile(fx_g, inps): 276 fx_g.to_folder("foo") 277 print( 278 f""" 279############################################################## 280# To minimize FX graph, copy and paste the below and run it # 281############################################################## 282 283import torch 284import torch.fx as fx 285from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess 286 287inps = {[(i.shape, i.dtype) for i in inps]} 288inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps] 289from foo import FxModule 290mod = FxModule().cuda() 291 292with torch.jit.fuser("fuser2"): 293 # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess 294 minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess) 295""" 296 ) 297 from foo import FxModule 298 299 FxModule().cuda()(*inps) 300 301 return ts_compile(fx_g, inps) 302 303 304graph_index = 0 305 306 307def get_inputs(input_data_path): 308 """ 309 Return a random input for the given inputs meta generated from _save_fx_default. 310 """ 311 inputs = [] 312 with open(input_data_path, "rb") as f: 313 inputs_meta = pickle.load(f) 314 inputs = [] 315 for meta in inputs_meta: 316 if len(meta) == 1: 317 type = meta 318 input = type(random.rand()) 319 else: 320 type, shape, stride, dtype, device = meta 321 if dtype in { 322 torch.int, 323 torch.int32, 324 torch.int64, 325 torch.bool, 326 torch.int, 327 torch.uint8, 328 int, 329 float, 330 }: 331 input = torch.randint(0, 1, shape, dtype=dtype, device=device) 332 else: 333 input = torch.rand(shape, dtype=dtype, device=device) 334 inputs.append(input) 335 return inputs 336 337 338def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs): 339 """ 340 The forward, backward, and joint computation graph will be stored in 341 {folder_name}/{current_name}/{current_name}_forward_{graph_index}, 342 {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and 343 {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively. 344 The input shape of the graphs will be stored in the .input files. 345 These files can be loaded with pickle, 346 and is a list of format (type, shape, stride, dtype, device). 347 In the case of type = int or float, it is just (type,). 348 For joint graph input, it is a nested list [[],[]] 349 where the two inner lists have the same format. 350 If dump_example_input is True, example_inputs will be stored in .pt file. 351 Since each function might produce multiple graphs, 352 the graph_index is used to distinguish difference graphs 353 """ 354 from functorch.compile import aot_module_simplified 355 356 def get_input_meta(args): 357 input_meta = [] 358 if len(args) > 0 and isinstance(args[0], tuple): # joint input 359 input_meta += get_input_meta(args[0]) 360 input_meta += get_input_meta(args[1]) 361 return input_meta 362 for arg in args: 363 if type(arg) == int or type(arg) == float: 364 input_meta.append((type(arg),)) 365 else: 366 input_meta.append( 367 (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device) 368 ) 369 return input_meta 370 371 def graph_saver_helper(gm_to_save, args, type_name): 372 global graph_index 373 if len(gm_to_save.graph.nodes) == 0: 374 log.log( 375 logging.WARNING, 376 "No nodes in graph {%s}_{%s}_{%s}.", 377 current_name, 378 type_name, 379 graph_index, 380 ) 381 return 382 383 gm = copy.deepcopy(gm_to_save) 384 gm.graph.set_codegen(torch.fx.graph.CodeGen()) # remove codegen 385 gm.recompile() 386 387 input_meta = get_input_meta(args) 388 389 os.makedirs(f"{folder_name}/{current_name}", exist_ok=True) 390 gm.to_folder( 391 f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}" 392 ) 393 pickle.dump( 394 input_meta, 395 open( 396 f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input", # noqa: B950 397 "wb", 398 ), 399 ) # noqa: E501 400 if dump_example_input: 401 torch.save( 402 args, 403 f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt", # noqa: B950 404 ) # noqa: E501 405 406 def graph_saver_forward(gm, fw_args): 407 graph_saver_helper(gm, fw_args, "forward") 408 return gm 409 410 def graph_saver_backward(gm, bw_args): 411 graph_saver_helper(gm, bw_args, "backward") 412 global graph_index 413 graph_index += 1 414 return gm 415 416 def graph_saver_joint(gm, joint_args): 417 graph_saver_helper(gm, joint_args, "joint") 418 return default_partition(gm, joint_args) 419 420 return aot_module_simplified( 421 gm, 422 example_inputs, 423 fw_compiler=graph_saver_forward, 424 bw_compiler=graph_saver_backward, 425 partition_fn=graph_saver_joint, 426 decompositions=default_decompositions, 427 ) 428 429 430# WARNING: This isn't tested anywhere!! 431def graph_dumper_aot(current_name, folder_name, dump_example_input=False): 432 """ 433 Dump the forward, backward, and joint computation graph. 434 Example Usage: 435 save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False) 436 optimize_ctx = torchdynamo.optimize( 437 save_fx_func 438 ) 439 with torch.enable_grad(): 440 with optimize_ctx: 441 result = forward_and_backward_pass(model, example_inputs) 442 """ 443 global graph_index 444 graph_index = 0 445 return partial(_save_fx_default, current_name, folder_name, dump_example_input) 446