xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/shape_inference/pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/shape_inference/pass.h"
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
22 #include "mlir/IR/Operation.h"  // from @llvm-project
23 #include "mlir/IR/Value.h"  // from @llvm-project
24 #include "mlir/IR/Visitors.h"  // from @llvm-project
25 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
26 #include "mlir/Support/LLVM.h"  // from @llvm-project
27 #include "tensorflow/core/framework/shape_inference.h"
28 #include "tensorflow/core/ir/ops.h"
29 #include "tensorflow/core/ir/tf_op_wrapper.h"
30 #include "tensorflow/core/ir/types/dialect.h"
31 #include "tensorflow/core/ir/utils/shape_inference_utils.h"
32 #include "tensorflow/core/transforms/pass_detail.h"
33 
34 namespace mlir {
35 namespace tfg {
36 
37 using tensorflow::shape_inference::DimensionHandle;
38 using tensorflow::shape_inference::InferenceContext;
39 using tensorflow::shape_inference::ShapeHandle;
40 
41 // Only non-static shape or type with subtype can be refined.
CanBeRefined(Type type)42 static bool CanBeRefined(Type type) {
43   auto shape_type = type.dyn_cast<ShapedType>();
44   if (!shape_type) return false;
45 
46   // Returns whether type with subtypes can be further refined.
47   auto can_refine_subtypes = [](tf_type::TensorFlowTypeWithSubtype tws) {
48     return tws.GetSubtypes().empty() ||
49            llvm::any_of(tws.GetSubtypes(), CanBeRefined);
50   };
51   auto type_with_subtype = shape_type.getElementType()
52                                .dyn_cast<tf_type::TensorFlowTypeWithSubtype>();
53   if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
54 
55   return !shape_type.hasStaticShape();
56 }
57 
CanBeRefined(Operation * op)58 static bool CanBeRefined(Operation *op) {
59   return llvm::any_of(op->getResultTypes(),
60                       static_cast<bool (*)(Type)>(CanBeRefined));
61 }
62 
63 class ShapeInference : public ShapeInferenceBase<ShapeInference> {
64  public:
65   void runOnOperation() override;
66 
67  private:
68   // Cache the tensor value if possible. After inferring the shape, some
69   // operations may also be able to construct the tensor value, e.g., an
70   // IdentityOp with Const operand, in these cases, cache the tensor value which
71   // may be useful for their users' shape inference.
72   void TryToCacheResultsTensorValue(Operation *op);
73 
74   // Get the tensor value if possible, return nullptr otherwise.
GetTensorValue(Value result)75   DenseElementsAttr GetTensorValue(Value result) {
76     OpResult op_result = result.dyn_cast<OpResult>();
77     if (op_result) {
78       auto it = cached_tensor_values_.find(op_result);
79       if (it != cached_tensor_values_.end()) return it->second;
80     }
81     return nullptr;
82   }
83 
84   DenseMap<OpResult, DenseElementsAttr> cached_tensor_values_;
85 };
86 
TryToCacheResultsTensorValue(Operation * op)87 void ShapeInference::TryToCacheResultsTensorValue(Operation *op) {
88   // Only op with static shape is able to construct the tensor value.
89   if (llvm::all_of(op->getResults().drop_back(), [this](Value value) {
90         auto shape = value.getType().cast<ShapedType>();
91         /// NOMUTANTS -- shape.hasStaticShape is a cheaper operation than
92         /// GetTensorValue
93         return (!shape.hasStaticShape() || GetTensorValue(value) != nullptr);
94       })) {
95     return;
96   }
97 
98   StringRef op_name = op->getName().stripDialect();
99   if (op_name == "Const") {
100     cached_tensor_values_[op->getResult(0)] =
101         op->getAttrOfType<DenseElementsAttr>("value");
102   } else if (op_name == "Identity" ||
103              (op_name == "IdentityN" &&
104               TFOp(op).getNonControlOperands().size() == 1)) {
105     DenseElementsAttr operand_tensor_value = GetTensorValue(op->getOperand(0));
106     if (!operand_tensor_value) return;
107     cached_tensor_values_[op->getResult(0)] = operand_tensor_value;
108   } else if (op_name == "Rank") {
109     ShapedType operand_shape = op->getOperand(0).getType().cast<ShapedType>();
110     if (!operand_shape.hasRank()) return;
111     ShapedType return_shape = op->getResultTypes()[0];
112     DenseElementsAttr tensor_value;
113     if (return_shape.getElementType().isInteger(32)) {
114       tensor_value = DenseElementsAttr::get(
115           op->getResultTypes()[0], ArrayRef<int>(operand_shape.getRank()));
116     } else {
117       tensor_value = DenseElementsAttr::get(
118           op->getResultTypes()[0], ArrayRef<int64_t>(operand_shape.getRank()));
119     }
120     cached_tensor_values_[op->getResult(0)] = tensor_value;
121   } else if (op_name == "Size") {
122     ShapedType operand_shape = op->getOperand(0).getType().cast<ShapedType>();
123     if (!operand_shape.hasStaticShape()) return;
124     ShapedType return_shape = op->getResultTypes()[0];
125     DenseElementsAttr tensor_value;
126     if (return_shape.getElementType().isInteger(32)) {
127       tensor_value =
128           DenseElementsAttr::get(op->getResultTypes()[0],
129                                  ArrayRef<int>(operand_shape.getNumElements()));
130     } else {
131       tensor_value = DenseElementsAttr::get(
132           op->getResultTypes()[0],
133           ArrayRef<int64_t>(operand_shape.getNumElements()));
134     }
135     cached_tensor_values_[op->getResult(0)] = tensor_value;
136   } else if (op_name == "Shape" || op_name == "ShapeN") {
137     for (OpOperand &operand : op->getOpOperands()) {
138       Type operand_type = operand.get().getType();
139       if (operand_type.isa<ControlType>()) break;
140 
141       auto operand_shape = operand_type.cast<ShapedType>();
142       if (!operand_shape.hasStaticShape()) continue;
143 
144       int idx = operand.getOperandNumber();
145       ShapedType return_shape = op->getResultTypes()[idx];
146       DenseElementsAttr tensor_value;
147       if (return_shape.getElementType().isInteger(32)) {
148         tensor_value = DenseElementsAttr::get<int>(
149             op->getResultTypes()[idx],
150             SmallVector<int>(llvm::map_range(
151                 operand_shape.getShape(),
152                 [](int64_t dim) { return static_cast<int>(dim); })));
153       } else {
154         tensor_value = DenseElementsAttr::get(op->getResultTypes()[idx],
155                                               operand_shape.getShape());
156       }
157       cached_tensor_values_[op->getResult(idx)] = tensor_value;
158     }
159   }
160 
161   // TODO(chiahungduan): In Grappler, it has cases for
162   // ConcatV2/Pack/Slice/StrideSlice which has their shape inference logic
163   // similar to how we do constant folding on them. I think constant folding
164   // will cover most of the cases. Handle them on demand later.
165 }
166 
runOnOperation()167 void ShapeInference::runOnOperation() {
168   auto operand_as_constant_fn = [this](Value operand) -> Attribute {
169     return GetTensorValue(operand);
170   };
171 
172   auto op_result_as_shape_fn = [this](InferenceContext &ic,
173                                       OpResult op_result) -> ShapeHandle {
174     auto rt = op_result.getType().dyn_cast<RankedTensorType>();
175     // NOMUTANTS -- TODO(chiahungduan): Review this condition to see if shape
176     // with known rank but unknown dimension is acceptable.
177     if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {};
178 
179     std::vector<DimensionHandle> dims(rt.getDimSize(0), ic.UnknownDim());
180 
181     DenseElementsAttr attr = GetTensorValue(op_result);
182     if (!attr) return {};
183 
184     for (const auto &element : llvm::enumerate(attr.getValues<APInt>()))
185       dims[element.index()] = ic.MakeDim(element.value().getSExtValue());
186     return ic.MakeShape(dims);
187   };
188 
189   auto infer_and_update_shapes = [&](Operation *op) -> bool {
190     auto result_element_type_fn = [&](int idx) -> Type {
191       return op->getResult(idx).getType().cast<ShapedType>().getElementType();
192     };
193 
194     SmallVector<ShapedTypeComponents> results;
195     if (failed(InferReturnTypeComponentsForTFOp(
196             op->getLoc(), op, TFOp(op).getNonControlOperands(), graph_version_,
197             operand_as_constant_fn, op_result_as_shape_fn,
198             result_element_type_fn,
199             /*get_attr_values_fn=*/nullptr, results)))
200       return false;
201 
202     bool updated = false;
203     for (auto it : llvm::zip(op->getResults().drop_back(), results)) {
204       OpResult op_result = std::get<0>(it);
205       ShapedTypeComponents result = std::get<1>(it);
206       TensorType inferred_type;
207       if (result.hasRank()) {
208         inferred_type =
209             RankedTensorType::get(result.getDims(), result.getElementType());
210       } else {
211         inferred_type = UnrankedTensorType::get(result.getElementType());
212       }
213 
214       Type refined_type = tf_type::GetCastCompatibleType(
215           op_result.getType().cast<ShapedType>(), inferred_type);
216 
217       // Certain attributes like _output_shapes may have incorrect shape
218       // information. When it's incompatible, use the result of shape inference
219       // context
220       if (!refined_type) refined_type = inferred_type;
221 
222       if (refined_type == op_result.getType()) continue;
223 
224       op_result.setType(refined_type);
225       updated = true;
226     }
227 
228     if (updated) TryToCacheResultsTensorValue(op);
229 
230     return updated;
231   };
232 
233   // Reset the cached tensor value.
234   cached_tensor_values_.clear();
235 
236   // Traverse all the operations and do the first round inference. We don't
237   // record any operations that need to be updated because most of them may lack
238   // shape information.
239   getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) {
240     if (auto func = dyn_cast<GraphFuncOp>(op)) {
241       // Don't infer the shape of ops in generic function, just skip it.
242       if (func.generic()) return WalkResult::skip();
243       return WalkResult::advance();
244     }
245     if (isa<ModuleOp, GraphOp>(op) || op->getNumResults() == 0)
246       return WalkResult::advance();
247 
248     if (!CanBeRefined(op)) {
249       TryToCacheResultsTensorValue(op);
250       return WalkResult::advance();
251     }
252 
253     (void)infer_and_update_shapes(op);
254     return WalkResult::advance();
255   });
256 
257   // This is used to track the set of operations that may be able to infer their
258   // shape. When an operation infers its shape successfully, it'll add its user
259   // to this vector. Which implies that an operation may be added multiple
260   // times if it has multiple operands. Use SetVector to avoid keeping duplicate
261   // entry.
262   SetVector<Operation *> may_need_update;
263 
264   // Collect operations that have the chance to infer the more precise shape
265   // information.
266   getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) {
267     if (auto func = dyn_cast<GraphFuncOp>(op)) {
268       // Don't infer the shape of ops in generic function, just skip it.
269       if (func.generic()) return WalkResult::skip();
270       return WalkResult::advance();
271     }
272     if (isa<ModuleOp, tfg::GraphOp>(op) || op->getNumResults() == 0)
273       return WalkResult::advance();
274 
275     // This op still needs to refine its shape, so there's no chance for its
276     // user to refine their shape as well.
277     if (CanBeRefined(op)) return WalkResult::advance();
278 
279     for (OpResult res : op->getResults().drop_back()) {
280       for (Operation *user : res.getUsers())
281         if (CanBeRefined(user)) may_need_update.insert(user);
282     }
283 
284     return WalkResult::advance();
285   });
286 
287   // TODO(chiahungduan): We may need to limit the iterations.
288   while (!may_need_update.empty()) {
289     Operation *op = may_need_update.pop_back_val();
290     bool updated = infer_and_update_shapes(op);
291     if (!updated) continue;
292 
293     // The users may be able to refine their shapes.
294     for (Value v : op->getResults().drop_back()) {
295       for (Operation *user : v.getUsers()) {
296         if (CanBeRefined(user)) may_need_update.insert(user);
297       }
298     }
299   }
300 
301   // Update the function signature.
302   getOperation()->walk([&](GraphFuncOp func) {
303     FunctionType func_type = func.function_type();
304     Operation *return_op = func.getBody()->getTerminator();
305 
306     bool types_updated = false;
307     for (auto &indexed_type : llvm::enumerate(func_type.getResults())) {
308       int res_num = indexed_type.index();
309       Type return_arg_type = return_op->getOperand(res_num).getType();
310       if (return_arg_type != indexed_type.value()) {
311         types_updated = true;
312         break;
313       }
314     }
315 
316     if (!types_updated) return;
317 
318     func.function_typeAttr(TypeAttr::get(
319         FunctionType::get(&getContext(), func_type.getInputs(),
320                           TFOp(return_op).getNonControlOperands().getTypes())));
321   });
322 }
323 
CreateShapeInferencePass()324 std::unique_ptr<Pass> CreateShapeInferencePass() {
325   return std::make_unique<ShapeInference>();
326 }
327 
328 }  // namespace tfg
329 }  // namespace mlir
330