xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/fx_onnx_interpreter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import inspect
5import logging
6import operator
7import re
8from typing import Callable, Sequence
9
10import onnxscript  # type: ignore[import]
11from onnxscript.function_libs.torch_lib import (  # type: ignore[import]
12    graph_building as onnxscript_graph_building,
13)
14
15import torch
16import torch.fx
17from torch.onnx import _type_utils as jit_type_utils
18from torch.onnx._internal.fx import (
19    _pass,
20    diagnostics,
21    onnxfunction_dispatcher,
22    type_utils as fx_type_utils,
23)
24from torch.utils import _pytree
25
26
27def _fx_node_to_onnx_message_formatter(
28    fn: Callable,
29    self,
30    node: torch.fx.Node,
31    *args,
32    **kwargs,
33) -> str:
34    return f"FX Node: {node.op}:{node.target}[name={node.name}]. "
35
36
37def _fx_graph_to_onnx_message_formatter(
38    fn: Callable,
39    self,
40    fx_graph_module: torch.fx.GraphModule,
41    *args,
42    **kwargs,
43) -> str:
44    return f"FX Graph: {fx_graph_module._get_name()}. "
45
46
47def _location_from_fx_stack_trace(
48    node_stack_trace: str,
49) -> diagnostics.infra.Location | None:
50    """Extract location from FX node stack trace.
51
52    TODO(bowbao): Create fx utils module and move this function there.
53
54    Args:
55        node_stack_trace: The stack trace of the FX node. Example:
56
57            File "path/file.py", line 311, in <function>
58                <code>
59            |   File "path/file2.py", line 389, in <function>
60                <code>
61
62    Returns:
63        location: The location of the FX node.
64    """
65    if "File" not in node_stack_trace:
66        return None
67
68    lines = node_stack_trace.strip().split("\n")
69    idx = 0
70    while idx < len(lines) and "File" not in lines[idx]:
71        idx += 1
72    if idx + 1 >= len(lines):
73        return None
74
75    pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$")
76    matches = pattern.match(lines[idx].strip())
77    if matches:
78        uri = matches.group(1)
79        line_number = int(matches.group(2))
80        snippet = lines[idx + 1].strip()
81        return diagnostics.infra.Location(uri=uri, line=line_number, snippet=snippet)
82    return None
83
84
85def _retrieve_or_adapt_input_to_graph_set(
86    fx_node_arg: fx_type_utils.Argument,
87    fx_name_to_onnxscript_value: dict[
88        str,
89        onnxscript_graph_building.TorchScriptTensor
90        | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
91    ],
92    tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
93):
94    """Map FX value to TorchScript value.
95
96    When creating TorchScript graph from FX graph, we need a mapping from FX variable
97    to TorchScript variable. This function maps FX variable, fx_node_arg, to torch.jit.Value.
98    """
99
100    onnx_tensor = fx_node_arg
101    if isinstance(onnx_tensor, torch.fx.Node):
102        # 1. fx_node_arg is a torch.fx.Node, which means
103        #    fx_node_arg stands for the output of that torch.fx.Node.
104        # 2. fx_node_arg (variable in torch.fx.Graph) is be mapped to
105        #    torch.jit.Value, fx_name_to_onnxscript_value[fx_node_arg.name],
106        #    in TorchScript graph.
107        return fx_name_to_onnxscript_value[onnx_tensor.name]
108    elif isinstance(onnx_tensor, (tuple, list)) and any(
109        isinstance(node, torch.fx.Node)
110        and fx_type_utils.is_torch_symbolic_type(node.meta.get("val"))
111        for node in onnx_tensor
112    ):
113        # This intends to handle dynamic axes. for example, if the input size of op.Expand
114        # is dynamic, each dimension would be variable (i.e., sym variable in Pytorch
115        # FX graph. Note that sym variable is mapped to tensor in ONNX Script world)
116        # calculated by other operators.
117        sequence_mixed_elements: list[
118            onnxscript_graph_building.TorchScriptTensor
119            | tuple[onnxscript_graph_building.TorchScriptTensor, ...]
120            | list[int]
121        ] = []
122        # onnx_tensor contains a list of scalars which could be one of
123        #   - tensor with empty shape,
124        #   - tensor with tensor with shape (1,),
125        #   - torch.SymInt,
126        #   - int
127        #   - ...
128        # They should all be promoted to tensor with shape (1,)
129        # in order to call ONNX's Concat.
130        for tensor in onnx_tensor:
131            # Prepare `tensor` as input of ONNX's Concat.
132
133            if isinstance(
134                tensor, torch.fx.Node
135            ) and fx_type_utils.is_torch_symbolic_type(tensor.meta.get("val")):
136                # In this case, tensor is a torch.SymInt from Dynamo's perspective.
137                # It might be mapped to tensor with shape () or (1,) in ONNX.
138                element_value = fx_name_to_onnxscript_value[tensor.name]
139                if isinstance(
140                    element_value, onnxscript_graph_building.TorchScriptTensor
141                ):
142                    # All elements sequence_mixed_elements will be send to onnx's Concat
143                    # as inputs. Therefore, they are required to have the same rank.
144                    # Since tensors with rank=0 (i.e., scalar) cannot be concated, all
145                    # scalars are promoted to tensors with shape (1,).
146                    with onnxscript.evaluator.default_as(tracer):
147                        element_value = onnxscript.opset18.Reshape(element_value, [1])  # type: ignore[arg-type, type-var]
148                sequence_mixed_elements.append(element_value)
149            elif isinstance(tensor, int):
150                # NOTE: op.Concat doesn't support scalar, so we need to wrap it with
151                # dim, and onnx-script will promote it to tensor(int64)
152                sequence_mixed_elements.append([tensor])
153            else:
154                raise RuntimeError(
155                    f"Unsupported type in sequence_mixed_elements: {type(tensor)}"
156                )
157        # Concat all the elements in the sequence.
158        # shapes are mapped to tensors in ONNX graph (TorchScriptGraph),
159        # so list of sym_ints is concatenated to a tensor before calling ONNX op.
160
161        # For example:
162        #    inputs: [[2], [4], fx.Node(SymIntA), [1], fx.Node(SymIntB)]
163        #    outputs: op.Concat([op.Constant(2), op.Constant(4), TorchScriptTensor(A), op.Constant(1), TorchScriptTensor(B)])
164
165        # onnx-script auto wraps python number with op.Constants,
166        # so we don't need to specifically process them.
167        with onnxscript.evaluator.default_as(tracer):
168            output = onnxscript.opset18.Concat(*sequence_mixed_elements, axis=0)  # type: ignore[type-var]
169        output.dtype = torch.int64  # type: ignore[union-attr]
170        output.shape = [len(sequence_mixed_elements)]  # type: ignore[union-attr]
171        return output
172    elif isinstance(onnx_tensor, (tuple, list)) and all(
173        isinstance(node, torch.fx.Node) or node is None for node in onnx_tensor
174    ):
175        sequence_elements: list[
176            onnxscript_graph_building.TorchScriptTensor
177            | None
178            | tuple[onnxscript_graph_building.TorchScriptTensor, ...]
179        ] = []
180        for tensor in onnx_tensor:
181            sequence_elements.append(
182                fx_name_to_onnxscript_value[tensor.name] if tensor is not None else None
183            )
184        return sequence_elements
185    if isinstance(onnx_tensor, torch.dtype):
186        onnx_tensor = int(  # type: ignore[call-overload]
187            jit_type_utils.JitScalarType.from_dtype(onnx_tensor).onnx_type()
188        )
189    # NOTE: if device is specified in kwargs (not consumed), it's free to ignored. But
190    # if it's in args, we need to set it to string for dispatcher to match schema.
191    if isinstance(onnx_tensor, torch.device):
192        # torch.device is not supported by onnxscript (no op). We turn it into
193        # a string.
194        return str(onnx_tensor)
195    # all other cases, we do nothing.
196    return onnx_tensor
197
198
199def filter_incompatible_and_dtype_convert_kwargs(kwargs):
200    """Filter out kwargs that are not supported by onnxscript."""
201    filtered = {}
202    for key, value in kwargs.items():
203        if key in {
204            "layout",
205            "device",
206            "requires_grad",
207            "pin_memory",
208            "memory_format",
209            "implicit",
210        }:
211            continue
212        if key == "dtype":
213            if value is None:
214                # We omit if dtype is not provided, because onnxscript handles the
215                # default case.
216                continue
217            else:
218                value = int(jit_type_utils.JitScalarType.from_dtype(value).onnx_type())  # type: ignore[call-overload]
219        filtered[key] = value
220    return filtered
221
222
223def _fill_tensor_shape_type(
224    onnxscript_values: onnxscript_graph_building.TorchScriptTensor
225    | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
226    name: str,
227    expected_values: fx_type_utils.META_VALUE_TYPE
228    | list[fx_type_utils.META_VALUE_TYPE]
229    | tuple[fx_type_utils.META_VALUE_TYPE | None, ...],
230):
231    """Fill the meta information of onnxscript_values with that from the fx FakeTensor."""
232
233    if isinstance(expected_values, (list, tuple)) and not isinstance(
234        onnxscript_values, (list, tuple)
235    ):
236        # ex: aten::split - in onnx_dtype: seq(tensor)
237        # onnxscript_values is a single tensor, but expected_values is a list of tensors.
238        return
239
240    flat_onnxscript_values, _ = _pytree.tree_flatten(onnxscript_values)
241    flat_expected_values, _ = _pytree.tree_flatten(expected_values)
242    for i, (onnxscript_value, expected_value) in enumerate(
243        zip(flat_onnxscript_values, flat_expected_values)
244    ):
245        if expected_value is None:
246            # There is no shape/type from None.
247            # NOTE: according to https://github.com/pytorch/pytorch/blob/main/torch/_meta_registrations.py,
248            # None could be a valid value for return type, so we need to handle it.
249            # e.g. the function: meta__scaled_dot_product_flash() in cpu mode.
250            continue
251        elif fx_type_utils.is_torch_symbolic_type(expected_value):
252            # aten::sym_size output is a int, not a tensor, which stands
253            # for the size of one dim. We treat it as 1-D tensor.
254            onnxscript_value.dtype = fx_type_utils.from_sym_value_to_torch_dtype(
255                expected_value
256            )
257            onnxscript_value.shape = torch.Size([1])
258        elif isinstance(expected_value, (int, float, bool)):
259            onnxscript_value.dtype = fx_type_utils.from_scalar_type_to_torch_dtype(
260                type(expected_value)
261            )
262            onnxscript_value.shape = torch.Size([])
263        elif isinstance(expected_value, complex):
264            # From complex scalar to real representation
265            onnxscript_value_to_torch_dtype = (
266                fx_type_utils.from_scalar_type_to_torch_dtype(type(expected_value))
267            )
268            onnxscript_value.dtype = (
269                fx_type_utils.from_complex_to_float(onnxscript_value_to_torch_dtype)
270                if onnxscript_value_to_torch_dtype is not None
271                else None
272            )
273            onnxscript_value.shape = torch.Size([2])
274        elif fx_type_utils.is_torch_complex_dtype(expected_value.dtype):
275            # Like torch.view_as_real, we flatten complex tensors to real tensors with
276            # additional last dimension of 2
277            onnxscript_value.shape = torch.Size((*expected_value.size(), 2))
278            # complex64 -> float32, complex128 -> float64, etc.
279            onnxscript_value.dtype = fx_type_utils.from_complex_to_float(
280                expected_value.dtype
281            )
282            # Dispatcher needs to know the value is complex
283            onnxscript_value.is_complex = True
284        else:
285            # We set node output sizes to be dynamic to continue the model conversion,
286            # and inputs are also set to be dynamic in add_input().
287            onnxscript_value.shape = expected_value.size()
288            onnxscript_value.dtype = expected_value.dtype
289
290        # naming
291        if i > 0:
292            onnxscript_value.name = f"{name}_{i}"
293        else:
294            onnxscript_value.name = name
295
296
297def _fill_in_default_kwargs(
298    node: torch.fx.Node,
299) -> tuple[list[fx_type_utils.Argument], dict[str, fx_type_utils.Argument]]:
300    """Find and Fill in the not provided kwargs with default values."""
301
302    # TODO: aten::sym_size has overload, but fx graph is using
303    # overloadpacket for some reasons.
304    # https://github.com/pytorch/pytorch/issues/97201
305    # We manually assigned overload for aten::sym_size.
306    if hasattr(node.target, "_schema"):
307        node_schema = node.target._schema  # type: ignore[union-attr]
308    else:
309        node_schema = torch.ops.aten.sym_size.int._schema  # type: ignore[union-attr]
310
311    # This function assumes the order of arguments in FX op is the
312    # same as the order of arguments in TorchScript op.
313    complete_args: list[fx_type_utils.Argument] = []
314    complete_kwargs: dict[str, fx_type_utils.Argument] = {}
315
316    if inspect.isbuiltin(node.target):
317        complete_args = list(node.args)
318    else:
319        for i, expected_arg in enumerate(node_schema.arguments):
320            if i < len(node.args):
321                complete_args.append(node.args[i])
322            elif expected_arg.name in node.kwargs:
323                complete_kwargs[expected_arg.name] = node.kwargs[expected_arg.name]
324            else:
325                # Get default from schema.
326                complete_kwargs[expected_arg.name] = expected_arg.default_value
327
328    return complete_args, complete_kwargs
329
330
331def _wrap_fx_args_as_onnxscript_args(
332    complete_args: list[fx_type_utils.Argument],
333    complete_kwargs: dict[str, fx_type_utils.Argument],
334    fx_name_to_onnxscript_value: dict[
335        str,
336        onnxscript_graph_building.TorchScriptTensor
337        | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
338    ],
339    tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
340) -> tuple[
341    Sequence[
342        onnxscript_graph_building.TorchScriptTensor
343        | str
344        | int
345        | float
346        | bool
347        | list
348        | complex
349        | None
350    ],
351    dict[str, fx_type_utils.Argument],
352]:
353    """Map all FX arguments of a node to arguments in TorchScript graph."""
354
355    onnxscript_args = tuple(
356        _retrieve_or_adapt_input_to_graph_set(arg, fx_name_to_onnxscript_value, tracer)
357        for arg in complete_args
358    )
359    onnxscript_kwargs = filter_incompatible_and_dtype_convert_kwargs(complete_kwargs)
360
361    return onnxscript_args, onnxscript_kwargs
362
363
364class FxOnnxInterpreter:
365    """Stateless class to process FX graph Nodes and translate them into their ONNX counterparts.
366
367    All FX nodes described by [FX Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) are supported.
368    Similarly to [FX Interpreter pattern](https://pytorch.org/docs/stable/fx.html#torch.fx.Interpreter), each FX node
369    must be implemented on its own method in this class.
370
371    Each operator's implementation returns either an `onnxscript.OnnxFunction` or
372    `onnxscript.TracedOnnxFunction` instance based on the dispatch algorithm. They can
373    also raise RuntimeError: If there are no overloaded functions available for the given FX node.
374
375    TODO: Convert methods to @staticmethod when the diagnostic system supports it
376          DO NOT ADD NEW ATTRIBUTES TO THIS CLASS!
377    """
378
379    def __init__(
380        self,
381        diagnostic_context: diagnostics.DiagnosticContext,
382    ):
383        # THIS SHOULD BE THE ONLY STATE IN THIS CLASS (constraint from diagnosticS API)
384        # TODO: Diagnostics API should be revised to get rid of this attribute.
385        # DO NOT add other class-level attributes.
386        self.diagnostic_context = diagnostic_context
387
388    @diagnostics.diagnose_call(
389        diagnostics.rules.fx_node_to_onnx,
390        diagnostic_message_formatter=_fx_node_to_onnx_message_formatter,
391    )
392    def run_node(
393        self,
394        node,
395        fx_graph_module: torch.fx.GraphModule,
396        onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
397        onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
398        onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
399        fx_name_to_onnxscript_value: dict[
400            str,
401            onnxscript_graph_building.TorchScriptTensor
402            | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
403        ],
404    ):
405        """Execute a single FX node to produce its ONNX counterpart.
406
407        Args:
408            node: The FX node to be translated.
409            fx_graph_module: The FX graph module containing the node.
410            onnxfunction_dispatcher: The dispatcher to find the best matched ONNX op.
411            onnxscript_graph: The ONNX graph to be populated.
412            onnxscript_tracer: The tracer to trace the ONNX graph.
413            fx_name_to_onnxscript_value: The mapping from FX node name to ONNX Script value.
414
415        Raises:
416            RuntimeError: When a node.op is not supported.
417        """
418        # Record stack trace of node in diagnostic.
419        node_stack_trace = node.stack_trace
420        if node_stack_trace:
421            diagnostic = self.diagnostic_context.inflight_diagnostic(
422                rule=diagnostics.rules.fx_node_to_onnx
423            )
424            with diagnostic.log_section(logging.INFO, "PyTorch source information"):
425                diagnostic.info("```\n%s\n```", node_stack_trace)
426            location = _location_from_fx_stack_trace(node_stack_trace)
427            if location is not None:
428                diagnostic.with_location(location)
429
430        if node.op == "placeholder":
431            self.placeholder(node, onnxscript_graph, fx_name_to_onnxscript_value)
432        elif node.op == "get_attr":
433            self.get_attr(
434                node,
435                onnxscript_graph,
436                fx_name_to_onnxscript_value,
437                fx_graph_module,
438            )
439        elif node.op == "call_function":
440            self.call_function(
441                node,
442                onnxscript_tracer,
443                fx_name_to_onnxscript_value,
444                onnxfunction_dispatcher,
445                fx_graph_module,
446            )
447        elif node.op == "call_method":
448            self.call_method(node)
449        elif node.op == "call_module":
450            self.call_module(
451                node,
452                onnxscript_graph,
453                fx_name_to_onnxscript_value,
454                onnxscript_tracer,
455                fx_graph_module,
456                onnxfunction_dispatcher,
457            )
458        elif node.op == "output":
459            self.output(node, onnxscript_graph, fx_name_to_onnxscript_value)
460        else:
461            raise RuntimeError(f"Found node type not defined in torch.fx: {node.op}")
462
463    @diagnostics.diagnose_call(
464        diagnostics.rules.fx_graph_to_onnx,
465        diagnostic_message_formatter=_fx_graph_to_onnx_message_formatter,
466    )
467    def run(
468        self,
469        fx_graph_module: torch.fx.GraphModule,
470        onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
471        parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph
472        | None = None,
473    ) -> onnxscript_graph_building.TorchScriptGraph:
474        """Analyze all FX nodes and trigger their ONNX translation.
475
476        Args:
477            fx_graph_module: FX graph module to be translated.
478            onnxfunction_dispatcher: ONNX function dispatcher.
479            parent_onnxscript_graph: The parent TorchScript graph. Must be provided if
480                `fx_graph_module` is a submodule. If not provided,
481                `fx_graph_module` is assumed to be the root module.
482        """
483        diagnostic = self.diagnostic_context.inflight_diagnostic()
484        with diagnostic.log_section(logging.DEBUG, "FX Graph:"):
485            diagnostic.debug(
486                "```\n%s\n```",
487                diagnostics.LazyString(fx_graph_module.print_readable, False),
488            )
489
490        if parent_onnxscript_graph is not None:
491            # If parent_onnxscript_graph is provided, we assume fx_graph_module is a
492            # submodule representing a forward call of an nn.Module.
493            # Compose package and version where the nn.Module is defined as domain name
494            # for the local function.
495
496            onnx_meta: _pass.GraphModuleOnnxMeta | None = fx_graph_module.meta.get(
497                "onnx"
498            )
499            if onnx_meta is None:
500                raise RuntimeError(
501                    f"ONNX meta is not found in submodule {fx_graph_module._get_name()}. "
502                    f"Only submodules produced by `Modularize` pass is supported in ONNX export."
503                )
504
505            onnx_domain = onnx_meta.package_info.to_onnx_domain_string()
506        else:
507            # Leave as default domain name for the root module.
508            onnx_domain = None
509
510        onnxscript_graph = onnxscript_graph_building.TorchScriptGraph(
511            parent_onnxscript_graph, domain_name=onnx_domain
512        )
513        onnxscript_tracer = onnxscript_graph_building.TorchScriptTracingEvaluator(
514            onnxscript_graph
515        )
516        # In the following loop, a TorchScript graph is created to
517        # represent the input FX graph with ONNX symbols (e.g., onnx::add).
518        # To connect the values to nodes in the TorchScript graph, we maintain
519        # fx_name_to_onnxscript_value. Basically, we want to translate
520        #   fx_tensor_x (type: torch.fx.Node) -> fx_node_1 -> fx_tensor_y (type: torch.fx.Node)
521        # to
522        #   fx_name_to_onnxscript_value[fx_tensor_x.name] -> onnx_node_1 -> fx_name_to_onnxscript_value[fx_tensor_y.name]
523        fx_name_to_onnxscript_value: dict[
524            str,
525            onnxscript_graph_building.TorchScriptTensor
526            | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
527        ] = {}
528
529        # TODO: Fix FakeTensorMode limitation asap
530        # We want to pass list of ints and floats to TorchScript graph correctly
531        # in _export_fx_to_ts, so we must disable FakeTensorMode. Otherwise, graph may
532        # receive FakeTensor and results runtime error. In addition, TorchScript-based
533        # ONNX exporter used in _ts_graph_to_onnx_model_in_protobuf is not compatible
534        # with FakeTensorMode.
535        with torch.utils._mode_utils.no_dispatch():
536            for node in fx_graph_module.graph.nodes:
537                self.run_node(
538                    node,
539                    fx_graph_module,
540                    onnxfunction_dispatcher,
541                    onnxscript_graph,
542                    onnxscript_tracer,
543                    fx_name_to_onnxscript_value,
544                )
545
546        with diagnostic.log_section(logging.DEBUG, "ONNX Graph:"):
547            diagnostic.debug("```\n%s\n```", onnxscript_graph.torch_graph)  # type: ignore[attr-defined]
548
549        return onnxscript_graph
550
551    def placeholder(
552        self,
553        node: torch.fx.Node,
554        onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
555        fx_name_to_onnxscript_value: dict[
556            str,
557            onnxscript_graph_building.TorchScriptTensor
558            | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
559        ],
560    ):
561        # Input of graph.
562        # The node.meta["val"] is generated by FakeTensorProp.
563        # NOTE: add_input() intends to create nodes with shape/type
564        fake_tensor = node.meta.get("val", None)
565        # NOTE: During the tracing, when inputs are constants, they are represented
566        # by nodes with node.meta['val'] being None (nn.Module to dynamo_export)
567        # or nodes with node.meta['val'] being a builtin value (ExportedProgram to dynamo_export).
568        # Nonethless, the nodes are not consumed by others, so we don't need to
569        # create a TorchScriptTensor for them.
570        if fake_tensor is None or isinstance(fake_tensor, (int, float, bool, str)):
571            output = onnxscript_graph.add_input(
572                input_name=None,
573            )
574        elif isinstance(fake_tensor, torch.Tensor):
575            # NOTE: ONNX doesn't support tensor of complex64/complex128, so we
576            # convert them to float32/float64 with real representation.
577            if fx_type_utils.is_torch_complex_dtype(fake_tensor.dtype):
578                fake_tensor = torch.view_as_real(fake_tensor.resolve_conj())
579            output = onnxscript_graph.add_input(
580                input_name=node.name,
581                shape=fake_tensor.shape,
582                dtype=fake_tensor.dtype,
583            )
584
585        elif fx_type_utils.is_torch_symbolic_type(fake_tensor):
586            output = onnxscript_graph.add_input(
587                input_name=node.name,
588                shape=torch.Size([]),
589                dtype=fx_type_utils.from_sym_value_to_torch_dtype(fake_tensor),
590            )
591        else:
592            raise RuntimeError(
593                f"Unsupported type(node.meta['val']) for placeholder: {type(fake_tensor)}"
594            )
595        assert (
596            output is not None
597        ), f"Node creates None with target={node.target} and name={node.name}"
598
599        assert isinstance(output, onnxscript_graph_building.TorchScriptTensor)
600        assert isinstance(output, onnxscript.tensor.Tensor)
601
602        fx_name_to_onnxscript_value[node.name] = output
603
604    def call_function(
605        self,
606        node: torch.fx.Node,
607        onnxscript_tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
608        fx_name_to_onnxscript_value: dict[
609            str,
610            onnxscript_graph_building.TorchScriptTensor
611            | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
612        ],
613        onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
614        fx_graph_module: torch.fx.GraphModule,
615    ):
616        # aten ops and other stateless functions.
617        if node.target == operator.getitem and isinstance(
618            fx_name_to_onnxscript_value[node.args[0].name],  # type: ignore[union-attr,index]
619            tuple,
620        ):
621            onnx_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name]  # type: ignore[union-attr,index]
622            index = node.args[1]
623            value = onnx_tensor_tuple[index]  # type: ignore[index]
624            assert (
625                value is not None
626            ), f"Node creates None with target={node.target} and name={node.name}"
627            assert isinstance(
628                value, (onnxscript_graph_building.TorchScriptTensor, tuple)
629            ), type(value)
630
631            fx_name_to_onnxscript_value[node.name] = value
632            return
633
634        # Map FX inputs to ONNX inputs and fill optional inputs with default values.
635        # torch_args and torch_kwargs are for op-level validation
636        fx_args, fx_kwargs = _fill_in_default_kwargs(node)
637
638        onnx_args, onnx_kwargs = _wrap_fx_args_as_onnxscript_args(
639            fx_args,
640            fx_kwargs,
641            fx_name_to_onnxscript_value,
642            onnxscript_tracer,
643        )
644        # Dispatch to ONNX op through OpShema. The input argument dtypes are compared to
645        # function signature in OpSchema, and find the best matched overload.
646        symbolic_fn = onnxfunction_dispatcher.dispatch(
647            node=node,
648            onnx_args=onnx_args,  # type: ignore[arg-type]
649            onnx_kwargs=onnx_kwargs,
650            diagnostic_context=self.diagnostic_context,
651        )
652        with onnxscript.evaluator.default_as(onnxscript_tracer):
653            output: (
654                onnxscript_graph_building.TorchScriptTensor
655                | tuple[onnxscript_graph_building.TorchScriptTensor, ...]
656            ) = symbolic_fn(*onnx_args, **onnx_kwargs)
657        assert (
658            output is not None
659        ), f"Node creates None with target={node.target}, name={node.name}, args={onnx_args}, kwargs={onnx_kwargs}"
660        # Assign type and shape from fx graph.
661        _fill_tensor_shape_type(output, node.name, node.meta["val"])
662        # One fx node could produce multiple outputs (e.g., tuple of tensors); in
663        # that case, v is a tuple of TorchScriptTensors.
664        assert isinstance(
665            output, (onnxscript_graph_building.TorchScriptTensor, tuple)
666        ), type(output)
667        fx_name_to_onnxscript_value[node.name] = output
668
669    def output(
670        self,
671        node: torch.fx.Node,
672        onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
673        fx_name_to_onnxscript_value: dict[
674            str,
675            onnxscript_graph_building.TorchScriptTensor
676            | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
677        ],
678    ):
679        if isinstance(node.args[0], torch.fx.Node):
680            onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[node.args[0].name]
681            onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
682        else:
683            # ONNX can't represent collection types (e.g., dictionary, tuple of tuple of
684            # tensor, etc), we flatten the collection and register each element as output.
685            flat_args, _ = _pytree.tree_flatten(node.args[0])
686            for arg in flat_args:
687                assert isinstance(
688                    arg, torch.fx.Node
689                ), f"arg must be a torch.fx.Node, not {type(arg)}"
690                onnx_tensor_or_tensor_tuple = fx_name_to_onnxscript_value[arg.name]
691                onnxscript_graph.register_outputs(onnx_tensor_or_tensor_tuple)
692
693    def call_method(self, node: torch.fx.Node):
694        # TODO(wechi): Support call_method.
695        raise RuntimeError("call_method is not supported yet.")
696
697    def call_module(
698        self,
699        node: torch.fx.Node,
700        parent_onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
701        fx_name_to_onnxscript_value: dict[
702            str,
703            onnxscript_graph_building.TorchScriptTensor
704            | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
705        ],
706        tracer: onnxscript_graph_building.TorchScriptTracingEvaluator,
707        root_fx_graph_module: torch.fx.GraphModule,
708        onnxfunction_dispatcher: onnxfunction_dispatcher.OnnxFunctionDispatcher,
709    ) -> None:
710        """Export a fx.GraphModule submodule to ONNXScript graph.
711
712        The export process specifically targets `call_module` nodes that are created by
713        the exporter's `Modularize` pass. Each `call_module` node has an associated fx.GraphModule
714        by `node.target` underneath the root fx.GraphModule. These `call_module` nodes are exported as ONNX
715        function nodes. The related `sub_module` is then exported as an ONNX model local function,
716        which is represented by another `TorchScriptGraph`. This `TorchScriptGraph` sets the current
717        `onnxscript_graph` as its parent.
718
719        Args:
720            node: The call_module node in the FX graph that represents the submodule call.
721            parent_onnxscript_graph: The parent ONNXScript graph to which the ONNX function and
722                function node belong.
723            fx_name_to_onnxscript_value: The mapping from FX node name to ONNXScript value.
724            tracer: The tracer used to trace the ONNXScript graph.
725            root_fx_graph_module: The root FX module.
726            onnxfunction_dispatcher: The dispatcher.
727        """
728        assert isinstance(
729            node.target, str
730        ), f"node.target must be a str, not {type(node.target)} for node {node}."
731
732        sub_module = root_fx_graph_module.get_submodule(node.target)
733
734        assert isinstance(
735            sub_module, torch.fx.GraphModule
736        ), f"sub_module must be a torch.fx.GraphModule, not {type(sub_module)} for node {node}."
737
738        sub_onnxscript_graph = self.run(
739            sub_module, onnxfunction_dispatcher, parent_onnxscript_graph
740        )
741
742        onnx_args, _ = _wrap_fx_args_as_onnxscript_args(
743            list(node.args), {}, fx_name_to_onnxscript_value, tracer
744        )
745
746        # TODO: We may want to consider other naming styles. The goal is to be stable and
747        # unique such that it can be easily identified in case of kernel substitution.
748        # Example for current style is combination of qualified module class name and
749        # module attribute name: `torch_nn_modules_conv_Conv2d_conv1`.
750        # Other naming styles such as qualified module class name made unique can also
751        # be considered.
752        unique_module_name = f"{sub_module._get_name()}_{node.target}"
753
754        outputs: (
755            onnxscript_graph_building.TorchScriptTensor
756            | tuple[onnxscript_graph_building.TorchScriptTensor, ...]
757        ) = parent_onnxscript_graph.add_module_call(  # type: ignore[assignment]
758            unique_module_name, sub_onnxscript_graph, onnx_args
759        )
760
761        assert isinstance(
762            outputs, (onnxscript_graph_building.TorchScriptTensor, tuple)
763        ), f"Unexpected outputs type {type(outputs)} for node {node}."
764
765        _fill_tensor_shape_type(outputs, node.name, node.meta["val"])
766        fx_name_to_onnxscript_value[node.name] = outputs
767
768        # Skip op_level_validation for call_module. Subgraph nodes are validated individually.
769
770    def get_attr(
771        self,
772        node: torch.fx.Node,
773        onnxscript_graph: onnxscript_graph_building.TorchScriptGraph,
774        fx_name_to_onnxscript_value: dict[
775            str,
776            onnxscript_graph_building.TorchScriptTensor
777            | tuple[onnxscript_graph_building.TorchScriptTensor, ...],
778        ],
779        fx_graph_module: torch.fx.GraphModule,
780    ):
781        assert isinstance(node.target, str), f"node.target {node.target} is not a str."
782        attr_tensor = getattr(fx_graph_module, node.target)
783        assert isinstance(attr_tensor, torch.Tensor), f"{attr_tensor} is not a tensor."
784
785        # Parameter/buffer name cannot contain "."
786        # Revert from "/" to restore namespace formatting.
787        input_ = onnxscript_graph.add_initializer(
788            name=node.target.replace("/", "."),
789            value=attr_tensor,
790        )
791
792        assert isinstance(input_, onnxscript_graph_building.TorchScriptTensor)
793        assert isinstance(input_, onnxscript.tensor.Tensor)
794        fx_name_to_onnxscript_value[node.name] = input_
795