1# mypy: allow-untyped-defs 2import functools 3from typing import Dict, Set, Tuple 4 5import torch 6from torch._dynamo.utils import counters 7from torch._ops import OpOverload, OpOverloadPacket 8 9from ..pattern_matcher import fwd_only, register_replacement 10 11 12aten = torch.ops.aten 13 14 15@functools.lru_cache(None) 16def _misc_patterns_init(): 17 from .joint_graph import patterns as joint_graph_patterns 18 from .post_grad import pass_patterns as post_grad_patterns_all 19 20 post_grad_patterns = post_grad_patterns_all[1] # medium priority 21 22 if torch.cuda.is_available(): 23 # workaround https://github.com/pytorch/pytorch/issues/97894 24 device = "cuda" 25 else: 26 device = "cpu" 27 28 # These patterns do 2 things 29 # 1. Since we know that index is completely unique, we can codegen it using 30 # stores instead of atomic adds, which is quite a bit faster. 31 # 2. Also, since we are guaranteed that they are completely within bounds, 32 # we can use unsafe indexing and skip debug asserts 33 def randperm_index_add_pattern(x, y): 34 index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] 35 return torch.index_add(x, dim=0, source=y, index=index), index 36 37 def randperm_index_add_replacement(x, y): 38 index = torch.randperm(x.shape[0], device=x.device)[: y.shape[0]] 39 return ( 40 torch.ops.aten._unsafe_index_put( 41 x, (index,), aten._unsafe_index(x, (index,)) + y, accumulate=False 42 ), 43 index, 44 ) 45 46 register_replacement( 47 randperm_index_add_pattern, 48 randperm_index_add_replacement, 49 [torch.empty(4, 8, device=device), torch.empty(2, 8, device=device)], 50 fwd_only, 51 [post_grad_patterns, joint_graph_patterns], 52 ) 53 54 def randperm_index_pattern(x, slice_shape): 55 index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] 56 return torch.ops.aten.index(x, (index,)), index 57 58 def randperm_index_replacement(x, slice_shape): 59 index = torch.randperm(x.shape[0], device=x.device)[:slice_shape] 60 return torch.ops.aten._unsafe_index(x, (index,)), index 61 62 register_replacement( 63 randperm_index_pattern, 64 randperm_index_replacement, 65 [torch.empty(4, 8, device=device)], 66 fwd_only, 67 [post_grad_patterns, joint_graph_patterns], 68 scalar_workaround={"slice_shape": 42}, 69 ) 70 71 72class NumpyCompatNormalization: 73 numpy_compat: Dict[str, Tuple[str, ...]] = { 74 "dim": ("axis",), 75 "keepdim": ("keepdims",), 76 "input": ("x", "a", "x1"), 77 "other": ("x2",), 78 } 79 inverse_mapping: Dict[str, str] 80 cache: Dict["torch.fx.graph.Target", Set[str]] 81 82 def __init__(self) -> None: 83 self.cache = {} # callable -> tuple of replaceable args e.g. ["axis"] 84 self.inverse_mapping = {} 85 for actual_kwarg, numpy_kwargs in self.numpy_compat.items(): 86 for numpy_kwarg in numpy_kwargs: 87 assert numpy_kwarg not in self.inverse_mapping 88 self.inverse_mapping[numpy_kwarg] = actual_kwarg 89 90 def __call__(self, graph: torch.fx.Graph): 91 for node in graph.nodes: 92 if node.op != "call_function": 93 continue 94 if isinstance(node.target, (OpOverload, OpOverloadPacket)): 95 # only applies to torch ops; e.g. torch.stack(axis=1) works, torch.ops.aten.stack(axis=1) doesn't. 96 continue 97 kwargs = node.kwargs 98 99 if node.target in self.cache: 100 replaceable_kwargs = self.cache[node.target] 101 else: 102 signatures = torch.fx.operator_schemas.get_signature_for_torch_op( 103 node.target 104 ) 105 signatures = () if signatures is None else signatures 106 replaceable_kwargs = set() 107 for sig in signatures: 108 for param_name in sig.parameters.keys(): 109 if param_name in self.numpy_compat: 110 replaceable_kwargs.update(self.numpy_compat[param_name]) 111 112 self.cache[node.target] = replaceable_kwargs 113 114 if not replaceable_kwargs: 115 continue 116 117 new_kwargs = {} 118 kwargs_changed = False 119 for k, v in kwargs.items(): 120 if k in replaceable_kwargs: 121 kwargs_changed = True 122 new_kwargs[self.inverse_mapping[k]] = v 123 else: 124 new_kwargs[k] = v 125 126 if kwargs_changed: 127 node.kwargs = torch.fx.immutable_collections.immutable_dict(new_kwargs) 128 counters["inductor"]["numpy_compat_normalization"] += 1 129 130 131numpy_compat_normalization = NumpyCompatNormalization() 132