xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/distance.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/nn/modules/distance.h>
2 
3 namespace F = torch::nn::functional;
4 
5 namespace torch {
6 namespace nn {
7 
CosineSimilarityImpl(const CosineSimilarityOptions & options_)8 CosineSimilarityImpl::CosineSimilarityImpl(
9     const CosineSimilarityOptions& options_)
10     : options(options_) {}
11 
reset()12 void CosineSimilarityImpl::reset() {}
13 
pretty_print(std::ostream & stream) const14 void CosineSimilarityImpl::pretty_print(std::ostream& stream) const {
15   stream << std::boolalpha << "torch::nn::CosineSimilarity"
16          << "(dim=" << options.dim() << ", eps=" << options.eps() << ")";
17 }
18 
forward(const Tensor & x1,const Tensor & x2)19 Tensor CosineSimilarityImpl::forward(const Tensor& x1, const Tensor& x2) {
20   return F::detail::cosine_similarity(x1, x2, options.dim(), options.eps());
21 }
22 
23 // ============================================================================
24 
PairwiseDistanceImpl(const PairwiseDistanceOptions & options_)25 PairwiseDistanceImpl::PairwiseDistanceImpl(
26     const PairwiseDistanceOptions& options_)
27     : options(options_) {}
28 
reset()29 void PairwiseDistanceImpl::reset() {}
30 
pretty_print(std::ostream & stream) const31 void PairwiseDistanceImpl::pretty_print(std::ostream& stream) const {
32   stream << std::boolalpha << "torch::nn::PairwiseDistance"
33          << "(p=" << options.p() << ", eps=" << options.eps()
34          << ", keepdim=" << options.keepdim() << ")";
35 }
36 
forward(const Tensor & x1,const Tensor & x2)37 Tensor PairwiseDistanceImpl::forward(const Tensor& x1, const Tensor& x2) {
38   return F::detail::pairwise_distance(
39       x1, x2, options.p(), options.eps(), options.keepdim());
40 }
41 
42 } // namespace nn
43 } // namespace torch
44