1# mypy: allow-untyped-defs 2"""Utilities for manipulating the torch.Graph object and the torchscript.""" 3 4# TODO(justinchuby): Move more of the symbolic helper functions here and expose 5# them to the user. 6 7from __future__ import annotations 8 9import dataclasses 10import re 11import typing 12from typing import Any, Iterable, Sequence 13 14import torch 15from torch import _C 16from torch.onnx._globals import GLOBALS 17from torch.onnx._internal import registration 18 19 20_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$") 21_SKIP_NODE_ATTRIBUTES = {"inplace", "aten"} 22 23 24@dataclasses.dataclass 25class GraphContext: 26 """Extra context for symbolic functions with all methods from torch.Graph. 27 28 NOTE: This class is not meant for external consumption. Please do not depend on 29 it outside of torch.onnx as the interface may evolve. 30 31 Attributes: 32 graph: The _C.Graph being constructed. 33 block: The current _C.Block being constructed. 34 opset: The opset version. 35 original_node: Current node that is being converted from. 36 params_dict: Mapping from graph initializer name to IValue. 37 env: Mapping from Torch domain graph Value to ONNX domain graph Value. 38 values_in_env: Set of all values in env, for constant-time lookups. 39 new_nodes: List that tracks all new nodes that are added (used to make 40 sure metadata is propagated to all new nodes). 41 """ 42 43 graph: _C.Graph 44 block: _C.Block 45 opset: int 46 original_node: _C.Node 47 params_dict: dict[str, _C.IValue] 48 env: dict[_C.Value, _C.Value] 49 values_in_env: set[_C.Value] 50 new_nodes: list[_C.Node] = dataclasses.field(default_factory=list) 51 52 # Relay methods from _C.Graph for compatibility with symbolic functions that expect 53 # a _C.Graph 54 def __getattr__(self, name: str) -> Any: 55 return getattr(self.graph, name) 56 57 def op( 58 self, 59 opname: str, 60 *raw_args: torch.Tensor | _C.Value, 61 outputs: int = 1, 62 **kwargs, 63 ): 64 """Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes. 65 66 The set of operators and the inputs/attributes they take 67 is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md 68 69 Args: 70 opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified 71 with a namespace, e.g., `aten::add`. 72 raw_args: The inputs to the operator; usually provided 73 as arguments to the `symbolic` definition. 74 outputs: The number of outputs this operator returns. 75 By default an operator is assumed to return a single output. 76 If `outputs` is greater than one, this functions returns a tuple 77 of output `Value`, representing each output of the ONNX operator 78 in order. 79 kwargs: The attributes of the ONNX operator, whose keys are named 80 according to the following convention: `alpha_f` indicates 81 the `alpha` attribute with type `f`. The valid type specifiers are 82 `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute 83 specified with type float accepts either a single float, or a 84 list of floats (e.g., you would say `dims_i` for a `dims` attribute 85 that takes a list of integers). 86 87 Returns: 88 The value representing the single output of this operator (see the `outputs` 89 keyword argument for multi-return nodes). 90 """ 91 # FIXME(justinchuby): Add the return type back once we know how to handle mypy 92 return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs) 93 94 def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs): 95 """Generates an ONNX ATen op node. 96 97 This function is for backward compatibility with the old symbolic functions. 98 """ 99 return self.op( 100 "aten::ATen", 101 *args, 102 operator_s=operator, 103 overload_name_s=overload_name, 104 **kwargs, 105 ) 106 107 # NOTE: For backward compatibility with the old symbolic functions. 108 # We are probably going to remove this only after the fx exporter is established. 109 at = aten_op 110 111 def onnxscript_op( 112 self, 113 onnx_fn, 114 *raw_args: torch.Tensor | _C.Value, 115 outputs: int = 1, 116 **kwargs, 117 ): 118 """Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes. 119 120 onnx-script repository: https://github.com/microsoft/onnx-script 121 122 Args: 123 onnx_fn: ONNXFunction from onnx-script; An example can be found at 124 https://github.com/microsoft/onnx-script#example 125 raw_args: The inputs to the operator; usually provided 126 as arguments to the `symbolic` definition. 127 outputs: The number of outputs this operator returns. 128 By default an operator is assumed to return a single output. 129 If `outputs` is greater than one, this functions returns a tuple 130 of output `Value`, representing each output of the ONNX operator 131 in order. 132 kwargs: The attributes of the ONNX operator, whose keys are named 133 according to the following convention: `alpha_f` indicates 134 the `alpha` attribute with type `f`. The valid type specifiers are 135 `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute 136 specified with type float accepts either a single float, or a 137 list of floats (e.g., you would say `dims_i` for a `dims` attribute 138 that takes a list of integers). 139 140 Returns: 141 The value representing the single output of this operator (see the `outputs` 142 keyword argument for multi-return nodes). 143 """ 144 # NOTE(titaiwang): This is using class attributes, and it needs to be updated 145 # if onnx-script makes any change on these. 146 symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.name}" 147 opset_version = onnx_fn.opset.version 148 149 registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn) 150 151 return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs) 152 153 154def add_op_with_blocks( 155 graph_context: GraphContext, 156 opname: str, 157 *inputs: _C.Value, 158 outputs: int = 1, 159 n_blocks: int = 1, 160 **attributes, 161) -> tuple[Any, tuple[GraphContext, ...], _C.Node]: 162 """Creates an ONNX operator "opname", taking inputs and attributes. 163 164 Args: 165 graph_context: The context for the current graph. 166 opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified 167 with a namespace, e.g., `aten::add`. 168 inputs: The inputs to the operator. 169 outputs: The number of outputs this operator returns. 170 By default an operator is assumed to return a single output. 171 If `outputs` is greater than one, this functions returns a tuple 172 of output `Value`, representing each output of the ONNX operator 173 in order. 174 n_blocks: The number of sub-blocks to create in the node. 175 attributes: The attributes of the ONNX operator. 176 177 Returns: 178 A tuple of (output_values, new_contexts, node) where: 179 output_values: One or more output value of this operator 180 (see the `outputs` keyword argument for multi-return nodes). 181 new_contexts: A tuple of new graph contexts for each sub-block. 182 node: The node representing the operator. 183 """ 184 185 output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes) 186 if isinstance(output_values, Sequence): 187 node = output_values[0].node() 188 else: 189 node = output_values.node() 190 191 new_contexts = [] 192 for _ in range(n_blocks): 193 new_block = node.addBlock() 194 # Create shallow copy of the graph context and update the block 195 new_context = dataclasses.replace(graph_context, block=new_block) 196 new_contexts.append(new_context) 197 198 return output_values, tuple(new_contexts), node 199 200 201def _add_op( 202 graph_context: GraphContext, 203 opname: str, 204 *args: torch.Tensor | _C.Value, 205 outputs: int = 1, 206 **kwargs, 207): 208 """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs". 209 210 The set of operators and the inputs/attributes they take 211 is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md 212 213 This function is monkey-patched onto Graph. 214 215 Args: 216 graph_context: The Torch Graph or Block. 217 opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified 218 with a namespace, e.g., `aten::add`. 219 args: The inputs to the operator; usually provided 220 as arguments to the `symbolic` definition. 221 outputs: The number of outputs this operator returns. 222 By default an operator is assumed to return a single output. 223 If `outputs` is greater than one, this functions returns a tuple 224 of output `Value`, representing each output of the ONNX operator 225 in order. 226 kwargs: The attributes of the ONNX operator, whose keys are named 227 according to the following convention: `alpha_f` indicates 228 the `alpha` attribute with type `f`. The valid type specifiers are 229 `f` (float), `i` (int), `s` (string) or `t` (Tensor). An attribute 230 specified with type float accepts either a single float, or a 231 list of floats (e.g., you would say `dims_i` for a `dims` attribute 232 that takes a list of integers). 233 234 Returns: 235 (Union[_C.Value, Tuple[_C.Value, ...]]) 236 The value representing the single output of this operator (see the `outputs` 237 keyword argument for multi-return nodes). 238 """ 239 inputs = [_const_if_tensor(graph_context, arg) for arg in args] 240 # Filter out None attributes, this can be convenient client side because 241 # now they can pass through None attributes, and have them not show up 242 attributes = {k: v for k, v in kwargs.items() if v is not None} 243 244 if "::" not in opname: 245 opname = "onnx::" + opname 246 247 node = _create_node( 248 graph_context.block, 249 opname, 250 inputs, 251 attributes, 252 params_dict=graph_context.params_dict, 253 opset_version=graph_context.opset, 254 n_outputs=outputs, 255 shape_inference=GLOBALS.onnx_shape_inference, 256 ) 257 graph_context.new_nodes.append(node) 258 259 if outputs == 1: 260 return node.output() 261 return tuple(node.outputs()) 262 263 264def _const_if_tensor(graph_context: GraphContext, arg): 265 if arg is None: 266 return arg 267 if isinstance(arg, _C.Value): 268 return arg 269 270 return _add_op(graph_context, "onnx::Constant", value_z=arg) 271 272 273def _create_node( 274 graph_or_block: _C.Graph | _C.Block, 275 domain_op: str, 276 inputs: Sequence, 277 attributes: dict, 278 params_dict: dict, 279 opset_version: int, 280 n_outputs: int, 281 shape_inference: bool = True, 282) -> _C.Node: 283 """Creates an node 'domain_op', taking inputs and attributes.""" 284 if isinstance(graph_or_block, _C.Graph): 285 graph = graph_or_block 286 node = graph.create(domain_op, inputs, n_outputs) 287 node = graph.insertNode(node) 288 elif isinstance(graph_or_block, _C.Block): 289 block = graph_or_block 290 node = block.addNode(domain_op, inputs) 291 292 # Block does not have create defined, so we need to add outputs manually 293 if n_outputs > 1: 294 for _ in range(1, n_outputs): 295 node.addOutput() 296 297 node_outputs = tuple(node.outputs()) # type: ignore[possibly-undefined] 298 assert len(node_outputs) == n_outputs 299 300 aten = domain_op.startswith("aten::") 301 302 # Add all attributes 303 for key, value in sorted(attributes.items()): 304 if key in _SKIP_NODE_ATTRIBUTES: 305 continue 306 _add_attribute(node, key, value, aten=aten) 307 if shape_inference: 308 _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version) 309 return node 310 311 312def _is_onnx_list(value): 313 return isinstance(value, Iterable) and not isinstance( 314 value, (str, bytes, torch.Tensor) 315 ) 316 317 318def _scalar(x: torch.Tensor): 319 """Convert a scalar tensor into a Python value.""" 320 assert x.numel() == 1 321 return x[0] 322 323 324def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool): 325 r"""Initializes the right attribute based on type of value.""" 326 m = _ATTR_PATTERN.match(key) 327 if m is None: 328 raise ValueError( 329 f"Invalid attribute specifier '{key}' names " 330 "must be suffixed with type, e.g. 'dim_i' or 'dims_i'" 331 ) 332 name, kind = m.group(1), m.group(2) 333 if _is_onnx_list(value): 334 kind += "s" 335 336 return getattr(node, f"{kind}_")(name, value) 337 338 339# TODO: Expose this to user when migrating symbolic helper functions to here. 340def _is_tensor(x: _C.Value) -> bool: 341 return x.type().isSubtypeOf(_C.TensorType.get()) 342 343 344def get_device_from_value(value: _C.Value) -> torch.device | None: 345 if not _is_tensor(value): 346 return None 347 tensor_type = typing.cast(_C.TensorType, value.type()) 348 return tensor_type.device() 349 350 351def parse_node_kind(kind: str) -> tuple[str, str]: 352 """Parse node kind into domain and Op name.""" 353 if "::" not in kind: 354 raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.") 355 domain, opname = kind.split("::", 1) 356 if "::" in opname: 357 raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.") 358 return domain, opname 359 360 361def is_aten(domain: str) -> bool: 362 """Check if the domain is official.""" 363 return domain == "aten" 364 365 366def is_prim(domain: str) -> bool: 367 """Check if the domain is official.""" 368 return domain == "prim" 369 370 371def is_onnx(domain: str) -> bool: 372 """Check if the domain is official.""" 373 return domain == "onnx" 374