xref: /aosp_15_r20/external/pytorch/torch/jit/_passes/_property_propagation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""
3Tools to help with tensor property propagation.
4
5This is not intended to be imported directly; please use the exposed
6functionalities in `torch.jit`.
7"""
8
9from typing import Any, List
10
11import torch
12from torch import TensorType
13from torch._C import Graph
14
15
16def apply_input_props_using_example(graph: Graph, example_input: List[Any]):
17    """
18    Applies properties for each tensor in the graph inputs
19    using the example supplied.
20    """
21    graph_inputs = list(graph.inputs())
22    if len(graph_inputs) == 0:
23        return
24
25    # Strip self args off for methods
26    in_0 = graph_inputs[0]
27    if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self":
28        graph_inputs = graph_inputs[1:]
29
30    if not len(graph_inputs) == len(example_input):
31        raise RuntimeError(
32            "Number of inputs in graph does not match number of inputs in the example"
33        )
34
35    for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)):
36        if example_i is None:
37            continue  # Skip the type check
38
39        if isinstance(example_i, torch.Tensor) != isinstance(
40            graph_i.type(), TensorType
41        ):
42            raise RuntimeError(
43                f"Input {i} does not match type of example", graph_i, example_i
44            )
45
46        if isinstance(example_i, torch.Tensor):
47            graph_i.setType(TensorType.create_from_tensor(example_i))  # type: ignore[arg-type]
48