xref: /aosp_15_r20/external/pytorch/test/dynamo/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport importlib
3*da0073e9SAndroid Build Coastguard Workerimport os
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport types
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workerg_tensor_export = torch.ones(10)
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workertensor_for_import_testing = torch.ones(10, 10)
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerdef inner_func():
18*da0073e9SAndroid Build Coastguard Worker    return torch.is_grad_enabled()
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerdef outer_func(func):
22*da0073e9SAndroid Build Coastguard Worker    def wrapped(*args):
23*da0073e9SAndroid Build Coastguard Worker        a = func(*args)
24*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.graph_break()
25*da0073e9SAndroid Build Coastguard Worker        return torch.sin(a + 1), inner_func()
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    return wrapped
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker# Create a dummy python module and function to test skipfiles rules.
31*da0073e9SAndroid Build Coastguard Workermodule_code = """
32*da0073e9SAndroid Build Coastguard Workerdef add(x):
33*da0073e9SAndroid Build Coastguard Worker    return x + 1
34*da0073e9SAndroid Build Coastguard Worker"""
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Workerdef add(x):
38*da0073e9SAndroid Build Coastguard Worker    return x + 1
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerdef create_dummy_module_and_function():
42*da0073e9SAndroid Build Coastguard Worker    module = types.ModuleType("dummy_module")
43*da0073e9SAndroid Build Coastguard Worker    module.__spec__ = importlib.machinery.ModuleSpec(
44*da0073e9SAndroid Build Coastguard Worker        "dummy_module", None, origin=os.path.abspath(__file__)
45*da0073e9SAndroid Build Coastguard Worker    )
46*da0073e9SAndroid Build Coastguard Worker    exec(module_code, module.__dict__)
47*da0073e9SAndroid Build Coastguard Worker    sys.modules["dummy_module"] = module
48*da0073e9SAndroid Build Coastguard Worker    # Need to override the original function since its __code__.co_filename is not a regular python file name,
49*da0073e9SAndroid Build Coastguard Worker    # and the skipfiles rules use filename when checking SKIP_DIRS.
50*da0073e9SAndroid Build Coastguard Worker    module.add = add
51*da0073e9SAndroid Build Coastguard Worker    return module, module.add
52