1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7"""Example of showcasing registering custom operator through torch library API.""" 8import torch 9from examples.portable.scripts.export import export_to_exec_prog, save_pte_program 10 11from executorch.exir import EdgeCompileConfig 12from torch.library import impl, Library 13 14my_op_lib = Library("my_ops", "DEF") 15 16# registering an operator that multiplies input tensor by 3 and returns it. 17my_op_lib.define("mul3(Tensor input) -> Tensor") # should print 'mul3' 18 19 20@impl(my_op_lib, "mul3", dispatch_key="CompositeExplicitAutograd") 21def mul3_impl(a: torch.Tensor) -> torch.Tensor: 22 return a * 3 23 24 25# registering the out variant. 26my_op_lib.define( 27 "mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)" 28) # should print 'mul3.out' 29 30 31@impl(my_op_lib, "mul3.out", dispatch_key="CompositeExplicitAutograd") 32def mul3_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: 33 out.copy_(a) 34 out.mul_(3) 35 return out 36 37 38# example model 39class Model(torch.nn.Module): 40 def forward(self, a): 41 return torch.ops.my_ops.mul3.default(a) 42 43 44def main(): 45 m = Model() 46 input = torch.randn(2, 3) 47 # capture and lower 48 model_name = "custom_ops_1" 49 prog = export_to_exec_prog( 50 m, 51 (input,), 52 edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), 53 ) 54 save_pte_program(prog, model_name) 55 56 57if __name__ == "__main__": 58 main() 59