xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_fast_hadamard_transform_aten.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10 #include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
11 
12 #include <torch/library.h>
13 
14 namespace torch::executor::native {
15 namespace {
fast_hadamard_transform_out_no_context(const Tensor & vec,Tensor & out)16 Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) {
17   exec_aten::RuntimeContext context;
18   return fast_hadamard_transform_out(context, vec, out);
19 }
fast_hadamard_transform_aten(const at::Tensor & vec)20 at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) {
21   auto out = at::empty_like(vec);
22   WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1)
23   (vec, out);
24   return out;
25 }
26 } // namespace
27 } // namespace torch::executor::native
28 
TORCH_LIBRARY_FRAGMENT(llama,m)29 TORCH_LIBRARY_FRAGMENT(llama, m) {
30   m.def("fast_hadamard_transform(Tensor mat) -> Tensor");
31   m.def(
32       "fast_hadamard_transform.out(Tensor mat, *, Tensor(a!) out) -> Tensor(a!)");
33 }
34 
TORCH_LIBRARY_IMPL(llama,CompositeExplicitAutograd,m)35 TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
36   m.impl(
37       "fast_hadamard_transform",
38       torch::executor::native::fast_hadamard_transform_aten);
39   m.impl(
40       "fast_hadamard_transform.out",
41       WRAP_TO_ATEN(
42           torch::executor::native::fast_hadamard_transform_out_no_context, 1));
43 }
44