xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/ns_types.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import enum
2from typing import Any, Callable, Dict, List, NamedTuple, Union
3
4from torch.fx.graph import Node
5
6
7class NSSingleResultValuesType(str, enum.Enum):
8    WEIGHT = "weight"
9    NODE_OUTPUT = "node_output"
10    NODE_INPUT = "node_input"
11
12
13class NSSubgraph(NamedTuple):
14    start_node: Node
15    end_node: Node
16    base_op_node: Node
17
18
19# TODO(future PR): see if we can use typing_extensions's TypedDict instead
20# to properly type the various keys
21# {
22#   # one of NSSingleResultValuesType
23#   'type': 'weight',
24#   # the values of type specified above
25#   'values': [torch.tensor(...), ...],
26#   # name of the node directly before the logger
27#   'prev_node_name': 'linear1',
28#   # type of the underlying function or module
29#   'prev_node_target_type': torch.nn.functional.linear  # or torch.nn.Linear, etc
30#   # name of the node responsible for adding this logger
31#   # Note: this may differ from prev_node_name if we are logging inputs
32#   'ref_node_name': 'linear1',
33#   # index of this node within the arg of the input/output node
34#   # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1
35#   'index_within_arg': 0,
36#   # index of this node within the args of the input/output node
37#   # for example, in add(x1, x2), x2 would have index_of_arg == 1
38#   'index_of_arg': 0,
39#   # precomputed comparisons of logger values to reference values
40#   'comparisons': [torch.tensor(...), ...]
41#   # name of function used for precomputed comparisons
42#   'comparison_fn_name': 'sqnr',
43#   # string representation of qconfig responsible for creating this logger
44#   'qconfig_str': 'QConfig(...)',
45# }
46NSSingleResultType = Dict[str, Any]
47
48# {
49#   'layer_name_1': {  # subgraph name
50#     'node_output': {  # results type (node_output, node_input, weight)
51#       'model_name_a':  # model name
52#          [NSSingleResultType, ...],  # results, ordered by index_within_arg
53#       'model_name_b':
54#          [NSSingleResultType, ...],
55#     },
56#   },
57# }
58#
59NSResultsType = Dict[str, Dict[str, Dict[str, List[NSSingleResultType]]]]
60
61# Defines the underlying target type of a node, for example:
62# `F.conv1d` for a `call_function` conv node
63# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module
64# `'sigmoid'` for a `call_method` node calling `x.sigmoid()`
65NSNodeTargetType = Union[Callable, str]
66