1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/types.h> 6 7 namespace torch { 8 namespace nn { 9 10 /// Options for the `CosineSimilarity` module. 11 /// 12 /// Example: 13 /// ``` 14 /// CosineSimilarity model(CosineSimilarityOptions().dim(0).eps(0.5)); 15 /// ``` 16 struct TORCH_API CosineSimilarityOptions { 17 /// Dimension where cosine similarity is computed. Default: 1 18 TORCH_ARG(int64_t, dim) = 1; 19 /// Small value to avoid division by zero. Default: 1e-8 20 TORCH_ARG(double, eps) = 1e-8; 21 }; 22 23 namespace functional { 24 /// Options for `torch::nn::functional::cosine_similarity`. 25 /// 26 /// See the documentation for `torch::nn::CosineSimilarityOptions` class to 27 /// learn what arguments are supported. 28 /// 29 /// Example: 30 /// ``` 31 /// namespace F = torch::nn::functional; 32 /// F::cosine_similarity(input1, input2, 33 /// F::CosineSimilarityFuncOptions().dim(1)); 34 /// ``` 35 using CosineSimilarityFuncOptions = CosineSimilarityOptions; 36 } // namespace functional 37 38 // ============================================================================ 39 40 /// Options for the `PairwiseDistance` module. 41 /// 42 /// Example: 43 /// ``` 44 /// PairwiseDistance 45 /// model(PairwiseDistanceOptions().p(3).eps(0.5).keepdim(true)); 46 /// ``` 47 struct TORCH_API PairwiseDistanceOptions { 48 /// The norm degree. Default: 2 49 TORCH_ARG(double, p) = 2.0; 50 /// Small value to avoid division by zero. Default: 1e-6 51 TORCH_ARG(double, eps) = 1e-6; 52 /// Determines whether or not to keep the vector dimension. Default: false 53 TORCH_ARG(bool, keepdim) = false; 54 }; 55 56 namespace functional { 57 /// Options for `torch::nn::functional::pairwise_distance`. 58 /// 59 /// See the documentation for `torch::nn::PairwiseDistanceOptions` class to 60 /// learn what arguments are supported. 61 /// 62 /// Example: 63 /// ``` 64 /// namespace F = torch::nn::functional; 65 /// F::pairwise_distance(input1, input2, F::PairwiseDistanceFuncOptions().p(1)); 66 /// ``` 67 using PairwiseDistanceFuncOptions = PairwiseDistanceOptions; 68 } // namespace functional 69 70 } // namespace nn 71 } // namespace torch 72