1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/List.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/TensorOperators.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/BinaryOps.h>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/_sparse_coo_tensor_unsafe.h>
16 #include <ATen/ops/embedding_backward_native.h>
17 #include <ATen/ops/embedding_dense_backward.h>
18 #include <ATen/ops/embedding_dense_backward_native.h>
19 #include <ATen/ops/embedding_native.h>
20 #include <ATen/ops/embedding_renorm_native.h>
21 #include <ATen/ops/embedding_sparse_backward.h>
22 #include <ATen/ops/embedding_sparse_backward_native.h>
23 #include <ATen/ops/empty.h>
24 #include <ATen/ops/zeros.h>
25 #endif
26
27 #include <c10/util/irange.h>
28
29 #include <cstring>
30 #include <memory>
31 #include <utility>
32 #include <vector>
33
34
35 namespace at::native {
36
embedding_symint(const Tensor & weight,const Tensor & indices,c10::SymInt padding_idx,bool scale_grad_by_freq,bool sparse)37 Tensor embedding_symint(const Tensor & weight, const Tensor & indices,
38 c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
39 TORCH_CHECK(weight.dim() == 2, "'weight' must be 2-D");
40 auto indices_arg = TensorArg(indices, "indices", 1);
41 checkScalarTypes("embedding", indices_arg, {kLong, kInt});
42
43 // TODO: use tensor.index() after improving perf
44 if (indices.dim() == 1) {
45 return weight.index_select(0, indices);
46 }
47
48 auto size = indices.sym_sizes().vec();
49 for (const auto& d : weight.sym_sizes().slice(1)) {
50 size.push_back(d);
51 }
52
53 return weight.index_select(0, indices.reshape(-1)).view_symint(size);
54 }
55
embedding_backward_symint(const Tensor & grad,const Tensor & indices,c10::SymInt num_weights,c10::SymInt padding_idx,bool scale_grad_by_freq,bool sparse)56 Tensor embedding_backward_symint(
57 const Tensor & grad, const Tensor & indices, c10::SymInt num_weights,
58 c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) {
59 if (sparse) {
60 // TODO: if we teach sparse tensor how to propagate symints, the guard
61 // here is not strictly necessary. However, we think it is fine as is
62 // because num weights is derived from a parameter and therefore
63 // typically not varying.
64 return at::embedding_sparse_backward(
65 grad, indices,
66 num_weights.guard_int(__FILE__, __LINE__),
67 padding_idx.guard_int(__FILE__, __LINE__),
68 scale_grad_by_freq);
69 } else {
70 return at::embedding_dense_backward_symint(
71 grad, indices, std::move(num_weights), padding_idx, scale_grad_by_freq);
72 }
73 }
74
embedding_sparse_backward(const Tensor & grad_,const Tensor & indices_,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)75 Tensor embedding_sparse_backward(
76 const Tensor & grad_, const Tensor & indices_, int64_t num_weights,
77 int64_t padding_idx, bool scale_grad_by_freq) {
78
79 auto indices_arg = TensorArg(indices_, "indices", 2);
80 checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
81
82 // TODO: implement scale_grad_by_freq
83 if (scale_grad_by_freq) {
84 AT_ERROR(
85 "embedding_backward: scale_grad_by_freq not supported with sparse gradients");
86 }
87
88 Tensor indices = indices_;
89 Tensor grad = grad_;
90 if (padding_idx != -1) {
91 c10::List<std::optional<Tensor>> c({indices != padding_idx});
92 indices = indices.index(c);
93 grad = grad.index(c);
94 }
95
96 auto num_features = grad_.sym_size(-1);
97 auto weight_size = std::array<c10::SymInt, 2>{{ num_weights, num_features }};
98 auto dense_options = grad.options();
99
100 // check if all our grad come from padding_idx
101 if (grad.sym_numel() == 0) {
102 return at::_sparse_coo_tensor_unsafe_symint(at::empty({1, 0}, indices_.options().dtype(kLong)),
103 at::empty_symint({c10::SymInt(0), std::move(num_features)}, dense_options),
104 weight_size);
105 }
106
107 auto index = indices.reshape({1, -1});
108 auto values = grad.reshape_symint({c10::SymInt(-1), std::move(num_features)});
109 return at::_sparse_coo_tensor_unsafe_symint(index.to(kLong), values, weight_size);
110 }
111
embedding_dense_backward_cpu(const Tensor & grad_,const Tensor & indices,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)112 Tensor embedding_dense_backward_cpu(
113 const Tensor & grad_, const Tensor & indices, int64_t num_weights,
114 int64_t padding_idx, bool scale_grad_by_freq) {
115
116 auto indices_arg = TensorArg(indices, "indices", 2);
117 checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
118
119 auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
120 auto indices_contig = indices.contiguous();
121 int64_t numel = indices.numel();
122 auto grad = grad_.contiguous().view({numel, grad_.size(-1)});
123
124 auto add_iter = TensorIteratorConfig()
125 .add_output(grad_weight)
126 .add_input(grad_weight)
127 .add_const_input(grad)
128 .resize_outputs(false)
129 .declare_static_shape(grad.sizes(), /*squash_dims=*/0)
130 .build();
131
132 const auto gW_data = reinterpret_cast<char*>(grad_weight.data_ptr());
133 const auto gO_data = reinterpret_cast<const char*>(grad.const_data_ptr());
134 const auto gW_stride = grad_weight.strides()[0] * grad_weight.element_size();
135 const auto gO_stride = grad.strides()[0] * grad.element_size();
136
137 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cpu", [&] () {
138 auto indices_data = indices_contig.const_data_ptr<index_t>();
139
140 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
141 std::unique_ptr<index_t[]> counts;
142 if (scale_grad_by_freq) {
143 counts.reset(new index_t[num_weights]);
144 for (const auto i : c10::irange(numel)) {
145 counts[indices_data[i]] = 0;
146 }
147 for (const auto i : c10::irange(numel)) {
148 counts[indices_data[i]]++;
149 }
150 }
151
152 auto parallel_section = [&](index_t start, index_t end) {
153 TensorIterator iter(add_iter);
154 for (const auto i : c10::irange(numel)) {
155 if (indices_data[i] != padding_idx) {
156 index_t k = indices_data[i];
157 if (k >= start && k < end) {
158 double scale = 1.0;
159 if (scale_grad_by_freq) {
160 // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays)
161 scale /= counts[k];
162 }
163
164 // grad_weight[k].add_(grad[i], scale);
165 iter.unsafe_replace_operand(0, gW_data + k * gW_stride);
166 iter.unsafe_replace_operand(1, gW_data + k * gW_stride);
167 iter.unsafe_replace_operand(2, const_cast<char*>(gO_data + i * gO_stride));
168 add_stub(kCPU, iter, scale);
169 }
170 }
171 }
172 };
173
174 at::parallel_for(0, num_weights, 1000, parallel_section);
175
176 });
177
178 return grad_weight;
179 }
180
embedding_renorm_cpu_(Tensor & self,const Tensor & indices,double max_norm,double norm_type)181 Tensor & embedding_renorm_cpu_(
182 Tensor & self, const Tensor & indices, double max_norm, double norm_type) {
183 auto self_arg = TensorArg(self, "self", 1);
184 auto indices_arg = TensorArg(indices, "indices", 2);
185 checkDim("embedding_renorm_", self_arg, 2);
186 checkScalarTypes("embedding_renorm_", indices_arg, {kLong, kInt});
187
188 auto indices_contig = indices.contiguous();
189 auto num_indices = indices.numel();
190
191 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cpu_", [&]() {
192 auto data_ptr = indices_contig.const_data_ptr<index_t>();
193 auto sorted_indices = std::vector<index_t>(data_ptr, data_ptr + num_indices);
194 std::sort(sorted_indices.begin(), sorted_indices.end());
195
196 // Note that we cannot use at::parallel_for here because we perform operations on
197 // Tensor inside the loop. See github.com/pytorch/pytorch/issues/28370 for more details.
198 for (const auto i : c10::irange(num_indices)) {
199 if (i > 0 && sorted_indices[i] == sorted_indices[i - 1]) {
200 continue;
201 }
202 auto row = self[sorted_indices[i]];
203 auto norm = row.norm(norm_type).item<double>();
204 if (norm > max_norm) {
205 auto scale = max_norm / (norm + 1e-7);
206 row *= scale;
207 }
208 }
209 });
210
211 return self;
212 }
213
214
215 } // namespace at::native
216