1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9r""" 10Handle the following op convertions: 11- convert a functional op to an out variant op 12- convert an out variant op to a scratch op. 13 14We assume there is already a functionalization pass being done that removes aliases and inplace variants. 15 16For the to_out_variant convertion, The functional variant will be represented 17as qualified op name plus the overload name. The returned out variant constains 18the following information 19- the OpOverload for the out variant 20- the list of keyward arguments names that are out variables. There should be 21 at least one out variables. Some ops may also have multiple out variables, 22 e.g. aten::topk.values returns both values and indices for the topk elements. 23 24""" 25 26import dataclasses 27import logging 28from typing import Dict, Optional, Tuple 29 30import torch 31from torch._ops import OpOverload 32from torchgen.model import FunctionSchema, SchemaKind 33 34# cache the FunctionSchema so we don't need to parse everytime> 35# Use OpOverload as hash key. We can not use torch._C.FunctionSchema as key since 36# it's not hashable. 37_op_overload_to_schema_cache: Dict[OpOverload, FunctionSchema] = {} 38 39# Value type is Optional so we can cache None if an op does not have 40# out variant/scratch op. This way, we don't need to confuse the op not 41# existing case with cache miss. 42_func_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {} 43_out_variant_to_scratch_map: Dict[OpOverload, Optional[OpOverload]] = {} 44_mutable_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {} 45 46# We've found a functional and an out variant with the same name, but their 47# schemas mismatch. This map collects all of these cases and provides proper 48# error message to user. The key is an `OpOverload` of a functional variant. 49_schema_mismatch_map: Dict[OpOverload, Optional[FunctionSchema]] = {} 50 51 52def _pybind_schema_to_native_schema( 53 pybind_schema: torch._C.FunctionSchema, 54) -> Optional[FunctionSchema]: 55 """ 56 We have 2 FunctionSchema definitions in python. 57 One is defined in torchgen (call it native FunctionSchema), another is a 58 pybind of c10::FunctionSchema (call it pybind FunctionSchema). 59 Because we want to leverage torchgen to handle out variant, we will 60 convert any pybind FunctionSchema to native FunctionSchema. 61 """ 62 native_schema = None 63 try: 64 native_schema = FunctionSchema.parse(str(pybind_schema)) 65 except (RuntimeError, AssertionError, ValueError): 66 # Need catch AssertionError since parsing prim ops like: 67 # aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b) 68 # cause an asertion error in torchgen when parsiong annotation 'a|b'. 69 # We should ignore it. Hopefully one day the C++ FunctionSchema parsing 70 # is 100% consistent with Python FunctionSchema parsing, then we don't need 71 # catch these exceptions any more. 72 73 # We also need catch ValueError for schema like: 74 # aten::copy.Dict_str(Dict(str, t)(a) self) -> Dict(str, t) 75 # torchgen throws ValueError since it does not expect the type string 76 # containing commas. Ignore those schemas for now. 77 logging.debug(f"Fail to parse function schema: {str(pybind_schema)}") 78 # ignore failure and return None. There are some schemas defined as 79 # prim ops that can not be parsed by torchgen. E.g.: 80 # https://www.fburl.com/code/1vvzhssa 81 # We should be safe to ignore them since PyE are not using these ops. 82 return native_schema 83 84 85def _get_overload_schema(op_overload: OpOverload) -> Optional[FunctionSchema]: 86 native_schema = _op_overload_to_schema_cache.get(op_overload) 87 if not native_schema: 88 native_schema = _pybind_schema_to_native_schema(op_overload._schema) 89 _op_overload_to_schema_cache[op_overload] = native_schema # pyre-ignore 90 return native_schema 91 92 93def get_out_args_from_opoverload(op_overload: OpOverload) -> Tuple[str]: 94 return get_out_args_from_schema(_get_overload_schema(op_overload)) # pyre-ignore 95 96 97def get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]: 98 """ 99 Assume the input is the schema for an out variant. 100 Return the name list of the out arguments. 101 """ 102 assert ( 103 out_var_schema.is_out_fn() 104 ), f"Expect an out variant, but get: {out_var_schema}" 105 return tuple(arg.name for arg in out_var_schema.arguments.out) 106 107 108def parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]: 109 """ 110 Given a qualified opname like aten::add, return a tuple for namespace 111 (aten here) and op name (add here) 112 """ 113 ns_and_opname = qualified_opname.split("::") 114 if len(ns_and_opname) != 2: 115 raise RuntimeError(f"Invalid qualified_opname {qualified_opname}") 116 return tuple(ns_and_opname) 117 118 119def get_op_overload(qualified_opname: str, overload: str) -> OpOverload: 120 """ 121 Arguments: 122 qualified_opname: string like {namespace}::{op name} 123 overload: the overload string of the op 124 """ 125 ns, opname = parse_qualified_opname(qualified_opname) 126 if not overload: 127 overload = "default" 128 return getattr(getattr(getattr(torch.ops, ns), opname), overload) 129 130 131def schema_to_opoverload(schema: FunctionSchema) -> OpOverload: 132 qualified_name = str(schema.name.name) 133 overload = schema.name.overload_name 134 return get_op_overload(qualified_name, overload) 135 136 137def set_mapping_for_op(op: OpOverload) -> None: 138 """ 139 op can either be a functional op, mutable op, or out variant op. 140 This method is only called if 141 1. either op is a functional op and it's missing in the _func_to_out_variant_map cache. 142 2. or op is a out variant op and it's missing in the _out_variant_to_scratch_map cache. 143 144 Setup entries in _func_to_out_variant_map and _out_variant_to_scratch_map for all ops sharing the same 145 op name as the passed in OpOverload. 146 """ 147 native_schema = _pybind_schema_to_native_schema(op._schema) 148 # pyre-fixme[16]: `Optional` has no attribute `kind`. 149 assert native_schema.kind() in ( 150 SchemaKind.functional, 151 SchemaKind.out, 152 SchemaKind.mutable, 153 ) 154 assert not ( 155 native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map 156 ) 157 assert not ( 158 native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map 159 ) 160 assert not ( 161 native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map 162 ) 163 qualified_opname = str(op._schema.name) 164 165 all_schemas = [ 166 _pybind_schema_to_native_schema(pybind_schema) 167 for pybind_schema in torch._C._jit_get_schemas_for_operator(qualified_opname) 168 ] 169 170 # skip the schema that we can not be parsed by torchgen 171 all_schemas = [schema for schema in all_schemas if schema is not None] 172 173 group_by_signature: Dict[str, Dict[SchemaKind, FunctionSchema]] = {} 174 175 for schema in all_schemas: 176 signature = schema.signature() 177 # override the return type to an empty tuple. Otherwise, for ops like 178 # aten.slice.Tensor_out that returns a Tensor list, 179 # the signature of the schema does not match the one for the functional 180 # op aten.slice.Tensor because of different return type. 181 # Schema for aten.slice.Tensor_out: 182 # split.Tensor_out(Tensor(a -> *) self, int split_size, int dim=0, *, Tensor(a!)[] out) -> () 183 # Schema for aten.slice.Tensor 184 # split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[] 185 # The reason of the above inconsistency is explained in: https://github.com/pytorch/pytorch/pull/76049 186 signature = dataclasses.replace(signature, returns=()) 187 188 kind = schema.kind() 189 # pyre-fixme[6]: For 1st argument expected `str` but got `FunctionSchema`. 190 group_by_kind = group_by_signature.setdefault(signature, {}) 191 assert ( 192 kind not in group_by_kind 193 ), f"Schema of kind {kind} already exist for {schema}" 194 group_by_kind[kind] = schema 195 196 # add all the functional op -> out variant op pairs to the cache 197 for group_by_kind in group_by_signature.values(): 198 func_op_schema = group_by_kind.get(SchemaKind.functional) 199 out_var_schema = group_by_kind.get(SchemaKind.out) 200 mutable_op_schema = group_by_kind.get(SchemaKind.mutable) 201 scratch_schema = group_by_kind.get(SchemaKind.scratch) 202 203 # update the map even if out_var_schema is None to cache the negative 204 # case 205 if func_op_schema: 206 _func_to_out_variant_map[schema_to_opoverload(func_op_schema)] = ( 207 schema_to_opoverload(out_var_schema) if out_var_schema else None 208 ) 209 # out variant schema missing from group_by_kind 210 if out_var_schema is None: 211 # find the out variant with a schema different than the functional variant 212 mismatched_out_schema: Optional[FunctionSchema] = next( 213 (s for s in all_schemas if s.kind() == SchemaKind.out), None 214 ) 215 _schema_mismatch_map[schema_to_opoverload(func_op_schema)] = ( 216 mismatched_out_schema 217 ) 218 219 # update hte map even if scratch_schema is None to cache the negative 220 # case 221 if out_var_schema: 222 _out_variant_to_scratch_map[schema_to_opoverload(out_var_schema)] = ( 223 schema_to_opoverload(scratch_schema) if scratch_schema else None 224 ) 225 if mutable_op_schema: 226 _mutable_to_out_variant_map[schema_to_opoverload(mutable_op_schema)] = ( 227 schema_to_opoverload(out_var_schema) if out_var_schema else None 228 ) 229 230 231def to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]: 232 r""" 233 Convert the passed in OpOverload to its out variant. Raise an exception if 234 on return the op_overload is not guaranteed to be an out variant. 235 236 If a conversion is found, return the out variant OpOverload alongwith the name of out 237 arguments. 238 """ 239 schema = _get_overload_schema(op_overload) 240 if schema.is_out_fn(): # pyre-ignore 241 return op_overload, get_out_args_from_schema(schema) # pyre-ignore[6] 242 243 # should be a functionalish op here 244 assert ( 245 schema.kind() == SchemaKind.functional # pyre-ignore[16] 246 or schema.kind() == SchemaKind.mutable 247 ), f"Expect a functionalish op, but get {schema.kind()} {schema}" 248 249 if ( 250 op_overload not in _func_to_out_variant_map 251 and op_overload not in _mutable_to_out_variant_map 252 ): 253 # setup out_var 254 set_mapping_for_op(op_overload) 255 256 if op_overload in _mutable_to_out_variant_map: 257 out_var = _mutable_to_out_variant_map[op_overload] 258 else: 259 out_var = _func_to_out_variant_map.get(op_overload) 260 261 if not out_var: 262 msg = f"Missing out variant for functional op: {schema} . Make sure you have loaded your custom operator library for compiler. E.g., custom_ops_generated_lib" 263 if op_overload in _schema_mismatch_map: 264 if _schema_mismatch_map[op_overload]: 265 msg += ( 266 f"\nFound an out variant for operator name {op_overload.name()} but its schema mismatched with functional op." 267 f"\nfunctional op schema:\t{schema}" 268 f"\nout variant op schema:\t{_schema_mismatch_map[op_overload]}" 269 ) 270 raise RuntimeError(msg) 271 272 return out_var, get_out_args_from_opoverload(out_var) 273 274 275def to_scratch_op(op_overload: OpOverload) -> Optional[OpOverload]: 276 schema = _get_overload_schema(op_overload) 277 278 # If the op is not an out variant, then we must have ignored some failure in to_out_var 279 # pass. Return immediately rather than throwing an exception since the user must have ignores 280 # errors for some reason (e.g. desigin some special unit tests, or unblock new 281 # use cases). 282 if schema.kind() != SchemaKind.out: # pyre-ignore 283 logging.debug(f"Expect an out variant op as input, got: {schema.kind()}") 284 return None 285 286 if op_overload not in _out_variant_to_scratch_map: 287 set_mapping_for_op(op_overload) 288 scratch_op = _out_variant_to_scratch_map.get(op_overload) 289 290 # scratch_op can be None 291 return scratch_op 292 293 294def is_out_variant(qualified_opname: str, overload: str) -> bool: 295 op_overload = get_op_overload(qualified_opname, overload) 296 schema = _get_overload_schema(op_overload) 297 if schema is None: 298 return False 299 return schema.is_out_fn() 300 301 302def is_inplace_variant(qualified_opname: str, overload: str) -> bool: 303 op_overload = get_op_overload(qualified_opname, overload) 304 schema = _get_overload_schema(op_overload) 305 if schema is None: 306 return False 307 return schema.kind() == SchemaKind.inplace 308