xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/graph_compactor/pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/graph_compactor/pass.h"
17 
18 #include <iterator>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "llvm/ADT/BitVector.h"
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
26 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/LLVM.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/core/framework/op.h"
31 #include "tensorflow/core/framework/op_def_builder.h"
32 #include "tensorflow/core/ir/dialect.h"
33 #include "tensorflow/core/ir/importexport/convert_attributes.h"
34 #include "tensorflow/core/ir/interfaces.h"
35 #include "tensorflow/core/ir/ops.h"
36 #include "tensorflow/core/ir/tf_op_registry.h"
37 #include "tensorflow/core/platform/statusor.h"
38 #include "tensorflow/core/transforms/pass_detail.h"
39 
40 namespace mlir {
41 namespace tfg {
42 
43 // Encode an unsigned integer in as few characters as possible to a string that
44 // is still a valid TensorFlow node name. The regex for valid names, according
45 // to `NodeDef`, is "[A-Za-z0-9.][A-Za-z0-9_>./]*"
46 //
47 // The valid characters are provided in the two arrays `first_valid_chars` and
48 // `trailing_valid_chars`.
EncodeName(unsigned counter,std::string & output,ArrayRef<char> first_valid_chars,ArrayRef<char> trailing_valid_chars)49 static void EncodeName(unsigned counter, std::string &output,
50                        ArrayRef<char> first_valid_chars,
51                        ArrayRef<char> trailing_valid_chars) {
52   assert(!first_valid_chars.empty() && !trailing_valid_chars.empty());
53   unsigned rem = counter % first_valid_chars.size();
54   counter /= first_valid_chars.size();
55   output.push_back(first_valid_chars[rem]);
56   while (counter > 0) {
57     --counter;
58     rem = counter % trailing_valid_chars.size();
59     counter /= trailing_valid_chars.size();
60     output.push_back(trailing_valid_chars[rem]);
61   }
62 }
63 
64 // Encode an unsigned integer to a valid TensorFlow node name.
EncodeName(unsigned counter,std::string & output)65 static void EncodeName(unsigned counter, std::string &output) {
66   // The alphabet of valid characters, but the last 3 are only valid in trailing
67   // characters.
68   static constexpr char valid_chars[] =
69       "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789._>/";
70   // Sanity check: all alphanumeric characters, four special characters, and a
71   // null terminator.
72   constexpr unsigned valid_first_chars = 26 * 2 + 10 + 1;
73   constexpr unsigned valid_trailing_chars = valid_first_chars + 3;
74   static_assert(sizeof(valid_chars) == valid_trailing_chars + 1,
75                 "alphabet sanity check");
76   EncodeName(counter, output,
77              llvm::makeArrayRef(valid_chars, valid_first_chars),
78              llvm::makeArrayRef(valid_chars, valid_trailing_chars));
79 }
80 
81 namespace {
82 class NameCompressPass : public NameCompressBase<NameCompressPass> {
83  public:
initialize(MLIRContext * context)84   LogicalResult initialize(MLIRContext *context) override {
85     dialect_ = context->getOrLoadDialect<TFGraphDialect>();
86     empty_dict_ = DictionaryAttr::get(context);
87     return success();
88   }
89 
runOnOperation()90   void runOnOperation() override {
91     GraphFuncOp func = getOperation();
92 
93     Builder b(&getContext());
94     unsigned counter = 0;
95     std::string name;
96     const auto encode_new_name = [&name, &b, &counter] {
97       name.clear();
98       EncodeName(counter++, name);
99       return b.getStringAttr(name);
100     };
101 
102     // Rename the arguments and results.
103     NamedAttrList attrs = func->getAttrDictionary();
104     if (func.getNumArguments()) {
105       assert(func.arg_attrs().has_value() && "expected argument attributes");
106       SmallVector<Attribute> arg_attrs;
107       arg_attrs.reserve(func.getNumArguments());
108       // Iterate over the function arguments, skipping the control tokens.
109       for (int i = 0, e = func.getNumArguments(); i != e; i += 2) {
110         NamedAttrList attrs = func.arg_attrsAttr()[i].cast<DictionaryAttr>();
111         attrs.set(dialect_->getTfgNameAttrIdentifier(), encode_new_name());
112         arg_attrs.append({attrs.getDictionary(&getContext()), empty_dict_});
113       }
114       attrs.set(func.arg_attrsAttrName(), b.getArrayAttr(arg_attrs));
115     }
116     if (func.getNumResults()) {
117       assert(func.res_attrs().has_value() && "expected result attributes");
118       SmallVector<Attribute> res_attrs;
119       res_attrs.reserve(func.getNumResults());
120       for (NamedAttrList attrs :
121            func.res_attrsAttr().getAsRange<DictionaryAttr>()) {
122         attrs.set(dialect_->getTfgNameAttrIdentifier(), encode_new_name());
123         res_attrs.push_back(attrs.getDictionary(&getContext()));
124       }
125       attrs.set(func.res_attrsAttrName(), b.getArrayAttr(res_attrs));
126     }
127     if (func.getNumArguments() || func.getNumResults()) {
128       func->setAttrs(attrs.getDictionary(&getContext()));
129     }
130 
131     // Rename the control results.
132     ReturnOp terminator = cast<ReturnOp>(func.getBody()->getTerminator());
133     ArrayAttr control_attrs = terminator.control_ret_attrs();
134     if (!attrs.empty()) {
135       SmallVector<Attribute> control_ret_attrs;
136       control_ret_attrs.reserve(control_attrs.size());
137       for (NamedAttrList attrs : control_attrs.getAsRange<DictionaryAttr>()) {
138         attrs.set(dialect_->getTfgNameAttrIdentifier(), encode_new_name());
139         control_ret_attrs.push_back(attrs.getDictionary(&getContext()));
140       }
141       terminator.control_ret_attrsAttr(b.getArrayAttr(control_ret_attrs));
142     }
143 
144     // Rename all non-intrisic operations.
145     func.walk([this, &encode_new_name](Operation *op) {
146       if (op->hasTrait<OpTrait::IntrinsicOperation>()) return;
147       op->setAttr(dialect_->getNameAttrIdentifier(), encode_new_name());
148     });
149   }
150 
151  private:
152   // An instance of the TFG dialect for accessing cached identifiers.
153   TFGraphDialect *dialect_;
154   // An instance of the empty dictionary attribute.
155   DictionaryAttr empty_dict_;
156 };
157 }  // namespace
158 
CreateNameCompressPass()159 std::unique_ptr<Pass> CreateNameCompressPass() {
160   return std::make_unique<NameCompressPass>();
161 }
162 
163 namespace {
164 class StripDefaultAttrsPass
165     : public StripDefaultAttrsBase<StripDefaultAttrsPass> {
166  public:
initialize(MLIRContext * context)167   LogicalResult initialize(MLIRContext *context) override {
168     // Initialize the pass by getting a registered instance of the TensorFlow
169     // operation registry. If no instance was registered, this pass will fail.
170     dialect_ = context->getOrLoadDialect<TFGraphDialect>();
171     registry_ = nullptr;
172     if (auto registry_interface =
173             dialect_->getRegisteredInterface<TensorFlowOpRegistryInterface>()) {
174       registry_ = registry_interface->GetRegistry();
175     }
176     return success(registry_);
177   }
178 
runOnOperation()179   void runOnOperation() override {
180     WalkResult result = getOperation()->walk([&](Operation *op) {
181       // Ignore intrinsic operations.
182       if (op->hasTrait<OpTrait::IntrinsicOperation>())
183         return WalkResult::advance();
184 
185       // If removing default-valued attributes failed (attribute conversion
186       // error), bail out.
187       if (failed(removeDefaultValuedAttrs(op))) return WalkResult::interrupt();
188 
189       return WalkResult::advance();
190     });
191 
192     // If the pass failed on any operation, signal failure.
193     if (result.wasInterrupted()) return signalPassFailure();
194   }
195 
196  private:
197   // Remove attributes from the operation equal to their default values
198   // according to the TensorFlow op registry.
199   LogicalResult removeDefaultValuedAttrs(Operation *op);
200 
201   // The TFG dialect instance.
202   TFGraphDialect *dialect_;
203   // The TensorFlow op registry to query for default-valued attributes.
204   const tensorflow::OpRegistry *registry_;
205 };
206 }  // namespace
207 
removeDefaultValuedAttrs(Operation * op)208 LogicalResult StripDefaultAttrsPass::removeDefaultValuedAttrs(Operation *op) {
209   const tensorflow::OpRegistrationData *op_reg_data =
210       registry_->LookUp(op->getName().stripDialect().str());
211   // Ignore unregistered ops.
212   if (!op_reg_data) return success();
213 
214   // Find the attributes to remove.
215   ArrayRef<NamedAttribute> attrs = op->getAttrs();
216   llvm::BitVector indices_to_remove(attrs.size());
217   Builder b(&getContext());
218   for (const tensorflow::OpDef::AttrDef &attr : op_reg_data->op_def.attr()) {
219     // Ignore attributes without default values.
220     if (!attr.has_default_value()) continue;
221     auto it = impl::findAttrSorted(attrs.begin(), attrs.end(), attr.name());
222     // Ignore default-valued attributes that are already missing.
223     if (!it.second) continue;
224     // Convert the TensorFlow attribute value and compare it to the MLIR
225     // attribute.
226     tensorflow::StatusOr<Attribute> maybe_attr =
227         ConvertAttributeValue(attr.default_value(), b);
228     if (!maybe_attr.ok())
229       return op->emitError(maybe_attr.status().error_message());
230     if (maybe_attr.ValueOrDie() == it.first->getValue())
231       indices_to_remove.set(std::distance(attrs.begin(), it.first));
232   }
233   if (indices_to_remove.none()) return success();
234 
235   // Construct and set the new attributes.
236   SmallVector<NamedAttribute> new_attrs;
237   new_attrs.reserve(attrs.size());
238   for (auto &it : llvm::enumerate(attrs)) {
239     if (indices_to_remove.test(it.index())) continue;
240     new_attrs.push_back(it.value());
241   }
242   op->setAttrs(DictionaryAttr::getWithSorted(&getContext(), new_attrs));
243 
244   return success();
245 }
246 
CreateStripDefaultAttrsPass()247 std::unique_ptr<Pass> CreateStripDefaultAttrsPass() {
248   return std::make_unique<StripDefaultAttrsPass>();
249 }
250 
251 namespace {
252 class AddDefaultAttrsPass : public AddDefaultAttrsBase<AddDefaultAttrsPass> {
253  public:
initialize(MLIRContext * context)254   LogicalResult initialize(MLIRContext *context) override {
255     // Initialize the pass by getting a registered instance of the TensorFlow
256     // operation registry. If no instance was registered, this pass will fail.
257     dialect_ = context->getOrLoadDialect<TFGraphDialect>();
258     registry_ = nullptr;
259     if (auto registry_interface =
260             dialect_->getRegisteredInterface<TensorFlowOpRegistryInterface>()) {
261       registry_ = registry_interface->GetRegistry();
262     }
263     return success(registry_);
264   }
265 
runOnOperation()266   void runOnOperation() override {
267     WalkResult result = getOperation()->walk([&](Operation *op) {
268       // Ignore intrinsic operations.
269       if (op->hasTrait<OpTrait::IntrinsicOperation>())
270         return WalkResult::advance();
271 
272       // If removing default-valued attributes failed (attribute conversion
273       // error), bail out.
274       if (failed(addDefaultValuedAttrs(op))) return WalkResult::interrupt();
275 
276       return WalkResult::advance();
277     });
278 
279     // If the pass failed on any operation, signal failure.
280     if (result.wasInterrupted()) return signalPassFailure();
281   }
282 
283  private:
284   // Remove attributes from the operation equal to their default values
285   // according to the TensorFlow op registry.
286   LogicalResult addDefaultValuedAttrs(Operation *op);
287 
288   // The TFG dialect instance.
289   TFGraphDialect *dialect_;
290   // The TensorFlow op registry to query for default-valued attributes.
291   const tensorflow::OpRegistry *registry_;
292 };
293 }  // namespace
294 
addDefaultValuedAttrs(Operation * op)295 LogicalResult AddDefaultAttrsPass::addDefaultValuedAttrs(Operation *op) {
296   const tensorflow::OpRegistrationData *op_reg_data =
297       registry_->LookUp(op->getName().stripDialect().str());
298   // Ignore unregistered ops.
299   if (!op_reg_data) return success();
300 
301   // Ignore operations with no default-valued attributes.
302   if (llvm::all_of(op_reg_data->op_def.attr(),
303                    [](const auto &attr) { return !attr.has_default_value(); }))
304     return success();
305 
306   // Add missing default-valued attributes
307   Builder b(&getContext());
308   NamedAttrList attrs = op->getAttrDictionary();
309   for (const auto &attr : op_reg_data->op_def.attr()) {
310     // Ignore attributes without default values.
311     if (!attr.has_default_value()) continue;
312     // Ignore default-valued attributes that are present.
313     if (attrs.get(attr.name())) continue;
314     // Convert the TensorFlow attribute value and set it.
315     tensorflow::StatusOr<Attribute> maybe_attr =
316         ConvertAttributeValue(attr.default_value(), b);
317     if (!maybe_attr.ok())
318       return op->emitError(maybe_attr.status().error_message());
319     attrs.set(attr.name(), maybe_attr.ValueOrDie());
320   }
321   op->setAttrs(attrs.getDictionary(&getContext()));
322 
323   return success();
324 }
325 
CreateAddDefaultAttrsPass()326 std::unique_ptr<Pass> CreateAddDefaultAttrsPass() {
327   return std::make_unique<AddDefaultAttrsPass>();
328 }
329 
330 }  // namespace tfg
331 }  // namespace mlir
332