1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport ast 3*da0073e9SAndroid Build Coastguard Workerimport builtins 4*da0073e9SAndroid Build Coastguard Workerimport dis 5*da0073e9SAndroid Build Coastguard Workerimport enum 6*da0073e9SAndroid Build Coastguard Workerimport inspect 7*da0073e9SAndroid Build Coastguard Workerimport re 8*da0073e9SAndroid Build Coastguard Workerimport typing 9*da0073e9SAndroid Build Coastguard Workerimport warnings 10*da0073e9SAndroid Build Coastguard Workerfrom textwrap import dedent 11*da0073e9SAndroid Build Coastguard Workerfrom typing import Type 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Workerimport torch 14*da0073e9SAndroid Build Coastguard Workerfrom torch._C import ( 15*da0073e9SAndroid Build Coastguard Worker _GeneratorType, 16*da0073e9SAndroid Build Coastguard Worker AnyType, 17*da0073e9SAndroid Build Coastguard Worker AwaitType, 18*da0073e9SAndroid Build Coastguard Worker BoolType, 19*da0073e9SAndroid Build Coastguard Worker ComplexType, 20*da0073e9SAndroid Build Coastguard Worker DeviceObjType, 21*da0073e9SAndroid Build Coastguard Worker DictType, 22*da0073e9SAndroid Build Coastguard Worker EnumType, 23*da0073e9SAndroid Build Coastguard Worker FloatType, 24*da0073e9SAndroid Build Coastguard Worker FutureType, 25*da0073e9SAndroid Build Coastguard Worker InterfaceType, 26*da0073e9SAndroid Build Coastguard Worker IntType, 27*da0073e9SAndroid Build Coastguard Worker ListType, 28*da0073e9SAndroid Build Coastguard Worker NoneType, 29*da0073e9SAndroid Build Coastguard Worker NumberType, 30*da0073e9SAndroid Build Coastguard Worker OptionalType, 31*da0073e9SAndroid Build Coastguard Worker StreamObjType, 32*da0073e9SAndroid Build Coastguard Worker StringType, 33*da0073e9SAndroid Build Coastguard Worker TensorType, 34*da0073e9SAndroid Build Coastguard Worker TupleType, 35*da0073e9SAndroid Build Coastguard Worker UnionType, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Workerfrom torch._jit_internal import ( # type: ignore[attr-defined] 38*da0073e9SAndroid Build Coastguard Worker _Await, 39*da0073e9SAndroid Build Coastguard Worker _qualified_name, 40*da0073e9SAndroid Build Coastguard Worker Any, 41*da0073e9SAndroid Build Coastguard Worker BroadcastingList1, 42*da0073e9SAndroid Build Coastguard Worker BroadcastingList2, 43*da0073e9SAndroid Build Coastguard Worker BroadcastingList3, 44*da0073e9SAndroid Build Coastguard Worker Dict, 45*da0073e9SAndroid Build Coastguard Worker Future, 46*da0073e9SAndroid Build Coastguard Worker is_await, 47*da0073e9SAndroid Build Coastguard Worker is_dict, 48*da0073e9SAndroid Build Coastguard Worker is_future, 49*da0073e9SAndroid Build Coastguard Worker is_ignored_fn, 50*da0073e9SAndroid Build Coastguard Worker is_list, 51*da0073e9SAndroid Build Coastguard Worker is_optional, 52*da0073e9SAndroid Build Coastguard Worker is_tuple, 53*da0073e9SAndroid Build Coastguard Worker is_union, 54*da0073e9SAndroid Build Coastguard Worker List, 55*da0073e9SAndroid Build Coastguard Worker Optional, 56*da0073e9SAndroid Build Coastguard Worker Tuple, 57*da0073e9SAndroid Build Coastguard Worker Union, 58*da0073e9SAndroid Build Coastguard Worker) 59*da0073e9SAndroid Build Coastguard Workerfrom torch._sources import get_source_lines_and_file 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerfrom ._state import _get_script_class 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Workerif torch.distributed.rpc.is_available(): 65*da0073e9SAndroid Build Coastguard Worker from torch._C import RRefType 66*da0073e9SAndroid Build Coastguard Worker from torch._jit_internal import is_rref, RRef 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import OpOverloadPacket 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerclass Module: 72*da0073e9SAndroid Build Coastguard Worker def __init__(self, name, members): 73*da0073e9SAndroid Build Coastguard Worker self.name = name 74*da0073e9SAndroid Build Coastguard Worker self.members = members 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def __getattr__(self, name): 77*da0073e9SAndroid Build Coastguard Worker try: 78*da0073e9SAndroid Build Coastguard Worker return self.members[name] 79*da0073e9SAndroid Build Coastguard Worker except KeyError: 80*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 81*da0073e9SAndroid Build Coastguard Worker f"Module {self.name} has no member called {name}" 82*da0073e9SAndroid Build Coastguard Worker ) from None 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Workerclass EvalEnv: 86*da0073e9SAndroid Build Coastguard Worker env = { 87*da0073e9SAndroid Build Coastguard Worker "torch": Module("torch", {"Tensor": torch.Tensor}), 88*da0073e9SAndroid Build Coastguard Worker "Tensor": torch.Tensor, 89*da0073e9SAndroid Build Coastguard Worker "typing": Module("typing", {"Tuple": Tuple}), 90*da0073e9SAndroid Build Coastguard Worker "Tuple": Tuple, 91*da0073e9SAndroid Build Coastguard Worker "List": List, 92*da0073e9SAndroid Build Coastguard Worker "Dict": Dict, 93*da0073e9SAndroid Build Coastguard Worker "Optional": Optional, 94*da0073e9SAndroid Build Coastguard Worker "Union": Union, 95*da0073e9SAndroid Build Coastguard Worker "Future": Future, 96*da0073e9SAndroid Build Coastguard Worker "Await": _Await, 97*da0073e9SAndroid Build Coastguard Worker } 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker def __init__(self, rcb): 100*da0073e9SAndroid Build Coastguard Worker self.rcb = rcb 101*da0073e9SAndroid Build Coastguard Worker if torch.distributed.rpc.is_available(): 102*da0073e9SAndroid Build Coastguard Worker self.env["RRef"] = RRef 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, name): 105*da0073e9SAndroid Build Coastguard Worker if name in self.env: 106*da0073e9SAndroid Build Coastguard Worker return self.env[name] 107*da0073e9SAndroid Build Coastguard Worker if self.rcb is not None: 108*da0073e9SAndroid Build Coastguard Worker return self.rcb(name) 109*da0073e9SAndroid Build Coastguard Worker return getattr(builtins, name, None) 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Workerdef get_signature(fn, rcb, loc, is_method): 113*da0073e9SAndroid Build Coastguard Worker if isinstance(fn, OpOverloadPacket): 114*da0073e9SAndroid Build Coastguard Worker signature = try_real_annotations(fn.op, loc) 115*da0073e9SAndroid Build Coastguard Worker else: 116*da0073e9SAndroid Build Coastguard Worker signature = try_real_annotations(fn, loc) 117*da0073e9SAndroid Build Coastguard Worker if signature is not None and is_method: 118*da0073e9SAndroid Build Coastguard Worker # If this is a method, then the signature will include a type for 119*da0073e9SAndroid Build Coastguard Worker # `self`, but type comments do not contain a `self`. So strip it 120*da0073e9SAndroid Build Coastguard Worker # away here so everything is consistent (`inspect.ismethod` does 121*da0073e9SAndroid Build Coastguard Worker # not work here since `fn` is unbound at this point) 122*da0073e9SAndroid Build Coastguard Worker param_types, return_type = signature 123*da0073e9SAndroid Build Coastguard Worker param_types = param_types[1:] 124*da0073e9SAndroid Build Coastguard Worker signature = (param_types, return_type) 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker if signature is None: 127*da0073e9SAndroid Build Coastguard Worker type_line, source = None, None 128*da0073e9SAndroid Build Coastguard Worker try: 129*da0073e9SAndroid Build Coastguard Worker source = dedent("".join(get_source_lines_and_file(fn)[0])) 130*da0073e9SAndroid Build Coastguard Worker type_line = get_type_line(source) 131*da0073e9SAndroid Build Coastguard Worker except TypeError: 132*da0073e9SAndroid Build Coastguard Worker pass 133*da0073e9SAndroid Build Coastguard Worker # This might happen both because we failed to get the source of fn, or 134*da0073e9SAndroid Build Coastguard Worker # because it didn't have any annotations. 135*da0073e9SAndroid Build Coastguard Worker if type_line is not None: 136*da0073e9SAndroid Build Coastguard Worker signature = parse_type_line(type_line, rcb, loc) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker return signature 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Workerdef is_function_or_method(the_callable): 142*da0073e9SAndroid Build Coastguard Worker # A stricter version of `inspect.isroutine` that does not pass for built-in 143*da0073e9SAndroid Build Coastguard Worker # functions 144*da0073e9SAndroid Build Coastguard Worker return inspect.isfunction(the_callable) or inspect.ismethod(the_callable) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Workerdef is_vararg(the_callable): 148*da0073e9SAndroid Build Coastguard Worker if not is_function_or_method(the_callable) and callable(the_callable): # noqa: B004 149*da0073e9SAndroid Build Coastguard Worker # If `the_callable` is a class, de-sugar the call so we can still get 150*da0073e9SAndroid Build Coastguard Worker # the signature 151*da0073e9SAndroid Build Coastguard Worker the_callable = the_callable.__call__ 152*da0073e9SAndroid Build Coastguard Worker 153*da0073e9SAndroid Build Coastguard Worker if is_function_or_method(the_callable): 154*da0073e9SAndroid Build Coastguard Worker return inspect.getfullargspec(the_callable).varargs is not None 155*da0073e9SAndroid Build Coastguard Worker else: 156*da0073e9SAndroid Build Coastguard Worker return False 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker 159*da0073e9SAndroid Build Coastguard Workerdef get_param_names(fn, n_args): 160*da0073e9SAndroid Build Coastguard Worker if isinstance(fn, OpOverloadPacket): 161*da0073e9SAndroid Build Coastguard Worker fn = fn.op 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker if ( 164*da0073e9SAndroid Build Coastguard Worker not is_function_or_method(fn) 165*da0073e9SAndroid Build Coastguard Worker and callable(fn) 166*da0073e9SAndroid Build Coastguard Worker and is_function_or_method(fn.__call__) 167*da0073e9SAndroid Build Coastguard Worker ): # noqa: B004 168*da0073e9SAndroid Build Coastguard Worker # De-sugar calls to classes 169*da0073e9SAndroid Build Coastguard Worker fn = fn.__call__ 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker if is_function_or_method(fn): 172*da0073e9SAndroid Build Coastguard Worker if is_ignored_fn(fn): 173*da0073e9SAndroid Build Coastguard Worker fn = inspect.unwrap(fn) 174*da0073e9SAndroid Build Coastguard Worker return inspect.getfullargspec(fn).args 175*da0073e9SAndroid Build Coastguard Worker else: 176*da0073e9SAndroid Build Coastguard Worker # The `fn` was not a method or function (maybe a class with a __call__ 177*da0073e9SAndroid Build Coastguard Worker # method, so use a default param name list) 178*da0073e9SAndroid Build Coastguard Worker return [str(i) for i in range(n_args)] 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Workerdef check_fn(fn, loc): 182*da0073e9SAndroid Build Coastguard Worker # Make sure the function definition is not a class instantiation 183*da0073e9SAndroid Build Coastguard Worker try: 184*da0073e9SAndroid Build Coastguard Worker source = dedent("".join(get_source_lines_and_file(fn)[0])) 185*da0073e9SAndroid Build Coastguard Worker except (OSError, TypeError): 186*da0073e9SAndroid Build Coastguard Worker return 187*da0073e9SAndroid Build Coastguard Worker if source is None: 188*da0073e9SAndroid Build Coastguard Worker return 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker py_ast = ast.parse(source) 191*da0073e9SAndroid Build Coastguard Worker if len(py_ast.body) == 1 and isinstance(py_ast.body[0], ast.ClassDef): 192*da0073e9SAndroid Build Coastguard Worker raise torch.jit.frontend.FrontendError( 193*da0073e9SAndroid Build Coastguard Worker loc, 194*da0073e9SAndroid Build Coastguard Worker f"Cannot instantiate class '{py_ast.body[0].name}' in a script function", 195*da0073e9SAndroid Build Coastguard Worker ) 196*da0073e9SAndroid Build Coastguard Worker if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): 197*da0073e9SAndroid Build Coastguard Worker raise torch.jit.frontend.FrontendError( 198*da0073e9SAndroid Build Coastguard Worker loc, "Expected a single top-level function" 199*da0073e9SAndroid Build Coastguard Worker ) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Workerdef _eval_no_call(stmt, glob, loc): 203*da0073e9SAndroid Build Coastguard Worker """Evaluate statement as long as it does not contain any method/function calls.""" 204*da0073e9SAndroid Build Coastguard Worker bytecode = compile(stmt, "", mode="eval") 205*da0073e9SAndroid Build Coastguard Worker for insn in dis.get_instructions(bytecode): 206*da0073e9SAndroid Build Coastguard Worker if "CALL" in insn.opname: 207*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 208*da0073e9SAndroid Build Coastguard Worker f"Type annotation should not contain calls, but '{stmt}' does" 209*da0073e9SAndroid Build Coastguard Worker ) 210*da0073e9SAndroid Build Coastguard Worker return eval(bytecode, glob, loc) # type: ignore[arg-type] # noqa: P204 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker 213*da0073e9SAndroid Build Coastguard Workerdef parse_type_line(type_line, rcb, loc): 214*da0073e9SAndroid Build Coastguard Worker """Parse a type annotation specified as a comment. 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker Example inputs: 217*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, torch.Tensor) -> Tuple[Tensor] 218*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, Tuple[Tensor, Tensor]) -> Tensor 219*da0073e9SAndroid Build Coastguard Worker """ 220*da0073e9SAndroid Build Coastguard Worker arg_ann_str, ret_ann_str = split_type_line(type_line) 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker try: 223*da0073e9SAndroid Build Coastguard Worker arg_ann = _eval_no_call(arg_ann_str, {}, EvalEnv(rcb)) 224*da0073e9SAndroid Build Coastguard Worker except (NameError, SyntaxError) as e: 225*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 226*da0073e9SAndroid Build Coastguard Worker "Failed to parse the argument list of a type annotation" 227*da0073e9SAndroid Build Coastguard Worker ) from e 228*da0073e9SAndroid Build Coastguard Worker 229*da0073e9SAndroid Build Coastguard Worker if not isinstance(arg_ann, tuple): 230*da0073e9SAndroid Build Coastguard Worker arg_ann = (arg_ann,) 231*da0073e9SAndroid Build Coastguard Worker 232*da0073e9SAndroid Build Coastguard Worker try: 233*da0073e9SAndroid Build Coastguard Worker ret_ann = _eval_no_call(ret_ann_str, {}, EvalEnv(rcb)) 234*da0073e9SAndroid Build Coastguard Worker except (NameError, SyntaxError) as e: 235*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 236*da0073e9SAndroid Build Coastguard Worker "Failed to parse the return type of a type annotation" 237*da0073e9SAndroid Build Coastguard Worker ) from e 238*da0073e9SAndroid Build Coastguard Worker 239*da0073e9SAndroid Build Coastguard Worker arg_types = [ann_to_type(ann, loc) for ann in arg_ann] 240*da0073e9SAndroid Build Coastguard Worker return arg_types, ann_to_type(ret_ann, loc) 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Workerdef get_type_line(source): 244*da0073e9SAndroid Build Coastguard Worker """Try to find the line containing a comment with the type annotation.""" 245*da0073e9SAndroid Build Coastguard Worker type_comment = "# type:" 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker lines = source.split("\n") 248*da0073e9SAndroid Build Coastguard Worker lines = list(enumerate(lines)) 249*da0073e9SAndroid Build Coastguard Worker type_lines = list(filter(lambda line: type_comment in line[1], lines)) 250*da0073e9SAndroid Build Coastguard Worker # `type: ignore` comments may be needed in JIT'ed functions for mypy, due 251*da0073e9SAndroid Build Coastguard Worker # to the hack in torch/_VF.py. 252*da0073e9SAndroid Build Coastguard Worker 253*da0073e9SAndroid Build Coastguard Worker # An ignore type comment can be of following format: 254*da0073e9SAndroid Build Coastguard Worker # 1) type: ignore 255*da0073e9SAndroid Build Coastguard Worker # 2) type: ignore[rule-code] 256*da0073e9SAndroid Build Coastguard Worker # This ignore statement must be at the end of the line 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker # adding an extra backslash before the space, to avoid triggering 259*da0073e9SAndroid Build Coastguard Worker # one of the checks in .github/workflows/lint.yml 260*da0073e9SAndroid Build Coastguard Worker type_pattern = re.compile("# type:\\ ignore(\\[[a-zA-Z-]+\\])?$") 261*da0073e9SAndroid Build Coastguard Worker type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker if len(type_lines) == 0: 264*da0073e9SAndroid Build Coastguard Worker # Catch common typo patterns like extra spaces, typo in 'ignore', etc. 265*da0073e9SAndroid Build Coastguard Worker wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") 266*da0073e9SAndroid Build Coastguard Worker wrong_type_lines = list( 267*da0073e9SAndroid Build Coastguard Worker filter(lambda line: wrong_type_pattern.search(line[1]), lines) 268*da0073e9SAndroid Build Coastguard Worker ) 269*da0073e9SAndroid Build Coastguard Worker if len(wrong_type_lines) > 0: 270*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 271*da0073e9SAndroid Build Coastguard Worker "The annotation prefix in line " 272*da0073e9SAndroid Build Coastguard Worker + str(wrong_type_lines[0][0]) 273*da0073e9SAndroid Build Coastguard Worker + " is probably invalid.\nIt must be '# type:'" 274*da0073e9SAndroid Build Coastguard Worker + "\nSee PEP 484 (https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" # noqa: B950 275*da0073e9SAndroid Build Coastguard Worker + "\nfor examples" 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker return None 278*da0073e9SAndroid Build Coastguard Worker elif len(type_lines) == 1: 279*da0073e9SAndroid Build Coastguard Worker # Only 1 type line, quit now 280*da0073e9SAndroid Build Coastguard Worker return type_lines[0][1].strip() 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker # Parse split up argument types according to PEP 484 283*da0073e9SAndroid Build Coastguard Worker # https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code 284*da0073e9SAndroid Build Coastguard Worker return_line = None 285*da0073e9SAndroid Build Coastguard Worker parameter_type_lines = [] 286*da0073e9SAndroid Build Coastguard Worker for line_num, line in type_lines: 287*da0073e9SAndroid Build Coastguard Worker if "# type: (...) -> " in line: 288*da0073e9SAndroid Build Coastguard Worker return_line = (line_num, line) 289*da0073e9SAndroid Build Coastguard Worker break 290*da0073e9SAndroid Build Coastguard Worker elif type_comment in line: 291*da0073e9SAndroid Build Coastguard Worker parameter_type_lines.append(line) 292*da0073e9SAndroid Build Coastguard Worker if return_line is None: 293*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 294*da0073e9SAndroid Build Coastguard Worker "Return type line '# type: (...) -> ...' not found on multiline " 295*da0073e9SAndroid Build Coastguard Worker "type annotation\nfor type lines:\n" 296*da0073e9SAndroid Build Coastguard Worker + "\n".join([line[1] for line in type_lines]) 297*da0073e9SAndroid Build Coastguard Worker + "\n(See PEP 484 https://www.python.org/dev/peps/pep-0484/#suggested-syntax-for-python-2-7-and-straddling-code)" 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker 300*da0073e9SAndroid Build Coastguard Worker def get_parameter_type(line): 301*da0073e9SAndroid Build Coastguard Worker item_type = line[line.find(type_comment) + len(type_comment) :] 302*da0073e9SAndroid Build Coastguard Worker return item_type.strip() 303*da0073e9SAndroid Build Coastguard Worker 304*da0073e9SAndroid Build Coastguard Worker types = map(get_parameter_type, parameter_type_lines) 305*da0073e9SAndroid Build Coastguard Worker parameter_types = ", ".join(types) 306*da0073e9SAndroid Build Coastguard Worker 307*da0073e9SAndroid Build Coastguard Worker return return_line[1].replace("...", parameter_types) 308*da0073e9SAndroid Build Coastguard Worker 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Workerdef split_type_line(type_line): 311*da0073e9SAndroid Build Coastguard Worker """Split the comment with the type annotation into parts for argument and return types. 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker For example, for an input of: 314*da0073e9SAndroid Build Coastguard Worker # type: (Tensor, torch.Tensor) -> Tuple[Tensor, Tensor] 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker This function will return: 317*da0073e9SAndroid Build Coastguard Worker ("(Tensor, torch.Tensor)", "Tuple[Tensor, Tensor]") 318*da0073e9SAndroid Build Coastguard Worker 319*da0073e9SAndroid Build Coastguard Worker """ 320*da0073e9SAndroid Build Coastguard Worker start_offset = len("# type:") 321*da0073e9SAndroid Build Coastguard Worker try: 322*da0073e9SAndroid Build Coastguard Worker arrow_pos = type_line.index("->") 323*da0073e9SAndroid Build Coastguard Worker except ValueError: 324*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 325*da0073e9SAndroid Build Coastguard Worker "Syntax error in type annotation (couldn't find `->`)" 326*da0073e9SAndroid Build Coastguard Worker ) from None 327*da0073e9SAndroid Build Coastguard Worker return type_line[start_offset:arrow_pos].strip(), type_line[arrow_pos + 2 :].strip() 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Workerdef try_real_annotations(fn, loc): 331*da0073e9SAndroid Build Coastguard Worker """Try to use the Py3.5+ annotation syntax to get the type.""" 332*da0073e9SAndroid Build Coastguard Worker try: 333*da0073e9SAndroid Build Coastguard Worker # Note: anything annotated as `Optional[T]` will automatically 334*da0073e9SAndroid Build Coastguard Worker # be returned as `Union[T, None]` per 335*da0073e9SAndroid Build Coastguard Worker # https://github.com/python/typing/blob/master/src/typing.py#L850 336*da0073e9SAndroid Build Coastguard Worker sig = inspect.signature(fn) 337*da0073e9SAndroid Build Coastguard Worker except ValueError: 338*da0073e9SAndroid Build Coastguard Worker return None 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker all_annots = [sig.return_annotation] + [ 341*da0073e9SAndroid Build Coastguard Worker p.annotation for p in sig.parameters.values() 342*da0073e9SAndroid Build Coastguard Worker ] 343*da0073e9SAndroid Build Coastguard Worker if all(ann is sig.empty for ann in all_annots): 344*da0073e9SAndroid Build Coastguard Worker return None 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker arg_types = [ann_to_type(p.annotation, loc) for p in sig.parameters.values()] 347*da0073e9SAndroid Build Coastguard Worker return_type = ann_to_type(sig.return_annotation, loc) 348*da0073e9SAndroid Build Coastguard Worker return arg_types, return_type 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker# Finds common type for enum values belonging to an Enum class. If not all 352*da0073e9SAndroid Build Coastguard Worker# values have the same type, AnyType is returned. 353*da0073e9SAndroid Build Coastguard Workerdef get_enum_value_type(e: Type[enum.Enum], loc): 354*da0073e9SAndroid Build Coastguard Worker enum_values: List[enum.Enum] = list(e) 355*da0073e9SAndroid Build Coastguard Worker if not enum_values: 356*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"No enum values defined for: '{e.__class__}'") 357*da0073e9SAndroid Build Coastguard Worker 358*da0073e9SAndroid Build Coastguard Worker types = {type(v.value) for v in enum_values} 359*da0073e9SAndroid Build Coastguard Worker ir_types = [try_ann_to_type(t, loc) for t in types] 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker # If Enum values are of different types, an exception will be raised here. 362*da0073e9SAndroid Build Coastguard Worker # Even though Python supports this case, we chose to not implement it to 363*da0073e9SAndroid Build Coastguard Worker # avoid overcomplicate logic here for a rare use case. Please report a 364*da0073e9SAndroid Build Coastguard Worker # feature request if you find it necessary. 365*da0073e9SAndroid Build Coastguard Worker res = torch._C.unify_type_list(ir_types) 366*da0073e9SAndroid Build Coastguard Worker if not res: 367*da0073e9SAndroid Build Coastguard Worker return AnyType.get() 368*da0073e9SAndroid Build Coastguard Worker return res 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Workerdef is_tensor(ann): 372*da0073e9SAndroid Build Coastguard Worker if issubclass(ann, torch.Tensor): 373*da0073e9SAndroid Build Coastguard Worker return True 374*da0073e9SAndroid Build Coastguard Worker 375*da0073e9SAndroid Build Coastguard Worker if issubclass( 376*da0073e9SAndroid Build Coastguard Worker ann, 377*da0073e9SAndroid Build Coastguard Worker ( 378*da0073e9SAndroid Build Coastguard Worker torch.LongTensor, 379*da0073e9SAndroid Build Coastguard Worker torch.DoubleTensor, 380*da0073e9SAndroid Build Coastguard Worker torch.FloatTensor, 381*da0073e9SAndroid Build Coastguard Worker torch.IntTensor, 382*da0073e9SAndroid Build Coastguard Worker torch.ShortTensor, 383*da0073e9SAndroid Build Coastguard Worker torch.HalfTensor, 384*da0073e9SAndroid Build Coastguard Worker torch.CharTensor, 385*da0073e9SAndroid Build Coastguard Worker torch.ByteTensor, 386*da0073e9SAndroid Build Coastguard Worker torch.BoolTensor, 387*da0073e9SAndroid Build Coastguard Worker ), 388*da0073e9SAndroid Build Coastguard Worker ): 389*da0073e9SAndroid Build Coastguard Worker warnings.warn( 390*da0073e9SAndroid Build Coastguard Worker "TorchScript will treat type annotations of Tensor " 391*da0073e9SAndroid Build Coastguard Worker "dtype-specific subtypes as if they are normal Tensors. " 392*da0073e9SAndroid Build Coastguard Worker "dtype constraints are not enforced in compilation either." 393*da0073e9SAndroid Build Coastguard Worker ) 394*da0073e9SAndroid Build Coastguard Worker return True 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker return False 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker 399*da0073e9SAndroid Build Coastguard Workerdef _fake_rcb(inp): 400*da0073e9SAndroid Build Coastguard Worker return None 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker 403*da0073e9SAndroid Build Coastguard Workerdef try_ann_to_type(ann, loc, rcb=None): 404*da0073e9SAndroid Build Coastguard Worker ann_args = typing.get_args(ann) # always returns a tuple! 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker if ann is inspect.Signature.empty: 407*da0073e9SAndroid Build Coastguard Worker return TensorType.getInferred() 408*da0073e9SAndroid Build Coastguard Worker if ann is None: 409*da0073e9SAndroid Build Coastguard Worker return NoneType.get() 410*da0073e9SAndroid Build Coastguard Worker if inspect.isclass(ann) and is_tensor(ann): 411*da0073e9SAndroid Build Coastguard Worker return TensorType.get() 412*da0073e9SAndroid Build Coastguard Worker if is_tuple(ann): 413*da0073e9SAndroid Build Coastguard Worker # Special case for the empty Tuple type annotation `Tuple[()]` 414*da0073e9SAndroid Build Coastguard Worker if len(ann_args) == 1 and ann_args[0] == (): 415*da0073e9SAndroid Build Coastguard Worker return TupleType([]) 416*da0073e9SAndroid Build Coastguard Worker return TupleType([try_ann_to_type(a, loc) for a in ann_args]) 417*da0073e9SAndroid Build Coastguard Worker if is_list(ann): 418*da0073e9SAndroid Build Coastguard Worker elem_type = try_ann_to_type(ann_args[0], loc) 419*da0073e9SAndroid Build Coastguard Worker if elem_type: 420*da0073e9SAndroid Build Coastguard Worker return ListType(elem_type) 421*da0073e9SAndroid Build Coastguard Worker if is_dict(ann): 422*da0073e9SAndroid Build Coastguard Worker key = try_ann_to_type(ann_args[0], loc) 423*da0073e9SAndroid Build Coastguard Worker value = try_ann_to_type(ann_args[1], loc) 424*da0073e9SAndroid Build Coastguard Worker # Raise error if key or value is None 425*da0073e9SAndroid Build Coastguard Worker if key is None: 426*da0073e9SAndroid Build Coastguard Worker raise ValueError( 427*da0073e9SAndroid Build Coastguard Worker f"Unknown type annotation: '{ann_args[0]}' at {loc.highlight()}" 428*da0073e9SAndroid Build Coastguard Worker ) 429*da0073e9SAndroid Build Coastguard Worker if value is None: 430*da0073e9SAndroid Build Coastguard Worker raise ValueError( 431*da0073e9SAndroid Build Coastguard Worker f"Unknown type annotation: '{ann_args[1]}' at {loc.highlight()}" 432*da0073e9SAndroid Build Coastguard Worker ) 433*da0073e9SAndroid Build Coastguard Worker return DictType(key, value) 434*da0073e9SAndroid Build Coastguard Worker if is_optional(ann): 435*da0073e9SAndroid Build Coastguard Worker if issubclass(ann_args[1], type(None)): 436*da0073e9SAndroid Build Coastguard Worker contained = ann_args[0] 437*da0073e9SAndroid Build Coastguard Worker else: 438*da0073e9SAndroid Build Coastguard Worker contained = ann_args[1] 439*da0073e9SAndroid Build Coastguard Worker valid_type = try_ann_to_type(contained, loc) 440*da0073e9SAndroid Build Coastguard Worker msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" 441*da0073e9SAndroid Build Coastguard Worker assert valid_type, msg.format(repr(ann), repr(contained), repr(loc)) 442*da0073e9SAndroid Build Coastguard Worker return OptionalType(valid_type) 443*da0073e9SAndroid Build Coastguard Worker if is_union(ann): 444*da0073e9SAndroid Build Coastguard Worker # TODO: this is hack to recognize NumberType 445*da0073e9SAndroid Build Coastguard Worker if set(ann_args) == {int, float, complex}: 446*da0073e9SAndroid Build Coastguard Worker return NumberType.get() 447*da0073e9SAndroid Build Coastguard Worker inner: List = [] 448*da0073e9SAndroid Build Coastguard Worker # We need these extra checks because both `None` and invalid 449*da0073e9SAndroid Build Coastguard Worker # values will return `None` 450*da0073e9SAndroid Build Coastguard Worker # TODO: Determine if the other cases need to be fixed as well 451*da0073e9SAndroid Build Coastguard Worker for a in typing.get_args(ann): 452*da0073e9SAndroid Build Coastguard Worker if a is None: 453*da0073e9SAndroid Build Coastguard Worker inner.append(NoneType.get()) 454*da0073e9SAndroid Build Coastguard Worker maybe_type = try_ann_to_type(a, loc) 455*da0073e9SAndroid Build Coastguard Worker msg = "Unsupported annotation {} could not be resolved because {} could not be resolved. At\n{}" 456*da0073e9SAndroid Build Coastguard Worker assert maybe_type, msg.format(repr(ann), repr(maybe_type), repr(loc)) 457*da0073e9SAndroid Build Coastguard Worker inner.append(maybe_type) 458*da0073e9SAndroid Build Coastguard Worker return UnionType(inner) # type: ignore[arg-type] 459*da0073e9SAndroid Build Coastguard Worker if torch.distributed.rpc.is_available() and is_rref(ann): 460*da0073e9SAndroid Build Coastguard Worker return RRefType(try_ann_to_type(ann_args[0], loc)) 461*da0073e9SAndroid Build Coastguard Worker if is_future(ann): 462*da0073e9SAndroid Build Coastguard Worker return FutureType(try_ann_to_type(ann_args[0], loc)) 463*da0073e9SAndroid Build Coastguard Worker if is_await(ann): 464*da0073e9SAndroid Build Coastguard Worker elementType = try_ann_to_type(ann_args[0], loc) if ann_args else AnyType.get() 465*da0073e9SAndroid Build Coastguard Worker return AwaitType(elementType) 466*da0073e9SAndroid Build Coastguard Worker if ann is float: 467*da0073e9SAndroid Build Coastguard Worker return FloatType.get() 468*da0073e9SAndroid Build Coastguard Worker if ann is complex: 469*da0073e9SAndroid Build Coastguard Worker return ComplexType.get() 470*da0073e9SAndroid Build Coastguard Worker if ann is int or ann is torch.SymInt: 471*da0073e9SAndroid Build Coastguard Worker return IntType.get() 472*da0073e9SAndroid Build Coastguard Worker if ann is str: 473*da0073e9SAndroid Build Coastguard Worker return StringType.get() 474*da0073e9SAndroid Build Coastguard Worker if ann is bool: 475*da0073e9SAndroid Build Coastguard Worker return BoolType.get() 476*da0073e9SAndroid Build Coastguard Worker if ann is Any: 477*da0073e9SAndroid Build Coastguard Worker return AnyType.get() 478*da0073e9SAndroid Build Coastguard Worker if ann is type(None): 479*da0073e9SAndroid Build Coastguard Worker return NoneType.get() 480*da0073e9SAndroid Build Coastguard Worker if inspect.isclass(ann) and hasattr(ann, "__torch_script_interface__"): 481*da0073e9SAndroid Build Coastguard Worker return InterfaceType(ann.__torch_script_interface__) 482*da0073e9SAndroid Build Coastguard Worker if ann is torch.device: 483*da0073e9SAndroid Build Coastguard Worker return DeviceObjType.get() 484*da0073e9SAndroid Build Coastguard Worker if ann is torch.Generator: 485*da0073e9SAndroid Build Coastguard Worker return _GeneratorType.get() 486*da0073e9SAndroid Build Coastguard Worker if ann is torch.Stream: 487*da0073e9SAndroid Build Coastguard Worker return StreamObjType.get() 488*da0073e9SAndroid Build Coastguard Worker if ann is torch.dtype: 489*da0073e9SAndroid Build Coastguard Worker return IntType.get() # dtype not yet bound in as its own type 490*da0073e9SAndroid Build Coastguard Worker if inspect.isclass(ann) and issubclass(ann, enum.Enum): 491*da0073e9SAndroid Build Coastguard Worker if _get_script_class(ann) is None: 492*da0073e9SAndroid Build Coastguard Worker scripted_class = torch.jit._script._recursive_compile_class(ann, loc) 493*da0073e9SAndroid Build Coastguard Worker name = scripted_class.qualified_name() 494*da0073e9SAndroid Build Coastguard Worker else: 495*da0073e9SAndroid Build Coastguard Worker name = _qualified_name(ann) 496*da0073e9SAndroid Build Coastguard Worker return EnumType(name, get_enum_value_type(ann, loc), list(ann)) 497*da0073e9SAndroid Build Coastguard Worker if inspect.isclass(ann): 498*da0073e9SAndroid Build Coastguard Worker maybe_script_class = _get_script_class(ann) 499*da0073e9SAndroid Build Coastguard Worker if maybe_script_class is not None: 500*da0073e9SAndroid Build Coastguard Worker return maybe_script_class 501*da0073e9SAndroid Build Coastguard Worker if torch._jit_internal.can_compile_class(ann): 502*da0073e9SAndroid Build Coastguard Worker return torch.jit._script._recursive_compile_class(ann, loc) 503*da0073e9SAndroid Build Coastguard Worker 504*da0073e9SAndroid Build Coastguard Worker # Maybe resolve a NamedTuple to a Tuple Type 505*da0073e9SAndroid Build Coastguard Worker if rcb is None: 506*da0073e9SAndroid Build Coastguard Worker rcb = _fake_rcb 507*da0073e9SAndroid Build Coastguard Worker return torch._C._resolve_type_from_object(ann, loc, rcb) 508*da0073e9SAndroid Build Coastguard Worker 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Workerdef ann_to_type(ann, loc, rcb=None): 511*da0073e9SAndroid Build Coastguard Worker the_type = try_ann_to_type(ann, loc, rcb) 512*da0073e9SAndroid Build Coastguard Worker if the_type is not None: 513*da0073e9SAndroid Build Coastguard Worker return the_type 514*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Unknown type annotation: '{ann}' at {loc.highlight()}") 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker 517*da0073e9SAndroid Build Coastguard Worker__all__ = [ 518*da0073e9SAndroid Build Coastguard Worker "Any", 519*da0073e9SAndroid Build Coastguard Worker "List", 520*da0073e9SAndroid Build Coastguard Worker "BroadcastingList1", 521*da0073e9SAndroid Build Coastguard Worker "BroadcastingList2", 522*da0073e9SAndroid Build Coastguard Worker "BroadcastingList3", 523*da0073e9SAndroid Build Coastguard Worker "Tuple", 524*da0073e9SAndroid Build Coastguard Worker "is_tuple", 525*da0073e9SAndroid Build Coastguard Worker "is_list", 526*da0073e9SAndroid Build Coastguard Worker "Dict", 527*da0073e9SAndroid Build Coastguard Worker "is_dict", 528*da0073e9SAndroid Build Coastguard Worker "is_optional", 529*da0073e9SAndroid Build Coastguard Worker "is_union", 530*da0073e9SAndroid Build Coastguard Worker "TensorType", 531*da0073e9SAndroid Build Coastguard Worker "TupleType", 532*da0073e9SAndroid Build Coastguard Worker "FloatType", 533*da0073e9SAndroid Build Coastguard Worker "ComplexType", 534*da0073e9SAndroid Build Coastguard Worker "IntType", 535*da0073e9SAndroid Build Coastguard Worker "ListType", 536*da0073e9SAndroid Build Coastguard Worker "StringType", 537*da0073e9SAndroid Build Coastguard Worker "DictType", 538*da0073e9SAndroid Build Coastguard Worker "AnyType", 539*da0073e9SAndroid Build Coastguard Worker "Module", 540*da0073e9SAndroid Build Coastguard Worker # TODO: Consider not exporting these during wildcard import (reserve 541*da0073e9SAndroid Build Coastguard Worker # that for the types; for idiomatic typing code.) 542*da0073e9SAndroid Build Coastguard Worker "get_signature", 543*da0073e9SAndroid Build Coastguard Worker "check_fn", 544*da0073e9SAndroid Build Coastguard Worker "get_param_names", 545*da0073e9SAndroid Build Coastguard Worker "parse_type_line", 546*da0073e9SAndroid Build Coastguard Worker "get_type_line", 547*da0073e9SAndroid Build Coastguard Worker "split_type_line", 548*da0073e9SAndroid Build Coastguard Worker "try_real_annotations", 549*da0073e9SAndroid Build Coastguard Worker "try_ann_to_type", 550*da0073e9SAndroid Build Coastguard Worker "ann_to_type", 551*da0073e9SAndroid Build Coastguard Worker] 552