xref: /aosp_15_r20/external/tensorflow/tensorflow/cc/gradients/linalg_grad.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cmath>
18 #include <string>
19 #include <tuple>
20 
21 #include "absl/container/btree_set.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_split.h"
25 #include "tensorflow/cc/framework/grad_op_registry.h"
26 #include "tensorflow/cc/framework/gradients.h"
27 #include "tensorflow/cc/gradients/grad_helper.h"
28 #include "tensorflow/cc/ops/array_ops_internal.h"
29 #include "tensorflow/cc/ops/math_ops_internal.h"
30 #include "tensorflow/cc/ops/standard_ops.h"
31 
32 namespace tensorflow {
33 namespace ops {
34 namespace {
35 
36 constexpr absl::string_view kEllipsis = "...";
37 
38 // Returns the axis (possibly negative) corresponding to a label.
39 //
40 // Returns the axis index of the axis label if it is before an ellipsis (or if
41 // the ellipsis is not present), and the negative index if it occurs after the
42 // ellipsis. E.g. index of `b` in `ab...cd`, is `1`, but that of `c` is `-2`.
43 //
44 // For multiple occurrences, returns the leftmost one. If not found, returns
45 // absl::nullopt.
46 //
47 // Parameters:
48 //   subscripts: A string denoting the einsum subscript (e.g. `ab...cd`)
49 //   label: The single character axis label.
EinsumGetAxisFromLabel(absl::string_view subscripts,char label)50 absl::optional<int> EinsumGetAxisFromLabel(absl::string_view subscripts,
51                                            char label) {
52   std::vector<absl::string_view> splits = absl::StrSplit(subscripts, kEllipsis);
53   auto index = splits[0].find(label);
54   if (index != splits[0].npos) {
55     return index;
56   }
57   if (splits.size() < 2) {
58     return absl::nullopt;
59   }
60   index = splits[1].find(label);
61   if (index != splits[1].npos) {
62     return index - splits[1].length();
63   }
64   return absl::nullopt;
65 }
66 
67 // Returns a tuple denoting the slice mapping to ellipsis.
68 //
69 // For a given subscript, returns a tuple (start, end) denoting the start
70 // axis index and the (negative) end axis index respectively. For any input
71 // Tensor `x` described by the subscript, `x[start:end]` would be the slice
72 // represented by the ellipsis. E.g. For `ab...cd` returns `[1, -2]`.
73 //
74 // If ellipsis is not present in `subscripts`, returns `(0, 0)`.
75 //
76 // Parameters:
77 //   subscripts: A string denoting the einsum subscript.
78 //   start: Output for the start index
79 //   end: Output for the end index (or nullopt to go to the end).
EinsumGetBcastSubshape(absl::string_view subscripts)80 std::tuple<int, absl::optional<int>> EinsumGetBcastSubshape(
81     absl::string_view subscripts) {
82   int start = subscripts.find(kEllipsis);
83   if (start == subscripts.npos) {
84     return std::make_tuple(0, 0);
85   }
86   int remaining = subscripts.length() - (start + kEllipsis.length());
87   absl::optional<int> end;
88   if (remaining > 0) {
89     end = -remaining;
90   } else {
91     end = absl::nullopt;
92   }
93   return std::make_tuple(start, end);
94 }
95 
96 // Slices elements of a 1d tensor from [start,end].
97 // If end is nullopt, it goes to the end of the tensor.
98 // Supports negative values for end.
99 // This attempts to give the same result as tenspr[start:end] would give in
100 // Python.
Slice1dHelper(const Scope & scope,Output tensor,int start,absl::optional<int> end)101 Output Slice1dHelper(const Scope& scope, Output tensor, int start,
102                      absl::optional<int> end) {
103   if (end.has_value() && *end > 0) {
104     return Slice(scope, tensor, Const(scope, start, TensorShape({1})),
105                  Const(scope, *end - start, TensorShape({1})));
106   } else {
107     return Slice(scope, tensor, Const(scope, start, TensorShape({1})),
108                  Add(scope, Shape(scope, tensor), end.value_or(0) - start));
109   }
110 }
111 
112 // Returns reduced subscripts and their corresponding dimensions and axes.
113 //
114 // Given a set of axis labels, returns their concatenated subscript, their
115 // corresponding dimensions from input_shape, and their corresponding axes.
116 // Note that the concatenated subscript `reduced_subs` may have axis labels
117 // from `reduced_label_set` in any order. For example, for the reduced label
118 // set `{b, d}`, subscripts `aabbcd` and input shape `[2,2,5,5,3,4]`, returns
119 // subscripts `bd`, dimensions `[5,4]` and axes `[2,5]`.
120 //
121 // Args:
122 //   reduced_label_set: Set of axis labels which appear in `subscripts`.
123 //   input_shape: A `Tensor` representing the shape of the einsum operand
124 //     corresponding to `subscripts`.
125 //   subscripts: A string denoting the einsum subscript.
126 //
127 // Returns:
128 //   reduced_subs: Subscripts formed by a concatenation of labels in
129 //     `reduced_label_set`.
130 //   reduced_dims: Dimensions from `input_shape` corresponding to each label
131 //     in `reduced_subs`.
132 //   reduced_axes: Axes described by `subscripts` corresponding to each label
133 //     in `reduced_subs`. If there are multiple occurrences in `subscripts`,
134 //     we consider only the leftmost one.
EinsumGetReducedSubscripts(const Scope & scope,const absl::btree_set<char> & reduced_label_set,Output input_shape,absl::string_view subscripts)135 std::tuple<std::string, Output, Output> EinsumGetReducedSubscripts(
136     const Scope& scope, const absl::btree_set<char>& reduced_label_set,
137     Output input_shape, absl::string_view subscripts) {
138   // Concatenate the sequence of reduced axis labels.
139   const std::string reduced_subs =
140       std::string(reduced_label_set.begin(), reduced_label_set.end());
141   // Get the axis (may be positive, negative or zero) for each of the reduced
142   // labels. If the same label appears multiple times, get the left-most axis.
143   std::vector<int> reduced_axes;
144   reduced_axes.reserve(reduced_subs.size());
145   for (const char s : reduced_subs) {
146     auto axis = EinsumGetAxisFromLabel(subscripts, s);
147     if (!axis.has_value()) {
148       // Should never happen.
149       scope.UpdateStatus(errors::Internal(
150           absl::StrCat("Missing axis", absl::string_view(&s, 1))));
151     } else {
152       reduced_axes.push_back(*axis);
153     }
154   }
155   // Get the corresponding dimensions for each reduced axis.
156   std::vector<Output> reduced_dims_inputs;
157   reduced_dims_inputs.reserve(reduced_axes.size());
158   for (const int i : reduced_axes) {
159     if (i < 0) {
160       reduced_dims_inputs.push_back(
161           Gather(scope, input_shape, Add(scope, Size(scope, input_shape), i)));
162     } else {
163       reduced_dims_inputs.push_back(Gather(scope, input_shape, i));
164     }
165   }
166   const Output reduced_dims = Stack(scope, reduced_dims_inputs);
167   Tensor reduced_axes_tensor(
168       DataType::DT_INT32, TensorShape({static_cast<int>(reduced_axes.size())}));
169   std::copy_n(reduced_axes.begin(), reduced_axes.size(),
170               reduced_axes_tensor.flat<int>().data());
171   return std::make_tuple(reduced_subs, reduced_dims,
172                          Const(scope, reduced_axes_tensor));
173 }
174 
175 // Returns the gradient wrt input for a unary einsum with reductions.
176 //
177 //  scope: Scope for grad operations.
178 //  output_grad: The gradient wrt the output of a unary einsum operation.
179 //  output_subs: The output subscript. (E.g. `ac` for equation `abc->ac`).
180 //  input_subs: The input subscript. (E.g. `abc` for equation `abc->ac`).
181 //  input_shape: The shape of the input operand.
182 //  reduced_label_set: The set of axis labels appearing in `input_subs` but
183 //    not in `output_subs`.
EinsumGradReducedHelper(const Scope & scope,const Output & output_grad,absl::string_view output_subs,absl::string_view input_subs,const Output & input_shape,const absl::btree_set<char> & reduced_label_set)184 Output EinsumGradReducedHelper(const Scope& scope, const Output& output_grad,
185                                absl::string_view output_subs,
186                                absl::string_view input_subs,
187                                const Output& input_shape,
188                                const absl::btree_set<char>& reduced_label_set) {
189   // Let's say the einsum operation was "aabbcd->ca", where axis labels 'b' and
190   // 'd' are reduced with input_shape [2,2,5,5,3,4]. Then obtain the reduced
191   // subscripts "bd", corresponding dimensions [5,4] and axes [2,5].
192   std::string reduced_subs;
193   Output reduced_dims, reduced_axes;
194   std::tie(reduced_subs, reduced_dims, reduced_axes) =
195       EinsumGetReducedSubscripts(scope, reduced_label_set, input_shape,
196                                  input_subs);
197   // Whether either the input or the output subscripts have a repeated label.
198   // This is true for "aabbcd->ca" or "abd->cca" but false for "abcd->ca".
199   const int distinct_input_labels =
200       absl::flat_hash_set<char>(input_subs.begin(), input_subs.end()).size();
201   const int distinct_output_labels =
202       absl::flat_hash_set<char>(output_subs.begin(), output_subs.end()).size();
203   const bool has_repeated_labels =
204       (distinct_input_labels + distinct_output_labels) <
205       input_subs.length() + output_subs.length();
206   // Compute the input subscripts without the reduced axis labels, e.g. "aac"
207   // for the equation "aabbcd->ca".
208   std::string input_subs_without_reduced_labels;
209   for (const char s : input_subs) {
210     if (!absl::c_linear_search(reduced_label_set, s)) {
211       input_subs_without_reduced_labels.push_back(s);
212     }
213   }
214 
215   // The gradient wrt the input for the equation "abc->ac" (or, equivalently
216   // reduce_sum(..., axis=1)) is just the gradient of the output tiled N times
217   // along axis 1, where label 'b' represents a dimension of size N.
218   //
219   // If we're not dealing with repeated labels, and the non-reduced labels
220   // doesn't need to be transposed, then just tiling is enough and there is no
221   // need to call another einsum. For example, tiling is sufficient for
222   // "abcd->ac". But for equations like "aabbcd->ac" (generalized traces) or
223   // "abc->ca" (transpose), we'd need another einsum operation after tiling.
224   if (!has_repeated_labels &&
225       input_subs_without_reduced_labels == output_subs) {
226     // Obtain the shape of the output, as if keepdims=True on reduce sum. E.g.
227     // for the equation "abcd->ac" with input shape [2,5,3,4], we get the
228     // reduced shape [2,1,3,1].
229     auto reduced_shape = ReducedShapeHelper(scope, input_shape, reduced_axes);
230     // Reshaping the gradient (wrt "ac") to [2,1,3,1] and broadcasting it to
231     // the shape [2,5,3,4] results in the gradient wrt "abcd".
232     return BroadcastTo(scope, Reshape(scope, output_grad, reduced_shape),
233                        input_shape);
234   }
235 
236   // If we *do* have traces or transpose operations, then prepend the extra
237   // reduced dimensions to the front. E.g. Given the equation "aabbcd->ca" we'd
238   // first obtain the VJP for "bdca->ca", and then the VJP for "aabbcd->bdca".
239   //
240   // Obtain the input shape with reduced dimensions prepended, viz. [5,4,3,2].
241   // This is the shape of the intermediate "bdca".
242   Output output_grad_shape = Shape(scope, output_grad);
243   auto grad_shape_with_reduced_labels =
244       Concat(scope, {reduced_dims, output_grad_shape}, /*axis=*/0);
245 
246   // Obtain the output shape of the reduction-only equation "bdca->ca" as if
247   // keepdims=True; viz. [1,1,3,2]. Since we prepended the reduced labels,
248   // we just have to prepend that many 1s to the output shape.
249 
250   auto reduced_shape = Concat(
251       scope,
252       {Const(scope, 1, TensorShape{static_cast<int>(reduced_label_set.size())}),
253        output_grad_shape},
254       /*axis=*/0);
255   // Compute the VJP for the intermediate (viz. "bdca->ca") for which
256   // broadcasting is sufficient.
257   Output broadcasted_grad =
258       BroadcastTo(scope, Reshape(scope, output_grad, reduced_shape),
259                   grad_shape_with_reduced_labels);
260   // Compute the VJP for the final step (viz. "aabbcd->bdca"). We can
261   // use einsum with the input and output subscripts reversed (viz.
262   // "bdca->aabbcd") since the output axis labels now appear in the
263   // input subscripts.
264   return Einsum(scope, {broadcasted_grad},
265                 absl::StrCat(reduced_subs, output_subs, "->", input_subs));
266 }
267 
268 // Returns the gradient wrt an input operand for a binary einsum.
269 //
270 // This function does not handle (un)broadcasting. This must be done separately
271 // on the returned gradient.
272 //
273 // Args:
274 //   output_grad: The gradient wrt the output of a binary einsum operation.
275 //   other_operand: The complementary `Tensor` operand i.e. which is not the
276 //     input operand.
277 //   input_shape: A `Tensor` representing the shape of input operand.
278 //   input_subs: The subscripts of the input operand.
279 //   other_subs: The subscripts of the complementary operand.
280 //   output_subs: The output subscripts.
EinsumGradWrt(const Scope & scope,Output output_grad,Output other_operand,Output input_shape,absl::string_view input_subs,absl::string_view other_subs,absl::string_view output_subs)281 Output EinsumGradWrt(const Scope& scope, Output output_grad,
282                      Output other_operand, Output input_shape,
283                      absl::string_view input_subs, absl::string_view other_subs,
284                      absl::string_view output_subs) {
285   // Claim: For the einsum operation z = einsum("{eq_x},{eq_y}->{eq_z}", x, y),
286   //   where the equation involves only Tensor contractions, generalized traces
287   //   and transposes, the input gradients are given by the vector-jacobian
288   //   products (VJPs):
289   //
290   //     grad_wrt_x = einsum("{eq_y},{eq_z}->{eq_x}", y, grad_wrt_z)
291   //     grad_wrt_y = einsum("{eq_x},{eq_z}->{eq_y}", x, grad_wrt_z}
292   //
293   //   where grad_wrt_x and grad_wrt_y are the gradients with respect to inputs
294   //   x and y and grad_wrt_z is the given gradient with respect to output z.
295   //
296   // Proof: For unary einsum equations involving only transpose ("ij->ji") and
297   //   traces ("ii->i"), the linear mapping's Jacobian at input x is given
298   //   by the function itself. We can verify that the linear map given by the
299   //   VJP are einsums with the equations "ji->ij" and "i->ii" respectively,
300   //   where the latter represents 'un-tracing', or filling the diagonal with
301   //   the input axis and non-diagonal entries are zeros.
302   //        Furthermore, recall that matrix multiplication, which is
303   //   represented by the equation "ab,bc->ac", has its VJPs given by the
304   //   einsum equations "ac,bc->ab" and "ab,ac->bc" (see, for example
305   //   https://math.stackexchange.com/a/2755680). Combined with transposes and
306   //   traces we can rewrite Tensor contractions as regular matrix
307   //   multiplication. Since each of these operations have their VJPs described
308   //   by einsums of the required pattern, the result follows.
309   //
310   // Accordingly, einsum operations except for those with reductions, e.g.
311   // "abc,cd->ad" have their VJPs defined by:
312   //   "{output_subs},{other_subs}->{input_subs}".
313   //
314   // But if there is a reduction, this would lead to the equation "ad,cd->abc"
315   // which is invalid because the reduced axis label 'b' is present in the
316   // output but not in any of the inputs. Therefore, we compute the VJP in two
317   // steps: first we obtain VJP for "ac,cd->ad" and then we compute the VJP of
318   // "abc->ac" or, equivalently, reduce_sum(..., axis=1).
319   //
320   // Compute the set of input axis labels which doesn't appear in either the
321   // output subscripts or the other operand's subscript. E.g. the set {'b'} for
322   // the equation "abc,cd->ad".
323   absl::btree_set<char> reduced_label_set(input_subs.begin(), input_subs.end());
324   for (const char x : output_subs) {
325     reduced_label_set.erase(x);
326   }
327   for (const char x : other_subs) {
328     reduced_label_set.erase(x);
329   }
330   reduced_label_set.erase('.');
331 
332   // Obtain the input subscripts with the reduced axis labels removed. E.g.
333   // "ac" in the above example.
334   std::string left_subs;
335   for (const char s : input_subs) {
336     if (!reduced_label_set.contains(s)) {
337       left_subs.push_back(s);
338     }
339   }
340 
341   // Compute the gradient wrt the input, without accounting for the operation
342   // "abc->ac". So, now we have the VJP of the operation "ac,cd->ad".
343   Output grad_reduced =
344       Einsum(scope, {output_grad, other_operand},
345              absl::StrCat(output_subs, ",", other_subs, "->", left_subs));
346 
347   // If the reduced_label_set is empty, then we already have the gradient
348   // wrt the input.
349   if (reduced_label_set.empty()) {
350     return grad_reduced;
351   }
352   // Otherwise, we currently have the gradient wrt the output of the reduction
353   // operation "abc->ac". Invoke the subroutine for the gradient for unary
354   // einsum with reductions.
355   return EinsumGradReducedHelper(scope, grad_reduced, left_subs, input_subs,
356                                  input_shape, reduced_label_set);
357 }
358 
EinsumGrad(const Scope & scope,const Operation & op,const std::vector<Output> & grad_inputs,std::vector<Output> * grad_outputs)359 Status EinsumGrad(const Scope& scope, const Operation& op,
360                   const std::vector<Output>& grad_inputs,
361                   std::vector<Output>* grad_outputs) {
362   if (grad_inputs.size() != 1) {
363     return errors::InvalidArgument("Expect 1 grad input.");
364   }
365   const Output& grad = grad_inputs[0];
366 
367   std::string equation;
368   TF_RETURN_IF_ERROR(GetNodeAttr(op.node()->attrs(), "equation", &equation));
369   std::vector<absl::string_view> equation_split =
370       absl::StrSplit(equation, "->");
371   if (equation_split.size() != 2) {
372     return errors::InvalidArgument("Equation must contain a single ->");
373   }
374 
375   const absl::string_view input_subs = equation_split[0];
376   const absl::string_view output_subs = equation_split[1];
377   if (op.num_inputs() == 1) {
378     // For the unary einsum z = einsum("{eq_x}->{eq_z}", x), the gradient wrt
379     // the input (VJP) is given by the reversed equation:
380     //   grad_wrt_x = einsum("{eq_z}->{eq_x}", grad_wrt_z)
381     // (See the justification in _GetGradWrt). This is valid unless there are
382     // reduced axis labels; i.e. axis labels appearing in the input but not in
383     // the output subscripts.
384     auto input_shape = Shape(scope, op.input(0));
385     // Find the axis labels which appear only in the input.
386     absl::btree_set<char> reduced_label_set(input_subs.begin(),
387                                             input_subs.end());
388     for (const char x : output_subs) {
389       reduced_label_set.erase(x);
390     }
391     reduced_label_set.erase('.');
392     if (reduced_label_set.empty()) {
393       grad_outputs->push_back(Einsum(
394           scope, grad_inputs, absl::StrCat(output_subs, "->", input_subs)));
395       return scope.status();
396     }
397     // We do have reduced axes, so we invoke the subroutine for reduced unary
398     // einsums.
399     grad_outputs->push_back(EinsumGradReducedHelper(
400         scope, grad, output_subs, input_subs, input_shape, reduced_label_set));
401     return scope.status();
402   }
403 
404   std::vector<absl::string_view> subs = absl::StrSplit(input_subs, ',');
405   if (subs.size() != 2) {
406     return errors::InvalidArgument("Only 2 inputs are supported");
407   }
408   std::string x_subs(subs[0]);
409   std::string y_subs(subs[1]);
410   // Add ellipsis for broadcasted dimensions if any operand does not have it.
411   // This is because the equation "...ij,jk->ik" may be valid if the 0th input's
412   // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
413   // because only the output subscripts contain ellipsis.
414   if (absl::StrContains(output_subs, kEllipsis)) {
415     if (!absl::StrContains(x_subs, kEllipsis)) {
416       absl::StrAppend(&x_subs, kEllipsis);
417     }
418     if (!absl::StrContains(y_subs, kEllipsis)) {
419       absl::StrAppend(&y_subs, kEllipsis);
420     }
421   }
422 
423   // Obtain the gradients wrt the inputs x and y, without taking into account
424   // the unbroadcasting.
425   tensorflow::Output x = op.input(0);
426   tensorflow::Output y = op.input(1);
427   if (DataTypeIsComplex(grad.type())) {
428     x = Conj(scope, x);
429     y = Conj(scope, y);
430   }
431 
432   const auto x_shape = Shape(scope, x);
433   const auto y_shape = Shape(scope, y);
434   Output grad_x =
435       EinsumGradWrt(scope, grad, y, x_shape, x_subs, y_subs, output_subs);
436   Output grad_y =
437       EinsumGradWrt(scope, grad, x, y_shape, y_subs, x_subs, output_subs);
438 
439   if (!absl::StrContains(output_subs, kEllipsis)) {
440     // If no ellipsis in the output; then no need to unbroadcast.
441     grad_outputs->push_back(grad_x);
442     grad_outputs->push_back(grad_y);
443     return scope.status();
444   }
445 
446   // Below we handle the case that broadcasting between x and y was necessary,
447   // with x and y having possibly different batch shapes.
448 
449   // Obtain the range of axes which map to ellipsis. E.g. for subscripts
450   // 'ab...c' and shape of rank 10; the range [3:-1] denotes the broadcasted
451   // axes.
452   int bx_start, by_start;
453   absl::optional<int> bx_end, by_end;
454   std::tie(bx_start, bx_end) = EinsumGetBcastSubshape(x_subs);
455   std::tie(by_start, by_end) = EinsumGetBcastSubshape(y_subs);
456 
457   // Sum the gradient across the broadcasted axes.
458   auto args = internal::BroadcastGradientArgs(
459       scope, Slice1dHelper(scope, x_shape, bx_start, bx_end),
460       Slice1dHelper(scope, y_shape, by_start, by_end));
461   grad_x = Reshape(
462       scope, ReduceSum(scope, grad_x, Add(scope, bx_start, args.r0)), x_shape);
463   grad_y = Reshape(
464       scope, ReduceSum(scope, grad_y, Add(scope, by_start, args.r1)), y_shape);
465   grad_outputs->push_back(grad_x);
466   grad_outputs->push_back(grad_y);
467   return scope.status();
468 }
469 
470 REGISTER_GRADIENT_OP("Einsum", EinsumGrad);
471 
472 }  // namespace
473 }  // namespace ops
474 }  // namespace tensorflow
475