1# mypy: allow-untyped-defs 2import collections 3import logging 4 5import torch 6from torch.fx.passes.graph_transform_observer import GraphTransformObserver 7from torch.fx.passes.shape_prop import _extract_tensor_metadata 8 9from .. import config, inductor_prims 10from ..pattern_matcher import ( 11 CallFunctionVarArgs, 12 Match, 13 PatternMatcherPass, 14 register_graph_pattern, 15) 16from ..virtualized import V 17 18 19log = logging.getLogger(__name__) 20patterns = PatternMatcherPass() 21aten = torch.ops.aten 22 23 24def replace_random_passes(gm: torch.fx.GraphModule): 25 """Modify the given FX graph to use backend-native random ops""" 26 if config.fallback_random: 27 return 0 28 29 count = patterns.apply(gm) 30 with GraphTransformObserver( 31 gm, "fuse_seed_creation_pass", config.trace.log_url_for_graph_xform 32 ): 33 count += fuse_seed_creation_pass(gm.graph) 34 35 return count 36 37 38def fuse_seed_creation_pass(graph: torch.fx.Graph): 39 """ 40 Horizontally fuse all the seed generation on each device 41 42 a = inductor_seed(dev) 43 b = inductor_seed(dev) 44 45 Becomes: 46 seeds = inductor_seeds(2, dev) 47 a = inductor_lookup_seed(seeds, 0) 48 b = inductor_lookup_seed(seeds, 1) 49 50 We do this because seed creation is entirely launch overhead bound. 51 """ 52 device_seeds = collections.defaultdict(list) 53 for node in graph.nodes: 54 if CallFunctionVarArgs(inductor_prims.seed).match(node): 55 device_seeds[node.args[0]].append(node) 56 57 if not device_seeds: 58 return 0 59 60 for device, seeds in device_seeds.items(): 61 with graph.inserting_before(seeds[0]): 62 combined = graph.call_function(inductor_prims.seeds, (len(seeds), device)) 63 with V.fake_mode: 64 combined.meta["val"] = torch.empty( 65 [len(seeds)], device=device, dtype=torch.int64 66 ) 67 combined.meta["tensor_meta"] = _extract_tensor_metadata( 68 combined.meta["val"] 69 ) 70 71 for idx, seed in enumerate(seeds): 72 with graph.inserting_before(seed): 73 new_seed = graph.call_function( 74 inductor_prims.lookup_seed, (combined, idx) 75 ) 76 seed.replace_all_uses_with(new_seed) 77 new_seed.meta.update(seed.meta) 78 graph.erase_node(seed) 79 80 return len(device_seeds) 81 82 83def default_kwargs(device): 84 return {} 85 86 87def get_device(device): 88 if device is not None: 89 return device 90 return torch.empty([]).device # default device 91 92 93@register_graph_pattern(CallFunctionVarArgs(aten.rand.default), pass_dict=patterns) 94@register_graph_pattern(CallFunctionVarArgs(aten.rand.generator), pass_dict=patterns) 95@register_graph_pattern(CallFunctionVarArgs(aten.randn.default), pass_dict=patterns) 96@register_graph_pattern(CallFunctionVarArgs(aten.randn.generator), pass_dict=patterns) 97def replace_random( 98 match: Match, 99 size, 100 *, 101 generator=None, 102 dtype=None, 103 device=None, 104 layout=None, 105 pin_memory=None, 106): 107 if generator is not None: 108 return 109 110 def replacement(size): 111 result = inductor_prims.random( 112 size, inductor_prims.seed(device), mode, **default_kwargs(device) 113 ) 114 if dtype is not None: 115 result = result.to(dtype) 116 return result 117 118 mode = { 119 aten.rand: "rand", 120 aten.randn: "randn", 121 }[ 122 match.output_node().target.overloadpacket # type: ignore[union-attr] 123 ] # type: ignore[union-attr] 124 device = get_device(device) 125 match.replace_by_example(replacement, [size]) 126 127 128@register_graph_pattern(CallFunctionVarArgs(aten.randint.low), pass_dict=patterns) 129def replace_randint( 130 match: Match, 131 low, 132 high, 133 size, 134 *, 135 dtype=torch.int64, 136 device=None, 137 layout=None, 138 pin_memory=None, 139): 140 def replacement(low, high, size): 141 result = inductor_prims.randint(low, high, size, inductor_prims.seed(device)) 142 return result.to(dtype) 143 144 device = get_device(device) 145 match.replace_by_example(replacement, [low, high, size]) 146