1import enum 2from typing import Any, Callable, Dict, List, NamedTuple, Union 3 4from torch.fx.graph import Node 5 6 7class NSSingleResultValuesType(str, enum.Enum): 8 WEIGHT = "weight" 9 NODE_OUTPUT = "node_output" 10 NODE_INPUT = "node_input" 11 12 13class NSSubgraph(NamedTuple): 14 start_node: Node 15 end_node: Node 16 base_op_node: Node 17 18 19# TODO(future PR): see if we can use typing_extensions's TypedDict instead 20# to properly type the various keys 21# { 22# # one of NSSingleResultValuesType 23# 'type': 'weight', 24# # the values of type specified above 25# 'values': [torch.tensor(...), ...], 26# # name of the node directly before the logger 27# 'prev_node_name': 'linear1', 28# # type of the underlying function or module 29# 'prev_node_target_type': torch.nn.functional.linear # or torch.nn.Linear, etc 30# # name of the node responsible for adding this logger 31# # Note: this may differ from prev_node_name if we are logging inputs 32# 'ref_node_name': 'linear1', 33# # index of this node within the arg of the input/output node 34# # for example, in cat([x1, x2, x3], dim=0), x2 would have index_within_arg == 1 35# 'index_within_arg': 0, 36# # index of this node within the args of the input/output node 37# # for example, in add(x1, x2), x2 would have index_of_arg == 1 38# 'index_of_arg': 0, 39# # precomputed comparisons of logger values to reference values 40# 'comparisons': [torch.tensor(...), ...] 41# # name of function used for precomputed comparisons 42# 'comparison_fn_name': 'sqnr', 43# # string representation of qconfig responsible for creating this logger 44# 'qconfig_str': 'QConfig(...)', 45# } 46NSSingleResultType = Dict[str, Any] 47 48# { 49# 'layer_name_1': { # subgraph name 50# 'node_output': { # results type (node_output, node_input, weight) 51# 'model_name_a': # model name 52# [NSSingleResultType, ...], # results, ordered by index_within_arg 53# 'model_name_b': 54# [NSSingleResultType, ...], 55# }, 56# }, 57# } 58# 59NSResultsType = Dict[str, Dict[str, Dict[str, List[NSSingleResultType]]]] 60 61# Defines the underlying target type of a node, for example: 62# `F.conv1d` for a `call_function` conv node 63# `nn.Conv1d` for a `call_module` node calling the forward of a `nn.Conv1d` module 64# `'sigmoid'` for a `call_method` node calling `x.sigmoid()` 65NSNodeTargetType = Union[Callable, str] 66