xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_cdist_forward.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10 #include <executorch/kernels/portable/cpu/util/distance_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 namespace native {
16 
17 using exec_aten::optional;
18 using exec_aten::Tensor;
19 
20 namespace {
21 
get_batch_sizes(const Tensor & tensor)22 inline ArrayRef<Tensor::SizesType> get_batch_sizes(const Tensor& tensor) {
23   return {tensor.sizes().data(), tensor.sizes().size() - 2};
24 }
25 
26 template <typename CTYPE, typename Norm>
cdist(const Tensor & x1,const Tensor & x2,Tensor & out,double p)27 void cdist(const Tensor& x1, const Tensor& x2, Tensor& out, double p) {
28   if (out.numel() == 0) {
29     return;
30   }
31 
32   CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
33 
34   // If the last dimension of x1 (which is equal to the last dimension of x2)
35   // has size 0, then the output is filled with 0s.
36   if (x1.numel() == 0) {
37     for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
38       out_data[out_ix] = 0;
39     }
40     return;
41   }
42 
43   const CTYPE* x1_data = x1.const_data_ptr<CTYPE>();
44   const CTYPE* x2_data = x2.const_data_ptr<CTYPE>();
45 
46   const ArrayRef<Tensor::SizesType> x1_batch_sizes = get_batch_sizes(x1);
47   const ArrayRef<Tensor::SizesType> x2_batch_sizes = get_batch_sizes(x2);
48   const ArrayRef<Tensor::SizesType> out_batch_sizes = get_batch_sizes(out);
49 
50   const bool x1_is_broadcasted = !out_batch_sizes.equals(x1_batch_sizes);
51   const bool x2_is_broadcasted = !out_batch_sizes.equals(x2_batch_sizes);
52   const bool any_is_broadcasted = (x1_is_broadcasted || x2_is_broadcasted);
53 
54   size_t out_batch_numel = 1;
55   for (auto i : out_batch_sizes) {
56     out_batch_numel *= i;
57   }
58 
59   size_t P = static_cast<size_t>(x1.size(x1.dim() - 2)); // NOLINT
60   size_t R = static_cast<size_t>(x2.size(x2.dim() - 2)); // NOLINT
61   size_t M = static_cast<size_t>(x1.size(x1.dim() - 1)); // NOLINT
62 
63   size_t x1_inner_size = P * M;
64   size_t x2_inner_size = R * M;
65   size_t out_inner_size = P * R;
66 
67   for (size_t b = 0; b < out_batch_numel; ++b) {
68     size_t x1_base_ix = b * x1_inner_size;
69     size_t x2_base_ix = b * x2_inner_size;
70     size_t out_base_ix = b * out_inner_size;
71 
72     if (any_is_broadcasted) {
73       size_t out_base_coord[kTensorDimensionLimit];
74       delinearize_index(
75           out_base_ix, out, out_base_coord, kTensorDimensionLimit);
76 
77       if (x1_is_broadcasted) {
78         x1_base_ix = linearize_access_indexes(out_base_coord, out.dim(), x1);
79       }
80       if (x2_is_broadcasted) {
81         x2_base_ix = linearize_access_indexes(out_base_coord, out.dim(), x2);
82       }
83     }
84 
85     size_t out_ix = 0;
86     for (size_t i = 0; i < P; ++i) {
87       const CTYPE* row_i = x1_data + x1_base_ix + i * M;
88       for (size_t j = 0; j < R; ++j) {
89         const CTYPE* row_j = x2_data + x2_base_ix + j * M;
90         CTYPE agg = 0;
91         for (size_t k = 0; k < M; ++k) {
92           CTYPE diff = std::abs(row_i[k] - row_j[k]);
93           agg = Norm::reduce(agg, Norm::map(diff, p));
94         }
95         out_data[out_base_ix + out_ix++] = Norm::finish(agg, p);
96       }
97     }
98   }
99 }
100 
101 template <typename CTYPE>
cdist(const Tensor & x1,const Tensor & x2,Tensor & out,double p)102 void cdist(const Tensor& x1, const Tensor& x2, Tensor& out, double p) {
103   if (p == 0.0) {
104     cdist<CTYPE, L0<CTYPE>>(x1, x2, out, p);
105   } else if (p == 1.0) {
106     cdist<CTYPE, L1<CTYPE>>(x1, x2, out, p);
107   } else if (p == 2.0) {
108     cdist<CTYPE, L2<CTYPE>>(x1, x2, out, p);
109   } else if (p == INFINITY) {
110     cdist<CTYPE, Linf<CTYPE>>(x1, x2, out, p);
111   } else {
112     cdist<CTYPE, Lp<CTYPE>>(x1, x2, out, p);
113   }
114 }
115 
116 } // namespace
117 
_cdist_forward_out(KernelRuntimeContext & ctx,const Tensor & x1,const Tensor & x2,double p,optional<int64_t> compute_mode,Tensor & out)118 Tensor& _cdist_forward_out(
119     KernelRuntimeContext& ctx,
120     const Tensor& x1,
121     const Tensor& x2,
122     double p,
123     optional<int64_t> compute_mode,
124     Tensor& out) {
125   (void)ctx;
126 
127   ET_KERNEL_CHECK(
128       ctx, tensors_have_same_dim_order(x1, x2, out), InvalidArgument, out);
129 
130   ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(x1), InvalidArgument, out);
131 
132   ET_KERNEL_CHECK(
133       ctx,
134       check_cdist_args(x1, x2, p, compute_mode, out),
135       InvalidArgument,
136       out);
137 
138   Tensor::SizesType target_sizes[kTensorDimensionLimit];
139   size_t target_ndim = 0;
140 
141   ET_KERNEL_CHECK(
142       ctx,
143       get_broadcast_target_size(
144           {x1.sizes().data(), x1.sizes().size() - 2},
145           {x2.sizes().data(), x2.sizes().size() - 2},
146           target_sizes,
147           kTensorDimensionLimit,
148           &target_ndim) == Error::Ok,
149       InvalidArgument,
150       out);
151 
152   target_ndim += 2;
153   target_sizes[target_ndim - 2] = x1.size(x1.dim() - 2);
154   target_sizes[target_ndim - 1] = x2.size(x2.dim() - 2);
155 
156   ET_KERNEL_CHECK(
157       ctx,
158       resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok,
159       InvalidArgument,
160       out);
161 
162   ScalarType out_type = out.scalar_type();
163   constexpr auto name = "_cdist_forward.out";
164 
165   ET_SWITCH_FLOAT_TYPES(
166       out_type, ctx, name, CTYPE, [&] { cdist<CTYPE>(x1, x2, out, p); });
167 
168   return out;
169 }
170 
171 } // namespace native
172 } // namespace executor
173 } // namespace torch
174