xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_tril.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <cstring>
12 
13 namespace torch {
14 namespace executor {
15 namespace native {
16 
17 using exec_aten::Scalar;
18 using ScalarType = exec_aten::ScalarType;
19 
20 namespace {
21 
22 /**
23  * Clears `out` by setting all elements to 0.
24  */
clear_out(Tensor & out)25 Tensor& clear_out(Tensor& out) {
26   uint8_t* out_data = out.mutable_data_ptr<uint8_t>();
27   if (out_data != nullptr) {
28     memset(out_data, 0, out.nbytes());
29   }
30   return out;
31 }
32 
33 /**
34  * Applies lower-triangular part of `self` to `out` using parameters defined.
35  * This function is agnostic to whether `self` is a 2D matrix or batch of
36  * matrices.
37  */
38 template <typename CTYPE>
apply_tril(CTYPE * __restrict__ self,CTYPE * __restrict__ out,int64_t diagonal,int64_t num_rows,int64_t num_cols,int64_t row_stride,int64_t col_stride)39 void apply_tril(
40     CTYPE* __restrict__ self,
41     CTYPE* __restrict__ out,
42     int64_t diagonal,
43     int64_t num_rows,
44     int64_t num_cols,
45     int64_t row_stride,
46     int64_t col_stride) {
47   for (int64_t i = 0; i < num_rows; i++) {
48     for (int64_t j = 0; j < std::min(num_cols, i + diagonal + 1); j++) {
49       out[i * row_stride + j * col_stride] =
50           self[i * row_stride + j * col_stride];
51     }
52   }
53 }
54 
55 /**
56  * `tril_out` helper function.
57  */
58 template <typename CTYPE>
tril_kernel(KernelRuntimeContext & ctx,const Tensor & self,int64_t diagonal,const Tensor & out)59 void tril_kernel(
60     KernelRuntimeContext& ctx,
61     const Tensor& self,
62     int64_t diagonal,
63     const Tensor& out) {
64   // Dynamically compute `self` sizes and strides.
65 
66   int64_t ndim = self.dim();
67 
68   ET_KERNEL_CHECK_MSG(
69       ctx,
70       ndim < kTensorDimensionLimit,
71       InvalidArgument,
72       ,
73       "ndim %" PRId64 " >= %zu",
74       ndim,
75       kTensorDimensionLimit);
76 
77   int64_t sizes[kTensorDimensionLimit];
78   int64_t strides[kTensorDimensionLimit];
79 
80   for (size_t i = 0; i < ndim; ++i) {
81     sizes[i] = self.size(i);
82     strides[i] = getTrailingDims(self, static_cast<int64_t>(i));
83   }
84 
85   IntArrayRef sizes_ref(sizes, ndim);
86   IntArrayRef strides_ref(strides, ndim);
87 
88   int64_t num_rows = sizes_ref[ndim - 2];
89   int64_t num_cols = sizes_ref[ndim - 1];
90 
91   // Compute `tril` for a 2D matrix or a batch of matrices. For a batch of
92   // matrices, `batch_size` will be >1, and `apply_tril` will be executed
93   // multiple times, each referencing a multiple of `self_stride`.
94 
95   int64_t batch_size = getLeadingDims(self, ndim - 2);
96   int64_t self_stride =
97       (self.dim() > 2 && strides_ref[ndim - 3] > 0) ? strides_ref[ndim - 3] : 1;
98 
99   auto data_self = self.mutable_data_ptr<CTYPE>();
100   auto data_out = out.mutable_data_ptr<CTYPE>();
101 
102   int64_t row_stride = strides_ref[ndim - 2];
103   int64_t col_stride = strides_ref[ndim - 1];
104 
105   for (int64_t i = 0; i < batch_size; i++) {
106     CTYPE* __restrict__ data_self_ptr = &data_self[i * self_stride];
107     CTYPE* __restrict__ data_out_ptr = &data_out[i * self_stride];
108 
109     apply_tril<CTYPE>(
110         data_self_ptr,
111         data_out_ptr,
112         diagonal,
113         num_rows,
114         num_cols,
115         row_stride,
116         col_stride);
117   }
118 }
119 
120 } // namespace
121 
122 /**
123  * `tril_out` implementation for all dtypes (real + bool). Returns the
124  * lower-triangular part of a 2D matrix or batch of matrices in `out`, where all
125  * other elements are set to 0, by default. Further, `diagonal` controls how the
126  * lower-triangular subset is defined:
127  *    1. `diagonal = 0`: Elements on and below the main diagonal are retained.
128  *    2. `diagonal > 0`: Similar to case (1); additional diagonals above the
129  *       main one are also captured.
130  *    3. `diagonal < 0`: Similar to case (1); additional diagonals below the
131  *       main one are also captured.
132  */
tril_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t diagonal,Tensor & out)133 Tensor& tril_out(
134     KernelRuntimeContext& ctx,
135     const Tensor& self,
136     int64_t diagonal,
137     Tensor& out) {
138   (void)ctx;
139 
140   ET_KERNEL_CHECK(ctx, check_tril_args(self, out), InvalidArgument, out);
141 
142   ET_KERNEL_CHECK(
143       ctx,
144       resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
145       InvalidArgument,
146       out);
147 
148   ET_KERNEL_CHECK(
149       ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
150 
151   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
152 
153   if (self.numel() == 0) {
154     return out;
155   }
156 
157   // Fill `out` with 0s prior to executing tril.
158   clear_out(out);
159 
160   ScalarType out_type = out.scalar_type();
161   ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&]() {
162     tril_kernel<CTYPE>(ctx, self, diagonal, out);
163   });
164 
165   return out;
166 }
167 
168 } // namespace native
169 } // namespace executor
170 } // namespace torch
171