1from typing import Dict, Union 2 3import torch 4import torch.utils._pytree as pytree 5from torch.export.exported_program import ExportedProgram 6 7 8__all__ = ["move_to_device_pass"] 9 10 11def move_to_device_pass( 12 ep: ExportedProgram, location: Union[torch.device, str, Dict[str, str]] 13) -> ExportedProgram: 14 """ 15 Move the exported program to the given device. 16 17 Args: 18 ep (ExportedProgram): The exported program to move. 19 location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to. 20 If a string, it is interpreted as a device name. 21 If a dict, it is interpreted as a mapping from 22 the existing device to the intended one 23 24 Returns: 25 ExportedProgram: The moved exported program. 26 """ 27 28 def _get_new_device( 29 curr_device: torch.device, 30 location: Union[torch.device, str, Dict[str, str]], 31 ) -> str: 32 if isinstance(location, dict): 33 if str(curr_device) in location.keys(): 34 return location[str(curr_device)] 35 else: 36 return str(curr_device) 37 else: 38 return str(location) 39 40 # move all the state_dict 41 for k, v in ep.state_dict.items(): 42 if isinstance(v, torch.nn.Parameter): 43 ep._state_dict[k] = torch.nn.Parameter( 44 v.to(_get_new_device(v.device, location)) 45 ) 46 else: 47 ep._state_dict[k] = v.to(_get_new_device(v.device, location)) 48 49 # move all the constants 50 for k, v in ep.constants.items(): 51 if isinstance(v, torch.Tensor): 52 ep._constants[k] = v.to(_get_new_device(v.device, location)) 53 54 for node in ep.graph.nodes: 55 # move all the nodes kwargs with burnt-in device 56 if "device" in node.kwargs: 57 kwargs = node.kwargs.copy() 58 kwargs["device"] = _get_new_device(kwargs["device"], location) 59 node.kwargs = kwargs 60 # move all the tensor metadata 61 node.meta["val"] = pytree.tree_map( 62 lambda v: v.to(_get_new_device(v.device, location)) 63 if isinstance(v, torch.Tensor) 64 else v, 65 node.meta.get("val"), 66 ) 67 68 ep.validate() 69 return ep 70