1#!/usr/bin/python3 2# mypy: allow-untyped-defs 3import importlib 4import logging 5import os 6import sys 7import tempfile 8from typing import Optional 9 10import torch 11from torch.distributed.nn.jit.templates.remote_module_template import ( 12 get_remote_module_template, 13) 14 15 16logger = logging.getLogger(__name__) 17 18 19_FILE_PREFIX = "_remote_module_" 20_TEMP_DIR = tempfile.TemporaryDirectory() 21INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name 22logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH) 23sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH) 24 25 26def get_arg_return_types_from_interface(module_interface): 27 assert getattr( 28 module_interface, "__torch_script_interface__", False 29 ), "Expect a TorchScript class interface decorated by @torch.jit.interface." 30 qualified_name = torch._jit_internal._qualified_name(module_interface) 31 cu = torch.jit._state._python_cu 32 module_interface_c = cu.get_interface(qualified_name) 33 assert ( 34 "forward" in module_interface_c.getMethodNames() 35 ), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}" 36 method_schema = module_interface_c.getMethod("forward") 37 38 arg_str_list = [] 39 arg_type_str_list = [] 40 assert method_schema is not None 41 for argument in method_schema.arguments: 42 arg_str_list.append(argument.name) 43 44 if argument.has_default_value(): 45 default_value_str = f" = {argument.default_value}" 46 else: 47 default_value_str = "" 48 arg_type_str = f"{argument.name}: {argument.type}{default_value_str}" 49 arg_type_str_list.append(arg_type_str) 50 51 arg_str_list = arg_str_list[1:] # Remove "self". 52 args_str = ", ".join(arg_str_list) 53 54 arg_type_str_list = arg_type_str_list[1:] # Remove "self". 55 arg_types_str = ", ".join(arg_type_str_list) 56 57 assert len(method_schema.returns) == 1 58 argument = method_schema.returns[0] 59 return_type_str = str(argument.type) 60 61 return args_str, arg_types_str, return_type_str 62 63 64def _write(out_path, text): 65 old_text: Optional[str] 66 try: 67 with open(out_path) as f: 68 old_text = f.read() 69 except OSError: 70 old_text = None 71 if old_text != text: 72 with open(out_path, "w") as f: 73 logger.info("Writing %s", out_path) 74 f.write(text) 75 else: 76 logger.info("Skipped writing %s", out_path) 77 78 79def _do_instantiate_remote_module_template( 80 generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda 81): 82 generated_code_text = get_remote_module_template( 83 enable_moving_cpu_tensors_to_cuda 84 ).format(**str_dict) 85 out_path = os.path.join( 86 INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py" 87 ) 88 _write(out_path, generated_code_text) 89 90 # From importlib doc, 91 # > If you are dynamically importing a module that was created since 92 # the interpreter began execution (e.g., created a Python source file), 93 # you may need to call invalidate_caches() in order for the new module 94 # to be noticed by the import system. 95 importlib.invalidate_caches() 96 generated_module = importlib.import_module(f"{generated_module_name}") 97 return generated_module 98 99 100def instantiate_scriptable_remote_module_template( 101 module_interface_cls, enable_moving_cpu_tensors_to_cuda=True 102): 103 if not getattr(module_interface_cls, "__torch_script_interface__", False): 104 raise ValueError( 105 f"module_interface_cls {module_interface_cls} must be a type object decorated by " 106 "@torch.jit.interface" 107 ) 108 109 # Generate the template instance name. 110 module_interface_cls_name = torch._jit_internal._qualified_name( 111 module_interface_cls 112 ).replace(".", "_") 113 generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}" 114 115 # Generate type annotation strs. 116 assign_module_interface_cls_str = ( 117 f"from {module_interface_cls.__module__} import " 118 f"{module_interface_cls.__name__} as module_interface_cls" 119 ) 120 args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface( 121 module_interface_cls 122 ) 123 kwargs_str = "" 124 arrow_and_return_type_str = f" -> {return_type_str}" 125 arrow_and_future_return_type_str = f" -> Future[{return_type_str}]" 126 127 str_dict = dict( 128 assign_module_interface_cls=assign_module_interface_cls_str, 129 arg_types=arg_types_str, 130 arrow_and_return_type=arrow_and_return_type_str, 131 arrow_and_future_return_type=arrow_and_future_return_type_str, 132 args=args_str, 133 kwargs=kwargs_str, 134 jit_script_decorator="@torch.jit.script", 135 ) 136 return _do_instantiate_remote_module_template( 137 generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda 138 ) 139 140 141def instantiate_non_scriptable_remote_module_template(): 142 generated_module_name = f"{_FILE_PREFIX}non_scriptable" 143 str_dict = dict( 144 assign_module_interface_cls="module_interface_cls = None", 145 args="*args", 146 kwargs="**kwargs", 147 arg_types="*args, **kwargs", 148 arrow_and_return_type="", 149 arrow_and_future_return_type="", 150 jit_script_decorator="", 151 ) 152 # For a non-scriptable template, always enable moving CPU tensors to a cuda device, 153 # because there is no syntax limitation on the extra handling caused by the script. 154 return _do_instantiate_remote_module_template(generated_module_name, str_dict, True) 155