1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/shape_inference/pass.h"
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
22 #include "mlir/IR/Operation.h" // from @llvm-project
23 #include "mlir/IR/Value.h" // from @llvm-project
24 #include "mlir/IR/Visitors.h" // from @llvm-project
25 #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
26 #include "mlir/Support/LLVM.h" // from @llvm-project
27 #include "tensorflow/core/framework/shape_inference.h"
28 #include "tensorflow/core/ir/ops.h"
29 #include "tensorflow/core/ir/tf_op_wrapper.h"
30 #include "tensorflow/core/ir/types/dialect.h"
31 #include "tensorflow/core/ir/utils/shape_inference_utils.h"
32 #include "tensorflow/core/transforms/pass_detail.h"
33
34 namespace mlir {
35 namespace tfg {
36
37 using tensorflow::shape_inference::DimensionHandle;
38 using tensorflow::shape_inference::InferenceContext;
39 using tensorflow::shape_inference::ShapeHandle;
40
41 // Only non-static shape or type with subtype can be refined.
CanBeRefined(Type type)42 static bool CanBeRefined(Type type) {
43 auto shape_type = type.dyn_cast<ShapedType>();
44 if (!shape_type) return false;
45
46 // Returns whether type with subtypes can be further refined.
47 auto can_refine_subtypes = [](tf_type::TensorFlowTypeWithSubtype tws) {
48 return tws.GetSubtypes().empty() ||
49 llvm::any_of(tws.GetSubtypes(), CanBeRefined);
50 };
51 auto type_with_subtype = shape_type.getElementType()
52 .dyn_cast<tf_type::TensorFlowTypeWithSubtype>();
53 if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
54
55 return !shape_type.hasStaticShape();
56 }
57
CanBeRefined(Operation * op)58 static bool CanBeRefined(Operation *op) {
59 return llvm::any_of(op->getResultTypes(),
60 static_cast<bool (*)(Type)>(CanBeRefined));
61 }
62
63 class ShapeInference : public ShapeInferenceBase<ShapeInference> {
64 public:
65 void runOnOperation() override;
66
67 private:
68 // Cache the tensor value if possible. After inferring the shape, some
69 // operations may also be able to construct the tensor value, e.g., an
70 // IdentityOp with Const operand, in these cases, cache the tensor value which
71 // may be useful for their users' shape inference.
72 void TryToCacheResultsTensorValue(Operation *op);
73
74 // Get the tensor value if possible, return nullptr otherwise.
GetTensorValue(Value result)75 DenseElementsAttr GetTensorValue(Value result) {
76 OpResult op_result = result.dyn_cast<OpResult>();
77 if (op_result) {
78 auto it = cached_tensor_values_.find(op_result);
79 if (it != cached_tensor_values_.end()) return it->second;
80 }
81 return nullptr;
82 }
83
84 DenseMap<OpResult, DenseElementsAttr> cached_tensor_values_;
85 };
86
TryToCacheResultsTensorValue(Operation * op)87 void ShapeInference::TryToCacheResultsTensorValue(Operation *op) {
88 // Only op with static shape is able to construct the tensor value.
89 if (llvm::all_of(op->getResults().drop_back(), [this](Value value) {
90 auto shape = value.getType().cast<ShapedType>();
91 /// NOMUTANTS -- shape.hasStaticShape is a cheaper operation than
92 /// GetTensorValue
93 return (!shape.hasStaticShape() || GetTensorValue(value) != nullptr);
94 })) {
95 return;
96 }
97
98 StringRef op_name = op->getName().stripDialect();
99 if (op_name == "Const") {
100 cached_tensor_values_[op->getResult(0)] =
101 op->getAttrOfType<DenseElementsAttr>("value");
102 } else if (op_name == "Identity" ||
103 (op_name == "IdentityN" &&
104 TFOp(op).getNonControlOperands().size() == 1)) {
105 DenseElementsAttr operand_tensor_value = GetTensorValue(op->getOperand(0));
106 if (!operand_tensor_value) return;
107 cached_tensor_values_[op->getResult(0)] = operand_tensor_value;
108 } else if (op_name == "Rank") {
109 ShapedType operand_shape = op->getOperand(0).getType().cast<ShapedType>();
110 if (!operand_shape.hasRank()) return;
111 ShapedType return_shape = op->getResultTypes()[0];
112 DenseElementsAttr tensor_value;
113 if (return_shape.getElementType().isInteger(32)) {
114 tensor_value = DenseElementsAttr::get(
115 op->getResultTypes()[0], ArrayRef<int>(operand_shape.getRank()));
116 } else {
117 tensor_value = DenseElementsAttr::get(
118 op->getResultTypes()[0], ArrayRef<int64_t>(operand_shape.getRank()));
119 }
120 cached_tensor_values_[op->getResult(0)] = tensor_value;
121 } else if (op_name == "Size") {
122 ShapedType operand_shape = op->getOperand(0).getType().cast<ShapedType>();
123 if (!operand_shape.hasStaticShape()) return;
124 ShapedType return_shape = op->getResultTypes()[0];
125 DenseElementsAttr tensor_value;
126 if (return_shape.getElementType().isInteger(32)) {
127 tensor_value =
128 DenseElementsAttr::get(op->getResultTypes()[0],
129 ArrayRef<int>(operand_shape.getNumElements()));
130 } else {
131 tensor_value = DenseElementsAttr::get(
132 op->getResultTypes()[0],
133 ArrayRef<int64_t>(operand_shape.getNumElements()));
134 }
135 cached_tensor_values_[op->getResult(0)] = tensor_value;
136 } else if (op_name == "Shape" || op_name == "ShapeN") {
137 for (OpOperand &operand : op->getOpOperands()) {
138 Type operand_type = operand.get().getType();
139 if (operand_type.isa<ControlType>()) break;
140
141 auto operand_shape = operand_type.cast<ShapedType>();
142 if (!operand_shape.hasStaticShape()) continue;
143
144 int idx = operand.getOperandNumber();
145 ShapedType return_shape = op->getResultTypes()[idx];
146 DenseElementsAttr tensor_value;
147 if (return_shape.getElementType().isInteger(32)) {
148 tensor_value = DenseElementsAttr::get<int>(
149 op->getResultTypes()[idx],
150 SmallVector<int>(llvm::map_range(
151 operand_shape.getShape(),
152 [](int64_t dim) { return static_cast<int>(dim); })));
153 } else {
154 tensor_value = DenseElementsAttr::get(op->getResultTypes()[idx],
155 operand_shape.getShape());
156 }
157 cached_tensor_values_[op->getResult(idx)] = tensor_value;
158 }
159 }
160
161 // TODO(chiahungduan): In Grappler, it has cases for
162 // ConcatV2/Pack/Slice/StrideSlice which has their shape inference logic
163 // similar to how we do constant folding on them. I think constant folding
164 // will cover most of the cases. Handle them on demand later.
165 }
166
runOnOperation()167 void ShapeInference::runOnOperation() {
168 auto operand_as_constant_fn = [this](Value operand) -> Attribute {
169 return GetTensorValue(operand);
170 };
171
172 auto op_result_as_shape_fn = [this](InferenceContext &ic,
173 OpResult op_result) -> ShapeHandle {
174 auto rt = op_result.getType().dyn_cast<RankedTensorType>();
175 // NOMUTANTS -- TODO(chiahungduan): Review this condition to see if shape
176 // with known rank but unknown dimension is acceptable.
177 if (!rt || rt.getRank() != 1 || !rt.hasStaticShape()) return {};
178
179 std::vector<DimensionHandle> dims(rt.getDimSize(0), ic.UnknownDim());
180
181 DenseElementsAttr attr = GetTensorValue(op_result);
182 if (!attr) return {};
183
184 for (const auto &element : llvm::enumerate(attr.getValues<APInt>()))
185 dims[element.index()] = ic.MakeDim(element.value().getSExtValue());
186 return ic.MakeShape(dims);
187 };
188
189 auto infer_and_update_shapes = [&](Operation *op) -> bool {
190 auto result_element_type_fn = [&](int idx) -> Type {
191 return op->getResult(idx).getType().cast<ShapedType>().getElementType();
192 };
193
194 SmallVector<ShapedTypeComponents> results;
195 if (failed(InferReturnTypeComponentsForTFOp(
196 op->getLoc(), op, TFOp(op).getNonControlOperands(), graph_version_,
197 operand_as_constant_fn, op_result_as_shape_fn,
198 result_element_type_fn,
199 /*get_attr_values_fn=*/nullptr, results)))
200 return false;
201
202 bool updated = false;
203 for (auto it : llvm::zip(op->getResults().drop_back(), results)) {
204 OpResult op_result = std::get<0>(it);
205 ShapedTypeComponents result = std::get<1>(it);
206 TensorType inferred_type;
207 if (result.hasRank()) {
208 inferred_type =
209 RankedTensorType::get(result.getDims(), result.getElementType());
210 } else {
211 inferred_type = UnrankedTensorType::get(result.getElementType());
212 }
213
214 Type refined_type = tf_type::GetCastCompatibleType(
215 op_result.getType().cast<ShapedType>(), inferred_type);
216
217 // Certain attributes like _output_shapes may have incorrect shape
218 // information. When it's incompatible, use the result of shape inference
219 // context
220 if (!refined_type) refined_type = inferred_type;
221
222 if (refined_type == op_result.getType()) continue;
223
224 op_result.setType(refined_type);
225 updated = true;
226 }
227
228 if (updated) TryToCacheResultsTensorValue(op);
229
230 return updated;
231 };
232
233 // Reset the cached tensor value.
234 cached_tensor_values_.clear();
235
236 // Traverse all the operations and do the first round inference. We don't
237 // record any operations that need to be updated because most of them may lack
238 // shape information.
239 getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) {
240 if (auto func = dyn_cast<GraphFuncOp>(op)) {
241 // Don't infer the shape of ops in generic function, just skip it.
242 if (func.generic()) return WalkResult::skip();
243 return WalkResult::advance();
244 }
245 if (isa<ModuleOp, GraphOp>(op) || op->getNumResults() == 0)
246 return WalkResult::advance();
247
248 if (!CanBeRefined(op)) {
249 TryToCacheResultsTensorValue(op);
250 return WalkResult::advance();
251 }
252
253 (void)infer_and_update_shapes(op);
254 return WalkResult::advance();
255 });
256
257 // This is used to track the set of operations that may be able to infer their
258 // shape. When an operation infers its shape successfully, it'll add its user
259 // to this vector. Which implies that an operation may be added multiple
260 // times if it has multiple operands. Use SetVector to avoid keeping duplicate
261 // entry.
262 SetVector<Operation *> may_need_update;
263
264 // Collect operations that have the chance to infer the more precise shape
265 // information.
266 getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) {
267 if (auto func = dyn_cast<GraphFuncOp>(op)) {
268 // Don't infer the shape of ops in generic function, just skip it.
269 if (func.generic()) return WalkResult::skip();
270 return WalkResult::advance();
271 }
272 if (isa<ModuleOp, tfg::GraphOp>(op) || op->getNumResults() == 0)
273 return WalkResult::advance();
274
275 // This op still needs to refine its shape, so there's no chance for its
276 // user to refine their shape as well.
277 if (CanBeRefined(op)) return WalkResult::advance();
278
279 for (OpResult res : op->getResults().drop_back()) {
280 for (Operation *user : res.getUsers())
281 if (CanBeRefined(user)) may_need_update.insert(user);
282 }
283
284 return WalkResult::advance();
285 });
286
287 // TODO(chiahungduan): We may need to limit the iterations.
288 while (!may_need_update.empty()) {
289 Operation *op = may_need_update.pop_back_val();
290 bool updated = infer_and_update_shapes(op);
291 if (!updated) continue;
292
293 // The users may be able to refine their shapes.
294 for (Value v : op->getResults().drop_back()) {
295 for (Operation *user : v.getUsers()) {
296 if (CanBeRefined(user)) may_need_update.insert(user);
297 }
298 }
299 }
300
301 // Update the function signature.
302 getOperation()->walk([&](GraphFuncOp func) {
303 FunctionType func_type = func.function_type();
304 Operation *return_op = func.getBody()->getTerminator();
305
306 bool types_updated = false;
307 for (auto &indexed_type : llvm::enumerate(func_type.getResults())) {
308 int res_num = indexed_type.index();
309 Type return_arg_type = return_op->getOperand(res_num).getType();
310 if (return_arg_type != indexed_type.value()) {
311 types_updated = true;
312 break;
313 }
314 }
315
316 if (!types_updated) return;
317
318 func.function_typeAttr(TypeAttr::get(
319 FunctionType::get(&getContext(), func_type.getInputs(),
320 TFOp(return_op).getNonControlOperands().getTypes())));
321 });
322 }
323
CreateShapeInferencePass()324 std::unique_ptr<Pass> CreateShapeInferencePass() {
325 return std::make_unique<ShapeInference>();
326 }
327
328 } // namespace tfg
329 } // namespace mlir
330