1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
10 #include <executorch/extension/llm/custom_ops/op_fast_hadamard_transform.h>
11
12 #include <torch/library.h>
13
14 namespace torch::executor::native {
15 namespace {
fast_hadamard_transform_out_no_context(const Tensor & vec,Tensor & out)16 Tensor& fast_hadamard_transform_out_no_context(const Tensor& vec, Tensor& out) {
17 exec_aten::RuntimeContext context;
18 return fast_hadamard_transform_out(context, vec, out);
19 }
fast_hadamard_transform_aten(const at::Tensor & vec)20 at::Tensor fast_hadamard_transform_aten(const at::Tensor& vec) {
21 auto out = at::empty_like(vec);
22 WRAP_TO_ATEN(fast_hadamard_transform_out_no_context, 1)
23 (vec, out);
24 return out;
25 }
26 } // namespace
27 } // namespace torch::executor::native
28
TORCH_LIBRARY_FRAGMENT(llama,m)29 TORCH_LIBRARY_FRAGMENT(llama, m) {
30 m.def("fast_hadamard_transform(Tensor mat) -> Tensor");
31 m.def(
32 "fast_hadamard_transform.out(Tensor mat, *, Tensor(a!) out) -> Tensor(a!)");
33 }
34
TORCH_LIBRARY_IMPL(llama,CompositeExplicitAutograd,m)35 TORCH_LIBRARY_IMPL(llama, CompositeExplicitAutograd, m) {
36 m.impl(
37 "fast_hadamard_transform",
38 torch::executor::native::fast_hadamard_transform_aten);
39 m.impl(
40 "fast_hadamard_transform.out",
41 WRAP_TO_ATEN(
42 torch::executor::native::fast_hadamard_transform_out_no_context, 1));
43 }
44