1 #include <c10/util/irange.h>
2 #include <torch/nn/modules/adaptive.h>
3 #include <torch/nn/options/activation.h>
4 #include <torch/nn/options/linear.h>
5
6 namespace F = torch::nn::functional;
7
8 using namespace torch::indexing;
9
10 namespace torch {
11 namespace nn {
12
ASMoutput(Tensor output_,double loss_)13 ASMoutput::ASMoutput(Tensor output_, double loss_)
14 : output(std::move(output_)), loss(loss_) {}
15
AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)16 AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl(
17 AdaptiveLogSoftmaxWithLossOptions options_)
18 : options(std::move(options_)),
19 shortlist_size(0),
20 n_clusters(0),
21 head_size(0) {
22 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
23 reset();
24 }
25
reset()26 void AdaptiveLogSoftmaxWithLossImpl::reset() {
27 TORCH_CHECK(
28 options.cutoffs().size() > 0,
29 "cutoffs should be a sequence of length larger than 0");
30 TORCH_CHECK(
31 std::is_sorted(options.cutoffs().begin(), options.cutoffs().end()) &&
32 *std::min_element(
33 options.cutoffs().begin(), options.cutoffs().end()) > 0 &&
34 *std::max_element(
35 options.cutoffs().begin(), options.cutoffs().end()) <=
36 (options.n_classes() - 1) &&
37 std::set<int64_t>(options.cutoffs().begin(), options.cutoffs().end())
38 .size() == options.cutoffs().size(),
39 "cutoffs should be a sequence of unique, positive integers sorted in an increasing order, ",
40 "where each value is between 1 and n_classes-1");
41 TORCH_CHECK(options.div_value() != 0, "div_value should not be equal to 0");
42
43 cutoffs = options.cutoffs();
44 cutoffs.push_back(options.n_classes());
45
46 shortlist_size = cutoffs[0];
47 n_clusters = cutoffs.size() - 1;
48 head_size = shortlist_size + n_clusters;
49
50 head = this->register_module(
51 "head",
52 Linear(LinearOptions(options.in_features(), head_size)
53 .bias(options.head_bias())));
54 tail = this->register_module("tail", ModuleList());
55
56 for (const auto i : c10::irange(n_clusters)) {
57 int64_t hsz = static_cast<int64_t>(std::floor(
58 options.in_features() / std::pow(options.div_value(), (i + 1))));
59 int64_t osz = cutoffs[i + 1] - cutoffs[i];
60
61 Sequential projection(
62 Linear(LinearOptions(options.in_features(), hsz).bias(false)),
63 Linear(LinearOptions(hsz, osz).bias(false)));
64 tail->push_back(projection);
65 }
66 }
67
reset_parameters()68 void AdaptiveLogSoftmaxWithLossImpl::reset_parameters() {
69 head->reset_parameters();
70 for (const auto i : c10::irange(tail->size())) {
71 auto i2h = tail[i]->children()[0]->as<Linear>();
72 auto h2o = tail[i]->children()[1]->as<Linear>();
73 i2h->reset_parameters();
74 h2o->reset_parameters();
75 }
76 }
77
forward(const Tensor & input_,const Tensor & target_)78 ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(
79 const Tensor& input_,
80 const Tensor& target_) {
81 auto targ_dim = target_.dim();
82
83 TORCH_CHECK(
84 targ_dim == 1 || targ_dim == 0,
85 "0D or 1D target tensor expected, multi-target not supported");
86
87 if (targ_dim == 1) {
88 TORCH_CHECK(
89 input_.dim() == 2,
90 "1D target tensor expects 2D input tensors, but found inputs with sizes ",
91 input_.sizes(),
92 ".");
93 } else {
94 TORCH_CHECK(
95 input_.dim() == 1,
96 "0D target tensor expects 1D input tensors, but found inputs with sizes ",
97 input_.sizes(),
98 ".");
99 }
100
101 bool is_batched = (targ_dim > 0);
102 Tensor input = is_batched ? input_ : input_.unsqueeze(0);
103 Tensor target = is_batched ? target_ : target_.unsqueeze(0);
104
105 int64_t used_rows = 0;
106 const int64_t batch_size = target.size(0);
107
108 Tensor output = input.new_zeros(batch_size);
109 Tensor gather_inds = target.new_empty(batch_size);
110
111 auto cutoff_values = cutoffs;
112 cutoff_values.insert(cutoff_values.begin(), 0);
113
114 for (const auto i : c10::irange(cutoff_values.size() - 1)) {
115 int64_t low_idx = cutoff_values[i];
116 int64_t high_idx = cutoff_values[i + 1];
117
118 const Tensor target_mask = (target >= low_idx) * (target < high_idx);
119 const Tensor row_indices = target_mask.nonzero().squeeze();
120
121 if (row_indices.numel() == 0) {
122 continue;
123 }
124
125 if (i == 0) {
126 gather_inds.index_copy_(0, row_indices, target.index({target_mask}));
127 } else {
128 Tensor relative_target = target.index({target_mask}) - low_idx;
129 Tensor input_subset = input.index_select(0, row_indices);
130
131 const Tensor cluster_output =
132 tail[i - 1]->as<Sequential>()->forward(input_subset);
133 int64_t cluster_index = shortlist_size + i - 1;
134
135 gather_inds.index_fill_(0, row_indices, cluster_index);
136
137 const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
138 const Tensor local_logprob =
139 cluster_logprob.gather(1, relative_target.unsqueeze(1));
140 output.index_copy_(0, row_indices, local_logprob.squeeze(1));
141 }
142
143 used_rows += row_indices.numel();
144 }
145
146 TORCH_CHECK(
147 used_rows == batch_size,
148 "Target values should be in [0, ",
149 options.n_classes() - 1,
150 "], "
151 "but values in range [",
152 target.min().item().toDouble(),
153 ", ",
154 target.max().item().toDouble(),
155 "] "
156 "were found. ");
157
158 const Tensor head_output = head(input);
159 const Tensor head_logprob = F::log_softmax(head_output, 1);
160 output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze();
161 const double loss = (-output).mean().item().toDouble();
162
163 if (!is_batched) {
164 output = output.squeeze(0);
165 }
166
167 return ASMoutput(output, loss);
168 }
169
_get_full_log_prob(const Tensor & input,const Tensor & head_output)170 Tensor AdaptiveLogSoftmaxWithLossImpl::_get_full_log_prob(
171 const Tensor& input,
172 const Tensor& head_output) {
173 Tensor out = input.new_empty({head_output.size(0), options.n_classes()});
174 const Tensor head_logprob = F::log_softmax(head_output, 1);
175
176 out.index_put_(
177 {Slice(), Slice(None, shortlist_size)},
178 head_logprob.index({Slice(), Slice(None, shortlist_size)}));
179
180 for (const auto i : c10::irange(cutoffs.size() - 1)) {
181 int64_t start_idx = cutoffs[i];
182 int64_t stop_idx = cutoffs[i + 1];
183 const Tensor cluster_output = tail[i]->as<Sequential>()->forward(input);
184 const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
185 auto output_logprob = cluster_logprob +
186 head_logprob.index({Slice(), static_cast<int64_t>(shortlist_size + i)})
187 .unsqueeze(1);
188
189 out.index_put_({Slice(), Slice(start_idx, stop_idx)}, output_logprob);
190 }
191 return out;
192 }
193
log_prob(const Tensor & input)194 Tensor AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl::log_prob(
195 const Tensor& input) {
196 const Tensor head_output = head(input);
197 return _get_full_log_prob(input, head_output);
198 }
199
predict(const Tensor & input)200 Tensor AdaptiveLogSoftmaxWithLossImpl::predict(const Tensor& input) {
201 const Tensor head_output = head(input);
202 Tensor output = torch::argmax(head_output, 1);
203 auto not_in_shortlist = (output >= shortlist_size);
204 auto all_in_shortlist = bitwise_not(not_in_shortlist.any());
205
206 if (all_in_shortlist.item().toBool()) {
207 return output;
208 } else if (not_in_shortlist.all().item().toBool()) {
209 const Tensor log_prob = _get_full_log_prob(input, head_output);
210 return torch::argmax(log_prob, 1);
211 } else {
212 const Tensor log_prob = _get_full_log_prob(
213 input.index({not_in_shortlist}), head_output.index({not_in_shortlist}));
214 output.index_put_({not_in_shortlist}, torch::argmax(log_prob, 1));
215 return output;
216 }
217 }
218
pretty_print(std::ostream & stream) const219 void AdaptiveLogSoftmaxWithLossImpl::pretty_print(std::ostream& stream) const {
220 stream << "torch::nn::AdaptiveLogSoftmaxWithLoss";
221 }
222
223 } // namespace nn
224 } // namespace torch
225