1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/broadcast_util.h>
10 #include <executorch/kernels/portable/cpu/util/distance_util.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16
17 using exec_aten::optional;
18 using exec_aten::Tensor;
19
20 namespace {
21
get_batch_sizes(const Tensor & tensor)22 inline ArrayRef<Tensor::SizesType> get_batch_sizes(const Tensor& tensor) {
23 return {tensor.sizes().data(), tensor.sizes().size() - 2};
24 }
25
26 template <typename CTYPE, typename Norm>
cdist(const Tensor & x1,const Tensor & x2,Tensor & out,double p)27 void cdist(const Tensor& x1, const Tensor& x2, Tensor& out, double p) {
28 if (out.numel() == 0) {
29 return;
30 }
31
32 CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
33
34 // If the last dimension of x1 (which is equal to the last dimension of x2)
35 // has size 0, then the output is filled with 0s.
36 if (x1.numel() == 0) {
37 for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
38 out_data[out_ix] = 0;
39 }
40 return;
41 }
42
43 const CTYPE* x1_data = x1.const_data_ptr<CTYPE>();
44 const CTYPE* x2_data = x2.const_data_ptr<CTYPE>();
45
46 const ArrayRef<Tensor::SizesType> x1_batch_sizes = get_batch_sizes(x1);
47 const ArrayRef<Tensor::SizesType> x2_batch_sizes = get_batch_sizes(x2);
48 const ArrayRef<Tensor::SizesType> out_batch_sizes = get_batch_sizes(out);
49
50 const bool x1_is_broadcasted = !out_batch_sizes.equals(x1_batch_sizes);
51 const bool x2_is_broadcasted = !out_batch_sizes.equals(x2_batch_sizes);
52 const bool any_is_broadcasted = (x1_is_broadcasted || x2_is_broadcasted);
53
54 size_t out_batch_numel = 1;
55 for (auto i : out_batch_sizes) {
56 out_batch_numel *= i;
57 }
58
59 size_t P = static_cast<size_t>(x1.size(x1.dim() - 2)); // NOLINT
60 size_t R = static_cast<size_t>(x2.size(x2.dim() - 2)); // NOLINT
61 size_t M = static_cast<size_t>(x1.size(x1.dim() - 1)); // NOLINT
62
63 size_t x1_inner_size = P * M;
64 size_t x2_inner_size = R * M;
65 size_t out_inner_size = P * R;
66
67 for (size_t b = 0; b < out_batch_numel; ++b) {
68 size_t x1_base_ix = b * x1_inner_size;
69 size_t x2_base_ix = b * x2_inner_size;
70 size_t out_base_ix = b * out_inner_size;
71
72 if (any_is_broadcasted) {
73 size_t out_base_coord[kTensorDimensionLimit];
74 delinearize_index(
75 out_base_ix, out, out_base_coord, kTensorDimensionLimit);
76
77 if (x1_is_broadcasted) {
78 x1_base_ix = linearize_access_indexes(out_base_coord, out.dim(), x1);
79 }
80 if (x2_is_broadcasted) {
81 x2_base_ix = linearize_access_indexes(out_base_coord, out.dim(), x2);
82 }
83 }
84
85 size_t out_ix = 0;
86 for (size_t i = 0; i < P; ++i) {
87 const CTYPE* row_i = x1_data + x1_base_ix + i * M;
88 for (size_t j = 0; j < R; ++j) {
89 const CTYPE* row_j = x2_data + x2_base_ix + j * M;
90 CTYPE agg = 0;
91 for (size_t k = 0; k < M; ++k) {
92 CTYPE diff = std::abs(row_i[k] - row_j[k]);
93 agg = Norm::reduce(agg, Norm::map(diff, p));
94 }
95 out_data[out_base_ix + out_ix++] = Norm::finish(agg, p);
96 }
97 }
98 }
99 }
100
101 template <typename CTYPE>
cdist(const Tensor & x1,const Tensor & x2,Tensor & out,double p)102 void cdist(const Tensor& x1, const Tensor& x2, Tensor& out, double p) {
103 if (p == 0.0) {
104 cdist<CTYPE, L0<CTYPE>>(x1, x2, out, p);
105 } else if (p == 1.0) {
106 cdist<CTYPE, L1<CTYPE>>(x1, x2, out, p);
107 } else if (p == 2.0) {
108 cdist<CTYPE, L2<CTYPE>>(x1, x2, out, p);
109 } else if (p == INFINITY) {
110 cdist<CTYPE, Linf<CTYPE>>(x1, x2, out, p);
111 } else {
112 cdist<CTYPE, Lp<CTYPE>>(x1, x2, out, p);
113 }
114 }
115
116 } // namespace
117
_cdist_forward_out(KernelRuntimeContext & ctx,const Tensor & x1,const Tensor & x2,double p,optional<int64_t> compute_mode,Tensor & out)118 Tensor& _cdist_forward_out(
119 KernelRuntimeContext& ctx,
120 const Tensor& x1,
121 const Tensor& x2,
122 double p,
123 optional<int64_t> compute_mode,
124 Tensor& out) {
125 (void)ctx;
126
127 ET_KERNEL_CHECK(
128 ctx, tensors_have_same_dim_order(x1, x2, out), InvalidArgument, out);
129
130 ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(x1), InvalidArgument, out);
131
132 ET_KERNEL_CHECK(
133 ctx,
134 check_cdist_args(x1, x2, p, compute_mode, out),
135 InvalidArgument,
136 out);
137
138 Tensor::SizesType target_sizes[kTensorDimensionLimit];
139 size_t target_ndim = 0;
140
141 ET_KERNEL_CHECK(
142 ctx,
143 get_broadcast_target_size(
144 {x1.sizes().data(), x1.sizes().size() - 2},
145 {x2.sizes().data(), x2.sizes().size() - 2},
146 target_sizes,
147 kTensorDimensionLimit,
148 &target_ndim) == Error::Ok,
149 InvalidArgument,
150 out);
151
152 target_ndim += 2;
153 target_sizes[target_ndim - 2] = x1.size(x1.dim() - 2);
154 target_sizes[target_ndim - 1] = x2.size(x2.dim() - 2);
155
156 ET_KERNEL_CHECK(
157 ctx,
158 resize_tensor(out, {target_sizes, target_ndim}) == Error::Ok,
159 InvalidArgument,
160 out);
161
162 ScalarType out_type = out.scalar_type();
163 constexpr auto name = "_cdist_forward.out";
164
165 ET_SWITCH_FLOAT_TYPES(
166 out_type, ctx, name, CTYPE, [&] { cdist<CTYPE>(x1, x2, out, p); });
167
168 return out;
169 }
170
171 } // namespace native
172 } // namespace executor
173 } // namespace torch
174