1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h"
16 
17 #include <queue>
18 #include <stack>
19 #include <string>
20 
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26 
27 namespace mlir {
28 namespace quant {
29 
30 // Checks if the op is inside a lifted function.
IsInLiftedFunc(Operation * op)31 bool IsInLiftedFunc(Operation *op) {
32   return op->getParentOfType<func::FuncOp>()->hasAttr(kFusedFunctionAttr);
33 }
34 
35 // Inserts the function to the symbol table of the module thread-safely.
InsertToSymbolTable(Operation * module,Operation * function,const std::string & func_name)36 StringAttr InsertToSymbolTable(Operation *module, Operation *function,
37                                const std::string &func_name) {
38   static tensorflow::mutex *mtx = new tensorflow::mutex();
39   tensorflow::mutex_lock lock(*mtx);
40 
41   SymbolTable symbol_table(module);
42   std::string unique_name = func_name;
43   int32_t uniquing_counter = 0;
44   while (symbol_table.lookup(unique_name) != nullptr) {
45     ++uniquing_counter;
46     unique_name = func_name + "_" + std::to_string(uniquing_counter);
47   }
48   function->setAttr("sym_name",
49                     StringAttr::get(module->getContext(), unique_name));
50   return symbol_table.insert(function);
51 }
52 
createFusedFnCall(OpBuilder builder,Location location,StringRef func_name,TypeRange output_types,ValueRange args)53 ValueRange createFusedFnCall(OpBuilder builder, Location location,
54                              StringRef func_name, TypeRange output_types,
55                              ValueRange args) {
56   TF::PartitionedCallOp call_op = builder.create<TF::PartitionedCallOp>(
57       location, output_types, args,
58       FlatSymbolRefAttr::get(builder.getStringAttr(func_name)),
59       /*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
60   call_op->setAttr(
61       kQuantTraitAttrName,
62       builder.getStringAttr(llvm::StringRef(
63           std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable]))));
64 
65   return call_op.output();
66 }
67 
68 // Finds ops in the paths from arguments to results. The ops is listed in an
69 // order that the former ops shouldn't have any dependencies on the later ones.
FindOpsFromArgumentsToResults(const llvm::SmallVector<Value> & arguments,const llvm::SmallVector<Value> & results)70 llvm::SmallVector<Operation *> FindOpsFromArgumentsToResults(
71     const llvm::SmallVector<Value> &arguments,
72     const llvm::SmallVector<Value> &results) {
73   std::queue<Value> value_queue;
74   for (Value result : results) {
75     value_queue.push(result);
76   }
77   absl::flat_hash_set<mlir::detail::ValueImpl *> argument_set;
78   for (Value argument : arguments) {
79     argument_set.insert(argument.getImpl());
80   }
81 
82   // Searching for ops from results to arguments. Duplicate ops in the op stack
83   // are intentional in order to make sure the op on the top of the stack
84   // doesn't depends on any ops below it.
85   std::stack<Operation *> op_stack;
86   while (!value_queue.empty()) {
87     Value current_value = value_queue.front();
88     value_queue.pop();
89 
90     Operation *defining_node = current_value.getDefiningOp();
91     if (defining_node == nullptr) continue;
92     op_stack.push(defining_node);
93     for (const auto &arg : defining_node->getOperands()) {
94       if (!argument_set.contains(arg.getImpl())) {
95         value_queue.push(arg);
96       }
97     }
98   }
99 
100   // Remove duplicate ops from the op stack.
101   llvm::SmallVector<Operation *> sorted_ops;
102   absl::flat_hash_set<Operation *> unique_ops;
103   while (!op_stack.empty()) {
104     Operation *current_op = op_stack.top();
105     op_stack.pop();
106     if (unique_ops.contains(current_op)) continue;
107     sorted_ops.push_back(current_op);
108     unique_ops.insert(current_op);
109   }
110   return sorted_ops;
111 }
112 
113 // Finds the name of each attribute in `attributes` and set the attr_map
114 // attribute which maps an attribute identifier to its attribute name. The
115 // identifier is the order of that attribute in `attributes`. This map
116 // is then used to set attributes in the quantized functions in the
117 // QuantizeCompositeFunctionsPass.
118 // For example, for tf.MatMul with `attributes` = {{"transpose_a", false},
119 // {"transpose_b", false}}, the generated attr_map is
120 // "0:transpose_a,1:transpose_b", where 0 and 1 are the respective attribute
121 // identifiers.
122 // This function returns success if all attributes could be found.
SetAttributeMap(MLIRContext * context,const llvm::SmallVector<NamedAttribute> & attributes,const llvm::SmallVector<Operation * > & ops)123 LogicalResult SetAttributeMap(
124     MLIRContext *context, const llvm::SmallVector<NamedAttribute> &attributes,
125     const llvm::SmallVector<Operation *> &ops) {
126   // A map to find which operation an attribute belongs to.
127   // The key for this map uses the entire NamedAttribute object, i.e. the
128   // {attribute_name, attribute_value} pair.
129   llvm::SmallDenseMap<NamedAttribute, Operation *> attr_to_op_map;
130   for (Operation *op : ops) {
131     for (const auto &named_attr : op->getAttrs()) {
132       attr_to_op_map.insert({named_attr, op});
133     }
134   }
135 
136   for (int idx : llvm::seq<int>(0, attributes.size())) {
137     const NamedAttribute &attribute = attributes[idx];
138 
139     // Skip the following steps if the attribute value is `NullAttribute`.
140     if (const auto string_attr =
141             attribute.getValue().dyn_cast_or_null<StringAttr>();
142         string_attr != nullptr &&
143         string_attr.getValue().equals(kNullAttributeValue)) {
144       continue;
145     }
146 
147     if (attr_to_op_map.count(attribute) == 0) {
148       mlir::emitError(UnknownLoc::get(context),
149                       "Could not find attribute: " + attribute.getName().str());
150       return failure();
151     }
152 
153     Operation *owner_op = attr_to_op_map[attribute];
154 
155     std::string new_attr_map_str{};
156     if (owner_op->hasAttr(kAttrMapAttribute)) {
157       new_attr_map_str =
158           owner_op->getAttrOfType<StringAttr>(kAttrMapAttribute).str();
159       absl::StrAppend(&new_attr_map_str, ",");
160     }
161 
162     // Append "<identifier>:<attribute_name>". Ex) "0:transpose_a".
163     const std::string identifier = std::to_string(idx);
164     const mlir::StringAttr attribute_name = attribute.getName();
165     absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str());
166     owner_op->setAttr(kAttrMapAttribute,
167                       StringAttr::get(context, new_attr_map_str));
168   }
169   return success();
170 }
171 
172 // Creates a function to wrap the section between arguments and results.
LiftAsFunctionCall(OpBuilder builder,Location location,StringRef func_name,const llvm::SmallVector<Value> & arguments,const llvm::SmallVector<Value> & results,const llvm::SmallVector<NamedAttribute> & attributes)173 llvm::SmallVector<Value, 4> LiftAsFunctionCall(
174     OpBuilder builder, Location location, StringRef func_name,
175     const llvm::SmallVector<Value> &arguments,
176     const llvm::SmallVector<Value> &results,
177     const llvm::SmallVector<NamedAttribute> &attributes) {
178   MLIRContext *context = builder.getContext();
179   if (results.empty()) {
180     mlir::emitError(UnknownLoc::get(context), "No result values specified");
181     return {};
182   }
183   Operation *result_op = results[0].getDefiningOp();
184   auto module = result_op->getParentOfType<ModuleOp>();
185 
186   // Create a private function and copy all ops between arguments and results.
187   auto current_func = result_op->getParentOfType<func::FuncOp>();
188   auto guard = OpBuilder::InsertionGuard(builder);
189   builder.setInsertionPointAfter(current_func);
190   TypeRange arg_types{ValueRange{arguments}};
191   TypeRange result_types{ValueRange{results}};
192   auto func_type = FunctionType::get(context, arg_types, result_types);
193 
194   llvm::SmallVector<Location> arg_locs;
195   for (const auto &arg : arguments) {
196     arg_locs.push_back(arg.getLoc());
197   }
198   auto wrap_func = builder.create<func::FuncOp>(location, func_name, func_type);
199   wrap_func.setVisibility(SymbolTable::Visibility::Private);
200   wrap_func->setAttr(kFusedFunctionAttr, builder.getUnitAttr());
201   builder.createBlock(&wrap_func.getBody(), wrap_func.begin(), arg_types,
202                       arg_locs);
203 
204   BlockAndValueMapping mapping;
205   for (int32_t i : llvm::seq<int32_t>(0, arguments.size())) {
206     mapping.map(arguments[i], wrap_func.getArgument(i));
207   }
208 
209   auto cloning_ops = FindOpsFromArgumentsToResults(arguments, results);
210   if (failed(SetAttributeMap(context, attributes, cloning_ops))) {
211     current_func.emitError() << "Some attributes couldn't be found.";
212   }
213   for (Operation *op : cloning_ops) {
214     builder.clone(*op, mapping);
215   }
216 
217   llvm::SmallVector<Value> return_values;
218   for (Value result : results) {
219     return_values.push_back(mapping.lookupOrNull(result));
220   }
221   builder.create<mlir::func::ReturnOp>(location, return_values);
222 
223   // Create a function call to the newly created function.
224   StringAttr new_func_name =
225       InsertToSymbolTable(module, wrap_func, func_name.str());
226   builder.setInsertionPointAfter(result_op);
227   ValueRange new_results = createFusedFnCall(
228       builder, location, new_func_name.getValue(), result_types, arguments);
229   return llvm::SmallVector<Value, 4>(new_results.begin(), new_results.end());
230 }
231 
LiftAsFunctionCall(OpBuilder builder,Location location,StringRef func_name,const llvm::SmallVector<Value> & arguments,const llvm::SmallVector<Value> & results)232 llvm::SmallVector<Value, 4> LiftAsFunctionCall(
233     OpBuilder builder, Location location, StringRef func_name,
234     const llvm::SmallVector<Value> &arguments,
235     const llvm::SmallVector<Value> &results) {
236   llvm::SmallVector<NamedAttribute> attributes;
237   return LiftAsFunctionCall(builder, location, func_name, arguments, results,
238                             attributes);
239 }
240 
241 }  // namespace quant
242 }  // namespace mlir
243