xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/jit_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Utilities for manipulating the torch.Graph object and the torchscript."""
3
4# TODO(justinchuby): Move more of the symbolic helper functions here and expose
5# them to the user.
6
7from __future__ import annotations
8
9import dataclasses
10import re
11import typing
12from typing import Any, Iterable, Sequence
13
14import torch
15from torch import _C
16from torch.onnx._globals import GLOBALS
17from torch.onnx._internal import registration
18
19
20_ATTR_PATTERN = re.compile("^(.+)_(([ifstgz])|(ty))$")
21_SKIP_NODE_ATTRIBUTES = {"inplace", "aten"}
22
23
24@dataclasses.dataclass
25class GraphContext:
26    """Extra context for symbolic functions with all methods from torch.Graph.
27
28    NOTE: This class is not meant for external consumption. Please do not depend on
29    it outside of torch.onnx as the interface may evolve.
30
31    Attributes:
32        graph: The _C.Graph being constructed.
33        block: The current _C.Block being constructed.
34        opset: The opset version.
35        original_node: Current node that is being converted from.
36        params_dict: Mapping from graph initializer name to IValue.
37        env: Mapping from Torch domain graph Value to ONNX domain graph Value.
38        values_in_env: Set of all values in env, for constant-time lookups.
39        new_nodes: List that tracks all new nodes that are added (used to make
40            sure metadata is propagated to all new nodes).
41    """
42
43    graph: _C.Graph
44    block: _C.Block
45    opset: int
46    original_node: _C.Node
47    params_dict: dict[str, _C.IValue]
48    env: dict[_C.Value, _C.Value]
49    values_in_env: set[_C.Value]
50    new_nodes: list[_C.Node] = dataclasses.field(default_factory=list)
51
52    # Relay methods from _C.Graph for compatibility with symbolic functions that expect
53    # a _C.Graph
54    def __getattr__(self, name: str) -> Any:
55        return getattr(self.graph, name)
56
57    def op(
58        self,
59        opname: str,
60        *raw_args: torch.Tensor | _C.Value,
61        outputs: int = 1,
62        **kwargs,
63    ):
64        """Creates an ONNX operator "opname", taking "raw_args" as inputs and "kwargs" as attributes.
65
66        The set of operators and the inputs/attributes they take
67        is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
68
69        Args:
70            opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
71                with a namespace, e.g., `aten::add`.
72            raw_args: The inputs to the operator; usually provided
73                as arguments to the `symbolic` definition.
74            outputs: The number of outputs this operator returns.
75                By default an operator is assumed to return a single output.
76                If `outputs` is greater than one, this functions returns a tuple
77                of output `Value`, representing each output of the ONNX operator
78                in order.
79            kwargs: The attributes of the ONNX operator, whose keys are named
80                according to the following convention: `alpha_f` indicates
81                the `alpha` attribute with type `f`.  The valid type specifiers are
82                `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
83                specified with type float accepts either a single float, or a
84                list of floats (e.g., you would say `dims_i` for a `dims` attribute
85                that takes a list of integers).
86
87        Returns:
88            The value representing the single output of this operator (see the `outputs`
89            keyword argument for multi-return nodes).
90        """
91        # FIXME(justinchuby): Add the return type back once we know how to handle mypy
92        return _add_op(self, opname, *raw_args, outputs=outputs, **kwargs)
93
94    def aten_op(self, operator: str, *args, overload_name: str = "", **kwargs):
95        """Generates an ONNX ATen op node.
96
97        This function is for backward compatibility with the old symbolic functions.
98        """
99        return self.op(
100            "aten::ATen",
101            *args,
102            operator_s=operator,
103            overload_name_s=overload_name,
104            **kwargs,
105        )
106
107    # NOTE: For backward compatibility with the old symbolic functions.
108    # We are probably going to remove this only after the fx exporter is established.
109    at = aten_op
110
111    def onnxscript_op(
112        self,
113        onnx_fn,
114        *raw_args: torch.Tensor | _C.Value,
115        outputs: int = 1,
116        **kwargs,
117    ):
118        """Creates an ONNX operator from onnx-script function, taking "raw_args" as inputs and "kwargs" as attributes.
119
120        onnx-script repository: https://github.com/microsoft/onnx-script
121
122        Args:
123            onnx_fn: ONNXFunction from onnx-script; An example can be found at
124                https://github.com/microsoft/onnx-script#example
125            raw_args: The inputs to the operator; usually provided
126                as arguments to the `symbolic` definition.
127            outputs: The number of outputs this operator returns.
128                By default an operator is assumed to return a single output.
129                If `outputs` is greater than one, this functions returns a tuple
130                of output `Value`, representing each output of the ONNX operator
131                in order.
132            kwargs: The attributes of the ONNX operator, whose keys are named
133                according to the following convention: `alpha_f` indicates
134                the `alpha` attribute with type `f`.  The valid type specifiers are
135                `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
136                specified with type float accepts either a single float, or a
137                list of floats (e.g., you would say `dims_i` for a `dims` attribute
138                that takes a list of integers).
139
140        Returns:
141            The value representing the single output of this operator (see the `outputs`
142            keyword argument for multi-return nodes).
143        """
144        # NOTE(titaiwang): This is using class attributes, and it needs to be updated
145        # if onnx-script makes any change on these.
146        symbolic_name = f"{onnx_fn.opset.domain}::{onnx_fn.name}"
147        opset_version = onnx_fn.opset.version
148
149        registration.custom_onnx_symbolic(symbolic_name, opset_version)(onnx_fn)
150
151        return _add_op(self, symbolic_name, *raw_args, outputs=outputs, **kwargs)
152
153
154def add_op_with_blocks(
155    graph_context: GraphContext,
156    opname: str,
157    *inputs: _C.Value,
158    outputs: int = 1,
159    n_blocks: int = 1,
160    **attributes,
161) -> tuple[Any, tuple[GraphContext, ...], _C.Node]:
162    """Creates an ONNX operator "opname", taking inputs and attributes.
163
164    Args:
165        graph_context: The context for the current graph.
166        opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
167            with a namespace, e.g., `aten::add`.
168        inputs: The inputs to the operator.
169        outputs: The number of outputs this operator returns.
170            By default an operator is assumed to return a single output.
171            If `outputs` is greater than one, this functions returns a tuple
172            of output `Value`, representing each output of the ONNX operator
173            in order.
174        n_blocks: The number of sub-blocks to create in the node.
175        attributes: The attributes of the ONNX operator.
176
177    Returns:
178        A tuple of (output_values, new_contexts, node) where:
179            output_values: One or more output value of this operator
180                (see the `outputs` keyword argument for multi-return nodes).
181            new_contexts: A tuple of new graph contexts for each sub-block.
182            node: The node representing the operator.
183    """
184
185    output_values = graph_context.op(opname, *inputs, outputs=outputs, **attributes)
186    if isinstance(output_values, Sequence):
187        node = output_values[0].node()
188    else:
189        node = output_values.node()
190
191    new_contexts = []
192    for _ in range(n_blocks):
193        new_block = node.addBlock()
194        # Create shallow copy of the graph context and update the block
195        new_context = dataclasses.replace(graph_context, block=new_block)
196        new_contexts.append(new_context)
197
198    return output_values, tuple(new_contexts), node
199
200
201def _add_op(
202    graph_context: GraphContext,
203    opname: str,
204    *args: torch.Tensor | _C.Value,
205    outputs: int = 1,
206    **kwargs,
207):
208    """Creates an ONNX operator "opname", taking "args" as inputs and attributes "kwargs".
209
210    The set of operators and the inputs/attributes they take
211    is documented at https://github.com/onnx/onnx/blob/master/docs/Operators.md
212
213    This function is monkey-patched onto Graph.
214
215    Args:
216        graph_context: The Torch Graph or Block.
217        opname: The ONNX operator name, e.g., `Abs` or `Add`, or an operator qualified
218            with a namespace, e.g., `aten::add`.
219        args: The inputs to the operator; usually provided
220            as arguments to the `symbolic` definition.
221        outputs: The number of outputs this operator returns.
222            By default an operator is assumed to return a single output.
223            If `outputs` is greater than one, this functions returns a tuple
224            of output `Value`, representing each output of the ONNX operator
225            in order.
226        kwargs: The attributes of the ONNX operator, whose keys are named
227            according to the following convention: `alpha_f` indicates
228            the `alpha` attribute with type `f`.  The valid type specifiers are
229            `f` (float), `i` (int), `s` (string) or `t` (Tensor).  An attribute
230            specified with type float accepts either a single float, or a
231            list of floats (e.g., you would say `dims_i` for a `dims` attribute
232            that takes a list of integers).
233
234    Returns:
235        (Union[_C.Value, Tuple[_C.Value, ...]])
236        The value representing the single output of this operator (see the `outputs`
237        keyword argument for multi-return nodes).
238    """
239    inputs = [_const_if_tensor(graph_context, arg) for arg in args]
240    # Filter out None attributes, this can be convenient client side because
241    # now they can pass through None attributes, and have them not show up
242    attributes = {k: v for k, v in kwargs.items() if v is not None}
243
244    if "::" not in opname:
245        opname = "onnx::" + opname
246
247    node = _create_node(
248        graph_context.block,
249        opname,
250        inputs,
251        attributes,
252        params_dict=graph_context.params_dict,
253        opset_version=graph_context.opset,
254        n_outputs=outputs,
255        shape_inference=GLOBALS.onnx_shape_inference,
256    )
257    graph_context.new_nodes.append(node)
258
259    if outputs == 1:
260        return node.output()
261    return tuple(node.outputs())
262
263
264def _const_if_tensor(graph_context: GraphContext, arg):
265    if arg is None:
266        return arg
267    if isinstance(arg, _C.Value):
268        return arg
269
270    return _add_op(graph_context, "onnx::Constant", value_z=arg)
271
272
273def _create_node(
274    graph_or_block: _C.Graph | _C.Block,
275    domain_op: str,
276    inputs: Sequence,
277    attributes: dict,
278    params_dict: dict,
279    opset_version: int,
280    n_outputs: int,
281    shape_inference: bool = True,
282) -> _C.Node:
283    """Creates an node 'domain_op', taking inputs and attributes."""
284    if isinstance(graph_or_block, _C.Graph):
285        graph = graph_or_block
286        node = graph.create(domain_op, inputs, n_outputs)
287        node = graph.insertNode(node)
288    elif isinstance(graph_or_block, _C.Block):
289        block = graph_or_block
290        node = block.addNode(domain_op, inputs)
291
292        # Block does not have create defined, so we need to add outputs manually
293        if n_outputs > 1:
294            for _ in range(1, n_outputs):
295                node.addOutput()
296
297    node_outputs = tuple(node.outputs())  # type: ignore[possibly-undefined]
298    assert len(node_outputs) == n_outputs
299
300    aten = domain_op.startswith("aten::")
301
302    # Add all attributes
303    for key, value in sorted(attributes.items()):
304        if key in _SKIP_NODE_ATTRIBUTES:
305            continue
306        _add_attribute(node, key, value, aten=aten)
307    if shape_inference:
308        _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
309    return node
310
311
312def _is_onnx_list(value):
313    return isinstance(value, Iterable) and not isinstance(
314        value, (str, bytes, torch.Tensor)
315    )
316
317
318def _scalar(x: torch.Tensor):
319    """Convert a scalar tensor into a Python value."""
320    assert x.numel() == 1
321    return x[0]
322
323
324def _add_attribute(node: _C.Node, key: str, value: Any, aten: bool):
325    r"""Initializes the right attribute based on type of value."""
326    m = _ATTR_PATTERN.match(key)
327    if m is None:
328        raise ValueError(
329            f"Invalid attribute specifier '{key}' names "
330            "must be suffixed with type, e.g. 'dim_i' or 'dims_i'"
331        )
332    name, kind = m.group(1), m.group(2)
333    if _is_onnx_list(value):
334        kind += "s"
335
336    return getattr(node, f"{kind}_")(name, value)
337
338
339# TODO: Expose this to user when migrating symbolic helper functions to here.
340def _is_tensor(x: _C.Value) -> bool:
341    return x.type().isSubtypeOf(_C.TensorType.get())
342
343
344def get_device_from_value(value: _C.Value) -> torch.device | None:
345    if not _is_tensor(value):
346        return None
347    tensor_type = typing.cast(_C.TensorType, value.type())
348    return tensor_type.device()
349
350
351def parse_node_kind(kind: str) -> tuple[str, str]:
352    """Parse node kind into domain and Op name."""
353    if "::" not in kind:
354        raise ValueError(f"Node kind: {kind} is invalid. '::' is not in node kind.")
355    domain, opname = kind.split("::", 1)
356    if "::" in opname:
357        raise ValueError(f"Node kind: {kind} is invalid. '::' should only apear once.")
358    return domain, opname
359
360
361def is_aten(domain: str) -> bool:
362    """Check if the domain is official."""
363    return domain == "aten"
364
365
366def is_prim(domain: str) -> bool:
367    """Check if the domain is official."""
368    return domain == "prim"
369
370
371def is_onnx(domain: str) -> bool:
372    """Check if the domain is official."""
373    return domain == "onnx"
374