xref: /aosp_15_r20/external/pytorch/test/dynamo/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import importlib
3import os
4import sys
5import types
6
7import torch
8import torch._dynamo
9
10
11g_tensor_export = torch.ones(10)
12
13
14tensor_for_import_testing = torch.ones(10, 10)
15
16
17def inner_func():
18    return torch.is_grad_enabled()
19
20
21def outer_func(func):
22    def wrapped(*args):
23        a = func(*args)
24        torch._dynamo.graph_break()
25        return torch.sin(a + 1), inner_func()
26
27    return wrapped
28
29
30# Create a dummy python module and function to test skipfiles rules.
31module_code = """
32def add(x):
33    return x + 1
34"""
35
36
37def add(x):
38    return x + 1
39
40
41def create_dummy_module_and_function():
42    module = types.ModuleType("dummy_module")
43    module.__spec__ = importlib.machinery.ModuleSpec(
44        "dummy_module", None, origin=os.path.abspath(__file__)
45    )
46    exec(module_code, module.__dict__)
47    sys.modules["dummy_module"] = module
48    # Need to override the original function since its __code__.co_filename is not a regular python file name,
49    # and the skipfiles rules use filename when checking SKIP_DIRS.
50    module.add = add
51    return module, module.add
52