xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Distance.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/native/DispatchStub.h>
4 
5 namespace at {
6 class Tensor;
7 
8 namespace native {
9 
10 using pdist_forward_fn = void(*)(Tensor&, const Tensor&, const double p);
11 using pdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
12 using cdist_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const double p);
13 using cdist_backward_fn = void(*)(Tensor&, const Tensor&, const Tensor&, const Tensor&, const double p, const Tensor&);
14 
15 DECLARE_DISPATCH(pdist_forward_fn, pdist_forward_stub);
16 DECLARE_DISPATCH(pdist_backward_fn, pdist_backward_stub);
17 DECLARE_DISPATCH(cdist_fn, cdist_stub);
18 DECLARE_DISPATCH(cdist_backward_fn, cdist_backward_stub);
19 
20 }} // namespace at::native
21