xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/adaptive.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/nn/cloneable.h>
4 #include <torch/nn/functional/activation.h>
5 #include <torch/nn/module.h>
6 #include <torch/nn/modules/container/modulelist.h>
7 #include <torch/nn/modules/container/sequential.h>
8 #include <torch/nn/modules/linear.h>
9 #include <torch/nn/options/adaptive.h>
10 
11 namespace torch {
12 namespace nn {
13 
14 /// The output of a single invocation of an AdaptiveLogSoftmaxWithLoss
15 /// module's `forward()` method.
16 struct TORCH_API ASMoutput {
17   ASMoutput(Tensor output_, double loss_);
18 
19   /// Tensor containing computed target log probabilities for each example
20   Tensor output;
21 
22   /// Scalar representing the computed negative log likelihood loss
23   double loss;
24 };
25 
26 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss
27 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
28 
29 /// Efficient softmax approximation as described in
30 /// `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin,
31 /// Moustapha Cissé, David Grangier, and Hervé Jégou.
32 /// See
33 /// https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss
34 /// to learn about the exact behavior of this module.
35 ///
36 /// See the documentation for `torch::nn::AdaptiveLogSoftmaxWithLossOptions`
37 /// class to learn what constructor arguments are supported for this module.
38 ///
39 /// Example:
40 /// ```
41 /// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10,
42 /// {4, 8}).div_value(2.).head_bias(true));
43 /// ```
44 class TORCH_API AdaptiveLogSoftmaxWithLossImpl
45     : public Cloneable<AdaptiveLogSoftmaxWithLossImpl> {
46  public:
AdaptiveLogSoftmaxWithLossImpl(int64_t in_features,int64_t n_classes,std::vector<int64_t> cutoffs)47   AdaptiveLogSoftmaxWithLossImpl(
48       int64_t in_features,
49       int64_t n_classes,
50       std::vector<int64_t> cutoffs)
51       : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions(
52             in_features,
53             n_classes,
54             cutoffs)) {}
55 
56   explicit AdaptiveLogSoftmaxWithLossImpl(
57       AdaptiveLogSoftmaxWithLossOptions options_);
58 
59   ASMoutput forward(const Tensor& input, const Tensor& target);
60 
61   void reset() override;
62 
63   void reset_parameters();
64 
65   /// Pretty prints the `AdaptiveLogSoftmaxWithLoss` module into the given
66   /// `stream`.
67   void pretty_print(std::ostream& stream) const override;
68 
69   /// Given input tensor, and output of `head`, computes the log of the full
70   /// distribution
71   Tensor _get_full_log_prob(const Tensor& input, const Tensor& head_output);
72 
73   /// Computes log probabilities for all n_classes
74   Tensor log_prob(const Tensor& input);
75 
76   /// This is equivalent to `log_pob(input).argmax(1)` but is more efficient in
77   /// some cases
78   Tensor predict(const Tensor& input);
79 
80   /// The options with which this `Module` was constructed
81   AdaptiveLogSoftmaxWithLossOptions options;
82 
83   /// Cutoffs used to assign targets to their buckets. It should be an ordered
84   /// Sequence of integers sorted in the increasing order
85   std::vector<int64_t> cutoffs;
86 
87   int64_t shortlist_size;
88 
89   /// Number of clusters
90   int64_t n_clusters;
91 
92   /// Output size of head classifier
93   int64_t head_size;
94 
95   Linear head = nullptr;
96 
97   ModuleList tail;
98 };
99 
100 /// A `ModuleHolder` subclass for `AdaptiveLogSoftmaxWithLossImpl`.
101 /// See the documentation for `AdaptiveLogSoftmaxWithLossImpl` class to learn
102 /// what methods it provides, and examples of how to use
103 /// `AdaptiveLogSoftmaxWithLoss` with
104 /// `torch::nn::AdaptiveLogSoftmaxWithLossOptions`. See the documentation for
105 /// `ModuleHolder` to learn about PyTorch's module storage semantics.
106 TORCH_MODULE(AdaptiveLogSoftmaxWithLoss);
107 
108 } // namespace nn
109 } // namespace torch
110