1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <cstring>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16
17 using exec_aten::Scalar;
18 using ScalarType = exec_aten::ScalarType;
19
20 namespace {
21
22 /**
23 * Clears `out` by setting all elements to 0.
24 */
clear_out(Tensor & out)25 Tensor& clear_out(Tensor& out) {
26 uint8_t* out_data = out.mutable_data_ptr<uint8_t>();
27 if (out_data != nullptr) {
28 memset(out_data, 0, out.nbytes());
29 }
30 return out;
31 }
32
33 /**
34 * Applies lower-triangular part of `self` to `out` using parameters defined.
35 * This function is agnostic to whether `self` is a 2D matrix or batch of
36 * matrices.
37 */
38 template <typename CTYPE>
apply_tril(CTYPE * __restrict__ self,CTYPE * __restrict__ out,int64_t diagonal,int64_t num_rows,int64_t num_cols,int64_t row_stride,int64_t col_stride)39 void apply_tril(
40 CTYPE* __restrict__ self,
41 CTYPE* __restrict__ out,
42 int64_t diagonal,
43 int64_t num_rows,
44 int64_t num_cols,
45 int64_t row_stride,
46 int64_t col_stride) {
47 for (int64_t i = 0; i < num_rows; i++) {
48 for (int64_t j = 0; j < std::min(num_cols, i + diagonal + 1); j++) {
49 out[i * row_stride + j * col_stride] =
50 self[i * row_stride + j * col_stride];
51 }
52 }
53 }
54
55 /**
56 * `tril_out` helper function.
57 */
58 template <typename CTYPE>
tril_kernel(KernelRuntimeContext & ctx,const Tensor & self,int64_t diagonal,const Tensor & out)59 void tril_kernel(
60 KernelRuntimeContext& ctx,
61 const Tensor& self,
62 int64_t diagonal,
63 const Tensor& out) {
64 // Dynamically compute `self` sizes and strides.
65
66 int64_t ndim = self.dim();
67
68 ET_KERNEL_CHECK_MSG(
69 ctx,
70 ndim < kTensorDimensionLimit,
71 InvalidArgument,
72 ,
73 "ndim %" PRId64 " >= %zu",
74 ndim,
75 kTensorDimensionLimit);
76
77 int64_t sizes[kTensorDimensionLimit];
78 int64_t strides[kTensorDimensionLimit];
79
80 for (size_t i = 0; i < ndim; ++i) {
81 sizes[i] = self.size(i);
82 strides[i] = getTrailingDims(self, static_cast<int64_t>(i));
83 }
84
85 IntArrayRef sizes_ref(sizes, ndim);
86 IntArrayRef strides_ref(strides, ndim);
87
88 int64_t num_rows = sizes_ref[ndim - 2];
89 int64_t num_cols = sizes_ref[ndim - 1];
90
91 // Compute `tril` for a 2D matrix or a batch of matrices. For a batch of
92 // matrices, `batch_size` will be >1, and `apply_tril` will be executed
93 // multiple times, each referencing a multiple of `self_stride`.
94
95 int64_t batch_size = getLeadingDims(self, ndim - 2);
96 int64_t self_stride =
97 (self.dim() > 2 && strides_ref[ndim - 3] > 0) ? strides_ref[ndim - 3] : 1;
98
99 auto data_self = self.mutable_data_ptr<CTYPE>();
100 auto data_out = out.mutable_data_ptr<CTYPE>();
101
102 int64_t row_stride = strides_ref[ndim - 2];
103 int64_t col_stride = strides_ref[ndim - 1];
104
105 for (int64_t i = 0; i < batch_size; i++) {
106 CTYPE* __restrict__ data_self_ptr = &data_self[i * self_stride];
107 CTYPE* __restrict__ data_out_ptr = &data_out[i * self_stride];
108
109 apply_tril<CTYPE>(
110 data_self_ptr,
111 data_out_ptr,
112 diagonal,
113 num_rows,
114 num_cols,
115 row_stride,
116 col_stride);
117 }
118 }
119
120 } // namespace
121
122 /**
123 * `tril_out` implementation for all dtypes (real + bool). Returns the
124 * lower-triangular part of a 2D matrix or batch of matrices in `out`, where all
125 * other elements are set to 0, by default. Further, `diagonal` controls how the
126 * lower-triangular subset is defined:
127 * 1. `diagonal = 0`: Elements on and below the main diagonal are retained.
128 * 2. `diagonal > 0`: Similar to case (1); additional diagonals above the
129 * main one are also captured.
130 * 3. `diagonal < 0`: Similar to case (1); additional diagonals below the
131 * main one are also captured.
132 */
tril_out(KernelRuntimeContext & ctx,const Tensor & self,int64_t diagonal,Tensor & out)133 Tensor& tril_out(
134 KernelRuntimeContext& ctx,
135 const Tensor& self,
136 int64_t diagonal,
137 Tensor& out) {
138 (void)ctx;
139
140 ET_KERNEL_CHECK(ctx, check_tril_args(self, out), InvalidArgument, out);
141
142 ET_KERNEL_CHECK(
143 ctx,
144 resize_tensor(out, self.sizes()) == torch::executor::Error::Ok,
145 InvalidArgument,
146 out);
147
148 ET_KERNEL_CHECK(
149 ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out);
150
151 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out);
152
153 if (self.numel() == 0) {
154 return out;
155 }
156
157 // Fill `out` with 0s prior to executing tril.
158 clear_out(out);
159
160 ScalarType out_type = out.scalar_type();
161 ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE, [&]() {
162 tril_kernel<CTYPE>(ctx, self, diagonal, out);
163 });
164
165 return out;
166 }
167
168 } // namespace native
169 } // namespace executor
170 } // namespace torch
171