xref: /aosp_15_r20/external/pytorch/torch/_functorch/compilers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import copy
4import logging
5import os
6import pickle
7import random
8from contextlib import contextmanager
9from functools import partial
10from typing import Callable, Union
11
12import sympy
13
14import torch
15import torch.fx as fx
16import torch.nn as nn
17import torch.utils._pytree as pytree
18from torch import SymInt
19from torch._decomp import get_decompositions
20from torch.fx.experimental.symbolic_shapes import bind_symbols
21
22from .aot_autograd import aot_function, aot_module, make_boxed_compiler
23from .compile_utils import strip_overloads
24from .partitioners import (
25    default_partition,
26    draw_graph,
27    min_cut_rematerialization_partition,
28)
29
30
31log = logging.getLogger(__name__)
32
33
34# These canonicalizations are needed here (and not decompositions), as the ops
35# we're trying to canonicalize to CompositeImplicitAutograd.
36def _canonicalize(fx_g):
37    for node in fx_g.graph.find_nodes(
38        op="call_function", target=torch.ops.aten._to_copy
39    ):
40        node.target = torch.ops.aten.to
41    fx_g.recompile()
42    return fx_g
43
44
45@contextmanager
46def _disable_jit_autocast():
47    old_jit_autocast_flag = torch._C._jit_set_autocast_mode(False)
48    try:
49        yield
50    finally:
51        torch._C._jit_set_autocast_mode(old_jit_autocast_flag)
52
53
54@make_boxed_compiler
55def ts_compile(fx_g: fx.GraphModule, inps) -> Callable:
56    """
57    Compiles the :attr:`fx_g` with Torchscript compiler.
58
59    .. warning::
60        This API is experimental and likely to change.
61
62    Args:
63        fx_g(fx.GraphModule): The input Fx graph module to be compiled.
64
65    Returns:
66        Torch scripted model.
67    """
68
69    with _disable_jit_autocast():
70        strip_overloads(fx_g)
71
72        for node in fx_g.graph.find_nodes(
73            op="call_function", target=torch.ops.aten._to_copy
74        ):
75            if len(node.args) == 1 and len(node.kwargs) == 1 and "dtype" in node.kwargs:
76                node.target = torch.ops.aten.to
77
78        for node in fx_g.graph.nodes:
79            new_kwargs = {}
80            for k, v in node.kwargs.items():
81                if isinstance(v, torch.device):
82                    v = v.type
83                new_kwargs[k] = v
84            node.kwargs = new_kwargs
85
86        fx_g.graph.lint()
87
88        fx_g.recompile()
89
90        f = torch.jit.script(fx_g)
91
92        torch._C._jit_pass_remove_mutation(f.graph)
93
94        f = torch.jit.freeze(f.eval())
95        f = torch.jit.optimize_for_inference(f)
96        if not any(isinstance(t, torch._subclasses.FakeTensor) for t in inps):
97            f(*inps)
98    return f
99
100
101def _draw_graph_compile(fx_g, _, name, clear_meta=True):
102    print(fx_g.code)
103    draw_graph(fx_g, name, clear_meta=clear_meta)
104    return fx_g
105
106
107def draw_graph_compile(name):
108    return make_boxed_compiler(partial(_draw_graph_compile, name=name))
109
110
111@make_boxed_compiler
112def nop(fx_g: fx.GraphModule, _) -> Callable:
113    """
114    Returns the :attr:`fx_g` Fx graph module as it is. This is a no-op compiler
115    and can be used to check accuracy.
116
117    .. warning::
118        This API is experimental and likely to change.
119
120    """
121    return fx_g
122
123
124class DebugInterpreter(fx.Interpreter):
125    def run(self, *args):
126        self.symbol_mapping = bind_symbols(self.module, *args)
127        super().run(*args)
128
129    def run_node(self, n):
130        def subst_symint(ni):
131            if not isinstance(ni, SymInt):
132                return ni
133            r = sympy.expand(ni.node.expr.xreplace(self.symbol_mapping))
134            assert r.is_number, r
135            return int(r)
136
137        def subst_symint_tuple(nis):
138            return tuple(subst_symint(ni) for ni in nis)
139
140        def check_significant_strides(a, b):
141            if subst_symint(a.numel()) > 0:
142                for idx in range(a.ndim):
143                    if (
144                        subst_symint(a.stride(idx)) != b.stride(idx)
145                        and subst_symint(a.size(idx)) > 1
146                    ):
147                        return False
148            return True
149
150        def check(nv, rv, desc):
151            assert callable(desc)
152            assert nv.dtype == rv.dtype, f"{desc()}: {nv.dtype} != {rv.dtype}"
153            assert (
154                subst_symint_tuple(nv.size()) == rv.size()
155            ), f"{desc()}: {nv.size()} aka {subst_symint_tuple(nv.size())} != {rv.size()}"
156            same_strides = check_significant_strides(nv, rv)
157            assert (
158                same_strides
159            ), f"{desc()}: {nv.stride()} aka {subst_symint_tuple(nv.stride())} != {rv.stride()}"
160
161        r = super().run_node(n)
162        if "val" in n.meta:
163            n_vals, n_spec = pytree.tree_flatten(n.meta["val"])
164            r_vals, r_spec = pytree.tree_flatten(r)
165            # TODO: There is some sort of problem where we record that an
166            # operator returned a tuple/list, and then later it turns out the
167            # real version of the operator returned a list/tuple. Need to
168            # figure out what's actually going on here, the error itself is
169            # harmless enough as we only getitem out the outputs.
170            # assert n_spec == r_spec, f"{n_spec} != {r_spec}"
171            assert len(n_vals) == len(r_vals), f"{len(n_vals)} != {len(r_vals)}"
172            for i, nv, rv in zip(range(len(n_vals)), n_vals, r_vals):
173                if not isinstance(rv, torch.Tensor):
174                    continue
175                check(nv, rv, lambda: f"output {i} where {self.symbol_mapping}")
176        return r
177
178
179@make_boxed_compiler
180def debug_nop(fx_g: fx.GraphModule, _) -> Callable:
181    """
182    Returns a (slow) interpreter over the FX graph module that also checks
183    various debugging properties (e.g., that tracing strides matched real
184    strides.)
185    """
186    return DebugInterpreter(fx_g).run
187
188
189@make_boxed_compiler
190def simple_ts_compile(fx_g, _):
191    strip_overloads(fx_g)
192    f = torch.jit.script(fx_g)
193    f = torch.jit.freeze(f.eval())
194    return f
195
196
197def nnc_jit(f):
198    return aot_function(f, simple_ts_compile)
199
200
201aten = torch.ops.aten
202default_decompositions = {
203    aten.detach,
204    aten.gelu_backward,
205    aten.leaky_relu_backward,
206    aten.sigmoid_backward,
207    aten.threshold_backward,
208    aten.hardtanh_backward,
209    aten.hardsigmoid_backward,
210    aten.hardswish_backward,
211    aten.tanh_backward,
212    aten.silu_backward,
213    aten.elu_backward,
214    aten.cudnn_batch_norm,
215    aten.cudnn_batch_norm_backward,
216    aten.masked_fill.Scalar,
217    aten.masked_fill.Tensor,
218    aten.elu,
219    aten.leaky_relu,
220    aten.hardtanh,
221    aten.hardswish,
222    aten.hardsigmoid,
223    aten.conj_physical,
224    aten.is_same_size,
225}
226
227default_decompositions = get_decompositions(default_decompositions)
228
229
230@make_boxed_compiler
231def print_compile(fx_g, _):
232    print(fx_g.code)
233    return fx_g
234
235
236def memory_efficient_fusion(
237    fn: Union[Callable, nn.Module],
238    **kwargs,
239):
240    """
241    Wrapper function over :func:`aot_function` and :func:`aot_module` to perform
242    memory efficient fusion. It uses the
243    :func:`min_cut_rematerialization_partition` partitioner to perform efficient
244    recomputation. It uses NVFuser to compile the generated forward and backward
245    graphs.
246
247    .. warning::
248        This API is experimental and likely to change.
249
250    Args:
251        fn (Union[Callable, nn.Module]): A Python function or a ``nn.Module``
252            that takes one ore more arguments. Must return one or more Tensors.
253        **kwargs: Any other overrides you want to make to the settings
254
255    Returns:
256        Returns a ``Callable``  or ``nn.Module`` that retains the eager behavior
257        of the original :attr:`fn`, but whose forward and backward graphs have
258        gone through recomputation optimizations, and the graphs have been
259        compiled with nvfuser.
260
261    """
262    config = {
263        "fw_compiler": ts_compile,
264        "bw_compiler": ts_compile,
265        "partition_fn": min_cut_rematerialization_partition,
266        "decompositions": default_decompositions,
267    }
268    config.update(kwargs)
269    if isinstance(fn, torch.nn.Module):
270        return aot_module(fn, **config)
271    else:
272        return aot_function(fn, **config)
273
274
275def debug_compile(fx_g, inps):
276    fx_g.to_folder("foo")
277    print(
278        f"""
279##############################################################
280# To minimize FX graph, copy and paste the below and run it  #
281##############################################################
282
283import torch
284import torch.fx as fx
285from functorch.compile import minifier, check_nvfuser_subprocess, check_nvfuser_correctness_subprocess
286
287inps = {[(i.shape, i.dtype) for i in inps]}
288inps = [torch.ones(shape, dtype=dtype, device='cuda') for (shape, dtype) in inps]
289from foo import FxModule
290mod = FxModule().cuda()
291
292with torch.jit.fuser("fuser2"):
293  # check_nvfuser_subprocess can be replaced with check_nvfuser_correctness_subprocess
294  minifier(fx.symbolic_trace(mod), inps, check_nvfuser_subprocess)
295"""
296    )
297    from foo import FxModule
298
299    FxModule().cuda()(*inps)
300
301    return ts_compile(fx_g, inps)
302
303
304graph_index = 0
305
306
307def get_inputs(input_data_path):
308    """
309    Return a random input for the given inputs meta generated from _save_fx_default.
310    """
311    inputs = []
312    with open(input_data_path, "rb") as f:
313        inputs_meta = pickle.load(f)
314        inputs = []
315        for meta in inputs_meta:
316            if len(meta) == 1:
317                type = meta
318                input = type(random.rand())
319            else:
320                type, shape, stride, dtype, device = meta
321                if dtype in {
322                    torch.int,
323                    torch.int32,
324                    torch.int64,
325                    torch.bool,
326                    torch.int,
327                    torch.uint8,
328                    int,
329                    float,
330                }:
331                    input = torch.randint(0, 1, shape, dtype=dtype, device=device)
332                else:
333                    input = torch.rand(shape, dtype=dtype, device=device)
334            inputs.append(input)
335    return inputs
336
337
338def _save_fx_default(current_name, folder_name, dump_example_input, gm, example_inputs):
339    """
340    The forward, backward, and joint computation graph will be stored in
341    {folder_name}/{current_name}/{current_name}_forward_{graph_index},
342    {folder_name}/{current_name}/{current_name}_backward_{graph_index}, and
343    {folder_name}/{current_name}/{current_name}_joint_{graph_index} respectively.
344    The input shape of the graphs will be stored in the .input files.
345    These files can be loaded with pickle,
346    and is a list of format (type, shape, stride, dtype, device).
347    In the case of type = int or float, it is just (type,).
348    For joint graph input, it is a nested list [[],[]]
349    where the two inner lists have the same format.
350    If dump_example_input is True, example_inputs will be stored in .pt file.
351    Since each function might produce multiple graphs,
352    the graph_index is used to distinguish difference graphs
353    """
354    from functorch.compile import aot_module_simplified
355
356    def get_input_meta(args):
357        input_meta = []
358        if len(args) > 0 and isinstance(args[0], tuple):  # joint input
359            input_meta += get_input_meta(args[0])
360            input_meta += get_input_meta(args[1])
361            return input_meta
362        for arg in args:
363            if type(arg) == int or type(arg) == float:
364                input_meta.append((type(arg),))
365            else:
366                input_meta.append(
367                    (type(arg), arg.shape, arg.stride(), arg.dtype, arg.device)
368                )
369        return input_meta
370
371    def graph_saver_helper(gm_to_save, args, type_name):
372        global graph_index
373        if len(gm_to_save.graph.nodes) == 0:
374            log.log(
375                logging.WARNING,
376                "No nodes in graph {%s}_{%s}_{%s}.",
377                current_name,
378                type_name,
379                graph_index,
380            )
381            return
382
383        gm = copy.deepcopy(gm_to_save)
384        gm.graph.set_codegen(torch.fx.graph.CodeGen())  # remove codegen
385        gm.recompile()
386
387        input_meta = get_input_meta(args)
388
389        os.makedirs(f"{folder_name}/{current_name}", exist_ok=True)
390        gm.to_folder(
391            f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}"
392        )
393        pickle.dump(
394            input_meta,
395            open(
396                f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.input",  # noqa: B950
397                "wb",
398            ),
399        )  # noqa: E501
400        if dump_example_input:
401            torch.save(
402                args,
403                f"{folder_name}/{current_name}/{current_name}_{type_name}_{graph_index}/{current_name}_{type_name}_{graph_index}.pt",  # noqa: B950
404            )  # noqa: E501
405
406    def graph_saver_forward(gm, fw_args):
407        graph_saver_helper(gm, fw_args, "forward")
408        return gm
409
410    def graph_saver_backward(gm, bw_args):
411        graph_saver_helper(gm, bw_args, "backward")
412        global graph_index
413        graph_index += 1
414        return gm
415
416    def graph_saver_joint(gm, joint_args):
417        graph_saver_helper(gm, joint_args, "joint")
418        return default_partition(gm, joint_args)
419
420    return aot_module_simplified(
421        gm,
422        example_inputs,
423        fw_compiler=graph_saver_forward,
424        bw_compiler=graph_saver_backward,
425        partition_fn=graph_saver_joint,
426        decompositions=default_decompositions,
427    )
428
429
430# WARNING: This isn't tested anywhere!!
431def graph_dumper_aot(current_name, folder_name, dump_example_input=False):
432    """
433    Dump the forward, backward, and joint computation graph.
434    Example Usage:
435    save_fx_func = graph_dumper_aot(current_name, folder_name, dump_example_input = False)
436    optimize_ctx = torchdynamo.optimize(
437        save_fx_func
438    )
439    with torch.enable_grad():
440        with optimize_ctx:
441            result = forward_and_backward_pass(model, example_inputs)
442    """
443    global graph_index
444    graph_index = 0
445    return partial(_save_fx_default, current_name, folder_name, dump_example_input)
446