xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/einsum_op_util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/einsum_op_util.h"
17 
18 #include <string>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/strings/str_split.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/lib/core/status.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 
26 namespace tensorflow {
27 
ValidateEinsumEquation(const string & equation,gtl::InlinedVector<string,2> * input_subscripts,string * output_subscript)28 Status ValidateEinsumEquation(const string& equation,
29                               gtl::InlinedVector<string, 2>* input_subscripts,
30                               string* output_subscript) {
31   gtl::InlinedVector<string, 2> inputs_and_output_subscripts =
32       absl::StrSplit(equation, "->");
33   if (inputs_and_output_subscripts.size() != 2) {
34     return errors::InvalidArgument(
35         "Expecting exactly one '->' in einsum equation: ", equation);
36   }
37   *output_subscript = std::move(inputs_and_output_subscripts[1]);
38   *input_subscripts =
39       absl::StrSplit(std::move(inputs_and_output_subscripts[0]), ',');
40   if (input_subscripts->size() != 1 && input_subscripts->size() != 2) {
41     return errors::InvalidArgument(
42         "Expecting 1 or 2 input subscripts in equation '", equation,
43         "' but got: ", input_subscripts->size());
44   }
45   return OkStatus();
46 }
47 
48 // Returns the EinsumDimensionType given whether the corresponding label is
49 // present in exactly one input subscript (is_unique) and whether it is absent
50 // from the output subscripts (is_removed). Does not handle broadcasting
51 // dimensions.
GetDimensionType(bool is_removed,bool is_unique)52 EinsumDimensionType GetDimensionType(bool is_removed, bool is_unique) {
53   if (!is_removed && !is_unique)
54     return kBatch;
55   else if (!is_removed && is_unique)
56     return kFree;
57   else if (is_removed && !is_unique)
58     return kContract;
59   else  // is_removed && is_unique
60     return kReduce;
61 }
62 
63 // Maps the character labels to consecutive integers.
MapToLabels(const string & subscript,Labels * labels,absl::flat_hash_map<char,int> * label_mapping)64 void MapToLabels(const string& subscript, Labels* labels,
65                  absl::flat_hash_map<char, int>* label_mapping) {
66   for (int i = 0; i < subscript.size(); ++i) {
67     const char label_char = subscript[i];
68     if (label_char == '.') {
69       labels->push_back(kEllipsisLabel);
70       i += 2;  // Skip next 2 characters as well.
71       continue;
72     }
73     if (!label_mapping->contains(label_char)) {
74       const int next_label = label_mapping->size();
75       (*label_mapping)[label_char] = next_label;
76     }
77     const int mapped_label = (*label_mapping)[label_char];
78     labels->push_back(mapped_label);
79   }
80 }
81 
ParseEinsumEquation(const string & equation,OperandLabels * input_labels,Labels * output_labels,std::vector<EinsumDimensionType> * label_types,OperandLabelCounts * input_label_counts,LabelCounts * output_label_counts,gtl::InlinedVector<bool,2> * input_has_ellipsis,bool * output_has_ellipsis)82 Status ParseEinsumEquation(const string& equation, OperandLabels* input_labels,
83                            Labels* output_labels,
84                            std::vector<EinsumDimensionType>* label_types,
85                            OperandLabelCounts* input_label_counts,
86                            LabelCounts* output_label_counts,
87                            gtl::InlinedVector<bool, 2>* input_has_ellipsis,
88                            bool* output_has_ellipsis) {
89   gtl::InlinedVector<string, 2> input_str;
90   string output_str;
91   TF_RETURN_IF_ERROR(ValidateEinsumEquation(equation, &input_str, &output_str));
92 
93   // Temporary map from single character labels to (consecutive) integer labels.
94   absl::flat_hash_map<char, int> label_mapping;
95   int num_inputs = input_str.size();
96   input_labels->resize(num_inputs);
97 
98   // Map from single characters to integer labels.
99   for (int i = 0; i < num_inputs; ++i) {
100     MapToLabels(input_str[i], &input_labels->at(i), &label_mapping);
101   }
102   MapToLabels(output_str, output_labels, &label_mapping);
103 
104   // Compute counts for input and output labels.
105   int num_labels = label_mapping.size();
106   input_label_counts->resize(num_inputs);
107   input_has_ellipsis->resize(num_inputs);
108   for (int i = 0; i < num_inputs; ++i) {
109     input_label_counts->at(i).resize(num_labels);
110     input_has_ellipsis->at(i) = false;
111     for (const int label : input_labels->at(i)) {
112       if (label != kEllipsisLabel)
113         input_label_counts->at(i)[label] += 1;
114       else
115         input_has_ellipsis->at(i) = true;
116     }
117   }
118   output_label_counts->resize(num_labels);
119   *output_has_ellipsis = false;
120   for (const int label : *output_labels) {
121     if (label != kEllipsisLabel)
122       output_label_counts->at(label) += 1;
123     else
124       *output_has_ellipsis = true;
125   }
126 
127   // Map each label to a unique EinsumDimensionType.
128   label_types->resize(num_labels);
129   for (int label = 0; label < num_labels; ++label) {
130     if (label == kEllipsisLabel) continue;
131     bool removed = (*output_label_counts)[label] == 0;
132     bool unique = num_inputs == 1 || (*input_label_counts)[0][label] == 0 ||
133                   (*input_label_counts)[1][label] == 0;
134     (*label_types)[label] = GetDimensionType(removed, unique);
135   }
136   return OkStatus();
137 }
138 
139 }  // namespace tensorflow
140