xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_diagonal_copy.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <executorch/runtime/platform/assert.h>
12 #include <cstddef>
13 
14 namespace torch {
15 namespace executor {
16 namespace native {
17 namespace {
18 
19 template <typename CTYPE>
diagonal_copy_impl(const Tensor & in,int64_t offset,int64_t dim1,int64_t dim2,Tensor & out)20 void diagonal_copy_impl(
21     const Tensor& in,
22     int64_t offset,
23     int64_t dim1,
24     int64_t dim2,
25     Tensor& out) {
26   if (out.numel() == 0) {
27     return;
28   }
29 
30   int64_t storage_offset = 0;
31   size_t diag_size = out.size(out.dim() - 1);
32 
33   if (diag_size == 0) {
34     // skip
35   } else if (offset >= 0) {
36     storage_offset += offset * in.strides().at(dim2);
37   } else {
38     storage_offset -= offset * in.strides().at(dim1);
39   }
40 
41   size_t new_ndim = out.dim();
42   int64_t new_sizes[kTensorDimensionLimit];
43   for (size_t i = 0; i < new_ndim; ++i) {
44     new_sizes[i] = out.size(i);
45   }
46 
47   int64_t new_strides[kTensorDimensionLimit];
48   size_t shift = 0;
49   for (size_t d = 0; d < in.dim(); ++d) {
50     if (d == dim1 || d == dim2) {
51       shift++;
52     } else {
53       new_strides[d - shift] = in.strides().at(d);
54     }
55   }
56   new_strides[in.dim() - 2] = in.strides().at(dim1) + in.strides().at(dim2);
57 
58   as_strided_copy<CTYPE>(
59       in, {new_sizes, new_ndim}, {new_strides, new_ndim}, storage_offset, out);
60 }
61 
62 } // namespace
63 
diagonal_copy_out(KernelRuntimeContext & ctx,const Tensor & in,int64_t offset,int64_t dim1,int64_t dim2,Tensor & out)64 Tensor& diagonal_copy_out(
65     KernelRuntimeContext& ctx,
66     const Tensor& in,
67     int64_t offset,
68     int64_t dim1,
69     int64_t dim2,
70     Tensor& out) {
71   (void)ctx;
72 
73   ET_KERNEL_CHECK(
74       ctx, check_diagonal_copy_args(in, dim1, dim2, out), InvalidArgument, out);
75 
76   ET_KERNEL_CHECK(
77       ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
78 
79   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
80 
81   if (dim1 < 0) {
82     dim1 += nonzero_dim(in);
83   }
84   if (dim2 < 0) {
85     dim2 += nonzero_dim(in);
86   }
87 
88   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
89   size_t expected_out_dim = 0;
90   get_diagonal_copy_out_target_size(
91       in, offset, dim1, dim2, expected_out_size, &expected_out_dim);
92 
93   ET_KERNEL_CHECK(
94       ctx,
95       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
96       InvalidArgument,
97       out);
98 
99   constexpr auto name = "diagonal_copy.out";
100 
101   ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
102     diagonal_copy_impl<CTYPE>(in, offset, dim1, dim2, out);
103   });
104 
105   return out;
106 }
107 
108 } // namespace native
109 } // namespace executor
110 } // namespace torch
111