1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/lift_as_function_call_utils.h"
16
17 #include <queue>
18 #include <stack>
19 #include <string>
20
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringRef.h"
23 #include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h"
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
26
27 namespace mlir {
28 namespace quant {
29
30 // Checks if the op is inside a lifted function.
IsInLiftedFunc(Operation * op)31 bool IsInLiftedFunc(Operation *op) {
32 return op->getParentOfType<func::FuncOp>()->hasAttr(kFusedFunctionAttr);
33 }
34
35 // Inserts the function to the symbol table of the module thread-safely.
InsertToSymbolTable(Operation * module,Operation * function,const std::string & func_name)36 StringAttr InsertToSymbolTable(Operation *module, Operation *function,
37 const std::string &func_name) {
38 static tensorflow::mutex *mtx = new tensorflow::mutex();
39 tensorflow::mutex_lock lock(*mtx);
40
41 SymbolTable symbol_table(module);
42 std::string unique_name = func_name;
43 int32_t uniquing_counter = 0;
44 while (symbol_table.lookup(unique_name) != nullptr) {
45 ++uniquing_counter;
46 unique_name = func_name + "_" + std::to_string(uniquing_counter);
47 }
48 function->setAttr("sym_name",
49 StringAttr::get(module->getContext(), unique_name));
50 return symbol_table.insert(function);
51 }
52
createFusedFnCall(OpBuilder builder,Location location,StringRef func_name,TypeRange output_types,ValueRange args)53 ValueRange createFusedFnCall(OpBuilder builder, Location location,
54 StringRef func_name, TypeRange output_types,
55 ValueRange args) {
56 TF::PartitionedCallOp call_op = builder.create<TF::PartitionedCallOp>(
57 location, output_types, args,
58 FlatSymbolRefAttr::get(builder.getStringAttr(func_name)),
59 /*config=*/"", /*config_proto=*/"", /*executor_type=*/"");
60 call_op->setAttr(
61 kQuantTraitAttrName,
62 builder.getStringAttr(llvm::StringRef(
63 std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable]))));
64
65 return call_op.output();
66 }
67
68 // Finds ops in the paths from arguments to results. The ops is listed in an
69 // order that the former ops shouldn't have any dependencies on the later ones.
FindOpsFromArgumentsToResults(const llvm::SmallVector<Value> & arguments,const llvm::SmallVector<Value> & results)70 llvm::SmallVector<Operation *> FindOpsFromArgumentsToResults(
71 const llvm::SmallVector<Value> &arguments,
72 const llvm::SmallVector<Value> &results) {
73 std::queue<Value> value_queue;
74 for (Value result : results) {
75 value_queue.push(result);
76 }
77 absl::flat_hash_set<mlir::detail::ValueImpl *> argument_set;
78 for (Value argument : arguments) {
79 argument_set.insert(argument.getImpl());
80 }
81
82 // Searching for ops from results to arguments. Duplicate ops in the op stack
83 // are intentional in order to make sure the op on the top of the stack
84 // doesn't depends on any ops below it.
85 std::stack<Operation *> op_stack;
86 while (!value_queue.empty()) {
87 Value current_value = value_queue.front();
88 value_queue.pop();
89
90 Operation *defining_node = current_value.getDefiningOp();
91 if (defining_node == nullptr) continue;
92 op_stack.push(defining_node);
93 for (const auto &arg : defining_node->getOperands()) {
94 if (!argument_set.contains(arg.getImpl())) {
95 value_queue.push(arg);
96 }
97 }
98 }
99
100 // Remove duplicate ops from the op stack.
101 llvm::SmallVector<Operation *> sorted_ops;
102 absl::flat_hash_set<Operation *> unique_ops;
103 while (!op_stack.empty()) {
104 Operation *current_op = op_stack.top();
105 op_stack.pop();
106 if (unique_ops.contains(current_op)) continue;
107 sorted_ops.push_back(current_op);
108 unique_ops.insert(current_op);
109 }
110 return sorted_ops;
111 }
112
113 // Finds the name of each attribute in `attributes` and set the attr_map
114 // attribute which maps an attribute identifier to its attribute name. The
115 // identifier is the order of that attribute in `attributes`. This map
116 // is then used to set attributes in the quantized functions in the
117 // QuantizeCompositeFunctionsPass.
118 // For example, for tf.MatMul with `attributes` = {{"transpose_a", false},
119 // {"transpose_b", false}}, the generated attr_map is
120 // "0:transpose_a,1:transpose_b", where 0 and 1 are the respective attribute
121 // identifiers.
122 // This function returns success if all attributes could be found.
SetAttributeMap(MLIRContext * context,const llvm::SmallVector<NamedAttribute> & attributes,const llvm::SmallVector<Operation * > & ops)123 LogicalResult SetAttributeMap(
124 MLIRContext *context, const llvm::SmallVector<NamedAttribute> &attributes,
125 const llvm::SmallVector<Operation *> &ops) {
126 // A map to find which operation an attribute belongs to.
127 // The key for this map uses the entire NamedAttribute object, i.e. the
128 // {attribute_name, attribute_value} pair.
129 llvm::SmallDenseMap<NamedAttribute, Operation *> attr_to_op_map;
130 for (Operation *op : ops) {
131 for (const auto &named_attr : op->getAttrs()) {
132 attr_to_op_map.insert({named_attr, op});
133 }
134 }
135
136 for (int idx : llvm::seq<int>(0, attributes.size())) {
137 const NamedAttribute &attribute = attributes[idx];
138
139 // Skip the following steps if the attribute value is `NullAttribute`.
140 if (const auto string_attr =
141 attribute.getValue().dyn_cast_or_null<StringAttr>();
142 string_attr != nullptr &&
143 string_attr.getValue().equals(kNullAttributeValue)) {
144 continue;
145 }
146
147 if (attr_to_op_map.count(attribute) == 0) {
148 mlir::emitError(UnknownLoc::get(context),
149 "Could not find attribute: " + attribute.getName().str());
150 return failure();
151 }
152
153 Operation *owner_op = attr_to_op_map[attribute];
154
155 std::string new_attr_map_str{};
156 if (owner_op->hasAttr(kAttrMapAttribute)) {
157 new_attr_map_str =
158 owner_op->getAttrOfType<StringAttr>(kAttrMapAttribute).str();
159 absl::StrAppend(&new_attr_map_str, ",");
160 }
161
162 // Append "<identifier>:<attribute_name>". Ex) "0:transpose_a".
163 const std::string identifier = std::to_string(idx);
164 const mlir::StringAttr attribute_name = attribute.getName();
165 absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str());
166 owner_op->setAttr(kAttrMapAttribute,
167 StringAttr::get(context, new_attr_map_str));
168 }
169 return success();
170 }
171
172 // Creates a function to wrap the section between arguments and results.
LiftAsFunctionCall(OpBuilder builder,Location location,StringRef func_name,const llvm::SmallVector<Value> & arguments,const llvm::SmallVector<Value> & results,const llvm::SmallVector<NamedAttribute> & attributes)173 llvm::SmallVector<Value, 4> LiftAsFunctionCall(
174 OpBuilder builder, Location location, StringRef func_name,
175 const llvm::SmallVector<Value> &arguments,
176 const llvm::SmallVector<Value> &results,
177 const llvm::SmallVector<NamedAttribute> &attributes) {
178 MLIRContext *context = builder.getContext();
179 if (results.empty()) {
180 mlir::emitError(UnknownLoc::get(context), "No result values specified");
181 return {};
182 }
183 Operation *result_op = results[0].getDefiningOp();
184 auto module = result_op->getParentOfType<ModuleOp>();
185
186 // Create a private function and copy all ops between arguments and results.
187 auto current_func = result_op->getParentOfType<func::FuncOp>();
188 auto guard = OpBuilder::InsertionGuard(builder);
189 builder.setInsertionPointAfter(current_func);
190 TypeRange arg_types{ValueRange{arguments}};
191 TypeRange result_types{ValueRange{results}};
192 auto func_type = FunctionType::get(context, arg_types, result_types);
193
194 llvm::SmallVector<Location> arg_locs;
195 for (const auto &arg : arguments) {
196 arg_locs.push_back(arg.getLoc());
197 }
198 auto wrap_func = builder.create<func::FuncOp>(location, func_name, func_type);
199 wrap_func.setVisibility(SymbolTable::Visibility::Private);
200 wrap_func->setAttr(kFusedFunctionAttr, builder.getUnitAttr());
201 builder.createBlock(&wrap_func.getBody(), wrap_func.begin(), arg_types,
202 arg_locs);
203
204 BlockAndValueMapping mapping;
205 for (int32_t i : llvm::seq<int32_t>(0, arguments.size())) {
206 mapping.map(arguments[i], wrap_func.getArgument(i));
207 }
208
209 auto cloning_ops = FindOpsFromArgumentsToResults(arguments, results);
210 if (failed(SetAttributeMap(context, attributes, cloning_ops))) {
211 current_func.emitError() << "Some attributes couldn't be found.";
212 }
213 for (Operation *op : cloning_ops) {
214 builder.clone(*op, mapping);
215 }
216
217 llvm::SmallVector<Value> return_values;
218 for (Value result : results) {
219 return_values.push_back(mapping.lookupOrNull(result));
220 }
221 builder.create<mlir::func::ReturnOp>(location, return_values);
222
223 // Create a function call to the newly created function.
224 StringAttr new_func_name =
225 InsertToSymbolTable(module, wrap_func, func_name.str());
226 builder.setInsertionPointAfter(result_op);
227 ValueRange new_results = createFusedFnCall(
228 builder, location, new_func_name.getValue(), result_types, arguments);
229 return llvm::SmallVector<Value, 4>(new_results.begin(), new_results.end());
230 }
231
LiftAsFunctionCall(OpBuilder builder,Location location,StringRef func_name,const llvm::SmallVector<Value> & arguments,const llvm::SmallVector<Value> & results)232 llvm::SmallVector<Value, 4> LiftAsFunctionCall(
233 OpBuilder builder, Location location, StringRef func_name,
234 const llvm::SmallVector<Value> &arguments,
235 const llvm::SmallVector<Value> &results) {
236 llvm::SmallVector<NamedAttribute> attributes;
237 return LiftAsFunctionCall(builder, location, func_name, arguments, results,
238 attributes);
239 }
240
241 } // namespace quant
242 } // namespace mlir
243