xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/options/adaptive.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/arg.h>
4 #include <torch/csrc/Export.h>
5 #include <torch/types.h>
6 
7 namespace torch {
8 namespace nn {
9 
10 /// Options for the `AdaptiveLogSoftmaxWithLoss` module.
11 ///
12 /// Example:
13 /// ```
14 /// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10,
15 /// {4, 8}).div_value(2.).head_bias(true));
16 /// ```
17 struct TORCH_API AdaptiveLogSoftmaxWithLossOptions {
18   /* implicit */ AdaptiveLogSoftmaxWithLossOptions(
19       int64_t in_features,
20       int64_t n_classes,
21       std::vector<int64_t> cutoffs);
22 
23   /// Number of features in the input tensor
24   TORCH_ARG(int64_t, in_features);
25 
26   /// Number of classes in the dataset
27   TORCH_ARG(int64_t, n_classes);
28 
29   /// Cutoffs used to assign targets to their buckets
30   TORCH_ARG(std::vector<int64_t>, cutoffs);
31 
32   /// value used as an exponent to compute sizes of the clusters. Default: 4.0
33   TORCH_ARG(double, div_value) = 4.;
34 
35   /// If ``true``, adds a bias term to the 'head' of
36   /// the adaptive softmax. Default: false
37   TORCH_ARG(bool, head_bias) = false;
38 };
39 
40 } // namespace nn
41 } // namespace torch
42