1# mypy: allow-untyped-defs 2""" 3Tools to help with tensor property propagation. 4 5This is not intended to be imported directly; please use the exposed 6functionalities in `torch.jit`. 7""" 8 9from typing import Any, List 10 11import torch 12from torch import TensorType 13from torch._C import Graph 14 15 16def apply_input_props_using_example(graph: Graph, example_input: List[Any]): 17 """ 18 Applies properties for each tensor in the graph inputs 19 using the example supplied. 20 """ 21 graph_inputs = list(graph.inputs()) 22 if len(graph_inputs) == 0: 23 return 24 25 # Strip self args off for methods 26 in_0 = graph_inputs[0] 27 if isinstance(in_0.type(), torch._C.ClassType) and in_0.debugName() == "self": 28 graph_inputs = graph_inputs[1:] 29 30 if not len(graph_inputs) == len(example_input): 31 raise RuntimeError( 32 "Number of inputs in graph does not match number of inputs in the example" 33 ) 34 35 for i, (graph_i, example_i) in enumerate(zip(graph_inputs, example_input)): 36 if example_i is None: 37 continue # Skip the type check 38 39 if isinstance(example_i, torch.Tensor) != isinstance( 40 graph_i.type(), TensorType 41 ): 42 raise RuntimeError( 43 f"Input {i} does not match type of example", graph_i, example_i 44 ) 45 46 if isinstance(example_i, torch.Tensor): 47 graph_i.setType(TensorType.create_from_tensor(example_i)) # type: ignore[arg-type] 48