xref: /aosp_15_r20/external/pytorch/test/custom_backend/backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport argparse
2*da0073e9SAndroid Build Coastguard Workerimport os.path
3*da0073e9SAndroid Build Coastguard Workerimport sys
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerdef get_custom_backend_library_path():
9*da0073e9SAndroid Build Coastguard Worker    """
10*da0073e9SAndroid Build Coastguard Worker    Get the path to the library containing the custom backend.
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker    Return:
13*da0073e9SAndroid Build Coastguard Worker        The path to the custom backend object, customized by platform.
14*da0073e9SAndroid Build Coastguard Worker    """
15*da0073e9SAndroid Build Coastguard Worker    if sys.platform.startswith("win32"):
16*da0073e9SAndroid Build Coastguard Worker        library_filename = "custom_backend.dll"
17*da0073e9SAndroid Build Coastguard Worker    elif sys.platform.startswith("darwin"):
18*da0073e9SAndroid Build Coastguard Worker        library_filename = "libcustom_backend.dylib"
19*da0073e9SAndroid Build Coastguard Worker    else:
20*da0073e9SAndroid Build Coastguard Worker        library_filename = "libcustom_backend.so"
21*da0073e9SAndroid Build Coastguard Worker    path = os.path.abspath(f"build/{library_filename}")
22*da0073e9SAndroid Build Coastguard Worker    assert os.path.exists(path), path
23*da0073e9SAndroid Build Coastguard Worker    return path
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Workerdef to_custom_backend(module):
27*da0073e9SAndroid Build Coastguard Worker    """
28*da0073e9SAndroid Build Coastguard Worker    This is a helper that wraps torch._C._jit_to_test_backend and compiles
29*da0073e9SAndroid Build Coastguard Worker    only the forward method with an empty compile spec.
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    Args:
32*da0073e9SAndroid Build Coastguard Worker        module: input ScriptModule.
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    Returns:
35*da0073e9SAndroid Build Coastguard Worker        The module, lowered so that it can run on TestBackend.
36*da0073e9SAndroid Build Coastguard Worker    """
37*da0073e9SAndroid Build Coastguard Worker    lowered_module = torch._C._jit_to_backend(
38*da0073e9SAndroid Build Coastguard Worker        "custom_backend", module, {"forward": {"": ""}}
39*da0073e9SAndroid Build Coastguard Worker    )
40*da0073e9SAndroid Build Coastguard Worker    return lowered_module
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerclass Model(torch.nn.Module):
44*da0073e9SAndroid Build Coastguard Worker    """
45*da0073e9SAndroid Build Coastguard Worker    Simple model used for testing that to_backend API supports saving, loading,
46*da0073e9SAndroid Build Coastguard Worker    and executing in C++.
47*da0073e9SAndroid Build Coastguard Worker    """
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def forward(self, a, b):
50*da0073e9SAndroid Build Coastguard Worker        return (a + b, a - b)
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Workerdef main():
54*da0073e9SAndroid Build Coastguard Worker    parser = argparse.ArgumentParser(description="Lower a Module to a custom backend")
55*da0073e9SAndroid Build Coastguard Worker    parser.add_argument("--export-module-to", required=True)
56*da0073e9SAndroid Build Coastguard Worker    options = parser.parse_args()
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    # Load the library containing the custom backend.
59*da0073e9SAndroid Build Coastguard Worker    library_path = get_custom_backend_library_path()
60*da0073e9SAndroid Build Coastguard Worker    torch.ops.load_library(library_path)
61*da0073e9SAndroid Build Coastguard Worker    assert library_path in torch.ops.loaded_libraries
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    # Lower an instance of Model to the custom backend  and export it
64*da0073e9SAndroid Build Coastguard Worker    # to the specified location.
65*da0073e9SAndroid Build Coastguard Worker    lowered_module = to_custom_backend(torch.jit.script(Model()))
66*da0073e9SAndroid Build Coastguard Worker    torch.jit.save(lowered_module, options.export_module_to)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
70*da0073e9SAndroid Build Coastguard Worker    main()
71