xref: /aosp_15_r20/external/pytorch/torch/export/passes/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Dict, Union
2
3import torch
4import torch.utils._pytree as pytree
5from torch.export.exported_program import ExportedProgram
6
7
8__all__ = ["move_to_device_pass"]
9
10
11def move_to_device_pass(
12    ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]]
13) -> ExportedProgram:
14    """
15    Move the exported program to the given device.
16
17    Args:
18        ep (ExportedProgram): The exported program to move.
19        location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to.
20            If a string, it is interpreted as a device name.
21            If a dict, it is interpreted as a mapping from
22            the existing device to the intended one
23
24    Returns:
25        ExportedProgram: The moved exported program.
26    """
27
28    def _get_new_device(
29        curr_device: torch.device,
30        location: Union[torch.device, str, Dict[str, str]],
31    ) -> str:
32        if isinstance(location, dict):
33            if str(curr_device) in location.keys():
34                return location[str(curr_device)]
35            else:
36                return str(curr_device)
37        else:
38            return str(location)
39
40    # move all the state_dict
41    for k, v in ep.state_dict.items():
42        if isinstance(v, torch.nn.Parameter):
43            ep._state_dict[k] = torch.nn.Parameter(
44                v.to(_get_new_device(v.device, location))
45            )
46        else:
47            ep._state_dict[k] = v.to(_get_new_device(v.device, location))
48
49    # move all the constants
50    for k, v in ep.constants.items():
51        if isinstance(v, torch.Tensor):
52            ep._constants[k] = v.to(_get_new_device(v.device, location))
53
54    for node in ep.graph.nodes:
55        # move all the nodes kwargs with burnt-in device
56        if "device" in node.kwargs:
57            kwargs = node.kwargs.copy()
58            kwargs["device"] = _get_new_device(kwargs["device"], location)
59            node.kwargs = kwargs
60        # move all the tensor metadata
61        node.meta["val"] = pytree.tree_map(
62            lambda v: v.to(_get_new_device(v.device, location))
63            if isinstance(v, torch.Tensor)
64            else v,
65            node.meta.get("val"),
66        )
67
68    ep.validate()
69    return ep
70