1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <unordered_set>
18 #include <utility>
19 #include <vector>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/SymbolTable.h" // from @llvm-project
32 #include "mlir/IR/Visitors.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassManager.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Transforms/Passes.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
40 #include "tensorflow/dtensor/cc/constants.h"
41 #include "tensorflow/dtensor/mlir/device_utils.h"
42 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
43 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
44 #include "tensorflow/dtensor/mlir/op_utils.h"
45 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
46 #include "tensorflow/dtensor/mlir/value_utils.h"
47
48 namespace tensorflow {
49 namespace dtensor {
50 namespace {
51
52 constexpr char kEntryFuncAttr[] = "tf.entry_function";
53 constexpr char kSparseIndicesStr[] = "op_input_sparse_indices";
54 constexpr char kSparseDenseShapesStr[] = "op_input_sparse_dense_shapes";
55 constexpr char kSparseValuesStr[] = "op_input_sparse_values";
56
57 typedef struct SparseTensorToComponentInfo {
58 mlir::RankedTensorType indices;
59 mlir::RankedTensorType values;
60 mlir::RankedTensorType dense_shapes;
61 unsigned int func_op_arg_index;
62 } SparseTensorToComponentInfo;
63
UpdateFunctionSignature(mlir::func::FuncOp function,mlir::OpBuilder & builder)64 void UpdateFunctionSignature(mlir::func::FuncOp function,
65 mlir::OpBuilder& builder) {
66 function.setType(mlir::FunctionType::get(
67 builder.getContext(),
68 llvm::to_vector<4>(function.front().getArgumentTypes()),
69 function.getFunctionType().getResults()));
70 }
71
72 // Add input attributes for new sparsetensor components and remove the
73 // old sparsetensor value input attributes.
74 //
75 // TF has a list of comma separated input names within `kEntryFuncAttr`
76 // attribute, under 'inputs'. Update this comma separated list of input names
77 // by correctly deleting the sparse tensor input name and replacing it with
78 // three new sparse component input names.
79 //
80 // Without this update, MLIR conversion to GraphDef will fail since
81 // the number of input names will not match with the FuncOp num arguments.
82 //
83 // e.g. "op_input_1" should become
84 // "op_input_sparse_indices_0,op_input_sparse_dense_shapes_0,
85 // "op_input_sparse_values_0"
UpdateFunctionInputAttributes(mlir::MLIRContext & context,mlir::func::FuncOp main_func,mlir::OpBuilder & builder,const std::vector<SparseTensorToComponentInfo> & sparse_tensor_components)86 mlir::LogicalResult UpdateFunctionInputAttributes(
87 mlir::MLIRContext& context, mlir::func::FuncOp main_func,
88 mlir::OpBuilder& builder,
89 const std::vector<SparseTensorToComponentInfo>& sparse_tensor_components) {
90 llvm::SmallVector<llvm::StringRef, 2> input_names;
91
92 auto dict_attr =
93 main_func->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr);
94 if (dict_attr) {
95 if (!dict_attr.get("inputs").isa<mlir::StringAttr>())
96 return main_func.emitOpError("Missing attribute inputs in main FuncOp.");
97
98 dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
99 input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
100
101 llvm::SmallVector<std::string, 2> new_input_names;
102
103 absl::flat_hash_set<int> skip_indices;
104 for (const auto component : sparse_tensor_components) {
105 skip_indices.insert(component.func_op_arg_index);
106 }
107
108 for (auto i = 0; i < input_names.size(); ++i) {
109 if (skip_indices.find(i) == skip_indices.end()) {
110 new_input_names.push_back(input_names[i].str());
111 }
112 }
113
114 for (const auto component : sparse_tensor_components) {
115 int arg_index = component.func_op_arg_index;
116 new_input_names.push_back(
117 absl::StrCat(kSparseIndicesStr, "_", arg_index));
118 new_input_names.push_back(
119 absl::StrCat(kSparseDenseShapesStr, "_", arg_index));
120 new_input_names.push_back(absl::StrCat(kSparseValuesStr, "_", arg_index));
121 }
122
123 mlir::NamedAttrList attributes(dict_attr);
124 attributes.set(
125 "inputs",
126 mlir::StringAttr::get(&context, absl::StrJoin(new_input_names, ",")));
127 main_func->setAttr(kEntryFuncAttr, attributes.getDictionary(&context));
128 }
129 UpdateFunctionSignature(main_func, builder);
130 return mlir::success();
131 }
132
133 // For each SparseTensor block argument of the main FuncOp, create
134 // three of the component tensors, `indices`, `values`, and `dense_shapes`
135 // and add it to `sparse_tensor_components`.
CreateComponentTensorsFromSparseTensors(mlir::func::FuncOp main_func,mlir::OpBuilder & builder,std::vector<SparseTensorToComponentInfo> * sparse_tensor_components)136 void CreateComponentTensorsFromSparseTensors(
137 mlir::func::FuncOp main_func, mlir::OpBuilder& builder,
138 std::vector<SparseTensorToComponentInfo>* sparse_tensor_components) {
139 for (const auto block_arg : main_func.getArguments()) {
140 const auto is_sparse = main_func.getArgAttrOfType<mlir::BoolAttr>(
141 block_arg.getArgNumber(), kSparseValue);
142 if (is_sparse) {
143 sparse_tensor_components->push_back(SparseTensorToComponentInfo{
144 /*indices=*/mlir::RankedTensorType::get({-1, ValueRank(block_arg)},
145 builder.getI64Type()),
146 /*values=*/
147 mlir::RankedTensorType::get({-1},
148 block_arg.getType()
149 .dyn_cast<mlir::RankedTensorType>()
150 .getElementType()),
151 /*dense_shapes=*/
152 mlir::RankedTensorType::get({ValueRank(block_arg)},
153 builder.getI64Type()),
154 /*func_op_arg_index=*/block_arg.getArgNumber()});
155 }
156 }
157 }
158
159 // Inserts SparseTensor components `components` into `main_func` at the end
160 // of block arguments list.
UpdateFunctionWithSparseTensorComponents(mlir::MLIRContext & context,mlir::func::FuncOp main_func,mlir::OpBuilder & builder,const SparseTensorToComponentInfo & component)161 void UpdateFunctionWithSparseTensorComponents(
162 mlir::MLIRContext& context, mlir::func::FuncOp main_func,
163 mlir::OpBuilder& builder, const SparseTensorToComponentInfo& component) {
164 main_func.front().addArgument(component.indices, main_func.getLoc());
165 main_func.front().addArgument(component.dense_shapes, main_func.getLoc());
166 main_func.front().addArgument(component.values, main_func.getLoc());
167 UpdateFunctionSignature(main_func, builder);
168 }
169
170 struct DTensorSparseTensorToDenseTensor
171 : public DTensorSparseTensorToDenseTensorBase<
172 DTensorSparseTensorToDenseTensor> {
runOnOperationtensorflow::dtensor::__anon64d813170111::DTensorSparseTensorToDenseTensor173 void runOnOperation() override {
174 mlir::MLIRContext& context = getContext();
175 auto module = getOperation();
176 mlir::OpBuilder builder(&context);
177
178 mlir::func::FuncOp main_func =
179 module.lookupSymbol<mlir::func::FuncOp>("main");
180
181 // Save Arg Attributes for each argument for later use, this will be
182 // reset and reordered after we insert sparse tensor components arguments.
183 llvm::DenseMap<mlir::Value, llvm::ArrayRef<mlir::NamedAttribute>>
184 arg_attribute_map;
185 for (auto block_arg : main_func.getArguments()) {
186 arg_attribute_map.insert(std::make_pair(
187 block_arg, main_func.getArgAttrs(block_arg.getArgNumber())));
188 }
189
190 std::vector<SparseTensorToComponentInfo> sparse_tensor_components;
191 CreateComponentTensorsFromSparseTensors(main_func, builder,
192 &sparse_tensor_components);
193
194 // Update func arguments in place by replacing SparseTensors with their
195 // components and emitting a SparseToDenseOp before all ops that consume
196 // a SparseTensor.
197 for (const SparseTensorToComponentInfo& components :
198 sparse_tensor_components) {
199 // Insert SparseTensor component into the main function's block
200 // arguments.
201 mlir::Value sparse_tensor_value =
202 main_func.getArgument(components.func_op_arg_index);
203
204 UpdateFunctionWithSparseTensorComponents(context, main_func, builder,
205 components);
206 mlir::Operation* front_op = &main_func.front().front();
207 builder.setInsertionPoint(front_op);
208
209 // Emit a SparseToDenseOp and replace the SparseTensor with the result of
210 // this new op.
211 auto zero_scalar = CreateZeroScalarConst(builder, front_op->getLoc(),
212 sparse_tensor_value.getType()
213 .cast<mlir::TensorType>()
214 .getElementType());
215 if (!zero_scalar.has_value()) return signalPassFailure();
216 mlir::TF::SparseToDenseOp sparse_to_dense_op =
217 builder.create<mlir::TF::SparseToDenseOp>(
218 front_op->getLoc(), sparse_tensor_value.getType(),
219 mlir::ValueRange(
220 {main_func.getArgument(main_func.getNumArguments() - 3),
221 main_func.getArgument(main_func.getNumArguments() - 2),
222 main_func.getArgument(main_func.getNumArguments() - 1),
223 zero_scalar.value()}));
224
225 sparse_tensor_value.replaceAllUsesWith(sparse_to_dense_op);
226 if (!sparse_tensor_value.use_empty()) return signalPassFailure();
227 }
228
229 // Erase sparse tensor arguments now that we converted all of them.
230 for (int i = 0; i < sparse_tensor_components.size(); ++i)
231 main_func.front().eraseArgument(
232 sparse_tensor_components[i].func_op_arg_index - i);
233
234 // Reset block argument attributes since they are likely mixed up
235 // due to change in ordering of arguments.
236 for (auto block_arg : main_func.getArguments()) {
237 if (arg_attribute_map.find(block_arg) == arg_attribute_map.end()) {
238 main_func.setArgAttrs(block_arg.getArgNumber(),
239 llvm::ArrayRef<mlir::NamedAttribute>{});
240 } else {
241 main_func.setArgAttrs(block_arg.getArgNumber(),
242 arg_attribute_map[block_arg]);
243 }
244 }
245 if (mlir::failed(UpdateFunctionInputAttributes(context, main_func, builder,
246 sparse_tensor_components)))
247 return signalPassFailure();
248 };
249 };
250
251 } // namespace
252
253 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorSparseTensorToDenseTensor()254 CreateDTensorSparseTensorToDenseTensor() {
255 return std::make_unique<DTensorSparseTensorToDenseTensor>();
256 }
257
258 } // namespace dtensor
259 } // namespace tensorflow
260