xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfr/utils/utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfr/utils/utils.h"
17 
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/StringRef.h"
20 #include "llvm/ADT/StringSet.h"
21 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
23 
24 namespace mlir {
25 namespace TFR {
26 namespace {
27 
28 // TODO(b/174692018): Use the official allowlist of the unregistered attrs.
GetAllowedAttributes()29 const llvm::StringSet<>& GetAllowedAttributes() {
30   static auto* const ops = new llvm::StringSet<>({"device", "_tpu_replicate"});
31   return *ops;
32 }
33 
34 // Some TFL optional attributes may not appear in their corresponding TF op
35 // attributes.
GetOptionalAttributes()36 const llvm::StringSet<>& GetOptionalAttributes() {
37   static auto* const ops =
38       new llvm::StringSet<>({"asymmetric_quantize_inputs"});
39   return *ops;
40 }
41 
CollectAllowedAttrs(CallOp src,NamedAttrList * attrs)42 void CollectAllowedAttrs(CallOp src, NamedAttrList* attrs) {
43   for (auto& attr : src->getAttrs()) {
44     if (GetAllowedAttributes().contains(attr.getName().strref())) {
45       attrs->append(attr);
46     }
47   }
48 }
49 
50 // Adds `attrs` to all the operations between `begin` and `end` in the same
51 // block. Does not include `end`.
AddAttributesInSameBlock(Block::iterator begin,Block::iterator end,const NamedAttrList & attrs)52 void AddAttributesInSameBlock(Block::iterator begin, Block::iterator end,
53                               const NamedAttrList& attrs) {
54   for (Block::iterator it = begin; it != end; ++it) {
55     for (auto& attr : attrs) {
56       it->setAttr(attr.getName(), attr.getValue());
57     }
58   }
59 }
60 
61 // Adds `attrs` to all the operations between `begin` and `end`. Does not
62 // include `end`. The operations might be across multiple  blocks.
AddAttributes(Block::iterator begin,Block::iterator end,const NamedAttrList & attrs)63 void AddAttributes(Block::iterator begin, Block::iterator end,
64                    const NamedAttrList& attrs) {
65   if (begin->getBlock() == end->getBlock()) {
66     AddAttributesInSameBlock(begin, end, attrs);
67   } else {
68     Region::iterator begin_block = Region::iterator(begin->getBlock());
69     Region::iterator end_block = Region::iterator(end->getBlock());
70     AddAttributesInSameBlock(begin, begin_block->end(), attrs);
71     for (Region::iterator it = ++begin_block; it != end_block; ++it) {
72       AddAttributesInSameBlock(it->begin(), it->end(), attrs);
73     }
74   }
75 }
76 
77 }  // namespace
78 
GetComposeFuncName(StringRef tf_op_name)79 std::string GetComposeFuncName(StringRef tf_op_name) {
80   std::string compose_func_name;
81   for (int i = 0; i < tf_op_name.size(); ++i) {
82     if (tf_op_name[i] == '_') {
83       // The field name must not contain "_"s. "_Arg" and "_RetVal" are special
84       // op names and we can return empty string to skip the decomposition.
85       return {};
86     }
87     if (tf_op_name[i] == '.') {
88       compose_func_name.push_back('_');
89     } else if (tf_op_name[i] >= 'A' && tf_op_name[i] <= 'Z') {
90       compose_func_name.push_back('_');
91       compose_func_name.push_back(tf_op_name[i] + 'a' - 'A');
92     } else {
93       compose_func_name.push_back(tf_op_name[i]);
94     }
95   }
96   return compose_func_name;
97 }
98 
GetTFOpName(StringRef compose_func_name)99 std::string GetTFOpName(StringRef compose_func_name) {
100   std::string tf_op_name;
101   bool after_underscore = false;
102   for (int i = 0; i < compose_func_name.size(); ++i) {
103     if (compose_func_name[i] >= 'A' && compose_func_name[i] <= 'Z') {
104       // The field name must not contain uppercase letters.
105       return {};
106     }
107     if (after_underscore) {
108       if (compose_func_name[i] >= 'a' && compose_func_name[i] <= 'z') {
109         tf_op_name.push_back(compose_func_name[i] + 'A' - 'a');
110         after_underscore = false;
111       } else {
112         // The character after a "_" must be a lowercase letter.
113         return {};
114       }
115     } else if (compose_func_name[i] == '_') {  // first time visit '_'
116       if (i + 1 < compose_func_name.size() && compose_func_name[i + 1] == '_') {
117         tf_op_name.push_back('.');
118         i++;
119       }
120       after_underscore = true;
121     } else {
122       tf_op_name.push_back(compose_func_name[i]);
123     }
124   }
125   if (after_underscore) {
126     // Trailing "_".
127     return {};
128   }
129   return tf_op_name;
130 }
131 
ValidateAttrs(Operation * src,const StringSet<> & registered)132 LogicalResult ValidateAttrs(Operation* src, const StringSet<>& registered) {
133   for (auto& attr : src->getAttrs()) {
134     StringRef attr_name = attr.getName().strref();
135 
136     if (!registered.contains(attr_name) &&
137         !(GetAllowedAttributes().contains(attr_name) ||
138           GetOptionalAttributes().contains(attr_name))) {
139       src->emitError("Denied unregistered attribute was found: " + attr_name);
140       return failure();
141     }
142   }
143   return success();
144 }
145 
CopyAllowedUnregisteredAttrs(Operation * src,CallOp dst,const StringSet<> & registered)146 LogicalResult CopyAllowedUnregisteredAttrs(Operation* src, CallOp dst,
147                                            const StringSet<>& registered) {
148   for (auto& attr : src->getAttrs()) {
149     StringRef attr_name = attr.getName().strref();
150     // Skip the registered or optional attribute.
151     if (registered.contains(attr_name) ||
152         GetOptionalAttributes().contains(attr_name))
153       continue;
154 
155     // Unregistered attribute.
156     if (GetAllowedAttributes().contains(attr_name)) {
157       dst->setAttr(attr.getName(), attr.getValue());
158     } else {
159       src->emitError("Denied unregistered attribute was found: " + attr_name);
160       return failure();
161     }
162   }
163   return success();
164 }
165 
CopyNonSymbolRefAttrs(CallOp src,Operation * dst)166 LogicalResult CopyNonSymbolRefAttrs(CallOp src, Operation* dst) {
167   NamedAttrList attrs;
168   CollectAllowedAttrs(src, &attrs);
169 
170   for (auto& attr : attrs) {
171     dst->setAttr(attr.getName(), attr.getValue());
172   }
173 
174   return success();
175 }
176 
PropagateAttrsToOperations(CallOp src,Block::iterator begin,Block::iterator end)177 void PropagateAttrsToOperations(CallOp src, Block::iterator begin,
178                                 Block::iterator end) {
179   // Find all the attributes in the call op. These attributes are not in the
180   // op definition, so needs to be propagated to all the target ops.
181   NamedAttrList attrs;
182   CollectAllowedAttrs(src, &attrs);
183 
184   // Add all the attributes to the operations in the range.
185   if (!attrs.empty()) {
186     AddAttributes(begin, end, attrs);
187   }
188 }
189 
190 }  // namespace TFR
191 }  // namespace mlir
192