1 #pragma once 2 3 #include <torch/nn/cloneable.h> 4 #include <torch/nn/functional/activation.h> 5 #include <torch/nn/module.h> 6 #include <torch/nn/modules/container/modulelist.h> 7 #include <torch/nn/modules/container/sequential.h> 8 #include <torch/nn/modules/linear.h> 9 #include <torch/nn/options/adaptive.h> 10 11 namespace torch { 12 namespace nn { 13 14 /// The output of a single invocation of an AdaptiveLogSoftmaxWithLoss 15 /// module's `forward()` method. 16 struct TORCH_API ASMoutput { 17 ASMoutput(Tensor output_, double loss_); 18 19 /// Tensor containing computed target log probabilities for each example 20 Tensor output; 21 22 /// Scalar representing the computed negative log likelihood loss 23 double loss; 24 }; 25 26 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AdaptiveLogSoftmaxWithLoss 27 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 28 29 /// Efficient softmax approximation as described in 30 /// `Efficient softmax approximation for GPUs`_ by Edouard Grave, Armand Joulin, 31 /// Moustapha Cissé, David Grangier, and Hervé Jégou. 32 /// See 33 /// https://pytorch.org/docs/main/nn.html#torch.nn.AdaptiveLogSoftmaxWithLoss 34 /// to learn about the exact behavior of this module. 35 /// 36 /// See the documentation for `torch::nn::AdaptiveLogSoftmaxWithLossOptions` 37 /// class to learn what constructor arguments are supported for this module. 38 /// 39 /// Example: 40 /// ``` 41 /// AdaptiveLogSoftmaxWithLoss model(AdaptiveLogSoftmaxWithLossOptions(8, 10, 42 /// {4, 8}).div_value(2.).head_bias(true)); 43 /// ``` 44 class TORCH_API AdaptiveLogSoftmaxWithLossImpl 45 : public Cloneable<AdaptiveLogSoftmaxWithLossImpl> { 46 public: AdaptiveLogSoftmaxWithLossImpl(int64_t in_features,int64_t n_classes,std::vector<int64_t> cutoffs)47 AdaptiveLogSoftmaxWithLossImpl( 48 int64_t in_features, 49 int64_t n_classes, 50 std::vector<int64_t> cutoffs) 51 : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions( 52 in_features, 53 n_classes, 54 cutoffs)) {} 55 56 explicit AdaptiveLogSoftmaxWithLossImpl( 57 AdaptiveLogSoftmaxWithLossOptions options_); 58 59 ASMoutput forward(const Tensor& input, const Tensor& target); 60 61 void reset() override; 62 63 void reset_parameters(); 64 65 /// Pretty prints the `AdaptiveLogSoftmaxWithLoss` module into the given 66 /// `stream`. 67 void pretty_print(std::ostream& stream) const override; 68 69 /// Given input tensor, and output of `head`, computes the log of the full 70 /// distribution 71 Tensor _get_full_log_prob(const Tensor& input, const Tensor& head_output); 72 73 /// Computes log probabilities for all n_classes 74 Tensor log_prob(const Tensor& input); 75 76 /// This is equivalent to `log_pob(input).argmax(1)` but is more efficient in 77 /// some cases 78 Tensor predict(const Tensor& input); 79 80 /// The options with which this `Module` was constructed 81 AdaptiveLogSoftmaxWithLossOptions options; 82 83 /// Cutoffs used to assign targets to their buckets. It should be an ordered 84 /// Sequence of integers sorted in the increasing order 85 std::vector<int64_t> cutoffs; 86 87 int64_t shortlist_size; 88 89 /// Number of clusters 90 int64_t n_clusters; 91 92 /// Output size of head classifier 93 int64_t head_size; 94 95 Linear head = nullptr; 96 97 ModuleList tail; 98 }; 99 100 /// A `ModuleHolder` subclass for `AdaptiveLogSoftmaxWithLossImpl`. 101 /// See the documentation for `AdaptiveLogSoftmaxWithLossImpl` class to learn 102 /// what methods it provides, and examples of how to use 103 /// `AdaptiveLogSoftmaxWithLoss` with 104 /// `torch::nn::AdaptiveLogSoftmaxWithLossOptions`. See the documentation for 105 /// `ModuleHolder` to learn about PyTorch's module storage semantics. 106 TORCH_MODULE(AdaptiveLogSoftmaxWithLoss); 107 108 } // namespace nn 109 } // namespace torch 110