1 #include <torch/nn/modules/distance.h>
2
3 namespace F = torch::nn::functional;
4
5 namespace torch {
6 namespace nn {
7
CosineSimilarityImpl(const CosineSimilarityOptions & options_)8 CosineSimilarityImpl::CosineSimilarityImpl(
9 const CosineSimilarityOptions& options_)
10 : options(options_) {}
11
reset()12 void CosineSimilarityImpl::reset() {}
13
pretty_print(std::ostream & stream) const14 void CosineSimilarityImpl::pretty_print(std::ostream& stream) const {
15 stream << std::boolalpha << "torch::nn::CosineSimilarity"
16 << "(dim=" << options.dim() << ", eps=" << options.eps() << ")";
17 }
18
forward(const Tensor & x1,const Tensor & x2)19 Tensor CosineSimilarityImpl::forward(const Tensor& x1, const Tensor& x2) {
20 return F::detail::cosine_similarity(x1, x2, options.dim(), options.eps());
21 }
22
23 // ============================================================================
24
PairwiseDistanceImpl(const PairwiseDistanceOptions & options_)25 PairwiseDistanceImpl::PairwiseDistanceImpl(
26 const PairwiseDistanceOptions& options_)
27 : options(options_) {}
28
reset()29 void PairwiseDistanceImpl::reset() {}
30
pretty_print(std::ostream & stream) const31 void PairwiseDistanceImpl::pretty_print(std::ostream& stream) const {
32 stream << std::boolalpha << "torch::nn::PairwiseDistance"
33 << "(p=" << options.p() << ", eps=" << options.eps()
34 << ", keepdim=" << options.keepdim() << ")";
35 }
36
forward(const Tensor & x1,const Tensor & x2)37 Tensor PairwiseDistanceImpl::forward(const Tensor& x1, const Tensor& x2) {
38 return F::detail::pairwise_distance(
39 x1, x2, options.p(), options.eps(), options.keepdim());
40 }
41
42 } // namespace nn
43 } // namespace torch
44