xref: /aosp_15_r20/external/executorch/examples/portable/custom_ops/custom_ops_1.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7"""Example of showcasing registering custom operator through torch library API."""
8import torch
9from examples.portable.scripts.export import export_to_exec_prog, save_pte_program
10
11from executorch.exir import EdgeCompileConfig
12from torch.library import impl, Library
13
14my_op_lib = Library("my_ops", "DEF")
15
16# registering an operator that multiplies input tensor by 3 and returns it.
17my_op_lib.define("mul3(Tensor input) -> Tensor")  # should print 'mul3'
18
19
20@impl(my_op_lib, "mul3", dispatch_key="CompositeExplicitAutograd")
21def mul3_impl(a: torch.Tensor) -> torch.Tensor:
22    return a * 3
23
24
25# registering the out variant.
26my_op_lib.define(
27    "mul3.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)"
28)  # should print 'mul3.out'
29
30
31@impl(my_op_lib, "mul3.out", dispatch_key="CompositeExplicitAutograd")
32def mul3_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
33    out.copy_(a)
34    out.mul_(3)
35    return out
36
37
38# example model
39class Model(torch.nn.Module):
40    def forward(self, a):
41        return torch.ops.my_ops.mul3.default(a)
42
43
44def main():
45    m = Model()
46    input = torch.randn(2, 3)
47    # capture and lower
48    model_name = "custom_ops_1"
49    prog = export_to_exec_prog(
50        m,
51        (input,),
52        edge_compile_config=EdgeCompileConfig(_check_ir_validity=False),
53    )
54    save_pte_program(prog, model_name)
55
56
57if __name__ == "__main__":
58    main()
59