xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/handle_sparsetensors.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <unordered_set>
18 #include <utility>
19 #include <vector>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
32 #include "mlir/IR/Visitors.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassManager.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Transforms/Passes.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
39 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
40 #include "tensorflow/dtensor/cc/constants.h"
41 #include "tensorflow/dtensor/mlir/device_utils.h"
42 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
43 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
44 #include "tensorflow/dtensor/mlir/op_utils.h"
45 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
46 #include "tensorflow/dtensor/mlir/value_utils.h"
47 
48 namespace tensorflow {
49 namespace dtensor {
50 namespace {
51 
52 constexpr char kEntryFuncAttr[] = "tf.entry_function";
53 constexpr char kSparseIndicesStr[] = "op_input_sparse_indices";
54 constexpr char kSparseDenseShapesStr[] = "op_input_sparse_dense_shapes";
55 constexpr char kSparseValuesStr[] = "op_input_sparse_values";
56 
57 typedef struct SparseTensorToComponentInfo {
58   mlir::RankedTensorType indices;
59   mlir::RankedTensorType values;
60   mlir::RankedTensorType dense_shapes;
61   unsigned int func_op_arg_index;
62 } SparseTensorToComponentInfo;
63 
UpdateFunctionSignature(mlir::func::FuncOp function,mlir::OpBuilder & builder)64 void UpdateFunctionSignature(mlir::func::FuncOp function,
65                              mlir::OpBuilder& builder) {
66   function.setType(mlir::FunctionType::get(
67       builder.getContext(),
68       llvm::to_vector<4>(function.front().getArgumentTypes()),
69       function.getFunctionType().getResults()));
70 }
71 
72 // Add input attributes for new sparsetensor components and remove the
73 // old sparsetensor value input attributes.
74 //
75 // TF has a list of comma separated input names within `kEntryFuncAttr`
76 // attribute, under 'inputs'. Update this comma separated list of input names
77 // by correctly deleting the sparse tensor input name and replacing it with
78 // three new sparse component input names.
79 //
80 // Without this update, MLIR conversion to GraphDef will fail since
81 // the number of input names will not match with the FuncOp num arguments.
82 //
83 // e.g. "op_input_1" should become
84 // "op_input_sparse_indices_0,op_input_sparse_dense_shapes_0,
85 // "op_input_sparse_values_0"
UpdateFunctionInputAttributes(mlir::MLIRContext & context,mlir::func::FuncOp main_func,mlir::OpBuilder & builder,const std::vector<SparseTensorToComponentInfo> & sparse_tensor_components)86 mlir::LogicalResult UpdateFunctionInputAttributes(
87     mlir::MLIRContext& context, mlir::func::FuncOp main_func,
88     mlir::OpBuilder& builder,
89     const std::vector<SparseTensorToComponentInfo>& sparse_tensor_components) {
90   llvm::SmallVector<llvm::StringRef, 2> input_names;
91 
92   auto dict_attr =
93       main_func->getAttrOfType<mlir::DictionaryAttr>(kEntryFuncAttr);
94   if (dict_attr) {
95     if (!dict_attr.get("inputs").isa<mlir::StringAttr>())
96       return main_func.emitOpError("Missing attribute inputs in main FuncOp.");
97 
98     dict_attr.get("inputs").cast<mlir::StringAttr>().getValue().split(
99         input_names, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false);
100 
101     llvm::SmallVector<std::string, 2> new_input_names;
102 
103     absl::flat_hash_set<int> skip_indices;
104     for (const auto component : sparse_tensor_components) {
105       skip_indices.insert(component.func_op_arg_index);
106     }
107 
108     for (auto i = 0; i < input_names.size(); ++i) {
109       if (skip_indices.find(i) == skip_indices.end()) {
110         new_input_names.push_back(input_names[i].str());
111       }
112     }
113 
114     for (const auto component : sparse_tensor_components) {
115       int arg_index = component.func_op_arg_index;
116       new_input_names.push_back(
117           absl::StrCat(kSparseIndicesStr, "_", arg_index));
118       new_input_names.push_back(
119           absl::StrCat(kSparseDenseShapesStr, "_", arg_index));
120       new_input_names.push_back(absl::StrCat(kSparseValuesStr, "_", arg_index));
121     }
122 
123     mlir::NamedAttrList attributes(dict_attr);
124     attributes.set(
125         "inputs",
126         mlir::StringAttr::get(&context, absl::StrJoin(new_input_names, ",")));
127     main_func->setAttr(kEntryFuncAttr, attributes.getDictionary(&context));
128   }
129   UpdateFunctionSignature(main_func, builder);
130   return mlir::success();
131 }
132 
133 // For each SparseTensor block argument of the main FuncOp, create
134 // three of the component tensors, `indices`, `values`, and `dense_shapes`
135 // and add it to `sparse_tensor_components`.
CreateComponentTensorsFromSparseTensors(mlir::func::FuncOp main_func,mlir::OpBuilder & builder,std::vector<SparseTensorToComponentInfo> * sparse_tensor_components)136 void CreateComponentTensorsFromSparseTensors(
137     mlir::func::FuncOp main_func, mlir::OpBuilder& builder,
138     std::vector<SparseTensorToComponentInfo>* sparse_tensor_components) {
139   for (const auto block_arg : main_func.getArguments()) {
140     const auto is_sparse = main_func.getArgAttrOfType<mlir::BoolAttr>(
141         block_arg.getArgNumber(), kSparseValue);
142     if (is_sparse) {
143       sparse_tensor_components->push_back(SparseTensorToComponentInfo{
144           /*indices=*/mlir::RankedTensorType::get({-1, ValueRank(block_arg)},
145                                                   builder.getI64Type()),
146           /*values=*/
147           mlir::RankedTensorType::get({-1},
148                                       block_arg.getType()
149                                           .dyn_cast<mlir::RankedTensorType>()
150                                           .getElementType()),
151           /*dense_shapes=*/
152           mlir::RankedTensorType::get({ValueRank(block_arg)},
153                                       builder.getI64Type()),
154           /*func_op_arg_index=*/block_arg.getArgNumber()});
155     }
156   }
157 }
158 
159 // Inserts SparseTensor components `components` into `main_func` at the end
160 // of block arguments list.
UpdateFunctionWithSparseTensorComponents(mlir::MLIRContext & context,mlir::func::FuncOp main_func,mlir::OpBuilder & builder,const SparseTensorToComponentInfo & component)161 void UpdateFunctionWithSparseTensorComponents(
162     mlir::MLIRContext& context, mlir::func::FuncOp main_func,
163     mlir::OpBuilder& builder, const SparseTensorToComponentInfo& component) {
164   main_func.front().addArgument(component.indices, main_func.getLoc());
165   main_func.front().addArgument(component.dense_shapes, main_func.getLoc());
166   main_func.front().addArgument(component.values, main_func.getLoc());
167   UpdateFunctionSignature(main_func, builder);
168 }
169 
170 struct DTensorSparseTensorToDenseTensor
171     : public DTensorSparseTensorToDenseTensorBase<
172           DTensorSparseTensorToDenseTensor> {
runOnOperationtensorflow::dtensor::__anon64d813170111::DTensorSparseTensorToDenseTensor173   void runOnOperation() override {
174     mlir::MLIRContext& context = getContext();
175     auto module = getOperation();
176     mlir::OpBuilder builder(&context);
177 
178     mlir::func::FuncOp main_func =
179         module.lookupSymbol<mlir::func::FuncOp>("main");
180 
181     // Save Arg Attributes for each argument for later use, this will be
182     // reset and reordered after we insert sparse tensor components arguments.
183     llvm::DenseMap<mlir::Value, llvm::ArrayRef<mlir::NamedAttribute>>
184         arg_attribute_map;
185     for (auto block_arg : main_func.getArguments()) {
186       arg_attribute_map.insert(std::make_pair(
187           block_arg, main_func.getArgAttrs(block_arg.getArgNumber())));
188     }
189 
190     std::vector<SparseTensorToComponentInfo> sparse_tensor_components;
191     CreateComponentTensorsFromSparseTensors(main_func, builder,
192                                             &sparse_tensor_components);
193 
194     // Update func arguments in place by replacing SparseTensors with their
195     // components and emitting a SparseToDenseOp before all ops that consume
196     // a SparseTensor.
197     for (const SparseTensorToComponentInfo& components :
198          sparse_tensor_components) {
199       // Insert SparseTensor component into the main function's block
200       // arguments.
201       mlir::Value sparse_tensor_value =
202           main_func.getArgument(components.func_op_arg_index);
203 
204       UpdateFunctionWithSparseTensorComponents(context, main_func, builder,
205                                                components);
206       mlir::Operation* front_op = &main_func.front().front();
207       builder.setInsertionPoint(front_op);
208 
209       // Emit a SparseToDenseOp and replace the SparseTensor with the result of
210       // this new op.
211       auto zero_scalar = CreateZeroScalarConst(builder, front_op->getLoc(),
212                                                sparse_tensor_value.getType()
213                                                    .cast<mlir::TensorType>()
214                                                    .getElementType());
215       if (!zero_scalar.has_value()) return signalPassFailure();
216       mlir::TF::SparseToDenseOp sparse_to_dense_op =
217           builder.create<mlir::TF::SparseToDenseOp>(
218               front_op->getLoc(), sparse_tensor_value.getType(),
219               mlir::ValueRange(
220                   {main_func.getArgument(main_func.getNumArguments() - 3),
221                    main_func.getArgument(main_func.getNumArguments() - 2),
222                    main_func.getArgument(main_func.getNumArguments() - 1),
223                    zero_scalar.value()}));
224 
225       sparse_tensor_value.replaceAllUsesWith(sparse_to_dense_op);
226       if (!sparse_tensor_value.use_empty()) return signalPassFailure();
227     }
228 
229     // Erase sparse tensor arguments now that we converted all of them.
230     for (int i = 0; i < sparse_tensor_components.size(); ++i)
231       main_func.front().eraseArgument(
232           sparse_tensor_components[i].func_op_arg_index - i);
233 
234     // Reset block argument attributes since they are likely mixed up
235     // due to change in ordering of arguments.
236     for (auto block_arg : main_func.getArguments()) {
237       if (arg_attribute_map.find(block_arg) == arg_attribute_map.end()) {
238         main_func.setArgAttrs(block_arg.getArgNumber(),
239                               llvm::ArrayRef<mlir::NamedAttribute>{});
240       } else {
241         main_func.setArgAttrs(block_arg.getArgNumber(),
242                               arg_attribute_map[block_arg]);
243       }
244     }
245     if (mlir::failed(UpdateFunctionInputAttributes(context, main_func, builder,
246                                                    sparse_tensor_components)))
247       return signalPassFailure();
248   };
249 };
250 
251 }  // namespace
252 
253 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorSparseTensorToDenseTensor()254 CreateDTensorSparseTensorToDenseTensor() {
255   return std::make_unique<DTensorSparseTensorToDenseTensor>();
256 }
257 
258 }  // namespace dtensor
259 }  // namespace tensorflow
260