1 #pragma once 2 3 #include <torch/arg.h> 4 #include <torch/csrc/Export.h> 5 #include <torch/types.h> 6 7 namespace torch { 8 namespace nn { 9 10 /// Options for the `AdaptiveLogSoftmaxWithLoss` module. 11 /// 12 /// Example: 13 /// ``` 14 /// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, 15 /// {4, 8}).div_value(2.).head_bias(true)); 16 /// ``` 17 struct TORCH_API AdaptiveLogSoftmaxWithLossOptions { 18 /* implicit */ AdaptiveLogSoftmaxWithLossOptions( 19 int64_t in_features, 20 int64_t n_classes, 21 std::vector<int64_t> cutoffs); 22 23 /// Number of features in the input tensor 24 TORCH_ARG(int64_t, in_features); 25 26 /// Number of classes in the dataset 27 TORCH_ARG(int64_t, n_classes); 28 29 /// Cutoffs used to assign targets to their buckets 30 TORCH_ARG(std::vector<int64_t>, cutoffs); 31 32 /// value used as an exponent to compute sizes of the clusters. Default: 4.0 33 TORCH_ARG(double, div_value) = 4.; 34 35 /// If ``true``, adds a bias term to the 'head' of 36 /// the adaptive softmax. Default: false 37 TORCH_ARG(bool, head_bias) = false; 38 }; 39 40 } // namespace nn 41 } // namespace torch 42