1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <initializer_list>
21 #include <iterator>
22 #include <queue>
23 #include <stack>
24
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Hashing.h"
27 #include "llvm/ADT/None.h"
28 #include "llvm/ADT/PointerUnion.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
36 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
37 #include "mlir/IR/Attributes.h" // from @llvm-project
38 #include "mlir/IR/Block.h" // from @llvm-project
39 #include "mlir/IR/Builders.h" // from @llvm-project
40 #include "mlir/IR/BuiltinDialect.h" // from @llvm-project
41 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
42 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
43 #include "mlir/IR/Diagnostics.h" // from @llvm-project
44 #include "mlir/IR/FunctionInterfaces.h" // from @llvm-project
45 #include "mlir/IR/Location.h" // from @llvm-project
46 #include "mlir/IR/Operation.h" // from @llvm-project
47 #include "mlir/IR/OperationSupport.h" // from @llvm-project
48 #include "mlir/IR/SymbolTable.h" // from @llvm-project
49 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
50 #include "mlir/IR/Value.h" // from @llvm-project
51 #include "mlir/IR/Visitors.h" // from @llvm-project
52 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
53 #include "mlir/Interfaces/FoldInterfaces.h" // from @llvm-project
54 #include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project
55 #include "mlir/Pass/Pass.h" // from @llvm-project
56 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
57 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
58 #include "mlir/Support/LLVM.h" // from @llvm-project
59 #include "mlir/Support/LogicalResult.h" // from @llvm-project
60 #include "mlir/Transforms/FoldUtils.h" // from @llvm-project
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
65 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
66 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
67 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
68 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
69 #include "tensorflow/compiler/xla/window_util.h"
70 #include "tensorflow/compiler/xla/xla_data.pb.h"
71 #include "tensorflow/core/framework/shape_inference.h"
72 #include "tensorflow/core/framework/types.pb.h"
73 #include "tensorflow/core/ir/types/dialect.h"
74
75 #define DEBUG_TYPE "tf-shape-inference"
76
77 #define DCOMMENT(MSG) LLVM_DEBUG(llvm::dbgs() << MSG << "\n")
78 #define DCOMMENT_OP(OP, MSG) \
79 LLVM_DEBUG(OP->print(llvm::dbgs() << MSG << " "); llvm::dbgs() << "\n")
80
81 using ::int64_t;
82 using mlir::func::FuncOp;
83 using tensorflow::shape_inference::DimensionHandle;
84 using tensorflow::shape_inference::InferenceContext;
85 using tensorflow::shape_inference::ShapeHandle;
86
87 namespace mlir {
88 namespace TF {
89 namespace {
90
91 // Compute a refined type between two types `lhs` and `rhs`, the result type
92 // is always more refined (i.e. has more static information) than `lhs`
93 // This method will actually merge the information contained in the
94 // types, it is capable of refining:
95 // tensor<!tf_type.variant<tensor<?x8xf32>>>
96 // and:
97 // tensor<!tf_type.variant<tensor<10x?xf32>>>
98 // into:
99 // tensor<!tf_type.variant<tensor<10x8xf32>>>
100 //
101 // In case of inconsistencies (rank disagreement for example), it returns `lhs`.
TypeMeet(Type lhs,Type rhs)102 Type TypeMeet(Type lhs, Type rhs) {
103 DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs);
104 if (lhs == rhs) return lhs;
105
106 auto rhs_shape_type = rhs.dyn_cast<ShapedType>();
107 if (!rhs_shape_type) return lhs;
108 auto lhs_shape_type = lhs.cast<ShapedType>();
109 if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() &&
110 lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
111 DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs);
112 return lhs;
113 }
114
115 SmallVector<int64_t> shape;
116 bool refined_shape = false;
117 // Build the shape of the refined type, if lhs is unranked it
118 // will be directly the shape of the refined type, otherwise we merged by
119 // taking the most specialized. This combines `10x?x?` and `?x?x8` into
120 // `10x?x8`.
121 if (!lhs_shape_type.hasRank()) {
122 if (rhs_shape_type.hasRank()) {
123 shape.append(rhs_shape_type.getShape().begin(),
124 rhs_shape_type.getShape().end());
125 refined_shape = true;
126 }
127 } else if (rhs_shape_type.hasRank()) {
128 for (auto shape_elts : llvm::enumerate(
129 llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) {
130 if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) &&
131 !ShapedType::isDynamic(std::get<1>(shape_elts.value()))) {
132 shape.push_back(std::get<1>(shape_elts.value()));
133 refined_shape = true;
134 DCOMMENT("-> refining shape element #" << shape_elts.index());
135 } else {
136 DCOMMENT("-> not refining shape element #" << shape_elts.index());
137 shape.push_back(std::get<0>(shape_elts.value()));
138 }
139 }
140 }
141
142 // Some tensor have an element type wrapping a subtensor, like resource and
143 // variants. In this case we may recurse on the wrapped subtype.
144 // `element_type` will contain the refined inferred element type for the
145 // returned type.
146 auto lhs_element_type = lhs_shape_type.getElementType();
147 auto rhs_element_type_with_subtype =
148 rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
149 // Look for resource or variant element type and ensure we refine the subtype.
150 // We only support a single subtype at the moment, we won't handle something
151 // like:
152 // tensor<!tf_type.variant<tensor<10xf32>, tensor<8xf32>>
153 if (rhs_element_type_with_subtype &&
154 rhs_element_type_with_subtype.GetSubtypes().size() == 1) {
155 auto lhs_element_type_with_subtype =
156 lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
157 TensorType subtype;
158 if (!lhs_element_type_with_subtype) {
159 DCOMMENT(
160 "Unexpected inferred `TensorFlowTypeWithSubtype` when original "
161 "result isn't");
162 } else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) {
163 DCOMMENT(
164 "Unexpected `TensorFlowTypeWithSubtype` original type with size>1");
165 } else if (lhs_element_type_with_subtype.GetSubtypes().empty()) {
166 subtype = rhs_element_type_with_subtype.GetSubtypes().front();
167 } else {
168 // Recurse on the subtypes in the variant/resource. Basically if the input
169 // were:
170 // tensor<!tf_type.variant<tensor<?x8xf32>>>
171 // and:
172 // tensor<!tf_type.variant<tensor<10x8xf32>>>
173 // we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>.
174 auto refined_subtype =
175 TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(),
176 rhs_element_type_with_subtype.GetSubtypes().front())
177 .cast<TensorType>();
178 if (refined_subtype !=
179 lhs_element_type_with_subtype.GetSubtypes().front())
180 subtype = refined_subtype;
181 }
182 // If we managed to refine the subtype, recreate the element type itself
183 // (i.e. the tf.variant or tf.resource).
184 if (subtype) {
185 lhs_element_type = lhs_element_type_with_subtype.clone({subtype});
186 }
187 }
188 if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) {
189 Type new_type;
190 if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank())
191 new_type = UnrankedTensorType::get(lhs_element_type);
192 else
193 new_type = lhs_shape_type.clone(shape, lhs_element_type);
194 DCOMMENT("Refined to: " << new_type);
195 return new_type;
196 }
197 DCOMMENT("No refinement " << lhs);
198 return lhs;
199 }
200
201 // Returns whether `original_type` type can be refined with
202 // `potential_refined_type` type.
CanRefineTypeWith(Type original_type,Type potential_refined_type)203 bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
204 return original_type != TypeMeet(original_type, potential_refined_type);
205 }
206
207 // Returns if the shape inference pass supports an op outside the TF dialect.
IsSupportedNonTFOp(Operation * op)208 bool IsSupportedNonTFOp(Operation* op) {
209 return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp,
210 tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp,
211 tf_executor::GraphOp, tf_executor::IslandOp,
212 tf_executor::LoopCondOp, tf_executor::MergeOp,
213 tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
214 tf_executor::SwitchOp, tf_executor::YieldOp>(op) ||
215 isa<InferTypeOpInterface>(op);
216 }
217
218 // Returns whether a cast back would need to be inserted, e.g., whether the
219 // operation of which use is an operand allows for shape refinement without
220 // a cast.
NeedsCastBack(OpOperand & use,Dialect * tf_dialect)221 bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
222 return use.getOwner()->getDialect() != tf_dialect &&
223 !IsSupportedNonTFOp(use.getOwner());
224 }
225
CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,Type element_type)226 TensorType CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,
227 Type element_type) {
228 if (shape.has_value())
229 return RankedTensorType::get(shape.getValue(), element_type);
230 return UnrankedTensorType::get(element_type);
231 }
232
233 // Returns true if the op creates a TensorList.
IsTensorListInitOp(Operation * op)234 bool IsTensorListInitOp(Operation* op) {
235 return isa<TensorListReserveOp>(op) || isa<EmptyTensorListOp>(op) ||
236 isa<TensorListFromTensorOp>(op);
237 }
238
239 // Returns the `element_shape` operand of the ops that create a TensorList.
GetElementShapeOperand(Operation * op)240 Value GetElementShapeOperand(Operation* op) {
241 if (auto empty_tl = dyn_cast<EmptyTensorListOp>(op))
242 return empty_tl.element_shape();
243 if (auto tl_reserve = dyn_cast<TensorListReserveOp>(op))
244 return tl_reserve.element_shape();
245 if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op))
246 return tl_from_tensor.element_shape();
247 llvm_unreachable("unsupported TensorList op");
248 }
249
250 // Utility function to create a ranked tensor type after dropping the first
251 // dimension from the input type.
DropFirstDimension(Type type)252 RankedTensorType DropFirstDimension(Type type) {
253 RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>();
254 if (!ranked_type) return {};
255 llvm::ArrayRef<int64_t> dims_except_first =
256 ranked_type.getShape().drop_front();
257 return RankedTensorType::get(dims_except_first, ranked_type.getElementType());
258 }
259
InsertCast(OpBuilder & b,Location loc,Type dst_type,Value input)260 Operation* InsertCast(OpBuilder& b, Location loc, Type dst_type, Value input) {
261 Type element_type = getElementTypeOrSelf(dst_type);
262 if (element_type.isa<IndexType>())
263 return b.create<tensor::CastOp>(loc, dst_type, input);
264 if (isa<TensorFlowDialect, BuiltinDialect>(element_type.getDialect()))
265 return b.create<TF::CastOp>(loc, dst_type, input,
266 /*truncate=*/b.getBoolAttr(false));
267 return nullptr;
268 }
269
270 // Follow the use chain of TensorList and return true iff all elements written
271 // to TensorList have same static shape. If all elements have same shape, assign
272 // it to `potential_element_type`.
273 //
274 // This can handle multiple mutations of a TensorList object and would return
275 // true if across all mutations the elements written have the same shape.
CanInferTensorListElementType(Value tensorlist,Value initial_element_shape,RankedTensorType * potential_element_type)276 bool CanInferTensorListElementType(Value tensorlist,
277 Value initial_element_shape,
278 RankedTensorType* potential_element_type) {
279 DCOMMENT("CanInferTensorListElementType " << tensorlist << " with initial "
280 << initial_element_shape);
281 // Verifies if the new element type has static shape and matches the potential
282 // type passed from caller. Updates the potential_element_type, if not defined
283 // yet.
284 auto verify_and_update_potential_element_type =
285 [&](RankedTensorType new_element_type) -> bool {
286 DCOMMENT("\t\tConsidering " << new_element_type << " with old "
287 << *potential_element_type);
288 if (!new_element_type || !new_element_type.hasStaticShape()) return false;
289 if (!*potential_element_type) {
290 DCOMMENT("\t\tUpdating potential_element_type " << new_element_type);
291 *potential_element_type = new_element_type;
292 return true;
293 }
294 return *potential_element_type == new_element_type;
295 };
296
297 std::stack<Value> worklist;
298 worklist.emplace(tensorlist);
299
300 while (!worklist.empty()) {
301 tensorlist = worklist.top();
302 worklist.pop();
303
304 // TensorLists are semantically immutable. For example, TensorListSetItem
305 // takes a TensorList as input and produces a TensorList as output. So to
306 // traverse modifications to TensorList and verify that all elements written
307 // to it have the same shape, we need to follow use-def chain of ops that
308 // (conceptually) modify it i.e., ops that take an input TensorList and
309 // produce an output TensorList.
310 for (auto& use : tensorlist.getUses()) {
311 if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
312 auto element_type =
313 push.tensor().getType().dyn_cast<RankedTensorType>();
314 if (!verify_and_update_potential_element_type(element_type))
315 return false;
316 worklist.emplace(push.output_handle());
317 continue;
318 }
319 if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
320 use.getOwner())) {
321 // For scatter op we can get the element shape by dropping the first
322 // dimension of the input tensor.
323 RankedTensorType element_type =
324 DropFirstDimension(scatter.tensor().getType());
325 if (!verify_and_update_potential_element_type(element_type))
326 return false;
327 worklist.emplace(scatter.output_handle());
328 continue;
329 }
330 if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
331 auto element_type =
332 set_item.item().getType().dyn_cast<RankedTensorType>();
333 DCOMMENT("\tTensorListSetItemOp " << element_type);
334 if (!verify_and_update_potential_element_type(element_type))
335 return false;
336 worklist.emplace(set_item.output_handle());
337 continue;
338 }
339 if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
340 worklist.emplace(pop.output_handle());
341 continue;
342 }
343 if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
344 worklist.emplace(resize.output_handle());
345 continue;
346 }
347 // WhileRegionOp can explicitly capture TensorList value to be used inside
348 // its regions. So we check the uses of corresponding block argument in
349 // each region and the use of TensorList returned using YieldOp.
350 if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
351 DCOMMENT("\tTL WhileRegion");
352 for (auto branch : while_region.getRegions())
353 worklist.emplace(branch->getArgument(use.getOperandNumber()));
354 continue;
355 }
356 if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
357 Operation* parent = yield->getParentOp();
358 worklist.emplace(parent->getResult(use.getOperandNumber()));
359 continue;
360 }
361 // TODO(jpienaar): This can be generalized.
362 if (isa<IdentityOp, IdentityNOp, StopGradientOp>(use.getOwner())) {
363 worklist.emplace(use.getOwner()->getResult(use.getOperandNumber()));
364 continue;
365 }
366 // Refining the tensor list element type might change the output of
367 // TensorListElementShape which is expected to be the originally assigned
368 // shape to TensorList init ops. So replace it with the original element
369 // shape value.
370 if (auto tl_element_shape =
371 dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
372 // If element types match, we can do a direct replacement.
373 if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
374 getElementTypeOrSelf(initial_element_shape.getType())) {
375 tl_element_shape.replaceAllUsesWith(initial_element_shape);
376 } else {
377 OpBuilder b(use.getOwner());
378 Operation* cast_op = InsertCast(
379 b, use.getOwner()->getLoc(),
380 tl_element_shape.getResult().getType(), initial_element_shape);
381 if (!cast_op) return false;
382 tl_element_shape.replaceAllUsesWith(cast_op->getResult(0));
383 }
384 continue;
385 }
386 // Ignore ops that just consume a TensorList and do not output another
387 // TensorList.
388 if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
389 TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
390 continue;
391
392 // For any other unknown users of the TensorList, we are conservative and
393 // stop element shape inference.
394 DCOMMENT("TensorListType infer, unknown op " << *use.getOwner());
395 return false;
396 }
397 }
398 return true;
399 }
400
401 // Returns the tensor type created from the `shape_attr` and `type_attr`
402 // attributes.
GetType(Attribute shape_attr,Attribute type_attr)403 Type GetType(Attribute shape_attr, Attribute type_attr) {
404 auto shape = shape_attr.cast<tf_type::ShapeAttr>();
405 auto type = type_attr.cast<TypeAttr>();
406 if (shape.hasRank())
407 return RankedTensorType::get(shape.getShape(), type.getValue());
408 else
409 return UnrankedTensorType::get(type.getValue());
410 }
411
412 } // namespace
413
414 // Returns whether type can be further refined.
CanBeRefined(Type type)415 bool CanBeRefined(Type type) {
416 auto shape_type = type.dyn_cast<ShapedType>();
417 if (!shape_type) return false;
418
419 // Returns whether type with subtypes can be further refined.
420 auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) {
421 return tws.GetSubtypes().empty() ||
422 llvm::any_of(tws.GetSubtypes(), CanBeRefined);
423 };
424 auto type_with_subtype =
425 shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
426 if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
427
428 return !shape_type.hasStaticShape();
429 }
430
431 // Combination of value producer and port of value produced (e.g.,
432 // <value result output>:<value in output tensor>,
433 // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
434 // scalar value).
435 struct ValuePort {
436 PointerUnion<Operation*, BlockArgument> producer;
437 SmallVector<unsigned int, 2> port;
438
operator ==mlir::TF::ValuePort439 bool operator==(const ValuePort& other) const {
440 return producer == other.producer && port == other.port;
441 }
442
443 // Convert output value to ValuePort.
ValuePortmlir::TF::ValuePort444 explicit ValuePort(Value v) {
445 OpResult opr = v.dyn_cast<OpResult>();
446 if (opr) {
447 producer = opr.getOwner();
448 port = {opr.getResultNumber()};
449 } else {
450 producer = v.cast<BlockArgument>();
451 port = {0};
452 }
453 }
ValuePortmlir::TF::ValuePort454 ValuePort(PointerUnion<Operation*, BlockArgument> producer,
455 SmallVector<unsigned int, 2> port)
456 : producer(producer), port(port) {}
457
printmlir::TF::ValuePort458 raw_ostream& print(raw_ostream& os) const {
459 if (auto* op = producer.dyn_cast<Operation*>())
460 os << "op " << op->getName();
461 if (auto ba = producer.dyn_cast<BlockArgument>())
462 os << "block_arg " << ba.getArgNumber();
463 os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
464 return os;
465 }
466 };
467
468 struct ValuePortHasher {
operator ()mlir::TF::ValuePortHasher469 std::size_t operator()(const ValuePort& other) const {
470 return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
471 hash_value(ArrayRef<unsigned int>(other.port)));
472 }
473 };
474
475 using ValuePortResultMap =
476 std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
477 using ComputedQueryFn = function_ref<bool(ValuePort)>;
478 using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
479 using ValuePortInputs = SmallVectorImpl<ValuePort>;
480
481 // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
482 // intended to be switched to op interfaces once more refined.
ComputeInputsRequiredForOutput(ValuePort value_port,ComputedQueryFn has_been_computed,ValuePortInputs * inputs)483 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
484 ComputedQueryFn has_been_computed,
485 ValuePortInputs* inputs) {
486 auto op = value_port.producer.dyn_cast<Operation*>();
487 auto& port = value_port.port;
488 if (!op) return failure();
489
490 // No inputs required for constants.
491 if (matchPattern(op, m_Constant())) return success();
492
493 // Note: this focusses only on the trivial pack op case and this could be
494 // generalized.
495 if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
496 auto type = pack_op.getType().cast<TensorType>();
497 if (!type.hasRank() || type.getRank() != 1) return failure();
498 if (port.size() != 2) return failure();
499 assert(port[0] == 0);
500 ValuePort req(pack_op.getOperand(port[1]));
501 if (!has_been_computed(req)) inputs->push_back(req);
502 return success();
503 }
504
505 return failure();
506 }
507
508 // Computes the output produced by ValuePort using the query function of
509 // existing computed values.
ComputeOutputComponent(const ValuePort & value_port,ValueQueryFn values)510 Attribute ComputeOutputComponent(const ValuePort& value_port,
511 ValueQueryFn values) {
512 LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
513 if (auto known = values(value_port)) return known;
514
515 auto op = value_port.producer.dyn_cast<Operation*>();
516 if (!op) return nullptr;
517 auto& port = value_port.port;
518
519 if (port.empty()) {
520 LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n");
521 return nullptr;
522 }
523
524 ElementsAttr attr;
525 if (matchPattern(op, m_Constant(&attr))) {
526 if (port.size() == 1 && port[0] == 0) return attr;
527 return nullptr;
528 }
529
530 if (auto id = dyn_cast<IdentityOp>(op)) {
531 if (port.size() == 1 && port[0] == 0)
532 return ComputeOutputComponent(ValuePort(id.input()), values);
533 return nullptr;
534 }
535
536 // Note: this focusses only on the trivial pack op case and this could be
537 // generalized.
538 if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
539 TensorType type = pack_op.getType().cast<TensorType>();
540 if (!type.hasRank() || type.getRank() != 1) return nullptr;
541 if (port.size() != 2 || port[0] != 0) return nullptr;
542 ValuePort op_port(op->getOperand(port[1]));
543 return values(op_port);
544 }
545
546 if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
547 if (port.size() == 1)
548 return ComputeOutputComponent(
549 ValuePort(graph.GetFetch().fetches()[port[0]]), values);
550 return nullptr;
551 }
552
553 if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
554 if (port.size() == 1)
555 return ComputeOutputComponent(
556 ValuePort(island.GetYield().fetches()[port[0]]), values);
557 return nullptr;
558 }
559
560 return nullptr;
561 }
562
563 // Context used during ShapeInference. This class contains common information
564 // that is required by the individual shape inference helper functions (e.g.,
565 // TF Graph version, constant values computed, etc.)
566 class ShapeInference {
567 public:
568 ShapeInference(int64_t graph_version, ModuleOp module,
569 bool propagate_caller_callee_constants);
570
ComputeInputsRequiredForOutput(ValuePort value_port,ValuePortInputs * inputs)571 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
572 ValuePortInputs* inputs) {
573 return ::mlir::TF::ComputeInputsRequiredForOutput(
574 value_port,
575 [this](const ValuePort& port) {
576 return results_.find(port) != results_.end();
577 },
578 inputs);
579 }
580
ComputeOutputComponent(const ValuePort & value_port)581 Attribute ComputeOutputComponent(const ValuePort& value_port) {
582 if (auto known_attr = results_[value_port]) return known_attr;
583 auto attr = ::mlir::TF::ComputeOutputComponent(
584 value_port, [this](const ValuePort& port) { return results_[port]; });
585 RecordValue(value_port, attr);
586 return attr;
587 }
588
589 // Returns ShapeHandle if the op result could be computed as shape.
590 ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
591
RecordValue(const ValuePort & value_port,Attribute value)592 void RecordValue(const ValuePort& value_port, Attribute value) {
593 LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
594 << value << "\n");
595 results_[value_port] = value;
596 }
597
598 // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
599 // set, operand types are set as result types if associated body result types
600 // match the operand type (does not change per loop iteration). If operand and
601 // body result types are not the same, only handle types are propagated to
602 // result types. This is necessary to not incorrectly change result shapes
603 // when the While op will have a different result shape. Otherwise operand
604 // shapes are propagated to result shapes.
605 template <typename WhileOpTy>
606 bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types);
607
608 // Performs shape inference on the provided op and return true if the type of
609 // at least one result has been changed.
610 // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
611 // `graph_version` indicates the current GraphDef compatibility versions
612 // (the versions field in graph.proto).
613 bool InferShapeForSingleOperation(Operation* op, int64_t max_iterations);
614
615 // Infers shape on the provided region, including nested ones, iterate until
616 // fix point with a limit of max_iteration.
617 // Returns a failure() on error, otherwise returns true to indicate that it
618 // reached convergence, false otherwise.
619 FailureOr<bool> InferShapeUntilFixPoint(Region* region,
620 int64_t max_iterations);
621
622 // Updates input types and refine shapes inside body of functions that are
623 // attached to ControlFlow ops (If/While) or Calls. These functions include
624 // Then/Else branches of IfOp and Cond/Body functions of WhileOp. Functions
625 // attached to control flow share following common properties:
626 // 1) They are never reused, ie. having a single use in module.
627 // 2) Their input types match those of their parent ops (excluding inputs
628 // like predicate).
629 // For calls, functions can be reused across multiple call sites. In this case
630 // we propagate the types when all call sites have the same operand types.
631 // Returns a failure() on error, otherwise returns true to indicate that it
632 // reached convergence, false otherwise.
633 FailureOr<bool> PropagateShapeToFunctions(ModuleOp module,
634 TypeRange input_types,
635 ArrayRef<func::FuncOp> functions,
636 int64_t max_iterations);
637
638 // Propagates shapes to regions given the shapes of the inputs of the regions.
639 // All regions provided in `regions` are assumed to have inputs of type
640 // `input_types`.
641 // Returns a failure() on error, otherwise returns true to indicate that it
642 // reached convergence, false otherwise.
643 FailureOr<bool> PropagateShapeToRegions(TypeRange input_types,
644 ArrayRef<Region*> regions,
645 int64_t max_iterations);
646
647 // Shape propagation for call/control flow ops.
648 // Returns a failure() on error, otherwise returns true to indicate that it
649 // reached convergence, false otherwise.
650 FailureOr<bool> PropagateShapeIntoAttachedFunctions(Operation* op,
651 int64_t max_iterations);
652
653 // Shape propagation for region based control flow.
654 // Returns a failure() on error, otherwise returns true to indicate that it
655 // reached convergence, false otherwise.
656 FailureOr<bool> PropagateShapeIntoAttachedRegions(Operation* op,
657 int64_t max_iterations);
658
659 // Propagates any constant operand of call_op to the called function body's
660 // corresponding argument if the callee has only one use.
661 //
662 // TODO(b/154065712): Move this to a more general inter-procedural constant
663 // folding pass.
664 void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func,
665 ModuleOp module);
666
667 // Propagates any constant return value of the callee function to the call
668 // op's corresponding result.
669 void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func,
670 ModuleOp module);
671
672 // Tries to compute the result of folding the op. This doesn't actually
673 // perform constant folding, it is just computes the equivalent constants.
674 // Returns whether it was able to compute constant values.
675 LogicalResult TryToFold(Operation* op);
676
677 // Makes result types match the operand types (the i-th result type will
678 // match the i-th operand type). Returns true if anything is changed.
679 bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands,
680 ResultRange results);
681
682 // Makes result type's shape match the corresponding operand's shape.
683 // Returns whether any change was made.
684 bool RefineShapeForPassThroughOps(Operation* op);
685
686 // Infers shape for necessary ops that are not in the TF dialect. Returns
687 // whether any result type changed.
688 bool InferShapeForNonTFDialectOperation(Operation* op);
689
690 // Infers shape for function return type and returns whether changed.
691 LogicalResult InferShapeForFunctionReturnType(func::FuncOp func);
692
693 // Enqueues function for processing.
enqueue(func::FuncOp fn)694 void enqueue(func::FuncOp fn) {
695 LLVM_DEBUG(llvm::dbgs()
696 << "enqueue " << fn.getName() << " ("
697 << (queue_set_.count(fn) ? "already inserted" : "newly inserted")
698 << ")\n");
699 if (queue_set_.insert(fn).second) queue_.push(fn);
700 }
701
702 // Enqueues callers on functions.
703 void EnqueueCallers(func::FuncOp fn);
704
705 // Returns the function at the front of the queue.
front()706 func::FuncOp front() { return queue_.front(); }
707
708 // Returns whether work queue is empty.
EmptyQueue() const709 bool EmptyQueue() const { return queue_.empty(); }
710
711 // Returns function from the front of the work queue.
pop_front()712 func::FuncOp pop_front() {
713 func::FuncOp ret = queue_.front();
714 queue_.pop();
715 queue_set_.erase(ret);
716 return ret;
717 }
718
719 // Returns the current size of the queue.
QueueSize() const720 std::queue<func::FuncOp>::size_type QueueSize() const {
721 return queue_.size();
722 }
723
724 Dialect* const tf_dialect_;
725
726 private:
727 // Returns whether the result of an operation could be updated to a new
728 // inferred type. Also inserts cast operation for uses that are incompatible
729 // with the new type.
730 bool UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, Value result);
731
732 // Refines the type of `result` of `op` using the type
733 // `potential_refined_type`. Return true if the type was changed.
734 bool RefineResultType(Operation* op, Value result,
735 Type potential_refined_type);
736
737 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
738 // called function and propagating the return type.
739 bool InferShapeForCall(CallOpInterface call_op);
740
741 bool InferShapeForCast(Operation* op);
742
743 bool InferShapeForRestore(Operation* op);
744
745 // Infers the shape IfOp outputs based on the shapes of the then and else
746 // function result types.
747 bool InferShapeForIf(IfOp op);
748
749 // Infers the shape IfRegion outputs based on the shapes of the then and else
750 // yields.
751 bool InferShapeForIfRegion(IfRegionOp op);
752
753 // Infers the shape of _XlaHostComputeMlir based on the host computation
754 // module. Returns true if a return type was changed.
755 bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op);
756
757 // Infers the shape for MapDatasetOp and its associated function. Returns
758 // whether either op or function type was changed.
759 bool InferShapeForMapDataset(MapDatasetOp op, int64_t max_iterations);
760
761 // Infers the shape for ReduceDatasetOp and its associated reduce function.
762 // Returns whether either op or function type was changed.
763 bool InferShapeForReduceDataset(ReduceDatasetOp op, int64_t max_iterations);
764
765 // Infers the shape for TakeWhileDatasetOp and its associated predicate
766 // function. Returns whether either op or function type was changed.
767 bool InferShapeForTakeWhileDataset(TakeWhileDatasetOp op,
768 int64_t max_iterations);
769
770 // Infers shape for dataset ops that have `M` input elements and `N`
771 // arguments, and also propagates the shape to the specified function (called
772 // only when function exists and has single use).
773 bool InferShapeForDatasetOpCommon(Operation* op, FuncOp f,
774 int64_t max_iterations);
775
776 // Infers the shape of ops that create TensorList. Specifically,
777 // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
778 // refines the element shape if all tensors written to the list across all
779 // mutations have identical static shape.
780 bool InferShapeForTensorListInitOps(Operation* op);
781
782 // Infers the shape of VarHandleOp based on the uses of the VarHandleOp to
783 // update the subtypes of the resource type.
784 bool InferShapeForVarHandleOp(VarHandleOp op);
785
786 // Infers the output shape of XlaConvV2Op based on the input shapes
787 bool InferShapeForXlaConvV2Op(XlaConvV2Op op);
788
789 // Infers the output shape of XlaReduceWindowOp based on the input shapes.
790 bool InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op);
791
792 // Infers the output shape of XlaSelectAndScatterOp based on the input shapes.
793 bool InferShapeForXlaSelectAndScatterOp(XlaSelectAndScatterOp op);
794
795 bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti);
796
797 // Returns all the callers of a function.
798 // Note: Usage of the return value of this function may not be interleaved
799 // with insertions to the callers map. This could occur if GetCallers is
800 // called with two separate functions, the 2nd one incurs a resize and then
801 // both first and 2nd stored callers are used.
802 ArrayRef<Operation*> GetCallers(func::FuncOp fn);
803
804 // Mapping between ValuePort (which corresponds to an OpResult or smaller,
805 // e.g., first element of OpResult produced) to an Attribute if the ValuePort
806 // corresponds to a constant value.
807 ValuePortResultMap results_;
808
809 // Map from a function to the callers of that function.
810 SymbolTableCollection symbol_table_;
811 SymbolUserMap symbol_users_;
812
813 // Queue of functions being processed.
814 llvm::DenseSet<func::FuncOp> queue_set_;
815 std::queue<func::FuncOp> queue_;
816
817 int64_t graph_version_;
818
819 // TODO(b/154065712): Remove propagate_caller_callee_constants once using
820 // SCCP pass instead.
821 bool propagate_caller_callee_constants_;
822 };
823
ShapeInference(int64_t graph_version,ModuleOp module,bool propagate_caller_callee_constants)824 ShapeInference::ShapeInference(int64_t graph_version, ModuleOp module,
825 bool propagate_caller_callee_constants)
826 : tf_dialect_(module->getContext()->getLoadedDialect<TensorFlowDialect>()),
827 symbol_users_(symbol_table_, module),
828 graph_version_(graph_version),
829 propagate_caller_callee_constants_(propagate_caller_callee_constants) {
830 // Create symbol table for module.
831 symbol_table_.getSymbolTable(module);
832 }
833
GetCallers(func::FuncOp fn)834 ArrayRef<Operation*> ShapeInference::GetCallers(func::FuncOp fn) {
835 return symbol_users_.getUsers(fn);
836 }
837
EnqueueCallers(func::FuncOp fn)838 void ShapeInference::EnqueueCallers(func::FuncOp fn) {
839 for (auto user : GetCallers(fn))
840 enqueue(user->getParentOfType<func::FuncOp>());
841 }
842
UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,Value result)843 bool ShapeInference::UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,
844 Value result) {
845 // No changes needed if the new type is unchanged.
846 if (new_type == result.getType()) return false;
847
848 Operation* cast_op = nullptr;
849 // First insert cast back for uses that need a cast and then
850 // update the type.
851 bool enqueue_callers = false;
852 for (OpOperand& use : make_early_inc_range(result.getUses())) {
853 if (isa<func::ReturnOp>(use.getOwner())) {
854 enqueue_callers = true;
855 } else if (NeedsCastBack(use, tf_dialect_)) {
856 if (!cast_op) {
857 Operation* op = result.getDefiningOp();
858 OpBuilder b(op);
859 b.setInsertionPointAfter(op);
860 cast_op = InsertCast(b, op->getLoc(), result.getType(), result);
861 if (!cast_op) return false;
862 }
863 use.set(Value(cast_op->getResult(0)));
864 }
865 }
866
867 result.setType(new_type);
868 if (enqueue_callers)
869 EnqueueCallers(result.getDefiningOp()->getParentOfType<func::FuncOp>());
870 return true;
871 }
872
RefineResultType(Operation * op,Value result,Type potential_refined_type)873 bool ShapeInference::RefineResultType(Operation* op, Value result,
874 Type potential_refined_type) {
875 if (!CanRefineTypeWith(result.getType(), potential_refined_type))
876 return false;
877
878 return UpdateTypeAndInsertIncompatibleUseCasts(potential_refined_type,
879 result);
880 }
881
882 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
883 // called function and propagating the return type.
InferShapeForCall(CallOpInterface call_op)884 bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
885 func::FuncOp func =
886 dyn_cast_or_null<func::FuncOp>(call_op.resolveCallable(&symbol_table_));
887 if (!func) return false;
888
889 DCOMMENT("Infer shape for call " << func.getName());
890 Operation* op = call_op.getOperation();
891 bool changed = false;
892 // Map each of the results of the call to the returned type of the
893 // function.
894 for (auto result :
895 zip(op->getResults(), func.getFunctionType().getResults())) {
896 changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
897 changed;
898 }
899 DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n");
900
901 return changed;
902 }
903
InferShapeForCast(Operation * op)904 bool ShapeInference::InferShapeForCast(Operation* op) {
905 DCOMMENT_OP(op, "Inferring shape for ");
906 Value result = op->getResult(0);
907 if (!CanBeRefined(result.getType())) return false;
908
909 Type operand_type = op->getOperand(0).getType();
910 auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>();
911 if (!ranked_op_type) return false;
912 auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>();
913 if (ranked_res_type &&
914 ranked_op_type.getShape() == ranked_res_type.getShape())
915 return false;
916
917 // Avoid inserting a cast where no users types could be refined (e.g., where
918 // there would need to be a cast inserted for every user again).
919 if (llvm::all_of(result.getUses(), [this](OpOperand& use) {
920 return NeedsCastBack(use, tf_dialect_);
921 }))
922 return false;
923
924 auto new_type = RankedTensorType::get(
925 ranked_op_type.getShape(),
926 result.getType().cast<ShapedType>().getElementType());
927
928 return UpdateTypeAndInsertIncompatibleUseCasts(new_type, op->getResult(0));
929 }
930
InferShapeForIf(IfOp op)931 bool ShapeInference::InferShapeForIf(IfOp op) {
932 DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
933 bool changed = false;
934 auto then_results =
935 op.ResolveThenFunction(&symbol_table_).getFunctionType().getResults();
936 auto else_results =
937 op.ResolveElseFunction(&symbol_table_).getFunctionType().getResults();
938 for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
939 // If then and else types do not match, skip refinement for that result.
940 if (std::get<1>(it) != std::get<2>(it)) continue;
941 changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
942 }
943 return changed;
944 }
945
InferShapeForIfRegion(IfRegionOp op)946 bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {
947 bool changed = false;
948
949 Operation* then_yield = op.then_branch().front().getTerminator();
950 Operation* else_yield = op.else_branch().front().getTerminator();
951 for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
952 else_yield->getOperandTypes())) {
953 // If then and else types do not match, skip refinement for that result.
954 if (std::get<1>(result) != std::get<2>(result)) continue;
955 changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
956 changed;
957 }
958 return changed;
959 }
960
InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp host_compute_op)961 bool ShapeInference::InferShapeForXlaHostComputeMlir(
962 _XlaHostComputeMlirOp host_compute_op) {
963 // Extract the module and function.
964 // The '_XlaHostComputeMlir` verifier verifies that `host_mlir_module`
965 // attribute is well formed, so we just return in case of an error in
966 // extracting the host function since it should never occur.
967 StringAttr host_module =
968 host_compute_op->getAttrOfType<StringAttr>("host_mlir_module");
969 if (host_module.getValue().empty()) return false;
970
971 mlir::OwningOpRef<mlir::ModuleOp> module_for_func;
972 func::FuncOp func = host_compute_op.GetHostFunc(&module_for_func);
973
974 // Update/use input shapes for function.
975 FunctionType func_type = func.getFunctionType();
976 func.setType(FunctionType::get(func.getContext(),
977 host_compute_op.getOperandTypes(),
978 func_type.getResults()));
979
980 // Run shape inference on the function.
981 if (failed(PropagateShapeToRegions(host_compute_op.getOperandTypes(),
982 {&func.getBody()}, 10)))
983 return false;
984 if (failed(InferShapeForFunctionReturnType(func))) return false;
985
986 bool changed = false;
987 // Use refined function return shape for XlaHostComputeMlirOp.
988 for (auto result :
989 zip(host_compute_op.getResults(), func.getFunctionType().getResults())) {
990 changed = RefineResultType(host_compute_op, std::get<0>(result),
991 std::get<1>(result)) ||
992 changed;
993 }
994
995 return changed;
996 }
997
998 // Infer the shape of `Restore` and `RestoreV2` op based on the first
999 // `AssignVariableOp` that uses the result. This requires that the resource
1000 // subtype inference is completed.
InferShapeForRestore(Operation * op)1001 bool ShapeInference::InferShapeForRestore(Operation* op) {
1002 DCOMMENT_OP(op, "Inferring shape for Restore,RestoreV2");
1003 // Currently only support single output.
1004 if (op->getNumResults() != 1) return false;
1005 if (!CanBeRefined(op->getResult(0).getType())) return false;
1006
1007 llvm::SmallVector<mlir::Operation*> worklist;
1008 llvm::append_range(worklist, op->getUsers());
1009
1010 // Look for any `AssignVariableOp` that uses the result of this op.
1011 while (!worklist.empty()) {
1012 mlir::Operation* const use = worklist.pop_back_val();
1013
1014 // Follow the `CastOp`/`IdentityOp`'s users to handle the `RestoreV2` ->
1015 // (optionally `IdentityOp`) -> `CastOp` `AssignVariableOp` case.
1016 if (llvm::isa<TF::CastOp, TF::IdentityOp>(use)) {
1017 llvm::append_range(worklist, use->getUsers());
1018 continue;
1019 }
1020
1021 TF::AssignVariableOp assign_op = llvm::dyn_cast<TF::AssignVariableOp>(use);
1022 if (!assign_op) {
1023 continue;
1024 }
1025 auto subtypes = getElementTypeOrSelf(assign_op.resource())
1026 .cast<TF::ResourceType>()
1027 .getSubtypes();
1028 if (subtypes.empty()) {
1029 continue;
1030 }
1031 auto subtype = subtypes.front().dyn_cast<ShapedType>();
1032 if (subtype == nullptr) {
1033 continue;
1034 }
1035 // Preserve the dtype from the restore op even if `AssignVariableOp` uses a
1036 // different dtype, which is possible when there's a `CastOp` between them.
1037 subtype = subtype.clone(
1038 op->getResult(0).getType().cast<ShapedType>().getElementType());
1039 // Update the result type of this op with the resource's type. We only use
1040 // the resource subtype of the first user since shapes from all the users
1041 // should be equal or compatible.
1042 return UpdateTypeAndInsertIncompatibleUseCasts(subtype, op->getResult(0));
1043 }
1044 return false;
1045 }
1046
1047 // Helper structure to capture shapes & types for Dataset input.
1048 struct DatasetInput {
operator boolmlir::TF::DatasetInput1049 explicit operator bool() const { return shapes && types; }
1050
1051 ArrayAttr shapes;
1052 ArrayAttr types;
1053 };
1054
1055 // Returns the input elements shapes and types for Dataset ops.
GetDatasetInput(Value value)1056 DatasetInput GetDatasetInput(Value value) {
1057 // TODO(haoliang): add an interface for DatasetOp to avoid the following
1058 // enumeration.
1059 // Iteratively tracing upwards if parent op is `IdentityOp` or `IdentityNOp`.
1060 while (
1061 llvm::isa_and_nonnull<IdentityOp, IdentityNOp>(value.getDefiningOp())) {
1062 value = value.getDefiningOp()->getOperand(
1063 value.cast<OpResult>().getResultNumber());
1064 }
1065
1066 Operation* op = value.getDefiningOp();
1067 if (!llvm::isa_and_nonnull<BatchDatasetV2Op, MapDatasetOp, RepeatDatasetOp,
1068 ParallelMapDatasetOp, ParallelMapDatasetV2Op,
1069 TakeDatasetOp, TakeWhileDatasetOp>(op))
1070 return DatasetInput{nullptr, nullptr};
1071
1072 return DatasetInput{op->getAttrOfType<ArrayAttr>("output_shapes"),
1073 op->getAttrOfType<ArrayAttr>("output_types")};
1074 }
1075
InferShapeForDatasetOpCommon(Operation * op,FuncOp f,int64_t max_iterations)1076 bool ShapeInference::InferShapeForDatasetOpCommon(Operation* op, FuncOp f,
1077 int64_t max_iterations) {
1078 int N = op->getNumOperands() - 1;
1079 int M = f.getNumArguments() - N;
1080 DCOMMENT_OP(op, "Inferring shape for with N = " << N << " and M = " << M);
1081
1082 // Initialize with function input types.
1083 auto input_types = llvm::to_vector<1>(
1084 cast<FunctionOpInterface>(f.getOperation()).getArgumentTypes());
1085
1086 DatasetInput input_elements = GetDatasetInput(op->getOperand(0));
1087 if (!input_elements) {
1088 op->emitWarning("unexpected dataset input; skipping function refinement");
1089 return false;
1090 }
1091
1092 // Track if changed to skip enqueueing.
1093 bool changed = false;
1094 auto it = input_types.begin();
1095 // First set first M argument shapes & types.
1096 for (int i = 0; i < M; ++i) {
1097 Type t = GetType(input_elements.shapes[i], input_elements.types[i]);
1098 t = TypeMeet(*it, t);
1099 changed = changed || (t != *it);
1100 *it++ = t;
1101 }
1102 // Now the remaining N from operand types.
1103 for (auto t : llvm::drop_begin(op->getOperandTypes())) {
1104 auto meet = TypeMeet(*it, t);
1105 changed = changed || (meet != *it);
1106 *it++ = meet;
1107 }
1108 if (!changed) return false;
1109
1110 FailureOr<bool> res = PropagateShapeToFunctions(
1111 op->getParentOfType<ModuleOp>(), input_types, {f}, max_iterations);
1112 if (failed(res)) {
1113 op->emitOpError("propagating shapes failed");
1114 return false;
1115 }
1116 return *res;
1117 }
1118
InferShapeForMapDataset(MapDatasetOp op,int64_t max_iterations)1119 bool ShapeInference::InferShapeForMapDataset(MapDatasetOp op,
1120 int64_t max_iterations) {
1121 // MapDatasetOp's relationship with its associated function is as
1122 // follows: first M function params are dictated by the set
1123 // output shapes and types, the next N are the last Ninputs from MapDataset
1124 // op. The MapDataset op always has N+1 inputs.
1125 // TODO(jpienaar): Avoid this lookup.
1126 auto module = op->getParentOfType<ModuleOp>();
1127 auto f = module.lookupSymbol<func::FuncOp>(op.f());
1128 // Skip if function is not found or more than one caller.
1129 if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
1130 return InferShapeForDatasetOpCommon(op, f, max_iterations);
1131 }
1132
InferShapeForTakeWhileDataset(TakeWhileDatasetOp op,int64_t max_iterations)1133 bool ShapeInference::InferShapeForTakeWhileDataset(TakeWhileDatasetOp op,
1134 int64_t max_iterations) {
1135 // TakeWhileDatasetOp's relationship with its associated function is as
1136 // follows: first M function params are dictated by the set
1137 // output shapes and types, the next N are the last Ninputs from
1138 // TakeWhileDataset op. The TakeWhileDataset op always has N+1 inputs.
1139 // TODO(jpienaar): Avoid this lookup.
1140 auto module = op->getParentOfType<ModuleOp>();
1141 auto f = module.lookupSymbol<func::FuncOp>(op.predicate());
1142 // Skip if function is not found or more than one caller.
1143 if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
1144 return InferShapeForDatasetOpCommon(op, f, max_iterations);
1145 }
1146
InferShapeForReduceDataset(ReduceDatasetOp op,int64_t max_iterations)1147 bool ShapeInference::InferShapeForReduceDataset(ReduceDatasetOp op,
1148 int64_t max_iterations) {
1149 // ReduceDatasetOp's relationship with its associated reduce function is
1150 // described as follows: The reduce function will in general have (X + Y + Z)
1151 // arguments, where X is the number of tensor components that represent the
1152 // state, Y is the number of tensor components that represent the input
1153 // elements, and Z is the number of tensor components that represent any
1154 // captured arguments. Y is determined by the output_shapes of an op that
1155 // defines the first operand of `op`.
1156
1157 // TODO(jpienaar): Avoid this lookup.
1158 auto module = op->getParentOfType<ModuleOp>();
1159 auto f = module.lookupSymbol<func::FuncOp>(op.f());
1160
1161 // Skip if function is not found or it has more than one caller.
1162 if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
1163
1164 DatasetInput input_elements = GetDatasetInput(op.input_dataset());
1165
1166 const int num_states = op.output_shapes().size();
1167 const int num_captured_arguments = op.getNumOperands() - 1 - num_states;
1168
1169 // If input_elements is undefined, we can still infer the shapes for the
1170 // states and captured arguments.
1171 int num_input_elements;
1172 auto input_types = llvm::to_vector<1>(
1173 cast<FunctionOpInterface>(f.getOperation()).getArgumentTypes());
1174 if (input_elements) {
1175 num_input_elements = input_elements.shapes.size();
1176 } else {
1177 num_input_elements =
1178 input_types.size() - num_states - num_captured_arguments;
1179 }
1180
1181 DCOMMENT_OP(op,
1182 "Inferring shape for ReduceDataset with #states = "
1183 << num_states << " , #input_elements = " << num_input_elements
1184 << " , and #captured_arguments = " << num_captured_arguments);
1185 if (num_states + num_input_elements + num_captured_arguments !=
1186 f.getNumArguments()) {
1187 op->emitOpError(
1188 "propagating shapes for ReduceDataset failed due to inconsistent "
1189 "number of arguments");
1190 return false;
1191 }
1192
1193 // Track if changed to skip enqueueing.
1194 bool changed = false;
1195 auto it = input_types.begin();
1196
1197 // Set the first num_states arguments shapes & types from the state.
1198 for (int i = 0; i < num_states; ++i) {
1199 Type t = GetType(op.output_shapes()[i], op.output_types()[i]);
1200 t = TypeMeet(*it, t);
1201 changed = changed || (t != *it);
1202 *it++ = t;
1203 }
1204
1205 // Second set the following num_input_elements arguments from
1206 // repeat_dataset_op. Skip propagating shape if input_elements is
1207 // undefined.
1208 for (int i = 0; i < num_input_elements; ++i) {
1209 if (input_elements) {
1210 Type t = GetType(input_elements.shapes[i], input_elements.types[i]);
1211 t = TypeMeet(*it, t);
1212 changed = changed || (t != *it);
1213 *it++ = t;
1214 } else {
1215 it++;
1216 }
1217 }
1218
1219 // Last set the remaining num_captured_arguments from op.
1220 for (auto t : llvm::drop_begin(op.getOperandTypes(), 1 + num_states)) {
1221 auto meet = TypeMeet(*it, t);
1222 changed = changed || (meet != *it);
1223 *it++ = meet;
1224 }
1225
1226 if (!changed) return false;
1227
1228 FailureOr<bool> res =
1229 PropagateShapeToFunctions(module, input_types, {f}, max_iterations);
1230 if (failed(res)) {
1231 op->emitOpError("Propagating shapes for ReduceDataset failed");
1232 return false;
1233 }
1234 return *res;
1235 }
1236
InferShapeForTensorListInitOps(Operation * op)1237 bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
1238 DCOMMENT_OP(op, "Inferring shape for TensorList ");
1239 Value handle = op->getResult(0);
1240 Value initial_element_shape = GetElementShapeOperand(op);
1241 RankedTensorType element_type;
1242 if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) {
1243 // For TensorListFromTensor op we can infer element shape by dropping the
1244 // first dimension of input tensor.
1245 element_type = DropFirstDimension(tl_from_tensor.tensor().getType());
1246 if (!element_type || !element_type.hasStaticShape()) return false;
1247 }
1248 if (!CanInferTensorListElementType(handle, initial_element_shape,
1249 &element_type)) {
1250 DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
1251 return false;
1252 }
1253 DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred "
1254 << element_type);
1255 if (!element_type || !element_type.hasStaticShape()) return false;
1256 auto variant_type = VariantType::get(element_type, op->getContext());
1257 auto tensor_type = RankedTensorType::get({}, variant_type);
1258 bool changed = RefineResultType(op, handle, tensor_type);
1259 if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1260 return changed;
1261 }
1262
InferShapeForVarHandleOp(VarHandleOp op)1263 bool ShapeInference::InferShapeForVarHandleOp(VarHandleOp op) {
1264 DCOMMENT_OP(op, "Inferring shape for VarHandleOp");
1265
1266 Value resource = op.resource();
1267 if (!CanBeRefined(resource.getType())) return false;
1268
1269 // Make sure there are only use cases from the `AssignVariableOp` and
1270 // `ReadVariableOp`. For other cases, we can skip to be conservative.
1271 for (auto& use : make_early_inc_range(resource.getUses())) {
1272 Operation* def = use.getOwner();
1273 if (!llvm::isa<AssignVariableOp>(def) && !llvm::isa<ReadVariableOp>(def)) {
1274 return false;
1275 }
1276 }
1277
1278 bool changed = false;
1279
1280 // Look for any `AssignVariableOp` and `ReadVariableOp` that uses the value of
1281 // this op.
1282 for (auto& use : make_early_inc_range(resource.getUses())) {
1283 Operation* def = use.getOwner();
1284 Value value;
1285 if (AssignVariableOp assign_op = dyn_cast<AssignVariableOp>(def)) {
1286 value = assign_op.value();
1287 } else if (ReadVariableOp read_op = dyn_cast<ReadVariableOp>(def)) {
1288 value = read_op.value();
1289 } else {
1290 llvm_unreachable("unexpected operator type");
1291 }
1292
1293 TensorType resource_subtype = value.getType().cast<TensorType>();
1294 ResourceType resource_type =
1295 ResourceType::get({resource_subtype}, op.getContext());
1296 UnrankedTensorType new_resource_type =
1297 UnrankedTensorType::get(resource_type);
1298
1299 Type refined_type = TypeMeet(resource.getType(), new_resource_type);
1300 if (refined_type == resource.getType()) continue;
1301 resource.setType(refined_type);
1302 changed = true;
1303 }
1304
1305 return changed;
1306 }
1307
1308 // Helper function for creating a Window proto from user-supplied data.
1309 // Returns llvm::None if the user-supplied data was invalid.
InferWindowFromDimensions(llvm::SmallVector<int64_t> window_dimensions,llvm::SmallVector<int64_t> window_strides,llvm::SmallVector<std::pair<int64_t,int64_t>> padding,llvm::SmallVector<int64_t> lhs_dilation,llvm::SmallVector<int64_t> rhs_dilation)1310 llvm::Optional<xla::Window> InferWindowFromDimensions(
1311 llvm::SmallVector<int64_t> window_dimensions,
1312 llvm::SmallVector<int64_t> window_strides,
1313 llvm::SmallVector<std::pair<int64_t, int64_t>> padding,
1314 llvm::SmallVector<int64_t> lhs_dilation,
1315 llvm::SmallVector<int64_t> rhs_dilation) {
1316 const auto verify_size = [&](const size_t x, const char* x_name) {
1317 if (x == 0 || x == window_dimensions.size()) {
1318 return true;
1319 } else {
1320 llvm::errs()
1321 << "Window has different number of window dimensions than of "
1322 << x_name
1323 << "\nNumber of window dimensions: " << window_dimensions.size()
1324 << "\nNumber of " << x_name << ": " << x << "\n";
1325 return false;
1326 }
1327 };
1328
1329 if (!(verify_size(window_dimensions.size(), "window_dimensions") &&
1330 verify_size(window_strides.size(), "window strides") &&
1331 verify_size(padding.size(), "padding entries") &&
1332 verify_size(lhs_dilation.size(), "lhs dilation factors") &&
1333 verify_size(rhs_dilation.size(), "rhs dilation factors")))
1334 return llvm::None;
1335
1336 xla::Window window;
1337 for (size_t i = 0; i < window_dimensions.size(); i++) {
1338 auto dim = window.add_dimensions();
1339 dim->set_size(window_dimensions[i]);
1340 if (!window_strides.empty()) {
1341 dim->set_stride(window_strides[i]);
1342 } else {
1343 dim->set_stride(1);
1344 }
1345 if (!padding.empty()) {
1346 dim->set_padding_low(padding[i].first);
1347 dim->set_padding_high(padding[i].second);
1348 } else {
1349 dim->set_padding_low(0);
1350 dim->set_padding_high(0);
1351 }
1352 if (!lhs_dilation.empty()) {
1353 dim->set_base_dilation(lhs_dilation[i]);
1354 } else {
1355 dim->set_base_dilation(1);
1356 }
1357 if (!rhs_dilation.empty()) {
1358 dim->set_window_dilation(rhs_dilation[i]);
1359 } else {
1360 dim->set_window_dilation(1);
1361 }
1362 dim->set_window_reversal(false);
1363 }
1364 return window;
1365 }
1366
InferWindowOutputShape(const ShapedType & base_shape,const xla::Window & window,Type element_type)1367 llvm::Optional<RankedTensorType> InferWindowOutputShape(
1368 const ShapedType& base_shape, const xla::Window& window,
1369 Type element_type) {
1370 if (window.dimensions_size() != base_shape.getRank()) {
1371 llvm::errs() << "Window has dimension " << window.dimensions_size()
1372 << " but base shape has dimension " << base_shape.getRank()
1373 << "\n";
1374 return llvm::None;
1375 }
1376
1377 std::vector<int64_t> output_dimensions(window.dimensions_size());
1378 std::vector<bool> output_is_dynamic(window.dimensions_size());
1379 for (int64_t i = 0; i < window.dimensions_size(); ++i) {
1380 const auto& dim = window.dimensions(i);
1381 if (dim.size() <= 0) {
1382 llvm::errs() << "Window " << window.DebugString()
1383 << " has a non-positive dimension.\n";
1384 return llvm::None;
1385 }
1386 if (dim.stride() <= 0) {
1387 llvm::errs() << "Window " << window.DebugString()
1388 << " has a non-positive stride.\n";
1389 return llvm::None;
1390 }
1391 if (dim.base_dilation() < 1) {
1392 llvm::errs() << "Window " << window.DebugString()
1393 << " has a non-positive base area dilation factor.\n";
1394 return llvm::None;
1395 }
1396 if (dim.window_dilation() < 1) {
1397 llvm::errs() << "Window " << window.DebugString()
1398 << " has a non-positive window dilation factor.\n";
1399 return llvm::None;
1400 }
1401
1402 if (base_shape.isDynamicDim(i)) {
1403 output_dimensions[i] = ShapedType::kDynamicSize;
1404 } else {
1405 const int64_t dilated_base = xla::window_util::DilatedBound(
1406 base_shape.getDimSize(i), dim.base_dilation());
1407 const int64_t padded_dilated_base =
1408 dim.padding_low() + dilated_base + dim.padding_high();
1409 const int64_t dilated_window =
1410 xla::window_util::DilatedBound(dim.size(), dim.window_dilation());
1411
1412 output_dimensions[i] = xla::window_util::StridedBound(
1413 padded_dilated_base, dilated_window, dim.stride());
1414 }
1415 }
1416
1417 return RankedTensorType::get(output_dimensions, element_type);
1418 }
1419
InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op)1420 bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) {
1421 DCOMMENT_OP(op, "Inferring shape for XlaReduceWindowOp");
1422
1423 bool changed = false;
1424
1425 auto input_ty = op.input().getType().cast<ShapedType>();
1426 DenseElementsAttr window_dimensions, window_strides, base_dilations,
1427 window_dilations, padding;
1428 if (input_ty.hasStaticShape() &&
1429 matchPattern(op.window_dimensions(), m_Constant(&window_dimensions)) &&
1430 matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
1431 matchPattern(op.base_dilations(), m_Constant(&base_dilations)) &&
1432 matchPattern(op.window_dilations(), m_Constant(&window_dilations)) &&
1433 matchPattern(op.padding(), m_Constant(&padding))) {
1434 llvm::SmallVector<int64_t> window_dimensions_vec, window_strides_vec,
1435 base_dilations_vec, window_dilations_vec;
1436 llvm::SmallVector<std::pair<int64_t, int64_t>> padding_pairs(
1437 padding.getNumElements() / 2);
1438
1439 for (auto i = 0; i < window_dimensions.size(); ++i) {
1440 window_dimensions_vec.push_back(
1441 window_dimensions.getValues<IntegerAttr>()[i].getInt());
1442 }
1443
1444 for (auto i = 0; i < window_strides.size(); ++i) {
1445 window_strides_vec.push_back(
1446 window_strides.getValues<IntegerAttr>()[i].getInt());
1447 }
1448
1449 for (auto i = 0; i < base_dilations.size(); ++i) {
1450 base_dilations_vec.push_back(
1451 base_dilations.getValues<IntegerAttr>()[i].getInt());
1452 }
1453
1454 for (auto i = 0; i < window_dilations.size(); ++i) {
1455 window_dilations_vec.push_back(
1456 window_dilations.getValues<IntegerAttr>()[i].getInt());
1457 }
1458
1459 for (auto i = 0; i < padding_pairs.size(); ++i) {
1460 padding_pairs[i] = {padding.getValues<IntegerAttr>()[i * 2].getInt(),
1461 padding.getValues<IntegerAttr>()[i * 2 + 1].getInt()};
1462 }
1463
1464 auto window = InferWindowFromDimensions(
1465 window_dimensions_vec, window_strides_vec, padding_pairs,
1466 base_dilations_vec, window_dilations_vec);
1467 if (!window) {
1468 op->emitOpError("failed to create window");
1469 }
1470 auto output_shape = InferWindowOutputShape(
1471 input_ty, window.getValue(),
1472 op.init_value().getType().cast<ShapedType>().getElementType());
1473
1474 if (!output_shape) {
1475 op->emitOpError("failed to infer output shape");
1476 }
1477
1478 changed = RefineResultType(op.getOperation(), op.getResult(),
1479 output_shape.getValue());
1480 }
1481
1482 return changed;
1483 }
1484
InferShapeForXlaSelectAndScatterOp(XlaSelectAndScatterOp op)1485 bool ShapeInference::InferShapeForXlaSelectAndScatterOp(
1486 XlaSelectAndScatterOp op) {
1487 DCOMMENT_OP(op, "Inferring shape for XlaSelectAndScatterOp");
1488
1489 auto operand_shape = op.operand().getType().cast<ShapedType>();
1490 auto source_shape = op.source().getType().cast<ShapedType>();
1491 DenseElementsAttr window_dimensions, window_strides, padding;
1492 if (operand_shape.hasRank() && source_shape.hasRank() &&
1493 matchPattern(op.window_dimensions(), m_Constant(&window_dimensions)) &&
1494 matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
1495 matchPattern(op.padding(), m_Constant(&padding))) {
1496 llvm::SmallVector<int64_t> window_dimensions_vec, window_strides_vec,
1497 base_dilations_vec, window_dilations_vec;
1498 llvm::SmallVector<std::pair<int64_t, int64_t>> padding_pairs(
1499 padding.getNumElements() / 2);
1500
1501 for (auto i = 0; i < window_dimensions.size(); ++i) {
1502 window_dimensions_vec.push_back(
1503 window_dimensions.getValues<IntegerAttr>()[i].getInt());
1504 }
1505
1506 for (auto i = 0; i < window_strides.size(); ++i) {
1507 window_strides_vec.push_back(
1508 window_strides.getValues<IntegerAttr>()[i].getInt());
1509 }
1510
1511 for (auto i = 0; i < padding_pairs.size(); ++i) {
1512 padding_pairs[i] = {padding.getValues<IntegerAttr>()[i * 2].getInt(),
1513 padding.getValues<IntegerAttr>()[i * 2 + 1].getInt()};
1514 }
1515
1516 auto window = InferWindowFromDimensions(
1517 window_dimensions_vec, window_strides_vec, padding_pairs, {}, {});
1518 if (!window) {
1519 op->emitOpError("failed to create window");
1520 }
1521 auto window_result_shape = InferWindowOutputShape(
1522 operand_shape, window.getValue(), operand_shape.getElementType());
1523
1524 if (!window_result_shape) {
1525 op->emitOpError("failed to infer window result shape");
1526 }
1527
1528 if (window_result_shape.getValue() != source_shape) {
1529 op->emitOpError(
1530 "Source shape does not match the shape of window-reduced operand.");
1531 }
1532 }
1533
1534 return RefineResultType(op.getOperation(), op.getResult(),
1535 op.operand().getType());
1536 }
1537
InferXlaConvOutputShape(llvm::SmallVector<int64_t> input_tensor_dims,llvm::SmallVector<int64_t> kernel_tensor_dims,llvm::SmallVector<int64_t> window_strides,llvm::SmallVector<std::pair<int64_t,int64_t>> paddings,llvm::SmallVector<int64_t> lhs_dilations,llvm::SmallVector<int64_t> rhs_dilations,int64_t batch_group_count,xla::ConvolutionDimensionNumbers dnums,Type element_type)1538 llvm::Optional<RankedTensorType> InferXlaConvOutputShape(
1539 llvm::SmallVector<int64_t> input_tensor_dims,
1540 llvm::SmallVector<int64_t> kernel_tensor_dims,
1541 llvm::SmallVector<int64_t> window_strides,
1542 llvm::SmallVector<std::pair<int64_t, int64_t>> paddings,
1543 llvm::SmallVector<int64_t> lhs_dilations,
1544 llvm::SmallVector<int64_t> rhs_dilations, int64_t batch_group_count,
1545 xla::ConvolutionDimensionNumbers dnums, Type element_type) {
1546 auto num_spatial_dims = input_tensor_dims.size() - 2;
1547 std::vector<int64_t> output_dims(input_tensor_dims.size());
1548
1549 auto input_batch = input_tensor_dims[dnums.input_batch_dimension()];
1550 auto kernel_output_feature =
1551 kernel_tensor_dims[dnums.kernel_output_feature_dimension()];
1552 output_dims[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1553 DCOMMENT("inferrd output batch dimension is "
1554 << output_dims[dnums.output_batch_dimension()]);
1555 output_dims[dnums.output_feature_dimension()] = kernel_output_feature;
1556 DCOMMENT("inferrd output output_feature_dimension is "
1557 << output_dims[dnums.output_feature_dimension()]);
1558
1559 std::vector<int64_t> input_spatial_dims;
1560 llvm::SmallVector<int64_t> window_spatial_dims;
1561 for (auto i = 0; i < num_spatial_dims; ++i) {
1562 input_spatial_dims.push_back(
1563 input_tensor_dims[dnums.input_spatial_dimensions(i)]);
1564 window_spatial_dims.push_back(
1565 kernel_tensor_dims[dnums.kernel_spatial_dimensions(i)]);
1566 }
1567
1568 ShapedType base_shape =
1569 RankedTensorType::get(input_spatial_dims, element_type);
1570
1571 auto window =
1572 InferWindowFromDimensions(window_spatial_dims, window_strides, paddings,
1573 lhs_dilations, rhs_dilations);
1574
1575 auto output_shape =
1576 InferWindowOutputShape(base_shape, window.getValue(), element_type);
1577
1578 for (auto i = 0; i < num_spatial_dims; ++i) {
1579 output_dims[dnums.output_spatial_dimensions(i)] =
1580 output_shape.getValue().getShape()[i];
1581 DCOMMENT("inferrd output spatial dimension "
1582 << i << " at dimension numebr "
1583 << dnums.output_spatial_dimensions(i) << " is "
1584 << output_dims[dnums.output_spatial_dimensions(i)]);
1585 }
1586 return RankedTensorType::get(output_dims, element_type);
1587 }
1588
1589 // TODO(hanxiongwang): The logic in this function need move to Op Verify method
1590 // when dependecy issue of adding header file
1591 // "third_party/tensorflow/compiler/xla/xla_data.pb.h" into
1592 // "third_party/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc" is
1593 // resolved
PrecheckForXlaConvV2Op(XlaConvV2Op op)1594 LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) {
1595 auto input_tensor = op.lhs();
1596 auto kernel_tensor = op.rhs();
1597 auto window_strides = op.window_strides();
1598 auto padding = op.padding();
1599 auto lhs_dilation = op.lhs_dilation();
1600 auto rhs_dilation = op.rhs_dilation();
1601 auto feature_group_count = op.feature_group_count();
1602 int64_t batch_group_count = op.batch_group_count();
1603
1604 auto input_args_have_static_shape = [&]() -> bool {
1605 return input_tensor.getType().cast<TensorType>().hasStaticShape() &&
1606 kernel_tensor.getType().cast<TensorType>().hasStaticShape() &&
1607 window_strides.getType().cast<TensorType>().hasStaticShape() &&
1608 padding.getType().cast<TensorType>().hasStaticShape() &&
1609 lhs_dilation.getType().cast<TensorType>().hasStaticShape() &&
1610 rhs_dilation.getType().cast<TensorType>().hasStaticShape() &&
1611 feature_group_count.getType().cast<TensorType>().hasStaticShape();
1612 };
1613
1614 // Return failure when one of the input args has not a static shape
1615 if (!input_args_have_static_shape()) {
1616 return failure();
1617 }
1618
1619 auto input_tensor_shape =
1620 input_tensor.getType().cast<RankedTensorType>().getShape();
1621 auto kernel_tensor_shape =
1622 kernel_tensor.getType().cast<RankedTensorType>().getShape();
1623
1624 if (input_tensor_shape.size() <= 2) {
1625 return op.emitOpError()
1626 << "input tensor argument is " << input_tensor_shape.size()
1627 << " which is invalid, since input tensor argument must has a "
1628 << "rank greater than 2.\n";
1629 }
1630
1631 if (kernel_tensor_shape.size() <= 2) {
1632 return op.emitOpError()
1633 << "kernel tensor argument is " << kernel_tensor_shape.size()
1634 << " which is invalid, since kernel tensor argument must has a "
1635 << "rank greater than 2.\n";
1636 }
1637
1638 if (input_tensor_shape.size() != kernel_tensor_shape.size()) {
1639 return op.emitOpError() << "both input tensor and kernel tensor must "
1640 << "have same number of dimensions.\n";
1641 }
1642
1643 DenseElementsAttr feature_group_count_attr;
1644 xla::ConvolutionDimensionNumbers dnums;
1645 dnums.ParseFromString(op.dimension_numbersAttr().getValue().str());
1646 if (dnums.input_spatial_dimensions_size() !=
1647 dnums.kernel_spatial_dimensions_size()) {
1648 return op.emitOpError() << "Both arguments to convolution must have "
1649 << "same number of dimensions.\n";
1650 }
1651
1652 if (dnums.input_spatial_dimensions_size() !=
1653 dnums.output_spatial_dimensions_size()) {
1654 return op.emitOpError() << "Both input and output of convolution must have "
1655 << "same number of dimensions.\n";
1656 }
1657 if (!matchPattern(feature_group_count,
1658 m_Constant(&feature_group_count_attr))) {
1659 return success();
1660 }
1661
1662 auto feature_group_count_val =
1663 feature_group_count_attr.getValues<IntegerAttr>()[0].getInt();
1664 auto input_features = input_tensor_shape[dnums.input_feature_dimension()];
1665 auto input_batch = input_tensor_shape[dnums.input_batch_dimension()];
1666 auto kernel_input_features =
1667 kernel_tensor_shape[dnums.kernel_input_feature_dimension()];
1668 auto kernel_output_features =
1669 kernel_tensor_shape[dnums.kernel_output_feature_dimension()];
1670
1671 if (feature_group_count_val <= 0) {
1672 return op.emitOpError()
1673 << "feature_group_count must be a positive number, got "
1674 << feature_group_count_val;
1675 }
1676
1677 if (batch_group_count <= 0) {
1678 return op.emitOpError()
1679 << "batch_group_count must be a positive number, got "
1680 << batch_group_count;
1681 }
1682 if (batch_group_count > 1 && feature_group_count_val > 1) {
1683 return op.emitOpError()
1684 << "both batch_group_count " << batch_group_count
1685 << "and feature_group_count " << feature_group_count_val
1686 << " cannot be greater than 1";
1687 }
1688 if (kernel_output_features % batch_group_count != 0) {
1689 return op.emitOpError()
1690 << "Expected output feature dimension size (value "
1691 << kernel_output_features
1692 << ") to be a multiple of batch group count " << batch_group_count;
1693 }
1694 if (input_features % feature_group_count_val != 0 ||
1695 input_features / feature_group_count_val != kernel_input_features) {
1696 return op.emitOpError()
1697 << "Expected the size of kernel_input_features (value "
1698 << kernel_input_features
1699 << ") in rhs times feature_group_count (value "
1700 << feature_group_count_val
1701 << ") in lhs should equal the size of the z dimension (value "
1702 << input_features << ") in lhs.\n";
1703 }
1704 if (kernel_output_features % feature_group_count_val > 0) {
1705 return op.emitOpError() << "Expected output feature dimension (value "
1706 << kernel_output_features << ") to be divisible by "
1707 << "feature_group_count (value "
1708 << feature_group_count_val << ").\n";
1709 }
1710 if (input_batch % batch_group_count != 0) {
1711 return op.emitOpError()
1712 << "Expected input batch dimension (value " << input_batch
1713 << " ) to be divisible by batch_group_count (value "
1714 << batch_group_count << "); ";
1715 }
1716 return success();
1717 }
1718
InferShapeForXlaConvV2Op(XlaConvV2Op op)1719 bool ShapeInference::InferShapeForXlaConvV2Op(XlaConvV2Op op) {
1720 DCOMMENT_OP(op, "Inferring shape for XlaConvV2Op");
1721
1722 bool changed = false;
1723
1724 if (PrecheckForXlaConvV2Op(op).failed()) {
1725 return changed;
1726 }
1727
1728 auto input_tensor = op.lhs();
1729 auto kernel_tensor = op.rhs();
1730 auto window_strides = op.window_strides();
1731 auto padding = op.padding();
1732 auto lhs_dilation = op.lhs_dilation();
1733 auto rhs_dilation = op.rhs_dilation();
1734 int64_t batch_group_count = op.batch_group_count();
1735
1736 DenseIntElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr,
1737 rhs_dilation_attr;
1738 if (matchPattern(window_strides, m_Constant(&window_strides_attr)) &&
1739 matchPattern(padding, m_Constant(&padding_attr)) &&
1740 matchPattern(lhs_dilation, m_Constant(&lhs_dilation_attr)) &&
1741 matchPattern(rhs_dilation, m_Constant(&rhs_dilation_attr))) {
1742 llvm::SmallVector<int64_t> input_tensor_dims_vec, kernel_tensor_dims_vec,
1743 window_strides_vec, lhs_dilations_vec, rhs_dilations_vec;
1744 llvm::SmallVector<std::pair<int64_t, int64_t>> padding_pairs(
1745 padding_attr.getNumElements() / 2);
1746 xla::ConvolutionDimensionNumbers dnums;
1747 dnums.ParseFromString(op.dimension_numbersAttr().getValue().str());
1748
1749 auto input_tensor_shape = input_tensor.getType().cast<RankedTensorType>();
1750 for (auto i = 0; i < input_tensor_shape.getShape().size(); ++i) {
1751 DCOMMENT("Input Tensor Shape " << i << "th is "
1752 << input_tensor_shape.getShape()[i]);
1753 input_tensor_dims_vec.push_back(input_tensor_shape.getShape()[i]);
1754 }
1755
1756 auto kernel_tensor_shape = kernel_tensor.getType().cast<RankedTensorType>();
1757 for (auto i = 0; i < kernel_tensor_shape.getShape().size(); ++i) {
1758 DCOMMENT("Kernel tensor Shape" << i << "th is "
1759 << kernel_tensor_shape.getShape()[i]);
1760 kernel_tensor_dims_vec.push_back(kernel_tensor_shape.getShape()[i]);
1761 }
1762
1763 for (const llvm::APInt& i : window_strides_attr) {
1764 window_strides_vec.push_back(i.getSExtValue());
1765 }
1766
1767 for (auto i = 0; i < padding_pairs.size(); ++i) {
1768 padding_pairs[i] = {
1769 padding_attr.getValues<IntegerAttr>()[i * 2].getInt(),
1770 padding_attr.getValues<IntegerAttr>()[i * 2 + 1].getInt()};
1771 }
1772
1773 for (const llvm::APInt& i : lhs_dilation_attr) {
1774 lhs_dilations_vec.push_back(i.getSExtValue());
1775 }
1776
1777 for (const llvm::APInt& i : rhs_dilation_attr) {
1778 rhs_dilations_vec.push_back(i.getSExtValue());
1779 }
1780
1781 Type input_tensor_element_type = input_tensor_shape.getElementType();
1782 Type result_element_type = op.getType().getElementType();
1783 Type element_type = input_tensor_element_type.getIntOrFloatBitWidth() >=
1784 result_element_type.getIntOrFloatBitWidth()
1785 ? input_tensor_element_type
1786 : result_element_type;
1787 auto output_shape = InferXlaConvOutputShape(
1788 input_tensor_dims_vec, kernel_tensor_dims_vec, window_strides_vec,
1789 padding_pairs, lhs_dilations_vec, rhs_dilations_vec, batch_group_count,
1790 dnums, element_type);
1791
1792 if (output_shape.getValue()) {
1793 changed = RefineResultType(op.getOperation(), op.getResult(),
1794 output_shape.getValue());
1795 return changed;
1796 }
1797 }
1798 return changed;
1799 }
1800
RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti)1801 bool ShapeInference::RefineWithInferTypeOpInterface(
1802 InferTypeOpInterface infer_ti) {
1803 Operation* op = infer_ti.getOperation();
1804 SmallVector<Type, 4> inferred;
1805 LogicalResult res = infer_ti.inferReturnTypes(
1806 op->getContext(), op->getLoc(), op->getOperands(),
1807 op->getAttrDictionary(), op->getRegions(), inferred);
1808 if (failed(res)) {
1809 op->emitOpError("failed to refine type as inference failed");
1810 return false;
1811 }
1812
1813 if (inferred == op->getResultTypes()) return false;
1814
1815 // Map each of the results of the call to the returned type of the
1816 // function.
1817 bool changed = false;
1818 for (auto result : zip(op->getResults(), inferred)) {
1819 if (std::get<0>(result).getType() == std::get<1>(result)) continue;
1820
1821 if (!UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result),
1822 std::get<0>(result)))
1823 continue;
1824 changed = true;
1825 }
1826 return changed;
1827 }
1828
ComputeOutputAsShape(OpResult result,InferenceContext * ic)1829 ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
1830 InferenceContext* ic) {
1831 LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
1832 auto rt = result.getType().dyn_cast<RankedTensorType>();
1833 if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
1834 int dim_size = rt.getDimSize(0);
1835
1836 // Worklist to direct partial evaluation.
1837 SmallVector<ValuePort, 4> worklist;
1838
1839 // Simple evaluator that attempts to partially evaluate the input value even
1840 // if unable to evaluate the complete output. Below follows a simple stack
1841 // based evaluation where it queries what operands/part of operands need to
1842 // be evaluated and attempting to partially evaluate those operands. It does
1843 // so by pushing the operands that need to be required on to the worklist
1844 // before enqueuing the operation requiering those values.
1845 std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim());
1846 for (unsigned int i = 0, e = dims.size(); i != e; ++i) {
1847 LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n");
1848
1849 worklist.push_back(
1850 ValuePort{result.getOwner(), {result.getResultNumber(), i}});
1851 while (!worklist.empty()) {
1852 auto front = worklist.pop_back_val();
1853 LLVM_DEBUG(front.print(llvm::dbgs() << "\nWorklist front "));
1854
1855 SmallVector<ValuePort, 4> inputs;
1856 auto res = ComputeInputsRequiredForOutput(front, &inputs);
1857 if (failed(res)) {
1858 // Abort if unable to find which required inputs need to be computed.
1859 worklist.clear();
1860 break;
1861 }
1862
1863 if (!inputs.empty()) {
1864 // Enqueue required computation followed by its required operands in
1865 // stack.
1866 worklist.push_back(std::move(front));
1867 for (auto& it : inputs) worklist.push_back(std::move(it));
1868 continue;
1869 }
1870
1871 auto ret = ComputeOutputComponent(front);
1872 if (!ret) continue;
1873
1874 LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
1875
1876 // If worklist is empty, then this is the root query op.
1877 if (worklist.empty()) {
1878 LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
1879 if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
1880 if (dea.getNumElements() != 1) {
1881 LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n");
1882 return {};
1883 }
1884 int64_t val = (*dea.getValues<APInt>().begin()).getSExtValue();
1885 dims[i] = ic->MakeDim(val);
1886 }
1887 }
1888 }
1889 }
1890 return ic->MakeShape(dims);
1891 }
1892
RefineTypeForPassThroughOperands(Operation * op,OperandRange operands,ResultRange results)1893 bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
1894 OperandRange operands,
1895 ResultRange results) {
1896 bool changed = false;
1897 for (auto entry : llvm::zip(operands, results)) {
1898 Type operand_type = std::get<0>(entry).getType();
1899 Value result = std::get<1>(entry);
1900 TensorType result_type = result.getType().cast<TensorType>();
1901 Type inferred_type = TypeMeet(result_type, operand_type);
1902 if (result_type == inferred_type) continue;
1903
1904 if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1905 continue;
1906 changed = true;
1907 }
1908 return changed;
1909 }
1910
RefineShapeForPassThroughOps(Operation * op)1911 bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
1912 DCOMMENT_OP(op, "Pass through op");
1913 bool changed = false;
1914 for (auto entry : llvm::zip(op->getOperands(), op->getResults())) {
1915 Value operand = std::get<0>(entry);
1916 Value result = std::get<1>(entry);
1917 Type inferred_type = TypeMeet(result.getType(), operand.getType());
1918 if (result.getType() == inferred_type) continue;
1919 if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1920 continue;
1921 changed = true;
1922 }
1923 return changed;
1924 }
1925
InferShapeForNonTFDialectOperation(Operation * op)1926 bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
1927 if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) {
1928 return RefineTypeForPassThroughOperands(
1929 graph_op.GetFetch(), graph_op.GetFetch().fetches(), op->getResults());
1930 }
1931 if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) {
1932 return RefineTypeForPassThroughOperands(
1933 island_op.GetYield(), island_op.GetYield().fetches(), op->getResults());
1934 }
1935 if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
1936 auto iter_source = cast<tf_executor::NextIterationSourceOp>(
1937 iter_sink.token().getDefiningOp());
1938 return RefineTypeForPassThroughOperands(
1939 op, iter_sink.getOperands().drop_front().take_front(),
1940 iter_source.getResults());
1941 }
1942 if (auto launch_op = dyn_cast<tf_device::LaunchOp>(op)) {
1943 auto terminator = launch_op.GetBody().getTerminator();
1944 return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1945 op->getResults());
1946 }
1947 if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
1948 auto terminator = cluster_op.GetBody().getTerminator();
1949 return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1950 op->getResults());
1951 }
1952 if (op->hasTrait<OpTrait::SameOperandsAndResultShape>())
1953 return RefineShapeForPassThroughOps(op);
1954 if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1955 if (isa<tensor::CastOp>(op)) return InferShapeForCast(op);
1956 return false;
1957 }
1958
1959 // Finds element type to be used for result from operand, with special handling
1960 // for handle types.
GetElementTypeFromOperand(TensorType operand_type,TensorType result_type)1961 Type GetElementTypeFromOperand(TensorType operand_type,
1962 TensorType result_type) {
1963 auto operand_handle_type =
1964 operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
1965 if (!operand_handle_type) return result_type.getElementType();
1966 auto result_handle_type =
1967 result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
1968 if (operand_handle_type.GetSubtypes().empty() ||
1969 !result_handle_type.GetSubtypes().empty())
1970 return result_type.getElementType();
1971 return operand_handle_type;
1972 }
1973
1974 // Checks if one tensor type can refine another type for tf.While/
1975 // tf.WhileRegion. If rank differs or static dimensions can be lost, the other
1976 // type cannot be used for refinement.
CanWhileTypeBeRefinedWith(TensorType current_type,TensorType potential_refined_type)1977 bool CanWhileTypeBeRefinedWith(TensorType current_type,
1978 TensorType potential_refined_type) {
1979 if (!current_type.hasRank()) return true;
1980 if (!potential_refined_type.hasRank()) return false;
1981 if (current_type.getRank() != potential_refined_type.getRank()) return false;
1982 for (auto dim :
1983 llvm::zip(current_type.getShape(), potential_refined_type.getShape())) {
1984 int64_t current_dim = std::get<0>(dim);
1985 int64_t potential_refined_dim = std::get<1>(dim);
1986 if (current_dim != potential_refined_dim &&
1987 current_dim != ShapedType::kDynamicSize)
1988 return false;
1989 }
1990 return true;
1991 }
1992
1993 template <typename WhileOpTy>
InferShapeForWhile(WhileOpTy op,TypeRange body_result_types)1994 bool ShapeInference::InferShapeForWhile(WhileOpTy op,
1995 TypeRange body_result_types) {
1996 if (!op.shape_invariant())
1997 return RefineTypeForPassThroughOperands(op, op.input(), op.output());
1998
1999 bool changed = false;
2000 for (auto entry :
2001 zip(op.input().getTypes(), op.output(), body_result_types)) {
2002 Value result = std::get<1>(entry);
2003 TensorType body_result_type =
2004 std::get<2>(entry).template cast<TensorType>();
2005 auto result_type = result.getType().cast<TensorType>();
2006
2007 Type potential_refined_type;
2008 if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) {
2009 Type element_type =
2010 GetElementTypeFromOperand(body_result_type, result_type);
2011 potential_refined_type = CreateTensorType(
2012 body_result_type.hasRank() ? body_result_type.getShape()
2013 : llvm::Optional<ArrayRef<int64_t>>(),
2014 element_type);
2015 } else {
2016 TensorType operand_type = std::get<0>(entry).template cast<TensorType>();
2017 Type element_type = GetElementTypeFromOperand(operand_type, result_type);
2018 potential_refined_type = CreateTensorType(
2019 result_type.hasRank() ? result_type.getShape()
2020 : llvm::Optional<ArrayRef<int64_t>>(),
2021 element_type);
2022 }
2023 changed |= RefineResultType(op, result, potential_refined_type);
2024 }
2025 return changed;
2026 }
2027
InferShapeForSingleOperation(Operation * op,int64_t max_iterations)2028 bool ShapeInference::InferShapeForSingleOperation(Operation* op,
2029 int64_t max_iterations) {
2030 LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
2031 llvm::dbgs() << "\n");
2032 assert(tf_dialect_ == op->getDialect());
2033 // The shape function of these ops sometimes does not propagate subtypes
2034 // (handle shapes) for resource and variant types. We use a simple passthrough
2035 // to make sure they are preserved in the output.
2036 if (isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp, TF::ZerosLikeOp>(
2037 op)) {
2038 return RefineTypeForPassThroughOperands(op, op->getOperands(),
2039 op->getResults());
2040 }
2041
2042 // The shape inference function for `ReduceDatasetOp` should always be
2043 // executed regardless of whether the result type can be refined.
2044 if (auto reduce_dataset_op = dyn_cast<ReduceDatasetOp>(op)) {
2045 // TODO(jpienaar): The output type of these ops need to be refined.
2046 return InferShapeForReduceDataset(reduce_dataset_op, max_iterations);
2047 }
2048
2049 // If no result for this op needs shape inference, we have a fast-path return.
2050 // But if the type is a resource/variant, we do not skip it because we might
2051 // not have the handle shapes.
2052 if (none_of(op->getResultTypes(), CanBeRefined)) {
2053 LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
2054 << op->getName() << "'.\n");
2055 return false;
2056 }
2057
2058 if (isa<TF::RestoreOp, TF::RestoreV2Op>(op)) return InferShapeForRestore(op);
2059
2060 // Handle call operations by looking up callee and inferring return shape as
2061 // needed.
2062 if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
2063
2064 // tf.Cast is only inferred if it has at least one user in the TF dialect or
2065 // feeding into the function return. This is necessary to avoid inserting
2066 // casts which cannot be refined.
2067 if (isa<CastOp>(op)) return InferShapeForCast(op);
2068
2069 // Handle IfOp here by inferring the shape from the else/then function
2070 // results. Since `output_shapes` is a derived attribute, avoid going down the
2071 // TF InferenceContext path as IfOp shape inference is implemented as just
2072 // a lookup of the output_shapes attribute.
2073 if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
2074
2075 // Handle IfRegion operations by inferring return shape from the then and else
2076 // branches.
2077 if (auto if_region = dyn_cast<IfRegionOp>(op))
2078 return InferShapeForIfRegion(if_region);
2079
2080 if (auto while_op = dyn_cast<WhileOp>(op))
2081 return InferShapeForWhile(
2082 while_op, while_op.body_function().getFunctionType().getResults());
2083
2084 if (auto while_region = dyn_cast<WhileRegionOp>(op))
2085 return InferShapeForWhile(
2086 while_region,
2087 while_region.body().front().getTerminator()->getOperandTypes());
2088
2089 if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) {
2090 return InferShapeForXlaHostComputeMlir(host_compute_op);
2091 }
2092
2093 // TODO(jpienaar): Extract function input arg constraint interface.
2094 // TODO(jpienaar): Unify the shape propagation to functions using interface.
2095 if (auto map_dataset_op = dyn_cast<MapDatasetOp>(op)) {
2096 // TODO(jpienaar): The output type of these ops need to be refined.
2097 return InferShapeForMapDataset(map_dataset_op, max_iterations);
2098 }
2099
2100 if (auto takewhile_dataset_op = dyn_cast<TakeWhileDatasetOp>(op)) {
2101 // TODO(jpienaar): The output type of these ops need to be refined.
2102 return InferShapeForTakeWhileDataset(takewhile_dataset_op, max_iterations);
2103 }
2104
2105 // Handle TensorList init operations by inferring shape from TensorList write
2106 // operations. If we are unable to refine element shape here, proceed to use
2107 // the InferenceContext below to get more precise shapes.
2108 if (IsTensorListInitOp(op) && InferShapeForTensorListInitOps(op)) return true;
2109
2110 if (auto var_handle_op = dyn_cast<VarHandleOp>(op)) {
2111 return InferShapeForVarHandleOp(var_handle_op);
2112 }
2113
2114 if (auto xla_reduce_window_op = dyn_cast<XlaReduceWindowOp>(op)) {
2115 return InferShapeForXlaReduceWindowOp(xla_reduce_window_op);
2116 }
2117
2118 if (auto xla_select_and_scatter_op = dyn_cast<XlaSelectAndScatterOp>(op)) {
2119 return InferShapeForXlaSelectAndScatterOp(xla_select_and_scatter_op);
2120 }
2121
2122 if (auto xla_conv_v2_op = dyn_cast<XlaConvV2Op>(op)) {
2123 return InferShapeForXlaConvV2Op(xla_conv_v2_op);
2124 }
2125
2126 // Return operand as a constant attribute.
2127 auto operand_as_constant_fn = [&](Value operand) {
2128 ValuePort vp(operand);
2129 Attribute attr = ComputeOutputComponent(vp);
2130 if (!attr && matchPattern(operand, m_Constant(&attr)))
2131 RecordValue(vp, attr);
2132 return attr;
2133 };
2134
2135 // Return op result as a shape.
2136 auto op_result_as_shape_fn = [&](InferenceContext& context,
2137 OpResult op_result) {
2138 return ComputeOutputAsShape(op_result, &context);
2139 };
2140
2141 // Return result element type at `index`.
2142 auto result_element_type_fn = [&](int index) {
2143 return op->getResult(index).getType().cast<TensorType>().getElementType();
2144 };
2145
2146 llvm::SmallVector<ShapedTypeComponents, 4> inferred_return_shapes;
2147 if (failed(InferReturnTypeComponentsForTFOp(
2148 /*location=*/None, op, graph_version_, operand_as_constant_fn,
2149 op_result_as_shape_fn, result_element_type_fn,
2150 inferred_return_shapes)))
2151 return false;
2152
2153 // Update the shape for each of the operation result if the InferenceContext
2154 // has more precise shapes recorded.
2155 bool changed = false;
2156 for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) {
2157 Value op_result = std::get<0>(result);
2158 if (!CanBeRefined(op_result.getType())) continue;
2159
2160 ShapedTypeComponents inferred = std::get<1>(result);
2161 TensorType inferred_type;
2162 if (inferred.hasRank())
2163 inferred_type =
2164 RankedTensorType::get(inferred.getDims(), inferred.getElementType());
2165 else
2166 inferred_type = UnrankedTensorType::get(inferred.getElementType());
2167
2168 inferred_type =
2169 TypeMeet(op_result.getType(), inferred_type).cast<TensorType>();
2170 if (op_result.getType() == inferred_type) continue;
2171 if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result))
2172 continue;
2173 changed = true;
2174 }
2175
2176 if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
2177 return changed;
2178 }
2179
PropagateShapeToFunctions(ModuleOp module,TypeRange input_types,ArrayRef<func::FuncOp> functions,int64_t max_iterations)2180 FailureOr<bool> ShapeInference::PropagateShapeToFunctions(
2181 ModuleOp module, TypeRange input_types, ArrayRef<func::FuncOp> functions,
2182 int64_t max_iterations) {
2183 bool any_failure = false;
2184 bool any_nonconvergence = false;
2185 // If shape propagation fails for one function, return failure, but do not
2186 // early exit and attempt to propagate shapes for all provided functions to
2187 // have a best-effort propagation.
2188 for (func::FuncOp func : functions) {
2189 DCOMMENT("Propating shape to " << func.getName());
2190 ArrayRef<Operation*> callers = GetCallers(func);
2191 if (!llvm::hasSingleElement(callers) &&
2192 !llvm::all_of(callers.drop_front(), [&](Operation* caller) {
2193 /// TODO(aminim): this is overly conservative as some operations
2194 /// (like TPUPartitionedCallOp) may have extra operands that aren't
2195 /// propagated to the callee.
2196 return isa<CallOpInterface>(caller) &&
2197 std::equal(caller->getOperandTypes().begin(),
2198 caller->getOperandTypes().end(),
2199 callers.front()->getOperandTypes().begin());
2200 })) {
2201 if (llvm::any_of(callers, [](Operation* op) {
2202 return isa<IfOp, WhileOp, CaseOp>(op);
2203 }))
2204 func.emitWarning(formatv(
2205 "expected control flow function @{0} to have exactly 1 use, "
2206 "found {1}.",
2207 func.getName(), callers.size()));
2208
2209 continue;
2210 }
2211 FunctionType func_type = func.getFunctionType();
2212 func.setType(FunctionType::get(func.getContext(), input_types,
2213 func_type.getResults()));
2214
2215 FailureOr<bool> failure_or_converged =
2216 PropagateShapeToRegions(input_types, {&func.getBody()}, max_iterations);
2217 if (failed(failure_or_converged)) {
2218 any_failure = true;
2219 continue;
2220 }
2221 any_nonconvergence = any_nonconvergence || !failure_or_converged.getValue();
2222 if (failed(InferShapeForFunctionReturnType(func))) any_failure = true;
2223 }
2224 if (any_failure) return failure();
2225 return any_nonconvergence;
2226 }
2227
PropagateShapeToRegions(TypeRange input_types,ArrayRef<Region * > regions,int64_t max_iterations)2228 FailureOr<bool> ShapeInference::PropagateShapeToRegions(
2229 TypeRange input_types, ArrayRef<Region*> regions, int64_t max_iterations) {
2230 DCOMMENT("\tPropagating shapes to regions");
2231 bool any_failure = false;
2232 bool any_nonconvergence = false;
2233 // If shape propagation fails for one region, return failure, but do not
2234 // early exit and attempt to propagate shapes for all provided regions to
2235 // have a best-effort propagation.
2236 for (auto region : regions) {
2237 // Refine region arguments.
2238 Block& entry = region->front();
2239 assert(llvm::size(input_types) == entry.getNumArguments());
2240 for (auto it : llvm::zip(entry.getArguments(), input_types)) {
2241 BlockArgument arg = std::get<0>(it);
2242 Type type = std::get<1>(it);
2243 arg.setType(type);
2244 }
2245
2246 // Propagate shapes into the region.
2247 FailureOr<bool> failure_or_converged =
2248 InferShapeUntilFixPoint(region, max_iterations);
2249 if (failed(failure_or_converged))
2250 any_failure = true;
2251 else if (!failure_or_converged.getValue())
2252 any_nonconvergence = true;
2253 }
2254 if (any_failure) return failure();
2255 return any_nonconvergence;
2256 }
2257
PropagateConstantToCallee(CallOpInterface call_op,func::FuncOp func,ModuleOp module)2258 void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
2259 func::FuncOp func,
2260 ModuleOp module) {
2261 auto callers = GetCallers(func);
2262 if (!llvm::hasSingleElement(callers)) return;
2263
2264 OpBuilder builder(&func.front().front());
2265 Operation* op = call_op.getOperation();
2266 // If this is the only caller, and an operand is a constant, propagate
2267 // the constant value inside the function.
2268 for (auto arg : func.getArguments()) {
2269 auto operand = op->getOperand(arg.getArgNumber());
2270 if (propagate_caller_callee_constants_) {
2271 if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
2272 arg.replaceAllUsesWith(
2273 builder.clone(*operand.getDefiningOp())->getResult(0));
2274 }
2275 continue;
2276 }
2277
2278 auto known_constant = ComputeOutputComponent(ValuePort(operand));
2279 if (!known_constant) continue;
2280 LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
2281 known_constant.print(llvm::dbgs() << " constant ");
2282 llvm::dbgs() << "\n");
2283 RecordValue(ValuePort(arg), known_constant);
2284 }
2285 }
2286
PropagateConstantFromCallee(CallOpInterface call_op,func::FuncOp func,ModuleOp module)2287 void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
2288 func::FuncOp func,
2289 ModuleOp module) {
2290 // If the return value is a constant, use the constant as the value of
2291 // the call return.
2292 Operation* op = call_op.getOperation();
2293 OpBuilder builder(op);
2294 builder.setInsertionPointAfter(op);
2295 for (auto retval :
2296 llvm::enumerate(func.front().getTerminator()->getOperands())) {
2297 if (propagate_caller_callee_constants_) {
2298 auto retval_op = retval.value().getDefiningOp();
2299 if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
2300 op->getResult(retval.index())
2301 .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
2302 }
2303 continue;
2304 }
2305
2306 ValuePort vp(retval.value());
2307 if (auto known_constant = ComputeOutputComponent(vp)) {
2308 LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
2309 call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
2310 RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
2311 }
2312 }
2313 }
2314
RankedAndSameRank(TensorType lhs,TensorType rhs)2315 bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
2316 return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank();
2317 }
2318
2319 // Creates a compatible RankedTensorType where mismatched dimensions are
2320 // replaced with dynamic sizes.
GetCompatibleRankedTensorType(RankedTensorType lhs,RankedTensorType rhs)2321 RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs,
2322 RankedTensorType rhs) {
2323 assert(lhs.getRank() == rhs.getRank());
2324 llvm::SmallVector<int64_t, 4> dims;
2325 dims.reserve(lhs.getRank());
2326 for (auto dim : llvm::zip(lhs.getShape(), rhs.getShape())) {
2327 int64_t lhs_dim = std::get<0>(dim);
2328 if (lhs_dim == std::get<1>(dim)) {
2329 dims.push_back(lhs_dim);
2330 } else {
2331 dims.push_back(ShapedType::kDynamicSize);
2332 }
2333 }
2334 return RankedTensorType::get(dims, GetElementTypeFromOperand(lhs, rhs));
2335 }
2336
2337 // Finds compatible types to propagate into functions/regions of a shape
2338 // invariant tf.While/tf.WhileRegion. If operand and result types are the same,
2339 // that type is returned. If operand and result types are of the same rank, a
2340 // compatible type with matching dimensions is used. Otherwise functions/regions
2341 // arguments are returned but with the handle type from the operand type.
GetWhileCompatibleTypes(TypeRange operand_types,TypeRange result_types,TypeRange region_argument_types)2342 llvm::SmallVector<Type, 4> GetWhileCompatibleTypes(
2343 TypeRange operand_types, TypeRange result_types,
2344 TypeRange region_argument_types) {
2345 llvm::SmallVector<Type, 4> types;
2346 types.reserve(operand_types.size());
2347 for (auto entry :
2348 llvm::zip(operand_types, result_types, region_argument_types)) {
2349 auto operand_type = std::get<0>(entry).cast<TensorType>();
2350 auto result_type = std::get<1>(entry).cast<TensorType>();
2351 if (operand_type == result_type) {
2352 types.push_back(operand_type);
2353 } else if (RankedAndSameRank(operand_type, result_type)) {
2354 auto potential_refined_type =
2355 GetCompatibleRankedTensorType(operand_type.cast<RankedTensorType>(),
2356 result_type.cast<RankedTensorType>());
2357 types.push_back(potential_refined_type);
2358 } else {
2359 auto region_argument_type = std::get<2>(entry).cast<TensorType>();
2360 Type element_type = GetElementTypeFromOperand(
2361 operand_type.cast<TensorType>(), region_argument_type);
2362 Type potential_refined_type = CreateTensorType(
2363 region_argument_type.hasRank() ? region_argument_type.getShape()
2364 : llvm::Optional<ArrayRef<int64_t>>(),
2365 element_type);
2366 types.push_back(potential_refined_type);
2367 }
2368 }
2369 return types;
2370 }
2371
PropagateShapeIntoAttachedFunctions(Operation * op,int64_t max_iterations)2372 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedFunctions(
2373 Operation* op, int64_t max_iterations) {
2374 ModuleOp module = op->getParentOfType<ModuleOp>();
2375 if (auto if_op = dyn_cast<TF::IfOp>(op)) {
2376 DCOMMENT("Propagating shapes into If");
2377 return PropagateShapeToFunctions(
2378 module, if_op.input().getTypes(),
2379 {if_op.ResolveThenFunction(&symbol_table_),
2380 if_op.ResolveElseFunction(&symbol_table_)},
2381 max_iterations);
2382 } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
2383 SmallVector<func::FuncOp, 4> branches;
2384 case_op.get_branch_functions(branches);
2385 return PropagateShapeToFunctions(module, case_op.input().getTypes(),
2386 branches, max_iterations);
2387 } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
2388 // If `shape_invariant` is set, operand shapes cannot be simply propagated
2389 // to result shapes as the op may have different intermediate shapes (such
2390 // While ops can have different result shapes from operand shapes).
2391 // Compatible shapes must be determined before propagating them.
2392 if (while_op.shape_invariant()) {
2393 auto compatible_types = GetWhileCompatibleTypes(
2394 while_op.input().getTypes(), while_op.output().getTypes(),
2395 while_op.ResolveBodyFunction(&symbol_table_)
2396 .getFunctionType()
2397 .getInputs());
2398 return PropagateShapeToFunctions(
2399 module, compatible_types,
2400 {while_op.ResolveCondFunction(&symbol_table_),
2401 while_op.ResolveBodyFunction(&symbol_table_)},
2402 max_iterations);
2403 }
2404 return PropagateShapeToFunctions(
2405 module, while_op.input().getTypes(),
2406 {while_op.ResolveCondFunction(&symbol_table_),
2407 while_op.ResolveBodyFunction(&symbol_table_)},
2408 max_iterations);
2409 } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
2410 if (auto func =
2411 dyn_cast<func::FuncOp>(call_op.resolveCallable(&symbol_table_))) {
2412 PropagateConstantToCallee(call_op, func, module);
2413 FailureOr<bool> failure_or_converged = PropagateShapeToFunctions(
2414 module, call_op.getArgOperands().getTypes(), {func}, max_iterations);
2415 if (failed(failure_or_converged)) return failure();
2416 PropagateConstantFromCallee(call_op, func, module);
2417 return failure_or_converged;
2418 }
2419 } else if (isa<TF::XlaReduceWindowOp>(op) ||
2420 isa<TF::XlaSelectAndScatterOp>(op) ||
2421 isa<TF::XlaVariadicReduceV2Op>(op) ||
2422 isa<TF::XlaVariadicSortOp>(op)) {
2423 auto propagate_shape_to = [&](mlir::SymbolRefAttr func_sym) {
2424 auto func = llvm::cast<mlir::func::FuncOp>(
2425 mlir::SymbolTable::lookupSymbolIn(module, func_sym));
2426 mlir::SmallVector<mlir::Type, 2> types;
2427 for (auto type : func.getFunctionType().getInputs()) {
2428 types.push_back(RankedTensorType::get({}, getElementTypeOrSelf(type)));
2429 }
2430 return PropagateShapeToFunctions(module, types, {func}, max_iterations);
2431 };
2432
2433 if (auto xla_reduce_window_op = dyn_cast<TF::XlaReduceWindowOp>(op)) {
2434 return propagate_shape_to(xla_reduce_window_op.computation());
2435 }
2436 if (auto xla_select_and_scatter_op =
2437 dyn_cast<TF::XlaSelectAndScatterOp>(op)) {
2438 return propagate_shape_to(xla_select_and_scatter_op.select())
2439 .getValue() &&
2440 propagate_shape_to(xla_select_and_scatter_op.scatter()).getValue();
2441 } else if (auto xla_variadic_reduce_v2_op =
2442 dyn_cast<TF::XlaVariadicReduceV2Op>(op)) {
2443 return propagate_shape_to(xla_variadic_reduce_v2_op.reducer());
2444 } else if (auto xla_variadic_sort_op =
2445 dyn_cast<TF::XlaVariadicSortOp>(op)) {
2446 return propagate_shape_to(xla_variadic_sort_op.comparator());
2447 }
2448 }
2449
2450 // TODO(ycao): Implement support for Call op, including function reuse.
2451
2452 return true;
2453 }
2454
PropagateShapeIntoAttachedRegions(Operation * op,int64_t max_iterations)2455 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedRegions(
2456 Operation* op, int64_t max_iterations) {
2457 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
2458 // If `shape_invariant` is set, operand shapes cannot be simply propagated
2459 // to result shapes as the op may have different intermediate shapes (such
2460 // While ops can have different result shapes from operand shapes).
2461 // Compatible shapes must be determined before propagating them.
2462 if (while_op.shape_invariant()) {
2463 auto compatible_types = GetWhileCompatibleTypes(
2464 while_op.input().getTypes(), while_op.output().getTypes(),
2465 while_op.body().getArgumentTypes());
2466 return PropagateShapeToRegions(compatible_types,
2467 {&while_op.cond(), &while_op.body()},
2468 max_iterations);
2469 }
2470 return PropagateShapeToRegions(while_op.input().getTypes(),
2471 {&while_op.cond(), &while_op.body()},
2472 max_iterations);
2473 }
2474 return true;
2475 }
2476
TryToFold(Operation * op)2477 LogicalResult ShapeInference::TryToFold(Operation* op) {
2478 LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
2479 // If any output result is known, then the op probably has been computed
2480 // before.
2481 if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
2482 return success();
2483
2484 SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
2485 SmallVector<OpFoldResult, 8> fold_results;
2486
2487 // Check to see if any operands to the operation is constant and whether
2488 // the operation knows how to constant fold itself.
2489 bool some_unknown = false;
2490 for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
2491 if (!(constant_operands[i] =
2492 ComputeOutputComponent(ValuePort(op->getOperand(i)))))
2493 some_unknown = true;
2494 }
2495
2496 // Attempt to constant fold the operation.
2497 auto abstract_op = op->getRegisteredInfo();
2498 LogicalResult folded = failure();
2499 if (abstract_op) {
2500 folded = abstract_op->foldHook(op, constant_operands, fold_results);
2501 }
2502 // Attempt dialect fallback if op's fold hook failed.
2503 if (failed(folded)) {
2504 Dialect* dialect = op->getDialect();
2505 if (!dialect) return failure();
2506 // Only attempt TF dialect fallback if there are no unknown operands.
2507 if (some_unknown && dialect == tf_dialect_) return failure();
2508 auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
2509 if (!interface) return failure();
2510
2511 if (failed(interface->fold(op, constant_operands, fold_results)))
2512 return failure();
2513 }
2514
2515 for (auto result : zip(op->getResults(), fold_results)) {
2516 auto fold_result = std::get<1>(result);
2517 Attribute attr = nullptr;
2518 if ((attr = fold_result.dyn_cast<Attribute>())) {
2519 RecordValue(ValuePort(std::get<0>(result)), attr);
2520 } else {
2521 auto value = fold_result.get<Value>();
2522 if ((attr = ComputeOutputComponent(ValuePort(value)))) {
2523 DCOMMENT("\t\tValue Result mapped to " << attr);
2524 RecordValue(ValuePort(std::get<0>(result)), attr);
2525 } else {
2526 DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
2527 RefineResultType(op, std::get<0>(result), value.getType());
2528 }
2529 }
2530
2531 if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
2532 if (std::get<0>(result).getType() == eattr.getType()) continue;
2533
2534 (void)UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(),
2535 std::get<0>(result));
2536 }
2537 }
2538
2539 return success();
2540 }
2541
InferShapeForFunctionReturnType(func::FuncOp func)2542 LogicalResult ShapeInference::InferShapeForFunctionReturnType(
2543 func::FuncOp func) {
2544 LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName()
2545 << "\n");
2546
2547 // Find any return ops.
2548 SmallVector<func::ReturnOp, 4> return_ops;
2549 for (Block& block : func) {
2550 if (auto return_op = dyn_cast<func::ReturnOp>(block.getTerminator())) {
2551 return_ops.push_back(return_op);
2552 }
2553 }
2554
2555 // Skip functions without a return, but don't flag as failure here.
2556 if (return_ops.empty()) return success();
2557
2558 // Right now we only handle the case of a single return op.
2559 // To handle multiple return ops, we would need to look at all their shapes
2560 // and come up with a common shape and insert appropriate casts.
2561 if (return_ops.size() != 1) return failure();
2562
2563 // Find the return type.
2564 auto return_op = return_ops.front();
2565
2566 // Manually fold tf.Cast that precedes the return instruction and only differs
2567 // in shape refinement level.
2568 bool changed = false;
2569 for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
2570 Operation* arg_defining_op = arg_op.get().getDefiningOp();
2571 if (isa_and_nonnull<CastOp, tensor::CastOp>(arg_defining_op)) {
2572 Value input = arg_defining_op->getOperand(0);
2573 Value result = arg_defining_op->getResult(0);
2574 Type meet = TypeMeet(result.getType(), input.getType());
2575 if (meet == result.getType()) continue;
2576
2577 LLVM_DEBUG({
2578 llvm::errs() << "\tfolding & updating return type ";
2579 result.getType().print(llvm::errs());
2580 input.getType().print(llvm::errs() << " to ");
2581 llvm::errs() << "\n";
2582 });
2583
2584 // Shape inference should not change the element type.
2585 if (HasCompatibleElementTypes(input.getType(), result.getType()) &&
2586 meet == input.getType()) {
2587 arg_op.set(input);
2588 } else {
2589 OpBuilder b(return_op.getOperation());
2590 auto new_cast_op = InsertCast(b, return_op.getLoc(), meet, input);
2591 if (!new_cast_op) return failure();
2592 arg_op.set(new_cast_op->getResult(0));
2593 }
2594 if (result.use_empty()) arg_defining_op->erase();
2595 changed = true;
2596 }
2597 }
2598
2599 DCOMMENT("Updating function type");
2600 func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
2601 return_op.getOperandTypes()));
2602
2603 if (changed) EnqueueCallers(func);
2604 return success();
2605 }
2606
InferShapeUntilFixPoint(Region * region,int64_t max_iterations)2607 FailureOr<bool> ShapeInference::InferShapeUntilFixPoint(
2608 Region* region, int64_t max_iterations) {
2609 bool changed = true;
2610
2611 // TODO(aminim): we could have a more efficient traversal by guiding the
2612 // traversal with a worklist and reconsider only the nodes for which an
2613 // operand type was inferred. This would need to be careful if working on a
2614 // region that would not be isolated.
2615 for (int iteration = 0; iteration < max_iterations && changed; ++iteration) {
2616 changed = false;
2617 LLVM_DEBUG(llvm::dbgs()
2618 << "Shape inference, iteration " << iteration << "\n");
2619 auto res = region->walk([&](Operation* op) {
2620 DCOMMENT_OP(op, "Inferring for");
2621 if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
2622 DCOMMENT("\tRefinining with type op interface");
2623 changed |= RefineWithInferTypeOpInterface(infer_ti);
2624 return WalkResult::advance();
2625 }
2626
2627 if (op->getDialect() != tf_dialect_) {
2628 DCOMMENT("\tInfer non-TF dialect");
2629 changed |= InferShapeForNonTFDialectOperation(op);
2630 return WalkResult::advance();
2631 }
2632
2633 // Before attempting inference, just try to compute the folded
2634 // value/shape.
2635 if (succeeded(TryToFold(op)) &&
2636 // Folding can "succeed" and yet not all types be refined. In such
2637 // cases we still want to give a try at `InferShapeForSingleOperation`
2638 none_of(op->getResultTypes(), CanBeRefined))
2639 return WalkResult::advance();
2640
2641 // Best-effort shape inference in attached functions. Do not return
2642 // failure even if it doesn't get to fixed point, but propagate "real"
2643 // failure.
2644 if (failed(PropagateShapeIntoAttachedFunctions(op, max_iterations))) {
2645 op->emitWarning() << "unable to refine shape of attached function "
2646 "arguments and bodies";
2647 return WalkResult::interrupt();
2648 }
2649
2650 if (failed(PropagateShapeIntoAttachedRegions(op, max_iterations))) {
2651 op->emitWarning() << "unable to refine shape of attached region "
2652 "arguments and bodies";
2653 return WalkResult::interrupt();
2654 }
2655
2656 changed |= InferShapeForSingleOperation(op, max_iterations);
2657 return WalkResult::advance();
2658 });
2659 if (res.wasInterrupted()) return failure();
2660 }
2661
2662 if (changed) {
2663 region->getParentOp()->emitWarning()
2664 << "shape inference did not reach stable state after " << max_iterations
2665 << " iterations";
2666 }
2667 return !changed;
2668 }
2669
InferShapeForFunction(ShapeInference & context,func::FuncOp func,int64_t max_iterations)2670 static FailureOr<bool> InferShapeForFunction(ShapeInference& context,
2671 func::FuncOp func,
2672 int64_t max_iterations) {
2673 FailureOr<bool> failure_or_converged =
2674 context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
2675 if (failed(failure_or_converged) || !failure_or_converged.getValue())
2676 return failure_or_converged;
2677 // TODO(b/156276510): Verify that it is always fine to refine a function's
2678 // return type, as long as we do not change the argument shapes.
2679 if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
2680 return true;
2681 }
2682
InferShapeForFunction(func::FuncOp func,ArrayRef<ArrayRef<int64_t>> arg_shapes,int64_t graph_version,int64_t max_iterations)2683 FailureOr<bool> InferShapeForFunction(func::FuncOp func,
2684 ArrayRef<ArrayRef<int64_t>> arg_shapes,
2685 int64_t graph_version,
2686 int64_t max_iterations) {
2687 ShapeInference context(graph_version, func->getParentOfType<ModuleOp>(),
2688 /*propagate_caller_callee_constants=*/true);
2689 if (arg_shapes.empty()) {
2690 return InferShapeForFunction(context, func, max_iterations);
2691 }
2692
2693 FunctionType func_type = func.getFunctionType();
2694 bool needs_refinement = false;
2695 SmallVector<Type, 4> new_arg_types;
2696 new_arg_types.reserve(func_type.getNumInputs());
2697
2698 // Update argument types in-place using the provided arg_shapes.
2699 for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
2700 ArrayRef<int64_t> shape = arg_shapes[i];
2701 Type element_type;
2702 if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
2703 if (input_ty.getRank() != shape.size()) {
2704 return failure();
2705 }
2706 element_type = input_ty.getElementType();
2707 } else {
2708 auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
2709 if (!unranked_input_ty) {
2710 return failure();
2711 }
2712 element_type = unranked_input_ty.getElementType();
2713 }
2714
2715 auto new_arg_type = RankedTensorType::get(shape, element_type);
2716 if (new_arg_type != func_type.getInput(i)) {
2717 // If the new type is more detailed, trigger shape inference.
2718 func.getArgument(i).setType(new_arg_type);
2719 needs_refinement = true;
2720 }
2721 new_arg_types.push_back(new_arg_type);
2722 }
2723
2724 if (!needs_refinement) return true;
2725
2726 FailureOr<bool> failure_or_converged =
2727 context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
2728 if (failed(failure_or_converged) || !failure_or_converged.getValue())
2729 return failure_or_converged;
2730
2731 if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
2732 func.setType(FunctionType::get(func.getContext(), new_arg_types,
2733 func.getFunctionType().getResults()));
2734
2735 return true;
2736 }
2737
InferModuleShape(ModuleOp module,int64_t max_iterations)2738 FailureOr<bool> InferModuleShape(ModuleOp module, int64_t max_iterations) {
2739 auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
2740 if (!producer_or.ok()) {
2741 // TODO(jpienaar): Keeping the existing behavior for now but this could
2742 // be relaxed.
2743 LLVM_DEBUG(llvm::dbgs()
2744 << "Skipping inference; " << producer_or.status().ToString());
2745 return true;
2746 }
2747 int64_t producer = producer_or.ValueOrDie();
2748 // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
2749 // it is no longer needed.
2750 ShapeInference context(producer, module,
2751 /*propagate_caller_callee_constants=*/false);
2752 if (auto main = module.lookupSymbol<mlir::func::FuncOp>("main"))
2753 context.enqueue(main);
2754 for (auto func : module.getOps<func::FuncOp>()) context.enqueue(func);
2755 // Arbitrarily upper bound the maximum number of functions that get processed
2756 // just to avoid pathological cases.
2757 auto max_iteration = context.QueueSize() * 4;
2758 while (!context.EmptyQueue()) {
2759 func::FuncOp func = context.front();
2760 FailureOr<bool> failure_or_converged =
2761 InferShapeForFunction(context, func, max_iterations);
2762 if (failed(failure_or_converged) || !failure_or_converged.getValue())
2763 return failure_or_converged;
2764 context.pop_front();
2765
2766 if ((--max_iteration) == 0) {
2767 emitWarning(UnknownLoc::get(module.getContext()))
2768 << "shape inference did not reach stable state after "
2769 << max_iteration << " iterations";
2770 return false;
2771 }
2772 }
2773 return true;
2774 }
2775
2776 } // namespace TF
2777 } // namespace mlir
2778