xref: /aosp_15_r20/external/pytorch/torch/_inductor/fx_passes/replace_random.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import logging
4
5import torch
6from torch.fx.passes.graph_transform_observer import GraphTransformObserver
7from torch.fx.passes.shape_prop import _extract_tensor_metadata
8
9from .. import config, inductor_prims
10from ..pattern_matcher import (
11    CallFunctionVarArgs,
12    Match,
13    PatternMatcherPass,
14    register_graph_pattern,
15)
16from ..virtualized import V
17
18
19log = logging.getLogger(__name__)
20patterns = PatternMatcherPass()
21aten = torch.ops.aten
22
23
24def replace_random_passes(gm: torch.fx.GraphModule):
25    """Modify the given FX graph to use backend-native random ops"""
26    if config.fallback_random:
27        return 0
28
29    count = patterns.apply(gm)
30    with GraphTransformObserver(
31        gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform
32    ):
33        count += fuse_seed_creation_pass(gm.graph)
34
35    return count
36
37
38def fuse_seed_creation_pass(graph: torch.fx.Graph):
39    """
40    Horizontally fuse all the seed generation on each device
41
42        a = inductor_seed(dev)
43        b = inductor_seed(dev)
44
45    Becomes:
46        seeds = inductor_seeds(2, dev)
47        a = inductor_lookup_seed(seeds, 0)
48        b = inductor_lookup_seed(seeds, 1)
49
50    We do this because seed creation is entirely launch overhead bound.
51    """
52    device_seeds = collections.defaultdict(list)
53    for node in graph.nodes:
54        if CallFunctionVarArgs(inductor_prims.seed).match(node):
55            device_seeds[node.args[0]].append(node)
56
57    if not device_seeds:
58        return 0
59
60    for device, seeds in device_seeds.items():
61        with graph.inserting_before(seeds[0]):
62            combined = graph.call_function(inductor_prims.seeds, (len(seeds), device))
63            with V.fake_mode:
64                combined.meta["val"] = torch.empty(
65                    [len(seeds)], device=device, dtype=torch.int64
66                )
67                combined.meta["tensor_meta"] = _extract_tensor_metadata(
68                    combined.meta["val"]
69                )
70
71        for idx, seed in enumerate(seeds):
72            with graph.inserting_before(seed):
73                new_seed = graph.call_function(
74                    inductor_prims.lookup_seed, (combined, idx)
75                )
76            seed.replace_all_uses_with(new_seed)
77            new_seed.meta.update(seed.meta)
78            graph.erase_node(seed)
79
80    return len(device_seeds)
81
82
83def default_kwargs(device):
84    return {}
85
86
87def get_device(device):
88    if device is not None:
89        return device
90    return torch.empty([]).device  # default device
91
92
93@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns)
94@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns)
95@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns)
96@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns)
97def replace_random(
98    match: Match,
99    size,
100    *,
101    generator=None,
102    dtype=None,
103    device=None,
104    layout=None,
105    pin_memory=None,
106):
107    if generator is not None:
108        return
109
110    def replacement(size):
111        result = inductor_prims.random(
112            size, inductor_prims.seed(device), mode, **default_kwargs(device)
113        )
114        if dtype is not None:
115            result = result.to(dtype)
116        return result
117
118    mode = {
119        aten.rand: "rand",
120        aten.randn: "randn",
121    }[
122        match.output_node().target.overloadpacket  # type: ignore[union-attr]
123    ]  # type: ignore[union-attr]
124    device = get_device(device)
125    match.replace_by_example(replacement, [size])
126
127
128@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns)
129def replace_randint(
130    match: Match,
131    low,
132    high,
133    size,
134    *,
135    dtype=torch.int64,
136    device=None,
137    layout=None,
138    pin_memory=None,
139):
140    def replacement(low, high, size):
141        result = inductor_prims.randint(low, high, size, inductor_prims.seed(device))
142        return result.to(dtype)
143
144    device = get_device(device)
145    match.replace_by_example(replacement, [low, high, size])
146