xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Embedding.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/List.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/BinaryOps.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
16 #include <ATen/ops/embedding_backward_native.h>
17 #include <ATen/ops/embedding_dense_backward.h>
18 #include <ATen/ops/embedding_dense_backward_native.h>
19 #include <ATen/ops/embedding_native.h>
20 #include <ATen/ops/embedding_renorm_native.h>
21 #include <ATen/ops/embedding_sparse_backward.h>
22 #include <ATen/ops/embedding_sparse_backward_native.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/zeros.h>
25 #endif
26 
27 #include <c10/util/irange.h>
28 
29 #include <cstring>
30 #include <memory>
31 #include <utility>
32 #include <vector>
33 
34 
35 namespace at::native {
36 
embedding_symint(const Tensor & weight,const Tensor & indices,c10::SymInt padding_idx,bool scale_grad_by_freq,bool sparse)37 Tensor embedding_symint(const Tensor & weight, const Tensor & indices,
38                         c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
39   TORCH_CHECK(weight.dim() == 2,  "'weight' must be 2-D");
40   auto indices_arg = TensorArg(indices, "indices", 1);
41   checkScalarTypes("embedding", indices_arg, {kLong, kInt});
42 
43   // TODO: use tensor.index() after improving perf
44   if (indices.dim() == 1) {
45     return weight.index_select(0, indices);
46   }
47 
48   auto size = indices.sym_sizes().vec();
49   for (const auto& d : weight.sym_sizes().slice(1)) {
50     size.push_back(d);
51   }
52 
53   return weight.index_select(0, indices.reshape(-1)).view_symint(size);
54 }
55 
embedding_backward_symint(const Tensor & grad,const Tensor & indices,c10::SymInt num_weights,c10::SymInt padding_idx,bool scale_grad_by_freq,bool sparse)56 Tensor embedding_backward_symint(
57     const Tensor & grad, const Tensor & indices, c10::SymInt num_weights,
58     c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
59   if (sparse) {
60     // TODO: if we teach sparse tensor how to propagate symints, the guard
61     // here is not strictly necessary.  However, we think it is fine as is
62     // because num weights is derived from a parameter and therefore
63     // typically not varying.
64     return at::embedding_sparse_backward(
65       grad, indices,
66       num_weights.guard_int(__FILE__, __LINE__),
67       padding_idx.guard_int(__FILE__, __LINE__),
68       scale_grad_by_freq);
69   } else {
70     return at::embedding_dense_backward_symint(
71       grad, indices, std::move(num_weights), padding_idx, scale_grad_by_freq);
72   }
73 }
74 
embedding_sparse_backward(const Tensor & grad_,const Tensor & indices_,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)75 Tensor embedding_sparse_backward(
76     const Tensor & grad_, const Tensor & indices_, int64_t num_weights,
77     int64_t padding_idx, bool scale_grad_by_freq) {
78 
79   auto indices_arg = TensorArg(indices_, "indices", 2);
80   checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
81 
82   // TODO: implement scale_grad_by_freq
83   if (scale_grad_by_freq) {
84     AT_ERROR(
85         "embedding_backward: scale_grad_by_freq not supported with sparse gradients");
86   }
87 
88   Tensor indices = indices_;
89   Tensor grad = grad_;
90   if (padding_idx != -1) {
91     c10::List<std::optional<Tensor>> c({indices != padding_idx});
92     indices = indices.index(c);
93     grad = grad.index(c);
94   }
95 
96   auto num_features = grad_.sym_size(-1);
97   auto weight_size = std::array<c10::SymInt, 2>{{ num_weights, num_features }};
98   auto dense_options = grad.options();
99 
100   // check if all our grad come from padding_idx
101   if (grad.sym_numel() == 0) {
102     return at::_sparse_coo_tensor_unsafe_symint(at::empty({1, 0}, indices_.options().dtype(kLong)),
103                                          at::empty_symint({c10::SymInt(0), std::move(num_features)}, dense_options),
104                                          weight_size);
105   }
106 
107   auto index = indices.reshape({1, -1});
108   auto values = grad.reshape_symint({c10::SymInt(-1), std::move(num_features)});
109   return at::_sparse_coo_tensor_unsafe_symint(index.to(kLong), values, weight_size);
110 }
111 
embedding_dense_backward_cpu(const Tensor & grad_,const Tensor & indices,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)112 Tensor embedding_dense_backward_cpu(
113     const Tensor & grad_, const Tensor & indices, int64_t num_weights,
114     int64_t padding_idx, bool scale_grad_by_freq) {
115 
116   auto indices_arg = TensorArg(indices, "indices", 2);
117   checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
118 
119   auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
120   auto indices_contig = indices.contiguous();
121   int64_t numel = indices.numel();
122   auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
123 
124   auto add_iter = TensorIteratorConfig()
125     .add_output(grad_weight)
126     .add_input(grad_weight)
127     .add_const_input(grad)
128     .resize_outputs(false)
129     .declare_static_shape(grad.sizes(), /*squash_dims=*/0)
130     .build();
131 
132   const auto gW_data = reinterpret_cast<char*>(grad_weight.data_ptr());
133   const auto gO_data = reinterpret_cast<const char*>(grad.const_data_ptr());
134   const auto gW_stride = grad_weight.strides()[0] * grad_weight.element_size();
135   const auto gO_stride = grad.strides()[0] * grad.element_size();
136 
137   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () {
138     auto indices_data = indices_contig.const_data_ptr<index_t>();
139 
140     // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
141     std::unique_ptr<index_t[]> counts;
142     if (scale_grad_by_freq) {
143       counts.reset(new index_t[num_weights]);
144       for (const auto i : c10::irange(numel)) {
145         counts[indices_data[i]] = 0;
146       }
147       for (const auto i : c10::irange(numel)) {
148         counts[indices_data[i]]++;
149       }
150     }
151 
152     auto parallel_section = [&](index_t start, index_t end) {
153       TensorIterator iter(add_iter);
154       for (const auto i : c10::irange(numel)) {
155         if (indices_data[i] != padding_idx) {
156           index_t k = indices_data[i];
157           if (k >= start && k < end) {
158             double scale = 1.0;
159             if (scale_grad_by_freq) {
160               // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
161               scale /= counts[k];
162             }
163 
164             // grad_weight[k].add_(grad[i], scale);
165             iter.unsafe_replace_operand(0, gW_data + k * gW_stride);
166             iter.unsafe_replace_operand(1, gW_data + k * gW_stride);
167             iter.unsafe_replace_operand(2, const_cast<char*>(gO_data + i * gO_stride));
168             add_stub(kCPU, iter, scale);
169           }
170         }
171       }
172     };
173 
174     at::parallel_for(0, num_weights, 1000, parallel_section);
175 
176   });
177 
178   return grad_weight;
179 }
180 
embedding_renorm_cpu_(Tensor & self,const Tensor & indices,double max_norm,double norm_type)181 Tensor & embedding_renorm_cpu_(
182     Tensor & self, const Tensor & indices, double max_norm, double norm_type) {
183   auto self_arg = TensorArg(self, "self", 1);
184   auto indices_arg = TensorArg(indices, "indices", 2);
185   checkDim("embedding_renorm_", self_arg, 2);
186   checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt});
187 
188   auto indices_contig = indices.contiguous();
189   auto num_indices = indices.numel();
190 
191   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() {
192     auto data_ptr = indices_contig.const_data_ptr<index_t>();
193     auto sorted_indices = std::vector<index_t>(data_ptr, data_ptr + num_indices);
194     std::sort(sorted_indices.begin(), sorted_indices.end());
195 
196     // Note that we cannot use at::parallel_for here because we perform operations on
197     // Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
198     for (const auto i : c10::irange(num_indices)) {
199       if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) {
200         continue;
201       }
202       auto row = self[sorted_indices[i]];
203       auto norm = row.norm(norm_type).item<double>();
204       if (norm > max_norm) {
205         auto scale = max_norm / (norm + 1e-7);
206         row *= scale;
207       }
208     }
209   });
210 
211   return self;
212 }
213 
214 
215 }  // namespace at::native
216