xref: /aosp_15_r20/external/pytorch/torch/fx/passes/graph_manipulation.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Any, Dict, List, NamedTuple, Optional
3
4import torch
5from torch.fx._compatibility import compatibility
6from torch.fx.graph import Graph
7from torch.fx.graph_module import GraphModule
8from torch.fx.node import (
9    map_arg,
10    Node,
11    Target,
12)
13from torch.fx.passes.shape_prop import ShapeProp
14
15__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta',
16           'get_size_of_node']
17
18@compatibility(is_backward_compatible=False)
19def replace_target_nodes_with(
20    fx_module: GraphModule,
21    old_op: str,
22    old_target: Target,
23    new_op: str,
24    new_target: Target,
25):
26    """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target,
27    and updates them to match the new op code and target"""
28    new_graph = Graph()
29    val_map: Dict[Node, Node] = {}
30    for node in fx_module.graph.nodes:
31        if node.op == old_op and node.target == old_target:
32            args = map_arg(node.args, lambda n: val_map[n])
33            kwargs = map_arg(node.kwargs, lambda n: val_map[n])
34            assert isinstance(args, tuple)
35            assert isinstance(kwargs, dict)
36            val_map[node] = new_graph.create_node(
37                new_op, new_target, args, kwargs, node.name
38            )
39        else:
40            val_map[node] = new_graph.node_copy(node, lambda n: val_map[n])
41    fx_module.graph = new_graph
42
43
44@compatibility(is_backward_compatible=False)
45class size_bytes(NamedTuple):
46    output_size: int
47    total_size: int
48
49
50@compatibility(is_backward_compatible=False)
51def get_size_of_all_nodes(
52    fx_module: GraphModule, args: Optional[List[torch.Tensor]] = None
53) -> None:
54    """Given a fx graph module, update each node with its total size (weights + bias + output)
55    and its output_size(output). For a non-module node, the total size is the output size.
56    return total size"""
57    if args is not None:
58        # Mark shape and dtype for each node (node.shape and node.dtype)
59        ShapeProp(fx_module).propagate(*args)
60    # Calculate the total size of the whole fx graph
61    total_size_of_graph = 0.0
62    for node in fx_module.graph.nodes:
63        if node.op == "output":
64            break
65        node.size_bytes = get_size_of_node(fx_module, node)
66    return
67
68
69@compatibility(is_backward_compatible=False)
70def get_tensor_meta(node: Node) -> Any:
71    tensor_meta = node.meta.get("tensor_meta")
72
73    if not tensor_meta:
74        raise RuntimeError(
75            f"Node {node} has no tensor metadata associated with it! "
76            f"Check that shape propagation has run."
77        )
78
79    return tensor_meta
80
81
82@compatibility(is_backward_compatible=False)
83def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes:
84    """Given a node with node.dtype and node.shape, return its total size and its output size.
85    total_size = weights + bias + output_size
86    """
87    # Total num of elements
88    total_num_of_elems = 0
89    # For a module, conside all parameters
90    if node.op == "call_module":
91        submodule_dict = dict(fx_module.named_modules())
92        submodule = submodule_dict[node.target]
93        parameters = submodule.named_parameters()
94        # Parameters are named tuples
95        for name, p in parameters:
96            total_num_of_elems += p.numel()
97    # Don't forget the output size
98    # node.shape is the shape of this node's output
99    tensor_meta = get_tensor_meta(node)
100    output_elem = tensor_meta.shape.numel()
101    total_num_of_elems += output_elem
102    # Assume for now if it's quantized then it's qint8 or quint8
103    if tensor_meta.is_quantized:
104        size_per_elem_bytes = torch._empty_affine_quantized(
105            [], dtype=tensor_meta.dtype
106        ).element_size()
107    else:
108        size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size()
109    total_size = size_per_elem_bytes * total_num_of_elems
110    output_size = size_per_elem_bytes * output_elem
111    return size_bytes(output_size, total_size)
112