1*da0073e9SAndroid Build Coastguard Workerimport argparse 2*da0073e9SAndroid Build Coastguard Workerimport os.path 3*da0073e9SAndroid Build Coastguard Workerimport sys 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerdef get_custom_backend_library_path(): 9*da0073e9SAndroid Build Coastguard Worker """ 10*da0073e9SAndroid Build Coastguard Worker Get the path to the library containing the custom backend. 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Worker Return: 13*da0073e9SAndroid Build Coastguard Worker The path to the custom backend object, customized by platform. 14*da0073e9SAndroid Build Coastguard Worker """ 15*da0073e9SAndroid Build Coastguard Worker if sys.platform.startswith("win32"): 16*da0073e9SAndroid Build Coastguard Worker library_filename = "custom_backend.dll" 17*da0073e9SAndroid Build Coastguard Worker elif sys.platform.startswith("darwin"): 18*da0073e9SAndroid Build Coastguard Worker library_filename = "libcustom_backend.dylib" 19*da0073e9SAndroid Build Coastguard Worker else: 20*da0073e9SAndroid Build Coastguard Worker library_filename = "libcustom_backend.so" 21*da0073e9SAndroid Build Coastguard Worker path = os.path.abspath(f"build/{library_filename}") 22*da0073e9SAndroid Build Coastguard Worker assert os.path.exists(path), path 23*da0073e9SAndroid Build Coastguard Worker return path 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Workerdef to_custom_backend(module): 27*da0073e9SAndroid Build Coastguard Worker """ 28*da0073e9SAndroid Build Coastguard Worker This is a helper that wraps torch._C._jit_to_test_backend and compiles 29*da0073e9SAndroid Build Coastguard Worker only the forward method with an empty compile spec. 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker Args: 32*da0073e9SAndroid Build Coastguard Worker module: input ScriptModule. 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker Returns: 35*da0073e9SAndroid Build Coastguard Worker The module, lowered so that it can run on TestBackend. 36*da0073e9SAndroid Build Coastguard Worker """ 37*da0073e9SAndroid Build Coastguard Worker lowered_module = torch._C._jit_to_backend( 38*da0073e9SAndroid Build Coastguard Worker "custom_backend", module, {"forward": {"": ""}} 39*da0073e9SAndroid Build Coastguard Worker ) 40*da0073e9SAndroid Build Coastguard Worker return lowered_module 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Workerclass Model(torch.nn.Module): 44*da0073e9SAndroid Build Coastguard Worker """ 45*da0073e9SAndroid Build Coastguard Worker Simple model used for testing that to_backend API supports saving, loading, 46*da0073e9SAndroid Build Coastguard Worker and executing in C++. 47*da0073e9SAndroid Build Coastguard Worker """ 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker def forward(self, a, b): 50*da0073e9SAndroid Build Coastguard Worker return (a + b, a - b) 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Workerdef main(): 54*da0073e9SAndroid Build Coastguard Worker parser = argparse.ArgumentParser(description="Lower a Module to a custom backend") 55*da0073e9SAndroid Build Coastguard Worker parser.add_argument("--export-module-to", required=True) 56*da0073e9SAndroid Build Coastguard Worker options = parser.parse_args() 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker # Load the library containing the custom backend. 59*da0073e9SAndroid Build Coastguard Worker library_path = get_custom_backend_library_path() 60*da0073e9SAndroid Build Coastguard Worker torch.ops.load_library(library_path) 61*da0073e9SAndroid Build Coastguard Worker assert library_path in torch.ops.loaded_libraries 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker # Lower an instance of Model to the custom backend and export it 64*da0073e9SAndroid Build Coastguard Worker # to the specified location. 65*da0073e9SAndroid Build Coastguard Worker lowered_module = to_custom_backend(torch.jit.script(Model())) 66*da0073e9SAndroid Build Coastguard Worker torch.jit.save(lowered_module, options.export_module_to) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 70*da0073e9SAndroid Build Coastguard Worker main() 71