1# mypy: allow-untyped-defs 2"""A context manager that disables the decomposition of certain ops during dynamo tracing. 3 4The approach is to temporarily hijack the operator callable with PT2 custom operator. 5The custom operator will not be decomposed and will show up as a single node to be exported to ONNX. 6 7For the time being the decomposition of these ops is otherwise unavoidable. 8 9https://github.com/pytorch/pytorch/issues/116684 10https://github.com/pytorch/pytorch/issues/115883 11 12This solution will no longer be required once the issue is resolved. 13""" 14 15from __future__ import annotations 16 17import abc 18import contextlib 19from typing import Callable, Sequence 20 21from onnxscript.function_libs.torch_lib.ops import ( # type: ignore[import-not-found] 22 core as torchlib_core, 23 nn as torchlib_nn, 24) 25 26import torch 27from torch._decomp import decompositions 28 29 30_NEW_OP_NAMESPACE: str = "onnx_export" 31"""The namespace for the custom operator.""" 32 33 34class DecompSkip(abc.ABC): 35 op_callable: Callable 36 """The original operator callable to skip decomposition.""" 37 onnxscript_function: Callable 38 """The ONNXScript function to be registered for exporting the custom operator.""" 39 40 new_op_name: str 41 """The name for the custom operator.""" 42 new_op_schema: str 43 """The schema for the custom operator. This should match with the signature of the original operator.""" 44 45 @classmethod 46 @abc.abstractmethod 47 def register(cls, export_options: torch.onnx.ExportOptions): 48 """Registers the custom operator and overrides the original operator. 49 50 It should do the following steps in order: 51 52 1. Register the custom operator. 53 2. Override the original operator with the replacement callable. 54 3. Register the ONNXScript function for exporting the custom operator. 55 """ 56 ... 57 58 @classmethod 59 @abc.abstractmethod 60 def unregister(cls): 61 """Restores the original operator callable.""" 62 ... 63 64 @classmethod 65 @abc.abstractmethod 66 def abstract(cls, *args, **kwargs): 67 """An abstract impl (meta kernel) for the operator.""" 68 ... 69 70 @classmethod 71 def register_custom_op(cls): 72 """Registers the custom operator.""" 73 new_op_qualname = f"{_NEW_OP_NAMESPACE}::{cls.new_op_name}" 74 torch.library.define(new_op_qualname, cls.new_op_schema) 75 torch.library.impl(new_op_qualname, "default", cls.replacement) 76 torch.library.register_fake(new_op_qualname, cls.abstract) 77 78 @classmethod 79 def replacement(cls, *args, **kwargs): 80 """A replacement callable for the operator to be hijacked. 81 82 This has the same signature and eager behavior as the original operator. 83 """ 84 return cls.op_callable(*args, **kwargs) 85 86 87class UpsampleBilinear2DDecompSkip(DecompSkip): 88 op_callable = torch._C._nn.upsample_bilinear2d # type: ignore[attr-defined] 89 onnxscript_function = torchlib_nn.aten_upsample_bilinear2d_vec # type: ignore[attr-defined] 90 new_op_name = "upsample_bilinear2d" 91 new_op_schema = "(Tensor self, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)" 92 93 @classmethod 94 def register(cls, export_options: torch.onnx.ExportOptions): 95 if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr( 96 torch.ops.onnx_export, cls.new_op_name 97 ): 98 cls.register_custom_op() 99 torch._C._nn.upsample_bilinear2d = torch.ops.onnx_export.upsample_bilinear2d # type: ignore[attr-defined] 100 if export_options.onnx_registry is None: 101 export_options.onnx_registry = torch.onnx.OnnxRegistry() 102 registry = export_options.onnx_registry 103 registry.register_op( 104 function=cls.onnxscript_function, 105 namespace=_NEW_OP_NAMESPACE, 106 op_name=cls.new_op_name, 107 ) 108 109 @classmethod 110 def unregister(cls): 111 torch._C._nn.upsample_bilinear2d = cls.op_callable # type: ignore[attr-defined] 112 113 @classmethod 114 def abstract(cls, input, output_size, align_corners, scale_factors): 115 osize = decompositions.upsample_compute_output_size( 116 input.size(), output_size, scale_factors 117 ) 118 return torch.empty( 119 (input.size(0), input.size(1), *osize), 120 dtype=input.dtype, 121 device=input.device, 122 ) 123 124 125class UpsampleTrilinear3DDecompSkip(DecompSkip): 126 op_callable = torch._C._nn.upsample_trilinear3d # type: ignore[attr-defined] 127 onnxscript_function = torchlib_nn.aten_upsample_trilinear3d_vec # type: ignore[attr-defined] 128 new_op_name = "upsample_trilinear3d" 129 new_op_schema = "(Tensor self, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)" 130 131 @classmethod 132 def register(cls, export_options: torch.onnx.ExportOptions): 133 if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr( 134 torch.ops.onnx_export, cls.new_op_name 135 ): 136 cls.register_custom_op() 137 torch._C._nn.upsample_trilinear3d = torch.ops.onnx_export.upsample_trilinear3d # type: ignore[attr-defined] 138 if export_options.onnx_registry is None: 139 export_options.onnx_registry = torch.onnx.OnnxRegistry() 140 registry = export_options.onnx_registry 141 registry.register_op( 142 function=cls.onnxscript_function, 143 namespace=_NEW_OP_NAMESPACE, 144 op_name=cls.new_op_name, 145 ) 146 147 @classmethod 148 def unregister(cls): 149 torch._C._nn.upsample_trilinear3d = cls.op_callable # type: ignore[attr-defined] 150 151 @classmethod 152 def abstract(cls, input, output_size, align_corners, scale_factors): 153 osize = decompositions.upsample_compute_output_size( 154 input.size(), output_size, scale_factors 155 ) 156 return torch.empty( 157 (input.size(0), input.size(1), input.size(2), *osize), 158 dtype=input.dtype, 159 device=input.device, 160 ) 161 162 163class InstanceNormDecompSkip(DecompSkip): 164 op_callable = torch.instance_norm # type: ignore[attr-defined] 165 onnxscript_function = torchlib_core.aten_instance_norm # type: ignore[attr-defined] 166 new_op_name = "instance_norm" 167 new_op_schema = ( 168 "(Tensor input, Tensor? weight, Tensor? bias, " 169 "Tensor? running_mean, Tensor? running_var, " 170 "bool use_input_stats, float momentum, float eps, " 171 "bool cudnn_enabled) -> Tensor" 172 ) 173 174 @classmethod 175 def register(cls, export_options: torch.onnx.ExportOptions): 176 if not hasattr(torch.ops, _NEW_OP_NAMESPACE) or not hasattr( 177 torch.ops.onnx_export, cls.new_op_name 178 ): 179 cls.register_custom_op() 180 181 torch.instance_norm = torch.ops.onnx_export.instance_norm # type: ignore[attr-defined] 182 if export_options.onnx_registry is None: 183 export_options.onnx_registry = torch.onnx.OnnxRegistry() 184 registry = export_options.onnx_registry 185 registry.register_op( 186 function=cls.onnxscript_function, 187 namespace=_NEW_OP_NAMESPACE, 188 op_name=cls.new_op_name, 189 ) 190 191 @classmethod 192 def unregister(cls): 193 torch.instance_norm = cls.op_callable # type: ignore[attr-defined] 194 195 @classmethod 196 def abstract( 197 cls, 198 input, 199 weight, 200 bias, 201 running_mean, 202 running_var, 203 use_input_stats: bool, 204 momentum: float, 205 eps: float, 206 cudnn_enabled: bool, 207 ): 208 return torch.empty( 209 input.size(), 210 dtype=input.dtype, 211 device=input.device, 212 ) 213 214 215_DEFAULT_SKIP_LIST = [ 216 UpsampleBilinear2DDecompSkip, 217 InstanceNormDecompSkip, 218 UpsampleTrilinear3DDecompSkip, 219] 220 221 222@contextlib.contextmanager 223def enable_decomposition_skips( 224 export_options: torch.onnx.ExportOptions, 225 skips: Sequence[type[DecompSkip]] = _DEFAULT_SKIP_LIST, 226): 227 """A context manager that enables the decomposition skips. 228 229 The original operator callables that are otherwise decomposed are replaced with custom operators. 230 The ONNXScript functions for exporting the custom operators are added to the ONNX registry inside export_options. 231 """ 232 try: 233 for skip in skips: 234 skip.register(export_options) 235 yield 236 finally: 237 for skip in skips: 238 skip.unregister() 239