xref: /aosp_15_r20/external/pytorch/torch/jit/annotations.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport ast
3*da0073e9SAndroid Build Coastguard Workerimport builtins
4*da0073e9SAndroid Build Coastguard Workerimport dis
5*da0073e9SAndroid Build Coastguard Workerimport enum
6*da0073e9SAndroid Build Coastguard Workerimport inspect
7*da0073e9SAndroid Build Coastguard Workerimport re
8*da0073e9SAndroid Build Coastguard Workerimport typing
9*da0073e9SAndroid Build Coastguard Workerimport warnings
10*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent
11*da0073e9SAndroid Build Coastguard Workerfrom typing import Type
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Workerimport torch
14*da0073e9SAndroid Build Coastguard Workerfrom torch._C import (
15*da0073e9SAndroid Build Coastguard Worker    _GeneratorType,
16*da0073e9SAndroid Build Coastguard Worker    AnyType,
17*da0073e9SAndroid Build Coastguard Worker    AwaitType,
18*da0073e9SAndroid Build Coastguard Worker    BoolType,
19*da0073e9SAndroid Build Coastguard Worker    ComplexType,
20*da0073e9SAndroid Build Coastguard Worker    DeviceObjType,
21*da0073e9SAndroid Build Coastguard Worker    DictType,
22*da0073e9SAndroid Build Coastguard Worker    EnumType,
23*da0073e9SAndroid Build Coastguard Worker    FloatType,
24*da0073e9SAndroid Build Coastguard Worker    FutureType,
25*da0073e9SAndroid Build Coastguard Worker    InterfaceType,
26*da0073e9SAndroid Build Coastguard Worker    IntType,
27*da0073e9SAndroid Build Coastguard Worker    ListType,
28*da0073e9SAndroid Build Coastguard Worker    NoneType,
29*da0073e9SAndroid Build Coastguard Worker    NumberType,
30*da0073e9SAndroid Build Coastguard Worker    OptionalType,
31*da0073e9SAndroid Build Coastguard Worker    StreamObjType,
32*da0073e9SAndroid Build Coastguard Worker    StringType,
33*da0073e9SAndroid Build Coastguard Worker    TensorType,
34*da0073e9SAndroid Build Coastguard Worker    TupleType,
35*da0073e9SAndroid Build Coastguard Worker    UnionType,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Workerfrom torch._jit_internal import (  # type: ignore[attr-defined]
38*da0073e9SAndroid Build Coastguard Worker    _Await,
39*da0073e9SAndroid Build Coastguard Worker    _qualified_name,
40*da0073e9SAndroid Build Coastguard Worker    Any,
41*da0073e9SAndroid Build Coastguard Worker    BroadcastingList1,
42*da0073e9SAndroid Build Coastguard Worker    BroadcastingList2,
43*da0073e9SAndroid Build Coastguard Worker    BroadcastingList3,
44*da0073e9SAndroid Build Coastguard Worker    Dict,
45*da0073e9SAndroid Build Coastguard Worker    Future,
46*da0073e9SAndroid Build Coastguard Worker    is_await,
47*da0073e9SAndroid Build Coastguard Worker    is_dict,
48*da0073e9SAndroid Build Coastguard Worker    is_future,
49*da0073e9SAndroid Build Coastguard Worker    is_ignored_fn,
50*da0073e9SAndroid Build Coastguard Worker    is_list,
51*da0073e9SAndroid Build Coastguard Worker    is_optional,
52*da0073e9SAndroid Build Coastguard Worker    is_tuple,
53*da0073e9SAndroid Build Coastguard Worker    is_union,
54*da0073e9SAndroid Build Coastguard Worker    List,
55*da0073e9SAndroid Build Coastguard Worker    Optional,
56*da0073e9SAndroid Build Coastguard Worker    Tuple,
57*da0073e9SAndroid Build Coastguard Worker    Union,
58*da0073e9SAndroid Build Coastguard Worker)
59*da0073e9SAndroid Build Coastguard Workerfrom torch._sources import get_source_lines_and_file
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerfrom ._state import _get_script_class
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Workerif torch.distributed.rpc.is_available():
65*da0073e9SAndroid Build Coastguard Worker    from torch._C import RRefType
66*da0073e9SAndroid Build Coastguard Worker    from torch._jit_internal import is_rref, RRef
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import OpOverloadPacket
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerclass Module:
72*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name, members):
73*da0073e9SAndroid Build Coastguard Worker        self.name = name
74*da0073e9SAndroid Build Coastguard Worker        self.members = members
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, name):
77*da0073e9SAndroid Build Coastguard Worker        try:
78*da0073e9SAndroid Build Coastguard Worker            return self.members[name]
79*da0073e9SAndroid Build Coastguard Worker        except KeyError:
80*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
81*da0073e9SAndroid Build Coastguard Worker                f"Module {self.name} has no member called {name}"
82*da0073e9SAndroid Build Coastguard Worker            ) from None
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Workerclass EvalEnv:
86*da0073e9SAndroid Build Coastguard Worker    env = {
87*da0073e9SAndroid Build Coastguard Worker        "torch": Module("torch", {"Tensor": torch.Tensor}),
88*da0073e9SAndroid Build Coastguard Worker        "Tensor": torch.Tensor,
89*da0073e9SAndroid Build Coastguard Worker        "typing": Module("typing", {"Tuple": Tuple}),
90*da0073e9SAndroid Build Coastguard Worker        "Tuple": Tuple,
91*da0073e9SAndroid Build Coastguard Worker        "List": List,
92*da0073e9SAndroid Build Coastguard Worker        "Dict": Dict,
93*da0073e9SAndroid Build Coastguard Worker        "Optional": Optional,
94*da0073e9SAndroid Build Coastguard Worker        "Union": Union,
95*da0073e9SAndroid Build Coastguard Worker        "Future": Future,
96*da0073e9SAndroid Build Coastguard Worker        "Await": _Await,
97*da0073e9SAndroid Build Coastguard Worker    }
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker    def __init__(self, rcb):
100*da0073e9SAndroid Build Coastguard Worker        self.rcb = rcb
101*da0073e9SAndroid Build Coastguard Worker        if torch.distributed.rpc.is_available():
102*da0073e9SAndroid Build Coastguard Worker            self.env["RRef"] = RRef
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, name):
105*da0073e9SAndroid Build Coastguard Worker        if name in self.env:
106*da0073e9SAndroid Build Coastguard Worker            return self.env[name]
107*da0073e9SAndroid Build Coastguard Worker        if self.rcb is not None:
108*da0073e9SAndroid Build Coastguard Worker            return self.rcb(name)
109*da0073e9SAndroid Build Coastguard Worker        return getattr(builtins, name, None)
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Workerdef get_signature(fn, rcb, loc, is_method):
113*da0073e9SAndroid Build Coastguard Worker    if isinstance(fn, OpOverloadPacket):
114*da0073e9SAndroid Build Coastguard Worker        signature = try_real_annotations(fn.op, loc)
115*da0073e9SAndroid Build Coastguard Worker    else:
116*da0073e9SAndroid Build Coastguard Worker        signature = try_real_annotations(fn, loc)
117*da0073e9SAndroid Build Coastguard Worker    if signature is not None and is_method:
118*da0073e9SAndroid Build Coastguard Worker        # If this is a method, then the signature will include a type for
119*da0073e9SAndroid Build Coastguard Worker        # `self`, but type comments do not contain a `self`. So strip it
120*da0073e9SAndroid Build Coastguard Worker        # away here so everything is consistent (`inspect.ismethod` does
121*da0073e9SAndroid Build Coastguard Worker        # not work here since `fn` is unbound at this point)
122*da0073e9SAndroid Build Coastguard Worker        param_types, return_type = signature
123*da0073e9SAndroid Build Coastguard Worker        param_types = param_types[1:]
124*da0073e9SAndroid Build Coastguard Worker        signature = (param_types, return_type)
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    if signature is None:
127*da0073e9SAndroid Build Coastguard Worker        type_line, source = None, None
128*da0073e9SAndroid Build Coastguard Worker        try:
129*da0073e9SAndroid Build Coastguard Worker            source = dedent("".join(get_source_lines_and_file(fn)[0]))
130*da0073e9SAndroid Build Coastguard Worker            type_line = get_type_line(source)
131*da0073e9SAndroid Build Coastguard Worker        except TypeError:
132*da0073e9SAndroid Build Coastguard Worker            pass
133*da0073e9SAndroid Build Coastguard Worker        # This might happen both because we failed to get the source of fn, or
134*da0073e9SAndroid Build Coastguard Worker        # because it didn't have any annotations.
135*da0073e9SAndroid Build Coastguard Worker        if type_line is not None:
136*da0073e9SAndroid Build Coastguard Worker            signature = parse_type_line(type_line, rcb, loc)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    return signature
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Workerdef is_function_or_method(the_callable):
142*da0073e9SAndroid Build Coastguard Worker    # A stricter version of `inspect.isroutine` that does not pass for built-in
143*da0073e9SAndroid Build Coastguard Worker    # functions
144*da0073e9SAndroid Build Coastguard Worker    return inspect.isfunction(the_callable) or inspect.ismethod(the_callable)
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Workerdef is_vararg(the_callable):
148*da0073e9SAndroid Build Coastguard Worker    if not is_function_or_method(the_callable) and callable(the_callable):  # noqa: B004
149*da0073e9SAndroid Build Coastguard Worker        # If `the_callable` is a class, de-sugar the call so we can still get
150*da0073e9SAndroid Build Coastguard Worker        # the signature
151*da0073e9SAndroid Build Coastguard Worker        the_callable = the_callable.__call__
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker    if is_function_or_method(the_callable):
154*da0073e9SAndroid Build Coastguard Worker        return inspect.getfullargspec(the_callable).varargs is not None
155*da0073e9SAndroid Build Coastguard Worker    else:
156*da0073e9SAndroid Build Coastguard Worker        return False
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Workerdef get_param_names(fn, n_args):
160*da0073e9SAndroid Build Coastguard Worker    if isinstance(fn, OpOverloadPacket):
161*da0073e9SAndroid Build Coastguard Worker        fn = fn.op
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    if (
164*da0073e9SAndroid Build Coastguard Worker        not is_function_or_method(fn)
165*da0073e9SAndroid Build Coastguard Worker        and callable(fn)
166*da0073e9SAndroid Build Coastguard Worker        and is_function_or_method(fn.__call__)
167*da0073e9SAndroid Build Coastguard Worker    ):  # noqa: B004
168*da0073e9SAndroid Build Coastguard Worker        # De-sugar calls to classes
169*da0073e9SAndroid Build Coastguard Worker        fn = fn.__call__
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker    if is_function_or_method(fn):
172*da0073e9SAndroid Build Coastguard Worker        if is_ignored_fn(fn):
173*da0073e9SAndroid Build Coastguard Worker            fn = inspect.unwrap(fn)
174*da0073e9SAndroid Build Coastguard Worker        return inspect.getfullargspec(fn).args
175*da0073e9SAndroid Build Coastguard Worker    else:
176*da0073e9SAndroid Build Coastguard Worker        # The `fn` was not a method or function (maybe a class with a __call__
177*da0073e9SAndroid Build Coastguard Worker        # method, so use a default param name list)
178*da0073e9SAndroid Build Coastguard Worker        return [str(i) for i in range(n_args)]
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Workerdef check_fn(fn, loc):
182*da0073e9SAndroid Build Coastguard Worker    # Make sure the function definition is not a class instantiation
183*da0073e9SAndroid Build Coastguard Worker    try:
184*da0073e9SAndroid Build Coastguard Worker        source = dedent("".join(get_source_lines_and_file(fn)[0]))
185*da0073e9SAndroid Build Coastguard Worker    except (OSError, TypeError):
186*da0073e9SAndroid Build Coastguard Worker        return
187*da0073e9SAndroid Build Coastguard Worker    if source is None:
188*da0073e9SAndroid Build Coastguard Worker        return
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker    py_ast = ast.parse(source)
191*da0073e9SAndroid Build Coastguard Worker    if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef):
192*da0073e9SAndroid Build Coastguard Worker        raise torch.jit.frontend.FrontendError(
193*da0073e9SAndroid Build Coastguard Worker            loc,
194*da0073e9SAndroid Build Coastguard Worker            f"Cannot instantiate class '{py_ast.body[0].name}' in a script function",
195*da0073e9SAndroid Build Coastguard Worker        )
196*da0073e9SAndroid Build Coastguard Worker    if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef):
197*da0073e9SAndroid Build Coastguard Worker        raise torch.jit.frontend.FrontendError(
198*da0073e9SAndroid Build Coastguard Worker            loc, "Expected a single top-level function"
199*da0073e9SAndroid Build Coastguard Worker        )
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Workerdef _eval_no_call(stmt, glob, loc):
203*da0073e9SAndroid Build Coastguard Worker    """Evaluate statement as long as it does not contain any method/function calls."""
204*da0073e9SAndroid Build Coastguard Worker    bytecode = compile(stmt, "", mode="eval")
205*da0073e9SAndroid Build Coastguard Worker    for insn in dis.get_instructions(bytecode):
206*da0073e9SAndroid Build Coastguard Worker        if "CALL" in insn.opname:
207*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
208*da0073e9SAndroid Build Coastguard Worker                f"Type annotation should not contain calls, but '{stmt}' does"
209*da0073e9SAndroid Build Coastguard Worker            )
210*da0073e9SAndroid Build Coastguard Worker    return eval(bytecode, glob, loc)  # type: ignore[arg-type] # noqa: P204
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Workerdef parse_type_line(type_line, rcb, loc):
214*da0073e9SAndroid Build Coastguard Worker    """Parse a type annotation specified as a comment.
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker    Example inputs:
217*da0073e9SAndroid Build Coastguard Worker        # type: (Tensor, torch.Tensor) -> Tuple[Tensor]
218*da0073e9SAndroid Build Coastguard Worker        # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor
219*da0073e9SAndroid Build Coastguard Worker    """
220*da0073e9SAndroid Build Coastguard Worker    arg_ann_str, ret_ann_str = split_type_line(type_line)
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    try:
223*da0073e9SAndroid Build Coastguard Worker        arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb))
224*da0073e9SAndroid Build Coastguard Worker    except (NameError, SyntaxError) as e:
225*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
226*da0073e9SAndroid Build Coastguard Worker            "Failed to parse the argument list of a type annotation"
227*da0073e9SAndroid Build Coastguard Worker        ) from e
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker    if not isinstance(arg_ann, tuple):
230*da0073e9SAndroid Build Coastguard Worker        arg_ann = (arg_ann,)
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    try:
233*da0073e9SAndroid Build Coastguard Worker        ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb))
234*da0073e9SAndroid Build Coastguard Worker    except (NameError, SyntaxError) as e:
235*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
236*da0073e9SAndroid Build Coastguard Worker            "Failed to parse the return type of a type annotation"
237*da0073e9SAndroid Build Coastguard Worker        ) from e
238*da0073e9SAndroid Build Coastguard Worker
239*da0073e9SAndroid Build Coastguard Worker    arg_types = [ann_to_type(ann, loc) for ann in arg_ann]
240*da0073e9SAndroid Build Coastguard Worker    return arg_types, ann_to_type(ret_ann, loc)
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Workerdef get_type_line(source):
244*da0073e9SAndroid Build Coastguard Worker    """Try to find the line containing a comment with the type annotation."""
245*da0073e9SAndroid Build Coastguard Worker    type_comment = "# type:"
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    lines = source.split("\n")
248*da0073e9SAndroid Build Coastguard Worker    lines = list(enumerate(lines))
249*da0073e9SAndroid Build Coastguard Worker    type_lines = list(filter(lambda line: type_comment in line[1], lines))
250*da0073e9SAndroid Build Coastguard Worker    # `type: ignore` comments may be needed in JIT'ed functions for mypy, due
251*da0073e9SAndroid Build Coastguard Worker    # to the hack in torch/_VF.py.
252*da0073e9SAndroid Build Coastguard Worker
253*da0073e9SAndroid Build Coastguard Worker    # An ignore type comment can be of following format:
254*da0073e9SAndroid Build Coastguard Worker    #   1) type: ignore
255*da0073e9SAndroid Build Coastguard Worker    #   2) type: ignore[rule-code]
256*da0073e9SAndroid Build Coastguard Worker    # This ignore statement must be at the end of the line
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    # adding an extra backslash before the space, to avoid triggering
259*da0073e9SAndroid Build Coastguard Worker    # one of the checks in .github/workflows/lint.yml
260*da0073e9SAndroid Build Coastguard Worker    type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$")
261*da0073e9SAndroid Build Coastguard Worker    type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines))
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker    if len(type_lines) == 0:
264*da0073e9SAndroid Build Coastguard Worker        # Catch common typo patterns like extra spaces, typo in 'ignore', etc.
265*da0073e9SAndroid Build Coastguard Worker        wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
266*da0073e9SAndroid Build Coastguard Worker        wrong_type_lines = list(
267*da0073e9SAndroid Build Coastguard Worker            filter(lambda line: wrong_type_pattern.search(line[1]), lines)
268*da0073e9SAndroid Build Coastguard Worker        )
269*da0073e9SAndroid Build Coastguard Worker        if len(wrong_type_lines) > 0:
270*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
271*da0073e9SAndroid Build Coastguard Worker                "The annotation prefix in line "
272*da0073e9SAndroid Build Coastguard Worker                + str(wrong_type_lines[0][0])
273*da0073e9SAndroid Build Coastguard Worker                + " is probably invalid.\nIt must be '# type:'"
274*da0073e9SAndroid Build Coastguard Worker                + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"  # noqa: B950
275*da0073e9SAndroid Build Coastguard Worker                + "\nfor examples"
276*da0073e9SAndroid Build Coastguard Worker            )
277*da0073e9SAndroid Build Coastguard Worker        return None
278*da0073e9SAndroid Build Coastguard Worker    elif len(type_lines) == 1:
279*da0073e9SAndroid Build Coastguard Worker        # Only 1 type line, quit now
280*da0073e9SAndroid Build Coastguard Worker        return type_lines[0][1].strip()
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker    # Parse split up argument types according to PEP 484
283*da0073e9SAndroid Build Coastguard Worker    # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code
284*da0073e9SAndroid Build Coastguard Worker    return_line = None
285*da0073e9SAndroid Build Coastguard Worker    parameter_type_lines = []
286*da0073e9SAndroid Build Coastguard Worker    for line_num, line in type_lines:
287*da0073e9SAndroid Build Coastguard Worker        if "# type: (...) -> " in line:
288*da0073e9SAndroid Build Coastguard Worker            return_line = (line_num, line)
289*da0073e9SAndroid Build Coastguard Worker            break
290*da0073e9SAndroid Build Coastguard Worker        elif type_comment in line:
291*da0073e9SAndroid Build Coastguard Worker            parameter_type_lines.append(line)
292*da0073e9SAndroid Build Coastguard Worker    if return_line is None:
293*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
294*da0073e9SAndroid Build Coastguard Worker            "Return type line '# type: (...) -> ...' not found on multiline "
295*da0073e9SAndroid Build Coastguard Worker            "type annotation\nfor type lines:\n"
296*da0073e9SAndroid Build Coastguard Worker            + "\n".join([line[1] for line in type_lines])
297*da0073e9SAndroid Build Coastguard Worker            + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)"
298*da0073e9SAndroid Build Coastguard Worker        )
299*da0073e9SAndroid Build Coastguard Worker
300*da0073e9SAndroid Build Coastguard Worker    def get_parameter_type(line):
301*da0073e9SAndroid Build Coastguard Worker        item_type = line[line.find(type_comment) + len(type_comment) :]
302*da0073e9SAndroid Build Coastguard Worker        return item_type.strip()
303*da0073e9SAndroid Build Coastguard Worker
304*da0073e9SAndroid Build Coastguard Worker    types = map(get_parameter_type, parameter_type_lines)
305*da0073e9SAndroid Build Coastguard Worker    parameter_types = ", ".join(types)
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker    return return_line[1].replace("...", parameter_types)
308*da0073e9SAndroid Build Coastguard Worker
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Workerdef split_type_line(type_line):
311*da0073e9SAndroid Build Coastguard Worker    """Split the comment with the type annotation into parts for argument and return types.
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    For example, for an input of:
314*da0073e9SAndroid Build Coastguard Worker        # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor]
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker    This function will return:
317*da0073e9SAndroid Build Coastguard Worker        ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]")
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker    """
320*da0073e9SAndroid Build Coastguard Worker    start_offset = len("# type:")
321*da0073e9SAndroid Build Coastguard Worker    try:
322*da0073e9SAndroid Build Coastguard Worker        arrow_pos = type_line.index("->")
323*da0073e9SAndroid Build Coastguard Worker    except ValueError:
324*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
325*da0073e9SAndroid Build Coastguard Worker            "Syntax error in type annotation (couldn't find `->`)"
326*da0073e9SAndroid Build Coastguard Worker        ) from None
327*da0073e9SAndroid Build Coastguard Worker    return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip()
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Workerdef try_real_annotations(fn, loc):
331*da0073e9SAndroid Build Coastguard Worker    """Try to use the Py3.5+ annotation syntax to get the type."""
332*da0073e9SAndroid Build Coastguard Worker    try:
333*da0073e9SAndroid Build Coastguard Worker        # Note: anything annotated as `Optional[T]` will automatically
334*da0073e9SAndroid Build Coastguard Worker        # be returned as `Union[T, None]` per
335*da0073e9SAndroid Build Coastguard Worker        # https://github.com/python/typing/blob/master/src/typing.py#L850
336*da0073e9SAndroid Build Coastguard Worker        sig = inspect.signature(fn)
337*da0073e9SAndroid Build Coastguard Worker    except ValueError:
338*da0073e9SAndroid Build Coastguard Worker        return None
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker    all_annots = [sig.return_annotation] + [
341*da0073e9SAndroid Build Coastguard Worker        p.annotation for p in sig.parameters.values()
342*da0073e9SAndroid Build Coastguard Worker    ]
343*da0073e9SAndroid Build Coastguard Worker    if all(ann is sig.empty for ann in all_annots):
344*da0073e9SAndroid Build Coastguard Worker        return None
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()]
347*da0073e9SAndroid Build Coastguard Worker    return_type = ann_to_type(sig.return_annotation, loc)
348*da0073e9SAndroid Build Coastguard Worker    return arg_types, return_type
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker# Finds common type for enum values belonging to an Enum class. If not all
352*da0073e9SAndroid Build Coastguard Worker# values have the same type, AnyType is returned.
353*da0073e9SAndroid Build Coastguard Workerdef get_enum_value_type(e: Type[enum.Enum], loc):
354*da0073e9SAndroid Build Coastguard Worker    enum_values: List[enum.Enum] = list(e)
355*da0073e9SAndroid Build Coastguard Worker    if not enum_values:
356*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"No enum values defined for: '{e.__class__}'")
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    types = {type(v.value) for v in enum_values}
359*da0073e9SAndroid Build Coastguard Worker    ir_types = [try_ann_to_type(t, loc) for t in types]
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    # If Enum values are of different types, an exception will be raised here.
362*da0073e9SAndroid Build Coastguard Worker    # Even though Python supports this case, we chose to not implement it to
363*da0073e9SAndroid Build Coastguard Worker    # avoid overcomplicate logic here for a rare use case. Please report a
364*da0073e9SAndroid Build Coastguard Worker    # feature request if you find it necessary.
365*da0073e9SAndroid Build Coastguard Worker    res = torch._C.unify_type_list(ir_types)
366*da0073e9SAndroid Build Coastguard Worker    if not res:
367*da0073e9SAndroid Build Coastguard Worker        return AnyType.get()
368*da0073e9SAndroid Build Coastguard Worker    return res
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Workerdef is_tensor(ann):
372*da0073e9SAndroid Build Coastguard Worker    if issubclass(ann, torch.Tensor):
373*da0073e9SAndroid Build Coastguard Worker        return True
374*da0073e9SAndroid Build Coastguard Worker
375*da0073e9SAndroid Build Coastguard Worker    if issubclass(
376*da0073e9SAndroid Build Coastguard Worker        ann,
377*da0073e9SAndroid Build Coastguard Worker        (
378*da0073e9SAndroid Build Coastguard Worker            torch.LongTensor,
379*da0073e9SAndroid Build Coastguard Worker            torch.DoubleTensor,
380*da0073e9SAndroid Build Coastguard Worker            torch.FloatTensor,
381*da0073e9SAndroid Build Coastguard Worker            torch.IntTensor,
382*da0073e9SAndroid Build Coastguard Worker            torch.ShortTensor,
383*da0073e9SAndroid Build Coastguard Worker            torch.HalfTensor,
384*da0073e9SAndroid Build Coastguard Worker            torch.CharTensor,
385*da0073e9SAndroid Build Coastguard Worker            torch.ByteTensor,
386*da0073e9SAndroid Build Coastguard Worker            torch.BoolTensor,
387*da0073e9SAndroid Build Coastguard Worker        ),
388*da0073e9SAndroid Build Coastguard Worker    ):
389*da0073e9SAndroid Build Coastguard Worker        warnings.warn(
390*da0073e9SAndroid Build Coastguard Worker            "TorchScript will treat type annotations of Tensor "
391*da0073e9SAndroid Build Coastguard Worker            "dtype-specific subtypes as if they are normal Tensors. "
392*da0073e9SAndroid Build Coastguard Worker            "dtype constraints are not enforced in compilation either."
393*da0073e9SAndroid Build Coastguard Worker        )
394*da0073e9SAndroid Build Coastguard Worker        return True
395*da0073e9SAndroid Build Coastguard Worker
396*da0073e9SAndroid Build Coastguard Worker    return False
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Workerdef _fake_rcb(inp):
400*da0073e9SAndroid Build Coastguard Worker    return None
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker
403*da0073e9SAndroid Build Coastguard Workerdef try_ann_to_type(ann, loc, rcb=None):
404*da0073e9SAndroid Build Coastguard Worker    ann_args = typing.get_args(ann)  # always returns a tuple!
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    if ann is inspect.Signature.empty:
407*da0073e9SAndroid Build Coastguard Worker        return TensorType.getInferred()
408*da0073e9SAndroid Build Coastguard Worker    if ann is None:
409*da0073e9SAndroid Build Coastguard Worker        return NoneType.get()
410*da0073e9SAndroid Build Coastguard Worker    if inspect.isclass(ann) and is_tensor(ann):
411*da0073e9SAndroid Build Coastguard Worker        return TensorType.get()
412*da0073e9SAndroid Build Coastguard Worker    if is_tuple(ann):
413*da0073e9SAndroid Build Coastguard Worker        # Special case for the empty Tuple type annotation `Tuple[()]`
414*da0073e9SAndroid Build Coastguard Worker        if len(ann_args) == 1 and ann_args[0] == ():
415*da0073e9SAndroid Build Coastguard Worker            return TupleType([])
416*da0073e9SAndroid Build Coastguard Worker        return TupleType([try_ann_to_type(a, loc) for a in ann_args])
417*da0073e9SAndroid Build Coastguard Worker    if is_list(ann):
418*da0073e9SAndroid Build Coastguard Worker        elem_type = try_ann_to_type(ann_args[0], loc)
419*da0073e9SAndroid Build Coastguard Worker        if elem_type:
420*da0073e9SAndroid Build Coastguard Worker            return ListType(elem_type)
421*da0073e9SAndroid Build Coastguard Worker    if is_dict(ann):
422*da0073e9SAndroid Build Coastguard Worker        key = try_ann_to_type(ann_args[0], loc)
423*da0073e9SAndroid Build Coastguard Worker        value = try_ann_to_type(ann_args[1], loc)
424*da0073e9SAndroid Build Coastguard Worker        # Raise error if key or value is None
425*da0073e9SAndroid Build Coastguard Worker        if key is None:
426*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
427*da0073e9SAndroid Build Coastguard Worker                f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}"
428*da0073e9SAndroid Build Coastguard Worker            )
429*da0073e9SAndroid Build Coastguard Worker        if value is None:
430*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
431*da0073e9SAndroid Build Coastguard Worker                f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}"
432*da0073e9SAndroid Build Coastguard Worker            )
433*da0073e9SAndroid Build Coastguard Worker        return DictType(key, value)
434*da0073e9SAndroid Build Coastguard Worker    if is_optional(ann):
435*da0073e9SAndroid Build Coastguard Worker        if issubclass(ann_args[1], type(None)):
436*da0073e9SAndroid Build Coastguard Worker            contained = ann_args[0]
437*da0073e9SAndroid Build Coastguard Worker        else:
438*da0073e9SAndroid Build Coastguard Worker            contained = ann_args[1]
439*da0073e9SAndroid Build Coastguard Worker        valid_type = try_ann_to_type(contained, loc)
440*da0073e9SAndroid Build Coastguard Worker        msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
441*da0073e9SAndroid Build Coastguard Worker        assert valid_type, msg.format(repr(ann), repr(contained), repr(loc))
442*da0073e9SAndroid Build Coastguard Worker        return OptionalType(valid_type)
443*da0073e9SAndroid Build Coastguard Worker    if is_union(ann):
444*da0073e9SAndroid Build Coastguard Worker        # TODO: this is hack to recognize NumberType
445*da0073e9SAndroid Build Coastguard Worker        if set(ann_args) == {int, float, complex}:
446*da0073e9SAndroid Build Coastguard Worker            return NumberType.get()
447*da0073e9SAndroid Build Coastguard Worker        inner: List = []
448*da0073e9SAndroid Build Coastguard Worker        # We need these extra checks because both `None` and invalid
449*da0073e9SAndroid Build Coastguard Worker        # values will return `None`
450*da0073e9SAndroid Build Coastguard Worker        # TODO: Determine if the other cases need to be fixed as well
451*da0073e9SAndroid Build Coastguard Worker        for a in typing.get_args(ann):
452*da0073e9SAndroid Build Coastguard Worker            if a is None:
453*da0073e9SAndroid Build Coastguard Worker                inner.append(NoneType.get())
454*da0073e9SAndroid Build Coastguard Worker            maybe_type = try_ann_to_type(a, loc)
455*da0073e9SAndroid Build Coastguard Worker            msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}"
456*da0073e9SAndroid Build Coastguard Worker            assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc))
457*da0073e9SAndroid Build Coastguard Worker            inner.append(maybe_type)
458*da0073e9SAndroid Build Coastguard Worker        return UnionType(inner)  # type: ignore[arg-type]
459*da0073e9SAndroid Build Coastguard Worker    if torch.distributed.rpc.is_available() and is_rref(ann):
460*da0073e9SAndroid Build Coastguard Worker        return RRefType(try_ann_to_type(ann_args[0], loc))
461*da0073e9SAndroid Build Coastguard Worker    if is_future(ann):
462*da0073e9SAndroid Build Coastguard Worker        return FutureType(try_ann_to_type(ann_args[0], loc))
463*da0073e9SAndroid Build Coastguard Worker    if is_await(ann):
464*da0073e9SAndroid Build Coastguard Worker        elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get()
465*da0073e9SAndroid Build Coastguard Worker        return AwaitType(elementType)
466*da0073e9SAndroid Build Coastguard Worker    if ann is float:
467*da0073e9SAndroid Build Coastguard Worker        return FloatType.get()
468*da0073e9SAndroid Build Coastguard Worker    if ann is complex:
469*da0073e9SAndroid Build Coastguard Worker        return ComplexType.get()
470*da0073e9SAndroid Build Coastguard Worker    if ann is int or ann is torch.SymInt:
471*da0073e9SAndroid Build Coastguard Worker        return IntType.get()
472*da0073e9SAndroid Build Coastguard Worker    if ann is str:
473*da0073e9SAndroid Build Coastguard Worker        return StringType.get()
474*da0073e9SAndroid Build Coastguard Worker    if ann is bool:
475*da0073e9SAndroid Build Coastguard Worker        return BoolType.get()
476*da0073e9SAndroid Build Coastguard Worker    if ann is Any:
477*da0073e9SAndroid Build Coastguard Worker        return AnyType.get()
478*da0073e9SAndroid Build Coastguard Worker    if ann is type(None):
479*da0073e9SAndroid Build Coastguard Worker        return NoneType.get()
480*da0073e9SAndroid Build Coastguard Worker    if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"):
481*da0073e9SAndroid Build Coastguard Worker        return InterfaceType(ann.__torch_script_interface__)
482*da0073e9SAndroid Build Coastguard Worker    if ann is torch.device:
483*da0073e9SAndroid Build Coastguard Worker        return DeviceObjType.get()
484*da0073e9SAndroid Build Coastguard Worker    if ann is torch.Generator:
485*da0073e9SAndroid Build Coastguard Worker        return _GeneratorType.get()
486*da0073e9SAndroid Build Coastguard Worker    if ann is torch.Stream:
487*da0073e9SAndroid Build Coastguard Worker        return StreamObjType.get()
488*da0073e9SAndroid Build Coastguard Worker    if ann is torch.dtype:
489*da0073e9SAndroid Build Coastguard Worker        return IntType.get()  # dtype not yet bound in as its own type
490*da0073e9SAndroid Build Coastguard Worker    if inspect.isclass(ann) and issubclass(ann, enum.Enum):
491*da0073e9SAndroid Build Coastguard Worker        if _get_script_class(ann) is None:
492*da0073e9SAndroid Build Coastguard Worker            scripted_class = torch.jit._script._recursive_compile_class(ann, loc)
493*da0073e9SAndroid Build Coastguard Worker            name = scripted_class.qualified_name()
494*da0073e9SAndroid Build Coastguard Worker        else:
495*da0073e9SAndroid Build Coastguard Worker            name = _qualified_name(ann)
496*da0073e9SAndroid Build Coastguard Worker        return EnumType(name, get_enum_value_type(ann, loc), list(ann))
497*da0073e9SAndroid Build Coastguard Worker    if inspect.isclass(ann):
498*da0073e9SAndroid Build Coastguard Worker        maybe_script_class = _get_script_class(ann)
499*da0073e9SAndroid Build Coastguard Worker        if maybe_script_class is not None:
500*da0073e9SAndroid Build Coastguard Worker            return maybe_script_class
501*da0073e9SAndroid Build Coastguard Worker        if torch._jit_internal.can_compile_class(ann):
502*da0073e9SAndroid Build Coastguard Worker            return torch.jit._script._recursive_compile_class(ann, loc)
503*da0073e9SAndroid Build Coastguard Worker
504*da0073e9SAndroid Build Coastguard Worker    # Maybe resolve a NamedTuple to a Tuple Type
505*da0073e9SAndroid Build Coastguard Worker    if rcb is None:
506*da0073e9SAndroid Build Coastguard Worker        rcb = _fake_rcb
507*da0073e9SAndroid Build Coastguard Worker    return torch._C._resolve_type_from_object(ann, loc, rcb)
508*da0073e9SAndroid Build Coastguard Worker
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Workerdef ann_to_type(ann, loc, rcb=None):
511*da0073e9SAndroid Build Coastguard Worker    the_type = try_ann_to_type(ann, loc, rcb)
512*da0073e9SAndroid Build Coastguard Worker    if the_type is not None:
513*da0073e9SAndroid Build Coastguard Worker        return the_type
514*da0073e9SAndroid Build Coastguard Worker    raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}")
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker
517*da0073e9SAndroid Build Coastguard Worker__all__ = [
518*da0073e9SAndroid Build Coastguard Worker    "Any",
519*da0073e9SAndroid Build Coastguard Worker    "List",
520*da0073e9SAndroid Build Coastguard Worker    "BroadcastingList1",
521*da0073e9SAndroid Build Coastguard Worker    "BroadcastingList2",
522*da0073e9SAndroid Build Coastguard Worker    "BroadcastingList3",
523*da0073e9SAndroid Build Coastguard Worker    "Tuple",
524*da0073e9SAndroid Build Coastguard Worker    "is_tuple",
525*da0073e9SAndroid Build Coastguard Worker    "is_list",
526*da0073e9SAndroid Build Coastguard Worker    "Dict",
527*da0073e9SAndroid Build Coastguard Worker    "is_dict",
528*da0073e9SAndroid Build Coastguard Worker    "is_optional",
529*da0073e9SAndroid Build Coastguard Worker    "is_union",
530*da0073e9SAndroid Build Coastguard Worker    "TensorType",
531*da0073e9SAndroid Build Coastguard Worker    "TupleType",
532*da0073e9SAndroid Build Coastguard Worker    "FloatType",
533*da0073e9SAndroid Build Coastguard Worker    "ComplexType",
534*da0073e9SAndroid Build Coastguard Worker    "IntType",
535*da0073e9SAndroid Build Coastguard Worker    "ListType",
536*da0073e9SAndroid Build Coastguard Worker    "StringType",
537*da0073e9SAndroid Build Coastguard Worker    "DictType",
538*da0073e9SAndroid Build Coastguard Worker    "AnyType",
539*da0073e9SAndroid Build Coastguard Worker    "Module",
540*da0073e9SAndroid Build Coastguard Worker    # TODO: Consider not exporting these during wildcard import (reserve
541*da0073e9SAndroid Build Coastguard Worker    # that for the types; for idiomatic typing code.)
542*da0073e9SAndroid Build Coastguard Worker    "get_signature",
543*da0073e9SAndroid Build Coastguard Worker    "check_fn",
544*da0073e9SAndroid Build Coastguard Worker    "get_param_names",
545*da0073e9SAndroid Build Coastguard Worker    "parse_type_line",
546*da0073e9SAndroid Build Coastguard Worker    "get_type_line",
547*da0073e9SAndroid Build Coastguard Worker    "split_type_line",
548*da0073e9SAndroid Build Coastguard Worker    "try_real_annotations",
549*da0073e9SAndroid Build Coastguard Worker    "try_ann_to_type",
550*da0073e9SAndroid Build Coastguard Worker    "ann_to_type",
551*da0073e9SAndroid Build Coastguard Worker]
552