xref: /aosp_15_r20/external/pytorch/torch/distributed/nn/jit/instantiator.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/python3
2# mypy: allow-untyped-defs
3import importlib
4import logging
5import os
6import sys
7import tempfile
8from typing import Optional
9
10import torch
11from torch.distributed.nn.jit.templates.remote_module_template import (
12    get_remote_module_template,
13)
14
15
16logger = logging.getLogger(__name__)
17
18
19_FILE_PREFIX = "_remote_module_"
20_TEMP_DIR = tempfile.TemporaryDirectory()
21INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name
22logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH)
23sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH)
24
25
26def get_arg_return_types_from_interface(module_interface):
27    assert getattr(
28        module_interface, "__torch_script_interface__", False
29    ), "Expect a TorchScript class interface decorated by @torch.jit.interface."
30    qualified_name = torch._jit_internal._qualified_name(module_interface)
31    cu = torch.jit._state._python_cu
32    module_interface_c = cu.get_interface(qualified_name)
33    assert (
34        "forward" in module_interface_c.getMethodNames()
35    ), f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}"
36    method_schema = module_interface_c.getMethod("forward")
37
38    arg_str_list = []
39    arg_type_str_list = []
40    assert method_schema is not None
41    for argument in method_schema.arguments:
42        arg_str_list.append(argument.name)
43
44        if argument.has_default_value():
45            default_value_str = f" = {argument.default_value}"
46        else:
47            default_value_str = ""
48        arg_type_str = f"{argument.name}: {argument.type}{default_value_str}"
49        arg_type_str_list.append(arg_type_str)
50
51    arg_str_list = arg_str_list[1:]  # Remove "self".
52    args_str = ", ".join(arg_str_list)
53
54    arg_type_str_list = arg_type_str_list[1:]  # Remove "self".
55    arg_types_str = ", ".join(arg_type_str_list)
56
57    assert len(method_schema.returns) == 1
58    argument = method_schema.returns[0]
59    return_type_str = str(argument.type)
60
61    return args_str, arg_types_str, return_type_str
62
63
64def _write(out_path, text):
65    old_text: Optional[str]
66    try:
67        with open(out_path) as f:
68            old_text = f.read()
69    except OSError:
70        old_text = None
71    if old_text != text:
72        with open(out_path, "w") as f:
73            logger.info("Writing %s", out_path)
74            f.write(text)
75    else:
76        logger.info("Skipped writing %s", out_path)
77
78
79def _do_instantiate_remote_module_template(
80    generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
81):
82    generated_code_text = get_remote_module_template(
83        enable_moving_cpu_tensors_to_cuda
84    ).format(**str_dict)
85    out_path = os.path.join(
86        INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py"
87    )
88    _write(out_path, generated_code_text)
89
90    # From importlib doc,
91    # > If you are dynamically importing a module that was created since
92    # the interpreter began execution (e.g., created a Python source file),
93    # you may need to call invalidate_caches() in order for the new module
94    # to be noticed by the import system.
95    importlib.invalidate_caches()
96    generated_module = importlib.import_module(f"{generated_module_name}")
97    return generated_module
98
99
100def instantiate_scriptable_remote_module_template(
101    module_interface_cls, enable_moving_cpu_tensors_to_cuda=True
102):
103    if not getattr(module_interface_cls, "__torch_script_interface__", False):
104        raise ValueError(
105            f"module_interface_cls {module_interface_cls} must be a type object decorated by "
106            "@torch.jit.interface"
107        )
108
109    # Generate the template instance name.
110    module_interface_cls_name = torch._jit_internal._qualified_name(
111        module_interface_cls
112    ).replace(".", "_")
113    generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}"
114
115    # Generate type annotation strs.
116    assign_module_interface_cls_str = (
117        f"from {module_interface_cls.__module__} import "
118        f"{module_interface_cls.__name__} as module_interface_cls"
119    )
120    args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface(
121        module_interface_cls
122    )
123    kwargs_str = ""
124    arrow_and_return_type_str = f" -> {return_type_str}"
125    arrow_and_future_return_type_str = f" -> Future[{return_type_str}]"
126
127    str_dict = dict(
128        assign_module_interface_cls=assign_module_interface_cls_str,
129        arg_types=arg_types_str,
130        arrow_and_return_type=arrow_and_return_type_str,
131        arrow_and_future_return_type=arrow_and_future_return_type_str,
132        args=args_str,
133        kwargs=kwargs_str,
134        jit_script_decorator="@torch.jit.script",
135    )
136    return _do_instantiate_remote_module_template(
137        generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda
138    )
139
140
141def instantiate_non_scriptable_remote_module_template():
142    generated_module_name = f"{_FILE_PREFIX}non_scriptable"
143    str_dict = dict(
144        assign_module_interface_cls="module_interface_cls = None",
145        args="*args",
146        kwargs="**kwargs",
147        arg_types="*args, **kwargs",
148        arrow_and_return_type="",
149        arrow_and_future_return_type="",
150        jit_script_decorator="",
151    )
152    # For a non-scriptable template, always enable moving CPU tensors to a cuda device,
153    # because there is no syntax limitation on the extra handling caused by the script.
154    return _do_instantiate_remote_module_template(generated_module_name, str_dict, True)
155