xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/distance.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 
10 /// Options for the `CosineSimilarity` module.
11 ///
12 /// Example:
13 /// ```
14 /// CosineSimilarity model(CosineSimilarityOptions().dim(0).eps(0.5));
15 /// ```
16 struct TORCH_API CosineSimilarityOptions {
17   /// Dimension where cosine similarity is computed. Default: 1
18   TORCH_ARG(int64_t, dim) = 1;
19   /// Small value to avoid division by zero. Default: 1e-8
20   TORCH_ARG(double, eps) = 1e-8;
21 };
22 
23 namespace functional {
24 /// Options for `torch::nn::functional::cosine_similarity`.
25 ///
26 /// See the documentation for `torch::nn::CosineSimilarityOptions` class to
27 /// learn what arguments are supported.
28 ///
29 /// Example:
30 /// ```
31 /// namespace F = torch::nn::functional;
32 /// F::cosine_similarity(input1, input2,
33 /// F::CosineSimilarityFuncOptions().dim(1));
34 /// ```
35 using CosineSimilarityFuncOptions = CosineSimilarityOptions;
36 } // namespace functional
37 
38 // ============================================================================
39 
40 /// Options for the `PairwiseDistance` module.
41 ///
42 /// Example:
43 /// ```
44 /// PairwiseDistance
45 /// model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true));
46 /// ```
47 struct TORCH_API PairwiseDistanceOptions {
48   /// The norm degree. Default: 2
49   TORCH_ARG(double, p) = 2.0;
50   /// Small value to avoid division by zero. Default: 1e-6
51   TORCH_ARG(double, eps) = 1e-6;
52   /// Determines whether or not to keep the vector dimension. Default: false
53   TORCH_ARG(bool, keepdim) = false;
54 };
55 
56 namespace functional {
57 /// Options for `torch::nn::functional::pairwise_distance`.
58 ///
59 /// See the documentation for `torch::nn::PairwiseDistanceOptions` class to
60 /// learn what arguments are supported.
61 ///
62 /// Example:
63 /// ```
64 /// namespace F = torch::nn::functional;
65 /// F::pairwise_distance(input1, input2, F::PairwiseDistanceFuncOptions().p(1));
66 /// ```
67 using PairwiseDistanceFuncOptions = PairwiseDistanceOptions;
68 } // namespace functional
69 
70 } // namespace nn
71 } // namespace torch
72