xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/nn/modules/adaptive.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/nn/modules/adaptive.h>
3 #include <torch/nn/options/activation.h>
4 #include <torch/nn/options/linear.h>
5 
6 namespace F = torch::nn::functional;
7 
8 using namespace torch::indexing;
9 
10 namespace torch {
11 namespace nn {
12 
ASMoutput(Tensor output_,double loss_)13 ASMoutput::ASMoutput(Tensor output_, double loss_)
14     : output(std::move(output_)), loss(loss_) {}
15 
AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions options_)16 AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl(
17     AdaptiveLogSoftmaxWithLossOptions options_)
18     : options(std::move(options_)),
19       shortlist_size(0),
20       n_clusters(0),
21       head_size(0) {
22   // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
23   reset();
24 }
25 
reset()26 void AdaptiveLogSoftmaxWithLossImpl::reset() {
27   TORCH_CHECK(
28       options.cutoffs().size() > 0,
29       "cutoffs should be a sequence of length larger than 0");
30   TORCH_CHECK(
31       std::is_sorted(options.cutoffs().begin(), options.cutoffs().end()) &&
32           *std::min_element(
33               options.cutoffs().begin(), options.cutoffs().end()) > 0 &&
34           *std::max_element(
35               options.cutoffs().begin(), options.cutoffs().end()) <=
36               (options.n_classes() - 1) &&
37           std::set<int64_t>(options.cutoffs().begin(), options.cutoffs().end())
38                   .size() == options.cutoffs().size(),
39       "cutoffs should be a sequence of unique, positive integers sorted in an increasing order, ",
40       "where each value is between 1 and n_classes-1");
41   TORCH_CHECK(options.div_value() != 0, "div_value should not be equal to 0");
42 
43   cutoffs = options.cutoffs();
44   cutoffs.push_back(options.n_classes());
45 
46   shortlist_size = cutoffs[0];
47   n_clusters = cutoffs.size() - 1;
48   head_size = shortlist_size + n_clusters;
49 
50   head = this->register_module(
51       "head",
52       Linear(LinearOptions(options.in_features(), head_size)
53                  .bias(options.head_bias())));
54   tail = this->register_module("tail", ModuleList());
55 
56   for (const auto i : c10::irange(n_clusters)) {
57     int64_t hsz = static_cast<int64_t>(std::floor(
58         options.in_features() / std::pow(options.div_value(), (i + 1))));
59     int64_t osz = cutoffs[i + 1] - cutoffs[i];
60 
61     Sequential projection(
62         Linear(LinearOptions(options.in_features(), hsz).bias(false)),
63         Linear(LinearOptions(hsz, osz).bias(false)));
64     tail->push_back(projection);
65   }
66 }
67 
reset_parameters()68 void AdaptiveLogSoftmaxWithLossImpl::reset_parameters() {
69   head->reset_parameters();
70   for (const auto i : c10::irange(tail->size())) {
71     auto i2h = tail[i]->children()[0]->as<Linear>();
72     auto h2o = tail[i]->children()[1]->as<Linear>();
73     i2h->reset_parameters();
74     h2o->reset_parameters();
75   }
76 }
77 
forward(const Tensor & input_,const Tensor & target_)78 ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward(
79     const Tensor& input_,
80     const Tensor& target_) {
81   auto targ_dim = target_.dim();
82 
83   TORCH_CHECK(
84       targ_dim == 1 || targ_dim == 0,
85       "0D or 1D target tensor expected, multi-target not supported");
86 
87   if (targ_dim == 1) {
88     TORCH_CHECK(
89         input_.dim() == 2,
90         "1D target tensor expects 2D input tensors, but found inputs with sizes ",
91         input_.sizes(),
92         ".");
93   } else {
94     TORCH_CHECK(
95         input_.dim() == 1,
96         "0D target tensor expects 1D input tensors, but found inputs with sizes ",
97         input_.sizes(),
98         ".");
99   }
100 
101   bool is_batched = (targ_dim > 0);
102   Tensor input = is_batched ? input_ : input_.unsqueeze(0);
103   Tensor target = is_batched ? target_ : target_.unsqueeze(0);
104 
105   int64_t used_rows = 0;
106   const int64_t batch_size = target.size(0);
107 
108   Tensor output = input.new_zeros(batch_size);
109   Tensor gather_inds = target.new_empty(batch_size);
110 
111   auto cutoff_values = cutoffs;
112   cutoff_values.insert(cutoff_values.begin(), 0);
113 
114   for (const auto i : c10::irange(cutoff_values.size() - 1)) {
115     int64_t low_idx = cutoff_values[i];
116     int64_t high_idx = cutoff_values[i + 1];
117 
118     const Tensor target_mask = (target >= low_idx) * (target < high_idx);
119     const Tensor row_indices = target_mask.nonzero().squeeze();
120 
121     if (row_indices.numel() == 0) {
122       continue;
123     }
124 
125     if (i == 0) {
126       gather_inds.index_copy_(0, row_indices, target.index({target_mask}));
127     } else {
128       Tensor relative_target = target.index({target_mask}) - low_idx;
129       Tensor input_subset = input.index_select(0, row_indices);
130 
131       const Tensor cluster_output =
132           tail[i - 1]->as<Sequential>()->forward(input_subset);
133       int64_t cluster_index = shortlist_size + i - 1;
134 
135       gather_inds.index_fill_(0, row_indices, cluster_index);
136 
137       const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
138       const Tensor local_logprob =
139           cluster_logprob.gather(1, relative_target.unsqueeze(1));
140       output.index_copy_(0, row_indices, local_logprob.squeeze(1));
141     }
142 
143     used_rows += row_indices.numel();
144   }
145 
146   TORCH_CHECK(
147       used_rows == batch_size,
148       "Target values should be in [0, ",
149       options.n_classes() - 1,
150       "], "
151       "but values in range [",
152       target.min().item().toDouble(),
153       ", ",
154       target.max().item().toDouble(),
155       "] "
156       "were found. ");
157 
158   const Tensor head_output = head(input);
159   const Tensor head_logprob = F::log_softmax(head_output, 1);
160   output += head_logprob.gather(1, gather_inds.unsqueeze(1)).squeeze();
161   const double loss = (-output).mean().item().toDouble();
162 
163   if (!is_batched) {
164     output = output.squeeze(0);
165   }
166 
167   return ASMoutput(output, loss);
168 }
169 
_get_full_log_prob(const Tensor & input,const Tensor & head_output)170 Tensor AdaptiveLogSoftmaxWithLossImpl::_get_full_log_prob(
171     const Tensor& input,
172     const Tensor& head_output) {
173   Tensor out = input.new_empty({head_output.size(0), options.n_classes()});
174   const Tensor head_logprob = F::log_softmax(head_output, 1);
175 
176   out.index_put_(
177       {Slice(), Slice(None, shortlist_size)},
178       head_logprob.index({Slice(), Slice(None, shortlist_size)}));
179 
180   for (const auto i : c10::irange(cutoffs.size() - 1)) {
181     int64_t start_idx = cutoffs[i];
182     int64_t stop_idx = cutoffs[i + 1];
183     const Tensor cluster_output = tail[i]->as<Sequential>()->forward(input);
184     const Tensor cluster_logprob = F::log_softmax(cluster_output, 1);
185     auto output_logprob = cluster_logprob +
186         head_logprob.index({Slice(), static_cast<int64_t>(shortlist_size + i)})
187             .unsqueeze(1);
188 
189     out.index_put_({Slice(), Slice(start_idx, stop_idx)}, output_logprob);
190   }
191   return out;
192 }
193 
log_prob(const Tensor & input)194 Tensor AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl::log_prob(
195     const Tensor& input) {
196   const Tensor head_output = head(input);
197   return _get_full_log_prob(input, head_output);
198 }
199 
predict(const Tensor & input)200 Tensor AdaptiveLogSoftmaxWithLossImpl::predict(const Tensor& input) {
201   const Tensor head_output = head(input);
202   Tensor output = torch::argmax(head_output, 1);
203   auto not_in_shortlist = (output >= shortlist_size);
204   auto all_in_shortlist = bitwise_not(not_in_shortlist.any());
205 
206   if (all_in_shortlist.item().toBool()) {
207     return output;
208   } else if (not_in_shortlist.all().item().toBool()) {
209     const Tensor log_prob = _get_full_log_prob(input, head_output);
210     return torch::argmax(log_prob, 1);
211   } else {
212     const Tensor log_prob = _get_full_log_prob(
213         input.index({not_in_shortlist}), head_output.index({not_in_shortlist}));
214     output.index_put_({not_in_shortlist}, torch::argmax(log_prob, 1));
215     return output;
216   }
217 }
218 
pretty_print(std::ostream & stream) const219 void AdaptiveLogSoftmaxWithLossImpl::pretty_print(std::ostream& stream) const {
220   stream << "torch::nn::AdaptiveLogSoftmaxWithLoss";
221 }
222 
223 } // namespace nn
224 } // namespace torch
225