xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Distance.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/grad_mode.h>
4 #include <ATen/ExpandUtils.h>
5 #include <ATen/NamedTensorUtils.h>
6 #include <ATen/TensorOperators.h>
7 #include <ATen/native/Distance.h>
8 #include <c10/util/accumulate.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/_cdist_backward_native.h>
15 #include <ATen/ops/_cdist_forward.h>
16 #include <ATen/ops/_cdist_forward_native.h>
17 #include <ATen/ops/_euclidean_dist.h>
18 #include <ATen/ops/_euclidean_dist_native.h>
19 #include <ATen/ops/_pdist_backward_native.h>
20 #include <ATen/ops/_pdist_forward.h>
21 #include <ATen/ops/_pdist_forward_native.h>
22 #include <ATen/ops/cat.h>
23 #include <ATen/ops/cdist_native.h>
24 #include <ATen/ops/cosine_similarity_native.h>
25 #include <ATen/ops/empty.h>
26 #include <ATen/ops/empty_like.h>
27 #include <ATen/ops/linalg_vector_norm.h>
28 #include <ATen/ops/norm.h>
29 #include <ATen/ops/ones_like.h>
30 #include <ATen/ops/pairwise_distance_native.h>
31 #include <ATen/ops/pdist_native.h>
32 #include <ATen/ops/pow.h>
33 #include <ATen/ops/result_type.h>
34 #include <ATen/ops/sum.h>
35 #include <ATen/ops/zeros.h>
36 #include <ATen/ops/zeros_like.h>
37 
38 #include <utility>
39 #endif
40 
41 namespace at::native {
42 
43 DEFINE_DISPATCH(pdist_forward_stub);
44 DEFINE_DISPATCH(pdist_backward_stub);
45 DEFINE_DISPATCH(cdist_stub);
46 DEFINE_DISPATCH(cdist_backward_stub);
47 
pairwise_distance(const Tensor & x1,const Tensor & x2,double p,double eps,bool keepdim)48 Tensor pairwise_distance(const Tensor& x1, const Tensor& x2, double p, double eps, bool keepdim) {
49   // Since either x1 or x2 could be broadcasted
50   auto x1_dim = x1.dim();
51   auto x2_dim = x2.dim();
52   auto output_dim = x1_dim > x2_dim ? x1_dim : x2_dim;
53   auto innermost_dim = output_dim - 1;
54   return at::norm(x1 - x2 + eps, p, innermost_dim, keepdim);
55 }
56 
57 // This is to guarantee that the contiguous memory is passed to the backward pass
pdist(const Tensor & self,const double p)58 Tensor pdist(const Tensor& self, const double p) {
59   TORCH_CHECK(self.dim() == 2,
60       "pdist only supports 2D tensors, got: ", self.dim(), "D");
61   TORCH_CHECK(at::isFloatingType(self.scalar_type()), "pdist only supports floating-point dtypes");
62   TORCH_CHECK(p >= 0, "pdist only supports non-negative p values");
63   return at::_pdist_forward(self.contiguous(), p);
64 }
65 
_euclidean_dist(const Tensor & x1,const Tensor & x2)66 Tensor _euclidean_dist(const Tensor& x1, const Tensor& x2) {
67   /** This function does the fist part of the euclidean distance calculation
68    * We divide it in two steps to simplify dealing with subgradients in the
69    * backward step */
70   Tensor x1_norm = x1.pow(2).sum(-1, true);
71   Tensor x1_pad = at::ones_like(x1_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
72   Tensor x2_norm = x2.pow(2).sum(-1, true);
73   Tensor x2_pad = at::ones_like(x2_norm, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
74   Tensor x1_ = at::cat({x1.mul(-2), std::move(x1_norm), std::move(x1_pad)}, -1);
75   Tensor x2_ = at::cat({x2, std::move(x2_pad), std::move(x2_norm)}, -1);
76   Tensor result = x1_.matmul(x2_.mT());
77   result.clamp_min_(0).sqrt_();
78   return result;
79 }
80 
cdist_impl(const Tensor & x1,const Tensor & x2,const double p,std::optional<int64_t> compute_mode)81 static Tensor cdist_impl(const Tensor& x1, const Tensor& x2, const double p, std::optional<int64_t> compute_mode) {
82   TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
83   auto device1 = x1.device().type();
84   TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
85   auto device2 = x2.device().type();
86   TORCH_CHECK(p >= 0, "cdist only supports non-negative p values");
87   TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
88   // TODO: This is bad; this test should apply universally
89   TORCH_CHECK(!x1.is_cuda() || x1.get_device() == x2.get_device(), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
90   SymInt c1 = x1.sym_size(-1);
91   SymInt c2 = x2.sym_size(-1);
92   // 0 - default value. If p = 2 and r1 > 25 or r2 > 25 (these values are based on performance metrics),
93   // it will try to compute distance using matrix multiplication approach
94   // 1 - force to use matrix multiplication for p = 2
95   // 2 - do not use matrix multiplication for p = 2
96   int64_t mode = compute_mode.value_or(0);
97   TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);
98 
99   SymInt r1 = x1.sym_size(-2);
100   SymInt r2 = x2.sym_size(-2);
101 
102   // See Note [cdist relies on cdist_impl redispatching]
103   // Keep this condition in sync with the condition at the Note
104   if (!(p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25))))) {
105     TORCH_CHECK(device1 == kCPU || device1 == kCUDA || device1 == kXPU, "cdist only supports CPU, XPU and CUDA devices, X1 got: ", device1);
106     TORCH_CHECK(device2 == kCPU || device2 == kCUDA || device2 == kXPU, "cdist only supports CPU, XPU and CUDA devices, X2 got: ", device2);
107   }
108 
109   auto dim1 = x1.dim();
110   auto dim2 = x2.dim();
111 
112   //For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
113   //The last two dimensions will stay the same
114   SymIntArrayRef batch_tensor1(x1.sym_sizes().data(), dim1 - 2);
115   SymIntArrayRef batch_tensor2(x2.sym_sizes().data(), dim2 - 2);
116   std::vector<SymInt> expand_batch_portion = infer_size_symint(batch_tensor1, batch_tensor2);
117   std::vector<SymInt> tensor1_expand_size(expand_batch_portion);
118   tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
119   std::vector<SymInt> tensor2_expand_size(expand_batch_portion);
120   tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
121 
122   const SymInt expand_batch_product = c10::multiply_integers(expand_batch_portion);
123   std::vector<SymInt> tensor1_view{expand_batch_product, r1, c1};
124   std::vector<SymInt> tensor2_view{expand_batch_product, r2, c2};
125 
126   Tensor tensor1_expanded = x1.expand_symint(tensor1_expand_size).contiguous().view_symint(tensor1_view);
127   Tensor tensor2_expanded = x2.expand_symint(tensor2_expand_size).contiguous().view_symint(tensor2_view);
128 
129   std::vector<SymInt> output_shape(std::move(expand_batch_portion));
130   output_shape.insert(output_shape.end(), {r1, r2});
131 
132   Tensor result;
133   if (r1 == 0 || r2 == 0 || expand_batch_product == 0) {
134     result = at::empty_symint(output_shape, x1.options());
135   } else if (c1 == 0) {
136     result = at::zeros_symint(output_shape, x1.options());
137   } else if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
138     // See Note [cdist relies on cdist_impl redispatching]
139     // Keep the condition above in sync with the condition at the Note
140     Tensor dist = (expand_batch_product == 1) ? at::_euclidean_dist(x1, x2) :
141                   at::_euclidean_dist(tensor1_expanded, tensor2_expanded);
142     result = dist.view_symint(output_shape);
143   } else {
144     result = at::empty_symint(output_shape, x1.options());
145     cdist_stub(device1, result, tensor1_expanded, tensor2_expanded, p);
146   }
147   return result;
148 }
149 
cdist(const Tensor & x1,const Tensor & x2,const double p,std::optional<int64_t> compute_mode)150 Tensor cdist(const Tensor& x1, const Tensor& x2, const double p, std::optional<int64_t> compute_mode) {
151   TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
152   TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
153   TORCH_CHECK(x1.sym_size(-1) == x2.sym_size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.sym_size(-1), " X2: ", x2.sym_size(-1));
154   auto maybe_outnames = namedinference::compute_cdist_outnames(x1, x2);
155   auto result = [&]() {
156     NoNamesGuard guard;
157     SymInt r1 = x1.sym_size(-2);
158     SymInt r2 = x2.sym_size(-2);
159     // Special case for empty input: always call the version with explicit autograd to ensure the graph is properly connected
160     if (x1.sym_numel() == 0 || x2.sym_numel() == 0) {
161         return at::_cdist_forward(x1, x2, p, compute_mode);
162     }
163     int64_t mode = compute_mode.value_or(0);
164     // Note [cdist relies on cdist_impl redispatching]
165     // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
166     // This is for pytorch to figure the backward pass itself
167     // when p=2.  Keep this condition in sync with the See Note reference
168     if (p == 2 && (mode == 1 || (mode == 0 && (r1 > 25 || r2 > 25)))) {
169         return cdist_impl(x1, x2, p, compute_mode);
170     } else {
171         return at::_cdist_forward(x1, x2, p, compute_mode);
172     }
173   }();
174   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
175   return result;
176 }
177 
_cdist_forward(const Tensor & x1,const Tensor & x2,const double p,std::optional<int64_t> compute_mode)178 Tensor _cdist_forward(const Tensor& x1, const Tensor& x2, const double p, std::optional<int64_t> compute_mode) {
179   TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
180   TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
181   TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
182   auto maybe_outnames = namedinference::compute_cdist_outnames(x1, x2);
183   auto result = [&]() {
184     NoNamesGuard guard;
185     return cdist_impl(x1, x2, p, compute_mode);
186   }();
187   namedinference::propagate_names_if_nonempty(result, maybe_outnames);
188   return result;
189 }
190 
_cdist_backward(const Tensor & _grad,const Tensor & _x1,const Tensor & _x2,const double p,const Tensor & _cdist)191 Tensor _cdist_backward(const Tensor& _grad, const Tensor& _x1, const Tensor& _x2, const double p, const Tensor& _cdist) {
192   // Broadcasting might generate non-contiguous Tensors, so handle it before doing checks
193   int64_t c1 = _x1.size(-1);
194   int64_t c2 = _x2.size(-1);
195   int64_t r1 = _x1.size(-2);
196   int64_t r2 = _x2.size(-2);
197   auto dim1 = _x1.dim();
198   auto dim2 = _x2.dim();
199   IntArrayRef batch_tensor1(_x1.sizes().data(), dim1 - 2);
200   IntArrayRef batch_tensor2(_x2.sizes().data(), dim2 - 2);
201   std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
202   std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
203   tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
204   std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
205   tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
206 
207   // Compute the linearized batch size
208   const int64_t batch_product = c10::multiply_integers(expand_batch_portion);
209 
210   // Gracefully handle empty Tensors
211   if (r1 == 0 || r2 == 0 || c1 == 0 || batch_product == 0) {
212     return at::zeros_like(_x1, _x1.options());
213   }
214 
215   Tensor x1 = _x1;
216   if (tensor1_expand_size != x1.sizes()) {
217     x1 = x1.expand(tensor1_expand_size);
218   }
219   Tensor x2 = _x2;
220   if (tensor2_expand_size != x2.sizes()) {
221     x2 = x2.expand(tensor2_expand_size);
222   }
223 
224   x1 = x1.contiguous();
225   x2 = x2.contiguous();
226   auto cdist = _cdist.contiguous();
227   auto grad = _grad.contiguous();
228   int64_t n = x1.size(-2);
229   int64_t m = x1.size(-1);
230   auto device1 = x1.device().type();
231   TORCH_CHECK(device1 == kCPU || device1 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X1 got: ", device1);
232   auto device2 = x2.device().type();
233   TORCH_CHECK(device2 == kCPU || device2 == kCUDA, "_cdist_backward only supports CPU and CUDA devices, X2 got: ", device2);
234 
235   Tensor grad_x1 =
236       at::empty({batch_product, n, m}, x1.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
237   cdist_backward_stub(device1, grad_x1, grad, x1, x2, p, cdist);
238 
239   // Use x1.size() here and not the original size of _x1.size() as this gradient is not taking broadcasting into account
240   // Broadcasting will be handled automatically by the autograd engine
241   return grad_x1.view(x1.sizes());
242 }
243 
_pdist_forward(const Tensor & self,const double p)244 Tensor _pdist_forward(const Tensor& self, const double p) {
245   TORCH_CHECK(self.is_contiguous(), "_pdist_forward requires contiguous input");
246   auto device = self.device().type();
247   TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_forward only supports CPU and CUDA devices, got: ", device);
248   Tensor result = at::empty({0}, self.options(), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
249   if (self.size(0) <= 1) {
250     result.resize_({0});
251   } else {
252     int64_t n = self.size(0);
253     int64_t c = n * (n - 1) / 2;
254     result.resize_({c});
255     if (self.size(1) == 0) {
256       result.fill_(0);
257     } else {
258       pdist_forward_stub(device, result, self, p);
259     }
260   }
261   return result;
262 }
263 
_pdist_backward(const Tensor & grad,const Tensor & self,const double p,const Tensor & pdist)264 Tensor _pdist_backward(const Tensor& grad, const Tensor& self, const double p, const Tensor& pdist) {
265   TORCH_CHECK(self.is_contiguous(), "_pdist_backward requires self to be contiguous");
266   TORCH_CHECK(pdist.is_contiguous(), "_pdist_backward requires pdist to be contiguous");
267   auto device = self.device().type();
268   TORCH_CHECK(device == kCPU || device == kCUDA, "_pdist_backward only supports CPU and CUDA devices, got: ", device);
269   Tensor result = at::empty_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
270   pdist_backward_stub(device, result, grad, self, p, pdist);
271   return result;
272 }
273 
cosine_similarity(const Tensor & x1_,const Tensor & x2_,int64_t dim,double eps)274 Tensor cosine_similarity(const Tensor& x1_, const Tensor& x2_, int64_t dim, double eps) {
275     /*
276    * cosine_similarity(x1, x2) = <x1, x2> / (||x1|| * ||x2||)
277    *
278    * The current implementation is an improvement over the previous version.
279    *
280    * Previous implementation:
281    * 1. Compute num = <x1, x2>,
282    * 2. Compute denom = ||x1|| * ||x2||,
283    * 3. Compute denom = max(denom, eps) to avoid division by zero,
284    * 4. Return num / denom.
285    *
286    * Previous implementation has the following issues:
287    * 1. Chance of losing precision in <x1, x2> when ||x1|| and ||x2|| are large.
288    * 2. Chance of losing precision in ||x1|| * ||x2|| when ||x1|| and ||x2|| are large.
289    * 3. Losing precision may cause |cosing_similarity(x1, x2)| > 1.0.
290    *
291    * Current implementation:
292    * 1. Compute x1_normalized = x1 / max(||x1||, eps),
293    *            x2_normalized = x2 / max(||x2||, eps),
294    * 2. Return <x1_normalized, x2_normalized>.
295    *
296    * The current implementation improves over the previous one by:
297    * 1. Making sure that <x1, x2> and ||x1|| * ||x2|| are not computed explicitly,
298    *    hence avoiding floating point overflows.
299    * 2. Both methods might have issues with computing ||x1|| and ||x2||, but for
300    *    the current method this is the only source of the floating point imprecision.
301    * 3. Makes sure |cosing_similarity(x1, x2)| <= 1.0.
302    *
303    */
304 
305   auto commonDtype = at::result_type(x1_, x2_);
306   TORCH_CHECK(at::isFloatingType(commonDtype), "expected common dtype to be floating point, yet common dtype is ", commonDtype);
307 
308   // We accept integral types (and bools lol) but vector_norm does not
309   auto x1_is_int = c10::isIntegralType(x1_.scalar_type(), /*încludeBool=*/true);
310   auto x2_is_int = c10::isIntegralType(x2_.scalar_type(), /*încludeBool=*/true);
311   auto x1_t = x1_is_int ? x1_.to(commonDtype) : x1_;
312   auto x2_t = x2_is_int ? x2_.to(commonDtype) : x2_;
313   auto [x1, x2] = expand_outplace(x1_t, x2_t);
314 
315 
316   // We want to divide each tensor by its norm first, as it's more numerically stable.
317   // This keeps the result between -1.0 and 1.0
318   // We clone them, as we're going to modify them in-place
319   // This allows the gradients to propagate properly all the way to x1 and x2
320   auto x1_norm = at::linalg_vector_norm(*x1, 2, /*dim=*/dim, /*keepdim=*/true).clone();
321   auto x2_norm = at::linalg_vector_norm(*x2, 2, /*dim=*/dim, /*keepdim=*/true).clone();
322 
323   {
324     at::NoGradGuard guard;
325     x1_norm.clamp_min_(eps);
326     x2_norm.clamp_min_(eps);
327   }
328 
329   return ((*x1 / x1_norm) * (*x2 / x2_norm)).sum(dim);
330 }
331 
332 }  // namespace at::native
333