xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <initializer_list>
21 #include <iterator>
22 #include <queue>
23 #include <stack>
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/ADT/Hashing.h"
27 #include "llvm/ADT/None.h"
28 #include "llvm/ADT/PointerUnion.h"
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/ADT/iterator_range.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
36 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
37 #include "mlir/IR/Attributes.h"  // from @llvm-project
38 #include "mlir/IR/Block.h"  // from @llvm-project
39 #include "mlir/IR/Builders.h"  // from @llvm-project
40 #include "mlir/IR/BuiltinDialect.h"  // from @llvm-project
41 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
42 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
43 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
44 #include "mlir/IR/FunctionInterfaces.h"  // from @llvm-project
45 #include "mlir/IR/Location.h"  // from @llvm-project
46 #include "mlir/IR/Operation.h"  // from @llvm-project
47 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
48 #include "mlir/IR/SymbolTable.h"  // from @llvm-project
49 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
50 #include "mlir/IR/Value.h"  // from @llvm-project
51 #include "mlir/IR/Visitors.h"  // from @llvm-project
52 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
53 #include "mlir/Interfaces/FoldInterfaces.h"  // from @llvm-project
54 #include "mlir/Interfaces/InferTypeOpInterface.h"  // from @llvm-project
55 #include "mlir/Pass/Pass.h"  // from @llvm-project
56 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
57 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
58 #include "mlir/Support/LLVM.h"  // from @llvm-project
59 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
60 #include "mlir/Transforms/FoldUtils.h"  // from @llvm-project
61 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
62 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
63 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
64 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
65 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
66 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
67 #include "tensorflow/compiler/mlir/tensorflow/utils/shape_inference_utils.h"
68 #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h"
69 #include "tensorflow/compiler/xla/window_util.h"
70 #include "tensorflow/compiler/xla/xla_data.pb.h"
71 #include "tensorflow/core/framework/shape_inference.h"
72 #include "tensorflow/core/framework/types.pb.h"
73 #include "tensorflow/core/ir/types/dialect.h"
74 
75 #define DEBUG_TYPE "tf-shape-inference"
76 
77 #define DCOMMENT(MSG) LLVM_DEBUG(llvm::dbgs() << MSG << "\n")
78 #define DCOMMENT_OP(OP, MSG) \
79   LLVM_DEBUG(OP->print(llvm::dbgs() << MSG << " "); llvm::dbgs() << "\n")
80 
81 using ::int64_t;
82 using mlir::func::FuncOp;
83 using tensorflow::shape_inference::DimensionHandle;
84 using tensorflow::shape_inference::InferenceContext;
85 using tensorflow::shape_inference::ShapeHandle;
86 
87 namespace mlir {
88 namespace TF {
89 namespace {
90 
91 // Compute a refined type between two types `lhs` and `rhs`, the result type
92 // is always more refined (i.e. has more static information) than `lhs`
93 // This method will actually merge the information contained in the
94 // types, it is capable of refining:
95 //   tensor<!tf_type.variant<tensor<?x8xf32>>>
96 // and:
97 //   tensor<!tf_type.variant<tensor<10x?xf32>>>
98 // into:
99 //   tensor<!tf_type.variant<tensor<10x8xf32>>>
100 //
101 // In case of inconsistencies (rank disagreement for example), it returns `lhs`.
TypeMeet(Type lhs,Type rhs)102 Type TypeMeet(Type lhs, Type rhs) {
103   DCOMMENT("RefineTypeWith : " << lhs << " : " << rhs);
104   if (lhs == rhs) return lhs;
105 
106   auto rhs_shape_type = rhs.dyn_cast<ShapedType>();
107   if (!rhs_shape_type) return lhs;
108   auto lhs_shape_type = lhs.cast<ShapedType>();
109   if (lhs_shape_type.hasRank() && rhs_shape_type.hasRank() &&
110       lhs_shape_type.getRank() != rhs_shape_type.getRank()) {
111     DCOMMENT("Unexpected rank mismatch: " << lhs << " vs " << rhs);
112     return lhs;
113   }
114 
115   SmallVector<int64_t> shape;
116   bool refined_shape = false;
117   // Build the shape of the refined type, if lhs is unranked it
118   // will be directly the shape of the refined type, otherwise we merged by
119   // taking the most specialized. This combines `10x?x?` and `?x?x8` into
120   // `10x?x8`.
121   if (!lhs_shape_type.hasRank()) {
122     if (rhs_shape_type.hasRank()) {
123       shape.append(rhs_shape_type.getShape().begin(),
124                    rhs_shape_type.getShape().end());
125       refined_shape = true;
126     }
127   } else if (rhs_shape_type.hasRank()) {
128     for (auto shape_elts : llvm::enumerate(
129              llvm::zip(lhs_shape_type.getShape(), rhs_shape_type.getShape()))) {
130       if (ShapedType::isDynamic(std::get<0>(shape_elts.value())) &&
131           !ShapedType::isDynamic(std::get<1>(shape_elts.value()))) {
132         shape.push_back(std::get<1>(shape_elts.value()));
133         refined_shape = true;
134         DCOMMENT("-> refining shape element #" << shape_elts.index());
135       } else {
136         DCOMMENT("-> not refining shape element #" << shape_elts.index());
137         shape.push_back(std::get<0>(shape_elts.value()));
138       }
139     }
140   }
141 
142   // Some tensor have an element type wrapping a subtensor, like resource and
143   // variants. In this case we may recurse on the wrapped subtype.
144   // `element_type` will contain the refined inferred element type for the
145   // returned type.
146   auto lhs_element_type = lhs_shape_type.getElementType();
147   auto rhs_element_type_with_subtype =
148       rhs_shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
149   // Look for resource or variant element type and ensure we refine the subtype.
150   // We only support a single subtype at the moment, we won't handle something
151   // like:
152   //   tensor<!tf_type.variant<tensor<10xf32>, tensor<8xf32>>
153   if (rhs_element_type_with_subtype &&
154       rhs_element_type_with_subtype.GetSubtypes().size() == 1) {
155     auto lhs_element_type_with_subtype =
156         lhs_element_type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
157     TensorType subtype;
158     if (!lhs_element_type_with_subtype) {
159       DCOMMENT(
160           "Unexpected inferred `TensorFlowTypeWithSubtype` when original "
161           "result isn't");
162     } else if (lhs_element_type_with_subtype.GetSubtypes().size() > 1) {
163       DCOMMENT(
164           "Unexpected `TensorFlowTypeWithSubtype` original type with size>1");
165     } else if (lhs_element_type_with_subtype.GetSubtypes().empty()) {
166       subtype = rhs_element_type_with_subtype.GetSubtypes().front();
167     } else {
168       // Recurse on the subtypes in the variant/resource. Basically if the input
169       // were:
170       //   tensor<!tf_type.variant<tensor<?x8xf32>>>
171       // and:
172       //   tensor<!tf_type.variant<tensor<10x8xf32>>>
173       // we'll try here to refine tensor<?x8xf32> with tensor<10x8xf32>.
174       auto refined_subtype =
175           TypeMeet(lhs_element_type_with_subtype.GetSubtypes().front(),
176                    rhs_element_type_with_subtype.GetSubtypes().front())
177               .cast<TensorType>();
178       if (refined_subtype !=
179           lhs_element_type_with_subtype.GetSubtypes().front())
180         subtype = refined_subtype;
181     }
182     // If we managed to refine the subtype, recreate the element type itself
183     // (i.e. the tf.variant or tf.resource).
184     if (subtype) {
185       lhs_element_type = lhs_element_type_with_subtype.clone({subtype});
186     }
187   }
188   if (refined_shape || lhs_element_type != lhs_shape_type.getElementType()) {
189     Type new_type;
190     if (!lhs_shape_type.hasRank() && !rhs_shape_type.hasRank())
191       new_type = UnrankedTensorType::get(lhs_element_type);
192     else
193       new_type = lhs_shape_type.clone(shape, lhs_element_type);
194     DCOMMENT("Refined to: " << new_type);
195     return new_type;
196   }
197   DCOMMENT("No refinement " << lhs);
198   return lhs;
199 }
200 
201 // Returns whether `original_type` type can be refined with
202 // `potential_refined_type` type.
CanRefineTypeWith(Type original_type,Type potential_refined_type)203 bool CanRefineTypeWith(Type original_type, Type potential_refined_type) {
204   return original_type != TypeMeet(original_type, potential_refined_type);
205 }
206 
207 // Returns if the shape inference pass supports an op outside the TF dialect.
IsSupportedNonTFOp(Operation * op)208 bool IsSupportedNonTFOp(Operation* op) {
209   return isa<tf_device::ReturnOp, tf_device::ClusterOp, tf_device::LaunchOp,
210              tf_executor::EnterOp, tf_executor::ExitOp, tf_executor::FetchOp,
211              tf_executor::GraphOp, tf_executor::IslandOp,
212              tf_executor::LoopCondOp, tf_executor::MergeOp,
213              tf_executor::NextIterationSinkOp, tf_executor::SwitchNOp,
214              tf_executor::SwitchOp, tf_executor::YieldOp>(op) ||
215          isa<InferTypeOpInterface>(op);
216 }
217 
218 // Returns whether a cast back would need to be inserted, e.g., whether the
219 // operation of which use is an operand allows for shape refinement without
220 // a cast.
NeedsCastBack(OpOperand & use,Dialect * tf_dialect)221 bool NeedsCastBack(OpOperand& use, Dialect* tf_dialect) {
222   return use.getOwner()->getDialect() != tf_dialect &&
223          !IsSupportedNonTFOp(use.getOwner());
224 }
225 
CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,Type element_type)226 TensorType CreateTensorType(llvm::Optional<llvm::ArrayRef<int64_t>> shape,
227                             Type element_type) {
228   if (shape.has_value())
229     return RankedTensorType::get(shape.getValue(), element_type);
230   return UnrankedTensorType::get(element_type);
231 }
232 
233 // Returns true if the op creates a TensorList.
IsTensorListInitOp(Operation * op)234 bool IsTensorListInitOp(Operation* op) {
235   return isa<TensorListReserveOp>(op) || isa<EmptyTensorListOp>(op) ||
236          isa<TensorListFromTensorOp>(op);
237 }
238 
239 // Returns the `element_shape` operand of the ops that create a TensorList.
GetElementShapeOperand(Operation * op)240 Value GetElementShapeOperand(Operation* op) {
241   if (auto empty_tl = dyn_cast<EmptyTensorListOp>(op))
242     return empty_tl.element_shape();
243   if (auto tl_reserve = dyn_cast<TensorListReserveOp>(op))
244     return tl_reserve.element_shape();
245   if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op))
246     return tl_from_tensor.element_shape();
247   llvm_unreachable("unsupported TensorList op");
248 }
249 
250 // Utility function to create a ranked tensor type after dropping the first
251 // dimension from the input type.
DropFirstDimension(Type type)252 RankedTensorType DropFirstDimension(Type type) {
253   RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>();
254   if (!ranked_type) return {};
255   llvm::ArrayRef<int64_t> dims_except_first =
256       ranked_type.getShape().drop_front();
257   return RankedTensorType::get(dims_except_first, ranked_type.getElementType());
258 }
259 
InsertCast(OpBuilder & b,Location loc,Type dst_type,Value input)260 Operation* InsertCast(OpBuilder& b, Location loc, Type dst_type, Value input) {
261   Type element_type = getElementTypeOrSelf(dst_type);
262   if (element_type.isa<IndexType>())
263     return b.create<tensor::CastOp>(loc, dst_type, input);
264   if (isa<TensorFlowDialect, BuiltinDialect>(element_type.getDialect()))
265     return b.create<TF::CastOp>(loc, dst_type, input,
266                                 /*truncate=*/b.getBoolAttr(false));
267   return nullptr;
268 }
269 
270 // Follow the use chain of TensorList and return true iff all elements written
271 // to TensorList have same static shape. If all elements have same shape, assign
272 // it to `potential_element_type`.
273 //
274 // This can handle multiple mutations of a TensorList object and would return
275 // true if across all mutations the elements written have the same shape.
CanInferTensorListElementType(Value tensorlist,Value initial_element_shape,RankedTensorType * potential_element_type)276 bool CanInferTensorListElementType(Value tensorlist,
277                                    Value initial_element_shape,
278                                    RankedTensorType* potential_element_type) {
279   DCOMMENT("CanInferTensorListElementType " << tensorlist << " with initial "
280                                             << initial_element_shape);
281   // Verifies if the new element type has static shape and matches the potential
282   // type passed from caller. Updates the potential_element_type, if not defined
283   // yet.
284   auto verify_and_update_potential_element_type =
285       [&](RankedTensorType new_element_type) -> bool {
286     DCOMMENT("\t\tConsidering " << new_element_type << " with old "
287                                 << *potential_element_type);
288     if (!new_element_type || !new_element_type.hasStaticShape()) return false;
289     if (!*potential_element_type) {
290       DCOMMENT("\t\tUpdating potential_element_type " << new_element_type);
291       *potential_element_type = new_element_type;
292       return true;
293     }
294     return *potential_element_type == new_element_type;
295   };
296 
297   std::stack<Value> worklist;
298   worklist.emplace(tensorlist);
299 
300   while (!worklist.empty()) {
301     tensorlist = worklist.top();
302     worklist.pop();
303 
304     // TensorLists are semantically immutable. For example, TensorListSetItem
305     // takes a TensorList as input and produces a TensorList as output. So to
306     // traverse modifications to TensorList and verify that all elements written
307     // to it have the same shape, we need to follow use-def chain of ops that
308     // (conceptually) modify it i.e., ops that take an input TensorList and
309     // produce an output TensorList.
310     for (auto& use : tensorlist.getUses()) {
311       if (auto push = llvm::dyn_cast<TensorListPushBackOp>(use.getOwner())) {
312         auto element_type =
313             push.tensor().getType().dyn_cast<RankedTensorType>();
314         if (!verify_and_update_potential_element_type(element_type))
315           return false;
316         worklist.emplace(push.output_handle());
317         continue;
318       }
319       if (auto scatter = llvm::dyn_cast<TensorListScatterIntoExistingListOp>(
320               use.getOwner())) {
321         // For scatter op we can get the element shape by dropping the first
322         // dimension of the input tensor.
323         RankedTensorType element_type =
324             DropFirstDimension(scatter.tensor().getType());
325         if (!verify_and_update_potential_element_type(element_type))
326           return false;
327         worklist.emplace(scatter.output_handle());
328         continue;
329       }
330       if (auto set_item = llvm::dyn_cast<TensorListSetItemOp>(use.getOwner())) {
331         auto element_type =
332             set_item.item().getType().dyn_cast<RankedTensorType>();
333         DCOMMENT("\tTensorListSetItemOp " << element_type);
334         if (!verify_and_update_potential_element_type(element_type))
335           return false;
336         worklist.emplace(set_item.output_handle());
337         continue;
338       }
339       if (auto pop = llvm::dyn_cast<TensorListPopBackOp>(use.getOwner())) {
340         worklist.emplace(pop.output_handle());
341         continue;
342       }
343       if (auto resize = llvm::dyn_cast<TensorListResizeOp>(use.getOwner())) {
344         worklist.emplace(resize.output_handle());
345         continue;
346       }
347       // WhileRegionOp can explicitly capture TensorList value to be used inside
348       // its regions. So we check the uses of corresponding block argument in
349       // each region and the use of TensorList returned using YieldOp.
350       if (auto while_region = llvm::dyn_cast<WhileRegionOp>(use.getOwner())) {
351         DCOMMENT("\tTL WhileRegion");
352         for (auto branch : while_region.getRegions())
353           worklist.emplace(branch->getArgument(use.getOperandNumber()));
354         continue;
355       }
356       if (auto yield = llvm::dyn_cast<YieldOp>(use.getOwner())) {
357         Operation* parent = yield->getParentOp();
358         worklist.emplace(parent->getResult(use.getOperandNumber()));
359         continue;
360       }
361       // TODO(jpienaar): This can be generalized.
362       if (isa<IdentityOp, IdentityNOp, StopGradientOp>(use.getOwner())) {
363         worklist.emplace(use.getOwner()->getResult(use.getOperandNumber()));
364         continue;
365       }
366       // Refining the tensor list element type might change the output of
367       // TensorListElementShape which is expected to be the originally assigned
368       // shape to TensorList init ops. So replace it with the original element
369       // shape value.
370       if (auto tl_element_shape =
371               dyn_cast<TensorListElementShapeOp>(use.getOwner())) {
372         // If element types match, we can do a direct replacement.
373         if (getElementTypeOrSelf(tl_element_shape.getResult()) ==
374             getElementTypeOrSelf(initial_element_shape.getType())) {
375           tl_element_shape.replaceAllUsesWith(initial_element_shape);
376         } else {
377           OpBuilder b(use.getOwner());
378           Operation* cast_op = InsertCast(
379               b, use.getOwner()->getLoc(),
380               tl_element_shape.getResult().getType(), initial_element_shape);
381           if (!cast_op) return false;
382           tl_element_shape.replaceAllUsesWith(cast_op->getResult(0));
383         }
384         continue;
385       }
386       // Ignore ops that just consume a TensorList and do not output another
387       // TensorList.
388       if (isa<TensorListStackOp, TensorListGatherOp, TensorListConcatV2Op,
389               TensorListLengthOp, TensorListGetItemOp>(use.getOwner()))
390         continue;
391 
392       // For any other unknown users of the TensorList, we are conservative and
393       // stop element shape inference.
394       DCOMMENT("TensorListType infer, unknown op " << *use.getOwner());
395       return false;
396     }
397   }
398   return true;
399 }
400 
401 // Returns the tensor type created from the `shape_attr` and `type_attr`
402 // attributes.
GetType(Attribute shape_attr,Attribute type_attr)403 Type GetType(Attribute shape_attr, Attribute type_attr) {
404   auto shape = shape_attr.cast<tf_type::ShapeAttr>();
405   auto type = type_attr.cast<TypeAttr>();
406   if (shape.hasRank())
407     return RankedTensorType::get(shape.getShape(), type.getValue());
408   else
409     return UnrankedTensorType::get(type.getValue());
410 }
411 
412 }  // namespace
413 
414 // Returns whether type can be further refined.
CanBeRefined(Type type)415 bool CanBeRefined(Type type) {
416   auto shape_type = type.dyn_cast<ShapedType>();
417   if (!shape_type) return false;
418 
419   // Returns whether type with subtypes can be further refined.
420   auto can_refine_subtypes = [](TF::TensorFlowTypeWithSubtype tws) {
421     return tws.GetSubtypes().empty() ||
422            llvm::any_of(tws.GetSubtypes(), CanBeRefined);
423   };
424   auto type_with_subtype =
425       shape_type.getElementType().dyn_cast<TF::TensorFlowTypeWithSubtype>();
426   if (type_with_subtype && can_refine_subtypes(type_with_subtype)) return true;
427 
428   return !shape_type.hasStaticShape();
429 }
430 
431 // Combination of value producer and port of value produced (e.g.,
432 //   <value result output>:<value in output tensor>,
433 // so for tf.Const -> tensor<10x20xf32>, [0,2,18] would point to a unique output
434 // scalar value).
435 struct ValuePort {
436   PointerUnion<Operation*, BlockArgument> producer;
437   SmallVector<unsigned int, 2> port;
438 
operator ==mlir::TF::ValuePort439   bool operator==(const ValuePort& other) const {
440     return producer == other.producer && port == other.port;
441   }
442 
443   // Convert output value to ValuePort.
ValuePortmlir::TF::ValuePort444   explicit ValuePort(Value v) {
445     OpResult opr = v.dyn_cast<OpResult>();
446     if (opr) {
447       producer = opr.getOwner();
448       port = {opr.getResultNumber()};
449     } else {
450       producer = v.cast<BlockArgument>();
451       port = {0};
452     }
453   }
ValuePortmlir::TF::ValuePort454   ValuePort(PointerUnion<Operation*, BlockArgument> producer,
455             SmallVector<unsigned int, 2> port)
456       : producer(producer), port(port) {}
457 
printmlir::TF::ValuePort458   raw_ostream& print(raw_ostream& os) const {
459     if (auto* op = producer.dyn_cast<Operation*>())
460       os << "op " << op->getName();
461     if (auto ba = producer.dyn_cast<BlockArgument>())
462       os << "block_arg " << ba.getArgNumber();
463     os << formatv(" [{0}]", llvm::make_range(port.begin(), port.end()));
464     return os;
465   }
466 };
467 
468 struct ValuePortHasher {
operator ()mlir::TF::ValuePortHasher469   std::size_t operator()(const ValuePort& other) const {
470     return hash_combine(llvm::hash_value(other.producer.getOpaqueValue()),
471                         hash_value(ArrayRef<unsigned int>(other.port)));
472   }
473 };
474 
475 using ValuePortResultMap =
476     std::unordered_map<ValuePort, Attribute, ValuePortHasher>;
477 using ComputedQueryFn = function_ref<bool(ValuePort)>;
478 using ValueQueryFn = function_ref<Attribute(const ValuePort&)>;
479 using ValuePortInputs = SmallVectorImpl<ValuePort>;
480 
481 // TODO(jpienaar): ComputeInputsRequiredForOutput and ComputeOutputComponent are
482 // intended to be switched to op interfaces once more refined.
ComputeInputsRequiredForOutput(ValuePort value_port,ComputedQueryFn has_been_computed,ValuePortInputs * inputs)483 LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
484                                              ComputedQueryFn has_been_computed,
485                                              ValuePortInputs* inputs) {
486   auto op = value_port.producer.dyn_cast<Operation*>();
487   auto& port = value_port.port;
488   if (!op) return failure();
489 
490   // No inputs required for constants.
491   if (matchPattern(op, m_Constant())) return success();
492 
493   // Note: this focusses only on the trivial pack op case and this could be
494   // generalized.
495   if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
496     auto type = pack_op.getType().cast<TensorType>();
497     if (!type.hasRank() || type.getRank() != 1) return failure();
498     if (port.size() != 2) return failure();
499     assert(port[0] == 0);
500     ValuePort req(pack_op.getOperand(port[1]));
501     if (!has_been_computed(req)) inputs->push_back(req);
502     return success();
503   }
504 
505   return failure();
506 }
507 
508 // Computes the output produced by ValuePort using the query function of
509 // existing computed values.
ComputeOutputComponent(const ValuePort & value_port,ValueQueryFn values)510 Attribute ComputeOutputComponent(const ValuePort& value_port,
511                                  ValueQueryFn values) {
512   LLVM_DEBUG(value_port.print(llvm::dbgs() << "Computing output for ") << "\n");
513   if (auto known = values(value_port)) return known;
514 
515   auto op = value_port.producer.dyn_cast<Operation*>();
516   if (!op) return nullptr;
517   auto& port = value_port.port;
518 
519   if (port.empty()) {
520     LLVM_DEBUG(llvm::dbgs() << "skipping, port outside spec of " << op << "\n");
521     return nullptr;
522   }
523 
524   ElementsAttr attr;
525   if (matchPattern(op, m_Constant(&attr))) {
526     if (port.size() == 1 && port[0] == 0) return attr;
527     return nullptr;
528   }
529 
530   if (auto id = dyn_cast<IdentityOp>(op)) {
531     if (port.size() == 1 && port[0] == 0)
532       return ComputeOutputComponent(ValuePort(id.input()), values);
533     return nullptr;
534   }
535 
536   // Note: this focusses only on the trivial pack op case and this could be
537   // generalized.
538   if (auto pack_op = dyn_cast<TF::PackOp>(op)) {
539     TensorType type = pack_op.getType().cast<TensorType>();
540     if (!type.hasRank() || type.getRank() != 1) return nullptr;
541     if (port.size() != 2 || port[0] != 0) return nullptr;
542     ValuePort op_port(op->getOperand(port[1]));
543     return values(op_port);
544   }
545 
546   if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
547     if (port.size() == 1)
548       return ComputeOutputComponent(
549           ValuePort(graph.GetFetch().fetches()[port[0]]), values);
550     return nullptr;
551   }
552 
553   if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
554     if (port.size() == 1)
555       return ComputeOutputComponent(
556           ValuePort(island.GetYield().fetches()[port[0]]), values);
557     return nullptr;
558   }
559 
560   return nullptr;
561 }
562 
563 // Context used during ShapeInference. This class contains common information
564 // that is required by the individual shape inference helper functions (e.g.,
565 // TF Graph version, constant values computed, etc.)
566 class ShapeInference {
567  public:
568   ShapeInference(int64_t graph_version, ModuleOp module,
569                  bool propagate_caller_callee_constants);
570 
ComputeInputsRequiredForOutput(ValuePort value_port,ValuePortInputs * inputs)571   LogicalResult ComputeInputsRequiredForOutput(ValuePort value_port,
572                                                ValuePortInputs* inputs) {
573     return ::mlir::TF::ComputeInputsRequiredForOutput(
574         value_port,
575         [this](const ValuePort& port) {
576           return results_.find(port) != results_.end();
577         },
578         inputs);
579   }
580 
ComputeOutputComponent(const ValuePort & value_port)581   Attribute ComputeOutputComponent(const ValuePort& value_port) {
582     if (auto known_attr = results_[value_port]) return known_attr;
583     auto attr = ::mlir::TF::ComputeOutputComponent(
584         value_port, [this](const ValuePort& port) { return results_[port]; });
585     RecordValue(value_port, attr);
586     return attr;
587   }
588 
589   // Returns ShapeHandle if the op result could be computed as shape.
590   ShapeHandle ComputeOutputAsShape(OpResult result, InferenceContext* ic);
591 
RecordValue(const ValuePort & value_port,Attribute value)592   void RecordValue(const ValuePort& value_port, Attribute value) {
593     LLVM_DEBUG(value_port.print(llvm::dbgs() << "\trecording ")
594                << value << "\n");
595     results_[value_port] = value;
596   }
597 
598   // Infers shape of tf.While/tf.WhileRegion. If `shape_invariant` attribute is
599   // set, operand types are set as result types if associated body result types
600   // match the operand type (does not change per loop iteration). If operand and
601   // body result types are not the same, only handle types are propagated to
602   // result types. This is necessary to not incorrectly change result shapes
603   // when the While op will have a different result shape. Otherwise operand
604   // shapes are propagated to result shapes.
605   template <typename WhileOpTy>
606   bool InferShapeForWhile(WhileOpTy op, TypeRange body_result_types);
607 
608   // Performs shape inference on the provided op and return true if the type of
609   // at least one result has been changed.
610   // A tf.Cast() is inserted for any uses that isn't in the TensorFlow dialect.
611   // `graph_version` indicates the current GraphDef compatibility versions
612   // (the versions field in graph.proto).
613   bool InferShapeForSingleOperation(Operation* op, int64_t max_iterations);
614 
615   // Infers shape on the provided region, including nested ones, iterate until
616   // fix point with a limit of max_iteration.
617   // Returns a failure() on error, otherwise returns true to indicate that it
618   // reached convergence, false otherwise.
619   FailureOr<bool> InferShapeUntilFixPoint(Region* region,
620                                           int64_t max_iterations);
621 
622   // Updates input types and refine shapes inside body of functions that are
623   // attached to ControlFlow ops (If/While) or Calls. These functions include
624   // Then/Else branches of IfOp and Cond/Body functions of WhileOp. Functions
625   // attached to control flow share following common properties:
626   //   1) They are never reused, ie. having a single use in module.
627   //   2) Their input types match those of their parent ops (excluding inputs
628   //      like predicate).
629   // For calls, functions can be reused across multiple call sites. In this case
630   // we propagate the types when all call sites have the same operand types.
631   // Returns a failure() on error, otherwise returns true to indicate that it
632   // reached convergence, false otherwise.
633   FailureOr<bool> PropagateShapeToFunctions(ModuleOp module,
634                                             TypeRange input_types,
635                                             ArrayRef<func::FuncOp> functions,
636                                             int64_t max_iterations);
637 
638   // Propagates shapes to regions given the shapes of the inputs of the regions.
639   // All regions provided in `regions` are assumed to have inputs of type
640   // `input_types`.
641   // Returns a failure() on error, otherwise returns true to indicate that it
642   // reached convergence, false otherwise.
643   FailureOr<bool> PropagateShapeToRegions(TypeRange input_types,
644                                           ArrayRef<Region*> regions,
645                                           int64_t max_iterations);
646 
647   // Shape propagation for call/control flow ops.
648   // Returns a failure() on error, otherwise returns true to indicate that it
649   // reached convergence, false otherwise.
650   FailureOr<bool> PropagateShapeIntoAttachedFunctions(Operation* op,
651                                                       int64_t max_iterations);
652 
653   // Shape propagation for region based control flow.
654   // Returns a failure() on error, otherwise returns true to indicate that it
655   // reached convergence, false otherwise.
656   FailureOr<bool> PropagateShapeIntoAttachedRegions(Operation* op,
657                                                     int64_t max_iterations);
658 
659   // Propagates any constant operand of call_op to the called function body's
660   // corresponding argument if the callee has only one use.
661   //
662   // TODO(b/154065712): Move this to a more general inter-procedural constant
663   // folding pass.
664   void PropagateConstantToCallee(CallOpInterface call_op, FuncOp func,
665                                  ModuleOp module);
666 
667   // Propagates any constant return value of the callee function to the call
668   // op's corresponding result.
669   void PropagateConstantFromCallee(CallOpInterface call_op, FuncOp func,
670                                    ModuleOp module);
671 
672   // Tries to compute the result of folding the op. This doesn't actually
673   // perform constant folding, it is just computes the equivalent constants.
674   // Returns whether it was able to compute constant values.
675   LogicalResult TryToFold(Operation* op);
676 
677   // Makes result types match the operand types (the i-th result type will
678   // match the i-th operand type). Returns true if anything is changed.
679   bool RefineTypeForPassThroughOperands(Operation* op, OperandRange operands,
680                                         ResultRange results);
681 
682   // Makes result type's shape match the corresponding operand's shape.
683   // Returns whether any change was made.
684   bool RefineShapeForPassThroughOps(Operation* op);
685 
686   // Infers shape for necessary ops that are not in the TF dialect. Returns
687   // whether any result type changed.
688   bool InferShapeForNonTFDialectOperation(Operation* op);
689 
690   // Infers shape for function return type and returns whether changed.
691   LogicalResult InferShapeForFunctionReturnType(func::FuncOp func);
692 
693   // Enqueues function for processing.
enqueue(func::FuncOp fn)694   void enqueue(func::FuncOp fn) {
695     LLVM_DEBUG(llvm::dbgs()
696                << "enqueue " << fn.getName() << " ("
697                << (queue_set_.count(fn) ? "already inserted" : "newly inserted")
698                << ")\n");
699     if (queue_set_.insert(fn).second) queue_.push(fn);
700   }
701 
702   // Enqueues callers on functions.
703   void EnqueueCallers(func::FuncOp fn);
704 
705   // Returns the function at the front of the queue.
front()706   func::FuncOp front() { return queue_.front(); }
707 
708   // Returns whether work queue is empty.
EmptyQueue() const709   bool EmptyQueue() const { return queue_.empty(); }
710 
711   // Returns function from the front of the work queue.
pop_front()712   func::FuncOp pop_front() {
713     func::FuncOp ret = queue_.front();
714     queue_.pop();
715     queue_set_.erase(ret);
716     return ret;
717   }
718 
719   // Returns the current size of the queue.
QueueSize() const720   std::queue<func::FuncOp>::size_type QueueSize() const {
721     return queue_.size();
722   }
723 
724   Dialect* const tf_dialect_;
725 
726  private:
727   // Returns whether the result of an operation could be updated to a new
728   // inferred type. Also inserts cast operation for uses that are incompatible
729   // with the new type.
730   bool UpdateTypeAndInsertIncompatibleUseCasts(Type new_type, Value result);
731 
732   // Refines the type of `result` of `op` using the type
733   // `potential_refined_type`. Return true if the type was changed.
734   bool RefineResultType(Operation* op, Value result,
735                         Type potential_refined_type);
736 
737   // Infers the shape from a (Stateful)PartionedCall operation by looking up the
738   // called function and propagating the return type.
739   bool InferShapeForCall(CallOpInterface call_op);
740 
741   bool InferShapeForCast(Operation* op);
742 
743   bool InferShapeForRestore(Operation* op);
744 
745   // Infers the shape IfOp outputs based on the shapes of the then and else
746   // function result types.
747   bool InferShapeForIf(IfOp op);
748 
749   // Infers the shape IfRegion outputs based on the shapes of the then and else
750   // yields.
751   bool InferShapeForIfRegion(IfRegionOp op);
752 
753   // Infers the shape of _XlaHostComputeMlir based on the host computation
754   // module.  Returns true if a return type was changed.
755   bool InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp op);
756 
757   // Infers the shape for MapDatasetOp and its associated function. Returns
758   // whether either op or function type was changed.
759   bool InferShapeForMapDataset(MapDatasetOp op, int64_t max_iterations);
760 
761   // Infers the shape for ReduceDatasetOp and its associated reduce function.
762   // Returns whether either op or function type was changed.
763   bool InferShapeForReduceDataset(ReduceDatasetOp op, int64_t max_iterations);
764 
765   // Infers the shape for TakeWhileDatasetOp and its associated predicate
766   // function. Returns whether either op or function type was changed.
767   bool InferShapeForTakeWhileDataset(TakeWhileDatasetOp op,
768                                      int64_t max_iterations);
769 
770   // Infers shape for dataset ops that have `M` input elements and `N`
771   // arguments, and also propagates the shape to the specified function (called
772   // only when function exists and has single use).
773   bool InferShapeForDatasetOpCommon(Operation* op, FuncOp f,
774                                     int64_t max_iterations);
775 
776   // Infers the shape of ops that create TensorList. Specifically,
777   // TensorListReserveOp, EmptyTensorListOp and TensorListFromTensor ops. It
778   // refines the element shape if all tensors written to the list across all
779   // mutations have identical static shape.
780   bool InferShapeForTensorListInitOps(Operation* op);
781 
782   // Infers the shape of VarHandleOp based on the uses of the VarHandleOp to
783   // update the subtypes of the resource type.
784   bool InferShapeForVarHandleOp(VarHandleOp op);
785 
786   // Infers the output shape of XlaConvV2Op based on the input shapes
787   bool InferShapeForXlaConvV2Op(XlaConvV2Op op);
788 
789   // Infers the output shape of XlaReduceWindowOp based on the input shapes.
790   bool InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op);
791 
792   // Infers the output shape of XlaSelectAndScatterOp based on the input shapes.
793   bool InferShapeForXlaSelectAndScatterOp(XlaSelectAndScatterOp op);
794 
795   bool RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti);
796 
797   // Returns all the callers of a function.
798   // Note: Usage of the return value of this function may not be interleaved
799   // with insertions to the callers map. This could occur if GetCallers is
800   // called with two separate functions, the 2nd one incurs a resize and then
801   // both first and 2nd stored callers are used.
802   ArrayRef<Operation*> GetCallers(func::FuncOp fn);
803 
804   // Mapping between ValuePort (which corresponds to an OpResult or smaller,
805   // e.g., first element of OpResult produced) to an Attribute if the ValuePort
806   // corresponds to a constant value.
807   ValuePortResultMap results_;
808 
809   // Map from a function to the callers of that function.
810   SymbolTableCollection symbol_table_;
811   SymbolUserMap symbol_users_;
812 
813   // Queue of functions being processed.
814   llvm::DenseSet<func::FuncOp> queue_set_;
815   std::queue<func::FuncOp> queue_;
816 
817   int64_t graph_version_;
818 
819   // TODO(b/154065712): Remove propagate_caller_callee_constants once using
820   // SCCP pass instead.
821   bool propagate_caller_callee_constants_;
822 };
823 
ShapeInference(int64_t graph_version,ModuleOp module,bool propagate_caller_callee_constants)824 ShapeInference::ShapeInference(int64_t graph_version, ModuleOp module,
825                                bool propagate_caller_callee_constants)
826     : tf_dialect_(module->getContext()->getLoadedDialect<TensorFlowDialect>()),
827       symbol_users_(symbol_table_, module),
828       graph_version_(graph_version),
829       propagate_caller_callee_constants_(propagate_caller_callee_constants) {
830   // Create symbol table for module.
831   symbol_table_.getSymbolTable(module);
832 }
833 
GetCallers(func::FuncOp fn)834 ArrayRef<Operation*> ShapeInference::GetCallers(func::FuncOp fn) {
835   return symbol_users_.getUsers(fn);
836 }
837 
EnqueueCallers(func::FuncOp fn)838 void ShapeInference::EnqueueCallers(func::FuncOp fn) {
839   for (auto user : GetCallers(fn))
840     enqueue(user->getParentOfType<func::FuncOp>());
841 }
842 
UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,Value result)843 bool ShapeInference::UpdateTypeAndInsertIncompatibleUseCasts(Type new_type,
844                                                              Value result) {
845   // No changes needed if the new type is unchanged.
846   if (new_type == result.getType()) return false;
847 
848   Operation* cast_op = nullptr;
849   // First insert cast back for uses that need a cast and then
850   // update the type.
851   bool enqueue_callers = false;
852   for (OpOperand& use : make_early_inc_range(result.getUses())) {
853     if (isa<func::ReturnOp>(use.getOwner())) {
854       enqueue_callers = true;
855     } else if (NeedsCastBack(use, tf_dialect_)) {
856       if (!cast_op) {
857         Operation* op = result.getDefiningOp();
858         OpBuilder b(op);
859         b.setInsertionPointAfter(op);
860         cast_op = InsertCast(b, op->getLoc(), result.getType(), result);
861         if (!cast_op) return false;
862       }
863       use.set(Value(cast_op->getResult(0)));
864     }
865   }
866 
867   result.setType(new_type);
868   if (enqueue_callers)
869     EnqueueCallers(result.getDefiningOp()->getParentOfType<func::FuncOp>());
870   return true;
871 }
872 
RefineResultType(Operation * op,Value result,Type potential_refined_type)873 bool ShapeInference::RefineResultType(Operation* op, Value result,
874                                       Type potential_refined_type) {
875   if (!CanRefineTypeWith(result.getType(), potential_refined_type))
876     return false;
877 
878   return UpdateTypeAndInsertIncompatibleUseCasts(potential_refined_type,
879                                                  result);
880 }
881 
882 // Infers the shape from a (Stateful)PartionedCall operation by looking up the
883 // called function and propagating the return type.
InferShapeForCall(CallOpInterface call_op)884 bool ShapeInference::InferShapeForCall(CallOpInterface call_op) {
885   func::FuncOp func =
886       dyn_cast_or_null<func::FuncOp>(call_op.resolveCallable(&symbol_table_));
887   if (!func) return false;
888 
889   DCOMMENT("Infer shape for call " << func.getName());
890   Operation* op = call_op.getOperation();
891   bool changed = false;
892   // Map each of the results of the call to the returned type of the
893   // function.
894   for (auto result :
895        zip(op->getResults(), func.getFunctionType().getResults())) {
896     changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
897               changed;
898   }
899   DCOMMENT(" - call " << func.getName() << "changed ? " << changed << "\n");
900 
901   return changed;
902 }
903 
InferShapeForCast(Operation * op)904 bool ShapeInference::InferShapeForCast(Operation* op) {
905   DCOMMENT_OP(op, "Inferring shape for ");
906   Value result = op->getResult(0);
907   if (!CanBeRefined(result.getType())) return false;
908 
909   Type operand_type = op->getOperand(0).getType();
910   auto ranked_op_type = operand_type.dyn_cast<RankedTensorType>();
911   if (!ranked_op_type) return false;
912   auto ranked_res_type = result.getType().dyn_cast<RankedTensorType>();
913   if (ranked_res_type &&
914       ranked_op_type.getShape() == ranked_res_type.getShape())
915     return false;
916 
917   // Avoid inserting a cast where no users types could be refined (e.g., where
918   // there would need to be a cast inserted for every user again).
919   if (llvm::all_of(result.getUses(), [this](OpOperand& use) {
920         return NeedsCastBack(use, tf_dialect_);
921       }))
922     return false;
923 
924   auto new_type = RankedTensorType::get(
925       ranked_op_type.getShape(),
926       result.getType().cast<ShapedType>().getElementType());
927 
928   return UpdateTypeAndInsertIncompatibleUseCasts(new_type, op->getResult(0));
929 }
930 
InferShapeForIf(IfOp op)931 bool ShapeInference::InferShapeForIf(IfOp op) {
932   DCOMMENT_OP(op.getOperation(), "Infer shape for if ");
933   bool changed = false;
934   auto then_results =
935       op.ResolveThenFunction(&symbol_table_).getFunctionType().getResults();
936   auto else_results =
937       op.ResolveElseFunction(&symbol_table_).getFunctionType().getResults();
938   for (auto it : llvm::zip(op.getResults(), then_results, else_results)) {
939     // If then and else types do not match, skip refinement for that result.
940     if (std::get<1>(it) != std::get<2>(it)) continue;
941     changed = RefineResultType(op, std::get<0>(it), std::get<1>(it)) || changed;
942   }
943   return changed;
944 }
945 
InferShapeForIfRegion(IfRegionOp op)946 bool ShapeInference::InferShapeForIfRegion(IfRegionOp op) {
947   bool changed = false;
948 
949   Operation* then_yield = op.then_branch().front().getTerminator();
950   Operation* else_yield = op.else_branch().front().getTerminator();
951   for (auto result : zip(op.getResults(), then_yield->getOperandTypes(),
952                          else_yield->getOperandTypes())) {
953     // If then and else types do not match, skip refinement for that result.
954     if (std::get<1>(result) != std::get<2>(result)) continue;
955     changed = RefineResultType(op, std::get<0>(result), std::get<1>(result)) ||
956               changed;
957   }
958   return changed;
959 }
960 
InferShapeForXlaHostComputeMlir(_XlaHostComputeMlirOp host_compute_op)961 bool ShapeInference::InferShapeForXlaHostComputeMlir(
962     _XlaHostComputeMlirOp host_compute_op) {
963   // Extract the module and function.
964   // The '_XlaHostComputeMlir` verifier verifies that `host_mlir_module`
965   // attribute is well formed, so we just return in case of an error in
966   // extracting the host function since it should never occur.
967   StringAttr host_module =
968       host_compute_op->getAttrOfType<StringAttr>("host_mlir_module");
969   if (host_module.getValue().empty()) return false;
970 
971   mlir::OwningOpRef<mlir::ModuleOp> module_for_func;
972   func::FuncOp func = host_compute_op.GetHostFunc(&module_for_func);
973 
974   // Update/use input shapes for function.
975   FunctionType func_type = func.getFunctionType();
976   func.setType(FunctionType::get(func.getContext(),
977                                  host_compute_op.getOperandTypes(),
978                                  func_type.getResults()));
979 
980   // Run shape inference on the function.
981   if (failed(PropagateShapeToRegions(host_compute_op.getOperandTypes(),
982                                      {&func.getBody()}, 10)))
983     return false;
984   if (failed(InferShapeForFunctionReturnType(func))) return false;
985 
986   bool changed = false;
987   // Use refined function return shape for XlaHostComputeMlirOp.
988   for (auto result :
989        zip(host_compute_op.getResults(), func.getFunctionType().getResults())) {
990     changed = RefineResultType(host_compute_op, std::get<0>(result),
991                                std::get<1>(result)) ||
992               changed;
993   }
994 
995   return changed;
996 }
997 
998 // Infer the shape of `Restore` and `RestoreV2` op based on the first
999 // `AssignVariableOp` that uses the result. This requires that the resource
1000 // subtype inference is completed.
InferShapeForRestore(Operation * op)1001 bool ShapeInference::InferShapeForRestore(Operation* op) {
1002   DCOMMENT_OP(op, "Inferring shape for Restore,RestoreV2");
1003   // Currently only support single output.
1004   if (op->getNumResults() != 1) return false;
1005   if (!CanBeRefined(op->getResult(0).getType())) return false;
1006 
1007   llvm::SmallVector<mlir::Operation*> worklist;
1008   llvm::append_range(worklist, op->getUsers());
1009 
1010   // Look for any `AssignVariableOp` that uses the result of this op.
1011   while (!worklist.empty()) {
1012     mlir::Operation* const use = worklist.pop_back_val();
1013 
1014     // Follow the `CastOp`/`IdentityOp`'s users to handle the `RestoreV2` ->
1015     // (optionally `IdentityOp`) -> `CastOp` `AssignVariableOp` case.
1016     if (llvm::isa<TF::CastOp, TF::IdentityOp>(use)) {
1017       llvm::append_range(worklist, use->getUsers());
1018       continue;
1019     }
1020 
1021     TF::AssignVariableOp assign_op = llvm::dyn_cast<TF::AssignVariableOp>(use);
1022     if (!assign_op) {
1023       continue;
1024     }
1025     auto subtypes = getElementTypeOrSelf(assign_op.resource())
1026                         .cast<TF::ResourceType>()
1027                         .getSubtypes();
1028     if (subtypes.empty()) {
1029       continue;
1030     }
1031     auto subtype = subtypes.front().dyn_cast<ShapedType>();
1032     if (subtype == nullptr) {
1033       continue;
1034     }
1035     // Preserve the dtype from the restore op even if `AssignVariableOp` uses a
1036     // different dtype, which is possible when there's a `CastOp` between them.
1037     subtype = subtype.clone(
1038         op->getResult(0).getType().cast<ShapedType>().getElementType());
1039     // Update the result type of this op with the resource's type. We only use
1040     // the resource subtype of the first user since shapes from all the users
1041     // should be equal or compatible.
1042     return UpdateTypeAndInsertIncompatibleUseCasts(subtype, op->getResult(0));
1043   }
1044   return false;
1045 }
1046 
1047 // Helper structure to capture shapes & types for Dataset input.
1048 struct DatasetInput {
operator boolmlir::TF::DatasetInput1049   explicit operator bool() const { return shapes && types; }
1050 
1051   ArrayAttr shapes;
1052   ArrayAttr types;
1053 };
1054 
1055 // Returns the input elements shapes and types for Dataset ops.
GetDatasetInput(Value value)1056 DatasetInput GetDatasetInput(Value value) {
1057   // TODO(haoliang): add an interface for DatasetOp to avoid the following
1058   // enumeration.
1059   // Iteratively tracing upwards if parent op is `IdentityOp` or `IdentityNOp`.
1060   while (
1061       llvm::isa_and_nonnull<IdentityOp, IdentityNOp>(value.getDefiningOp())) {
1062     value = value.getDefiningOp()->getOperand(
1063         value.cast<OpResult>().getResultNumber());
1064   }
1065 
1066   Operation* op = value.getDefiningOp();
1067   if (!llvm::isa_and_nonnull<BatchDatasetV2Op, MapDatasetOp, RepeatDatasetOp,
1068                              ParallelMapDatasetOp, ParallelMapDatasetV2Op,
1069                              TakeDatasetOp, TakeWhileDatasetOp>(op))
1070     return DatasetInput{nullptr, nullptr};
1071 
1072   return DatasetInput{op->getAttrOfType<ArrayAttr>("output_shapes"),
1073                       op->getAttrOfType<ArrayAttr>("output_types")};
1074 }
1075 
InferShapeForDatasetOpCommon(Operation * op,FuncOp f,int64_t max_iterations)1076 bool ShapeInference::InferShapeForDatasetOpCommon(Operation* op, FuncOp f,
1077                                                   int64_t max_iterations) {
1078   int N = op->getNumOperands() - 1;
1079   int M = f.getNumArguments() - N;
1080   DCOMMENT_OP(op, "Inferring shape for with N = " << N << " and M = " << M);
1081 
1082   // Initialize with function input types.
1083   auto input_types = llvm::to_vector<1>(
1084       cast<FunctionOpInterface>(f.getOperation()).getArgumentTypes());
1085 
1086   DatasetInput input_elements = GetDatasetInput(op->getOperand(0));
1087   if (!input_elements) {
1088     op->emitWarning("unexpected dataset input; skipping function refinement");
1089     return false;
1090   }
1091 
1092   // Track if changed to skip enqueueing.
1093   bool changed = false;
1094   auto it = input_types.begin();
1095   // First set first M argument shapes & types.
1096   for (int i = 0; i < M; ++i) {
1097     Type t = GetType(input_elements.shapes[i], input_elements.types[i]);
1098     t = TypeMeet(*it, t);
1099     changed = changed || (t != *it);
1100     *it++ = t;
1101   }
1102   // Now the remaining N from operand types.
1103   for (auto t : llvm::drop_begin(op->getOperandTypes())) {
1104     auto meet = TypeMeet(*it, t);
1105     changed = changed || (meet != *it);
1106     *it++ = meet;
1107   }
1108   if (!changed) return false;
1109 
1110   FailureOr<bool> res = PropagateShapeToFunctions(
1111       op->getParentOfType<ModuleOp>(), input_types, {f}, max_iterations);
1112   if (failed(res)) {
1113     op->emitOpError("propagating shapes failed");
1114     return false;
1115   }
1116   return *res;
1117 }
1118 
InferShapeForMapDataset(MapDatasetOp op,int64_t max_iterations)1119 bool ShapeInference::InferShapeForMapDataset(MapDatasetOp op,
1120                                              int64_t max_iterations) {
1121   // MapDatasetOp's relationship with its associated function is as
1122   // follows: first M function params are dictated by the set
1123   // output shapes and types, the next N are the last Ninputs from MapDataset
1124   // op. The MapDataset op always has N+1 inputs.
1125   // TODO(jpienaar): Avoid this lookup.
1126   auto module = op->getParentOfType<ModuleOp>();
1127   auto f = module.lookupSymbol<func::FuncOp>(op.f());
1128   // Skip if function is not found or more than one caller.
1129   if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
1130   return InferShapeForDatasetOpCommon(op, f, max_iterations);
1131 }
1132 
InferShapeForTakeWhileDataset(TakeWhileDatasetOp op,int64_t max_iterations)1133 bool ShapeInference::InferShapeForTakeWhileDataset(TakeWhileDatasetOp op,
1134                                                    int64_t max_iterations) {
1135   // TakeWhileDatasetOp's relationship with its associated function is as
1136   // follows: first M function params are dictated by the set
1137   // output shapes and types, the next N are the last Ninputs from
1138   // TakeWhileDataset op. The TakeWhileDataset op always has N+1 inputs.
1139   // TODO(jpienaar): Avoid this lookup.
1140   auto module = op->getParentOfType<ModuleOp>();
1141   auto f = module.lookupSymbol<func::FuncOp>(op.predicate());
1142   // Skip if function is not found or more than one caller.
1143   if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
1144   return InferShapeForDatasetOpCommon(op, f, max_iterations);
1145 }
1146 
InferShapeForReduceDataset(ReduceDatasetOp op,int64_t max_iterations)1147 bool ShapeInference::InferShapeForReduceDataset(ReduceDatasetOp op,
1148                                                 int64_t max_iterations) {
1149   // ReduceDatasetOp's relationship with its associated reduce function is
1150   // described as follows: The reduce function will in general have (X + Y + Z)
1151   // arguments, where X is the number of tensor components that represent the
1152   // state, Y is the number of tensor components that represent the input
1153   // elements, and Z is the number of tensor components that represent any
1154   // captured arguments. Y is determined by the output_shapes of an op that
1155   // defines the first operand of `op`.
1156 
1157   // TODO(jpienaar): Avoid this lookup.
1158   auto module = op->getParentOfType<ModuleOp>();
1159   auto f = module.lookupSymbol<func::FuncOp>(op.f());
1160 
1161   // Skip if function is not found or it has more than one caller.
1162   if (!f || !llvm::hasSingleElement(GetCallers(f))) return false;
1163 
1164   DatasetInput input_elements = GetDatasetInput(op.input_dataset());
1165 
1166   const int num_states = op.output_shapes().size();
1167   const int num_captured_arguments = op.getNumOperands() - 1 - num_states;
1168 
1169   // If input_elements is undefined, we can still infer the shapes for the
1170   // states and captured arguments.
1171   int num_input_elements;
1172   auto input_types = llvm::to_vector<1>(
1173       cast<FunctionOpInterface>(f.getOperation()).getArgumentTypes());
1174   if (input_elements) {
1175     num_input_elements = input_elements.shapes.size();
1176   } else {
1177     num_input_elements =
1178         input_types.size() - num_states - num_captured_arguments;
1179   }
1180 
1181   DCOMMENT_OP(op,
1182               "Inferring shape for ReduceDataset with #states = "
1183                   << num_states << " , #input_elements = " << num_input_elements
1184                   << " , and #captured_arguments = " << num_captured_arguments);
1185   if (num_states + num_input_elements + num_captured_arguments !=
1186       f.getNumArguments()) {
1187     op->emitOpError(
1188         "propagating shapes for ReduceDataset failed due to inconsistent "
1189         "number of arguments");
1190     return false;
1191   }
1192 
1193   // Track if changed to skip enqueueing.
1194   bool changed = false;
1195   auto it = input_types.begin();
1196 
1197   // Set the first num_states arguments shapes & types from the state.
1198   for (int i = 0; i < num_states; ++i) {
1199     Type t = GetType(op.output_shapes()[i], op.output_types()[i]);
1200     t = TypeMeet(*it, t);
1201     changed = changed || (t != *it);
1202     *it++ = t;
1203   }
1204 
1205   // Second set the following num_input_elements arguments from
1206   // repeat_dataset_op.  Skip propagating shape if input_elements is
1207   // undefined.
1208   for (int i = 0; i < num_input_elements; ++i) {
1209     if (input_elements) {
1210       Type t = GetType(input_elements.shapes[i], input_elements.types[i]);
1211       t = TypeMeet(*it, t);
1212       changed = changed || (t != *it);
1213       *it++ = t;
1214     } else {
1215       it++;
1216     }
1217   }
1218 
1219   // Last set the remaining num_captured_arguments from op.
1220   for (auto t : llvm::drop_begin(op.getOperandTypes(), 1 + num_states)) {
1221     auto meet = TypeMeet(*it, t);
1222     changed = changed || (meet != *it);
1223     *it++ = meet;
1224   }
1225 
1226   if (!changed) return false;
1227 
1228   FailureOr<bool> res =
1229       PropagateShapeToFunctions(module, input_types, {f}, max_iterations);
1230   if (failed(res)) {
1231     op->emitOpError("Propagating shapes for ReduceDataset failed");
1232     return false;
1233   }
1234   return *res;
1235 }
1236 
InferShapeForTensorListInitOps(Operation * op)1237 bool ShapeInference::InferShapeForTensorListInitOps(Operation* op) {
1238   DCOMMENT_OP(op, "Inferring shape for TensorList ");
1239   Value handle = op->getResult(0);
1240   Value initial_element_shape = GetElementShapeOperand(op);
1241   RankedTensorType element_type;
1242   if (auto tl_from_tensor = dyn_cast<TensorListFromTensorOp>(op)) {
1243     // For TensorListFromTensor op we can infer element shape by dropping the
1244     // first dimension of input tensor.
1245     element_type = DropFirstDimension(tl_from_tensor.tensor().getType());
1246     if (!element_type || !element_type.hasStaticShape()) return false;
1247   }
1248   if (!CanInferTensorListElementType(handle, initial_element_shape,
1249                                      &element_type)) {
1250     DCOMMENT("InferShapeForListInitOps " << op << " could not infer");
1251     return false;
1252   }
1253   DCOMMENT("InferShapeForListInitOps " << *op << " could be inferred "
1254                                        << element_type);
1255   if (!element_type || !element_type.hasStaticShape()) return false;
1256   auto variant_type = VariantType::get(element_type, op->getContext());
1257   auto tensor_type = RankedTensorType::get({}, variant_type);
1258   bool changed = RefineResultType(op, handle, tensor_type);
1259   if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
1260   return changed;
1261 }
1262 
InferShapeForVarHandleOp(VarHandleOp op)1263 bool ShapeInference::InferShapeForVarHandleOp(VarHandleOp op) {
1264   DCOMMENT_OP(op, "Inferring shape for VarHandleOp");
1265 
1266   Value resource = op.resource();
1267   if (!CanBeRefined(resource.getType())) return false;
1268 
1269   // Make sure there are only use cases from the `AssignVariableOp` and
1270   // `ReadVariableOp`. For other cases, we can skip to be conservative.
1271   for (auto& use : make_early_inc_range(resource.getUses())) {
1272     Operation* def = use.getOwner();
1273     if (!llvm::isa<AssignVariableOp>(def) && !llvm::isa<ReadVariableOp>(def)) {
1274       return false;
1275     }
1276   }
1277 
1278   bool changed = false;
1279 
1280   // Look for any `AssignVariableOp` and `ReadVariableOp` that uses the value of
1281   // this op.
1282   for (auto& use : make_early_inc_range(resource.getUses())) {
1283     Operation* def = use.getOwner();
1284     Value value;
1285     if (AssignVariableOp assign_op = dyn_cast<AssignVariableOp>(def)) {
1286       value = assign_op.value();
1287     } else if (ReadVariableOp read_op = dyn_cast<ReadVariableOp>(def)) {
1288       value = read_op.value();
1289     } else {
1290       llvm_unreachable("unexpected operator type");
1291     }
1292 
1293     TensorType resource_subtype = value.getType().cast<TensorType>();
1294     ResourceType resource_type =
1295         ResourceType::get({resource_subtype}, op.getContext());
1296     UnrankedTensorType new_resource_type =
1297         UnrankedTensorType::get(resource_type);
1298 
1299     Type refined_type = TypeMeet(resource.getType(), new_resource_type);
1300     if (refined_type == resource.getType()) continue;
1301     resource.setType(refined_type);
1302     changed = true;
1303   }
1304 
1305   return changed;
1306 }
1307 
1308 // Helper function for creating a Window proto from user-supplied data.
1309 // Returns llvm::None if the user-supplied data was invalid.
InferWindowFromDimensions(llvm::SmallVector<int64_t> window_dimensions,llvm::SmallVector<int64_t> window_strides,llvm::SmallVector<std::pair<int64_t,int64_t>> padding,llvm::SmallVector<int64_t> lhs_dilation,llvm::SmallVector<int64_t> rhs_dilation)1310 llvm::Optional<xla::Window> InferWindowFromDimensions(
1311     llvm::SmallVector<int64_t> window_dimensions,
1312     llvm::SmallVector<int64_t> window_strides,
1313     llvm::SmallVector<std::pair<int64_t, int64_t>> padding,
1314     llvm::SmallVector<int64_t> lhs_dilation,
1315     llvm::SmallVector<int64_t> rhs_dilation) {
1316   const auto verify_size = [&](const size_t x, const char* x_name) {
1317     if (x == 0 || x == window_dimensions.size()) {
1318       return true;
1319     } else {
1320       llvm::errs()
1321           << "Window has different number of window dimensions than of "
1322           << x_name
1323           << "\nNumber of window dimensions: " << window_dimensions.size()
1324           << "\nNumber of " << x_name << ": " << x << "\n";
1325       return false;
1326     }
1327   };
1328 
1329   if (!(verify_size(window_dimensions.size(), "window_dimensions") &&
1330         verify_size(window_strides.size(), "window strides") &&
1331         verify_size(padding.size(), "padding entries") &&
1332         verify_size(lhs_dilation.size(), "lhs dilation factors") &&
1333         verify_size(rhs_dilation.size(), "rhs dilation factors")))
1334     return llvm::None;
1335 
1336   xla::Window window;
1337   for (size_t i = 0; i < window_dimensions.size(); i++) {
1338     auto dim = window.add_dimensions();
1339     dim->set_size(window_dimensions[i]);
1340     if (!window_strides.empty()) {
1341       dim->set_stride(window_strides[i]);
1342     } else {
1343       dim->set_stride(1);
1344     }
1345     if (!padding.empty()) {
1346       dim->set_padding_low(padding[i].first);
1347       dim->set_padding_high(padding[i].second);
1348     } else {
1349       dim->set_padding_low(0);
1350       dim->set_padding_high(0);
1351     }
1352     if (!lhs_dilation.empty()) {
1353       dim->set_base_dilation(lhs_dilation[i]);
1354     } else {
1355       dim->set_base_dilation(1);
1356     }
1357     if (!rhs_dilation.empty()) {
1358       dim->set_window_dilation(rhs_dilation[i]);
1359     } else {
1360       dim->set_window_dilation(1);
1361     }
1362     dim->set_window_reversal(false);
1363   }
1364   return window;
1365 }
1366 
InferWindowOutputShape(const ShapedType & base_shape,const xla::Window & window,Type element_type)1367 llvm::Optional<RankedTensorType> InferWindowOutputShape(
1368     const ShapedType& base_shape, const xla::Window& window,
1369     Type element_type) {
1370   if (window.dimensions_size() != base_shape.getRank()) {
1371     llvm::errs() << "Window has dimension " << window.dimensions_size()
1372                  << " but base shape has dimension " << base_shape.getRank()
1373                  << "\n";
1374     return llvm::None;
1375   }
1376 
1377   std::vector<int64_t> output_dimensions(window.dimensions_size());
1378   std::vector<bool> output_is_dynamic(window.dimensions_size());
1379   for (int64_t i = 0; i < window.dimensions_size(); ++i) {
1380     const auto& dim = window.dimensions(i);
1381     if (dim.size() <= 0) {
1382       llvm::errs() << "Window " << window.DebugString()
1383                    << " has a non-positive dimension.\n";
1384       return llvm::None;
1385     }
1386     if (dim.stride() <= 0) {
1387       llvm::errs() << "Window " << window.DebugString()
1388                    << " has a non-positive stride.\n";
1389       return llvm::None;
1390     }
1391     if (dim.base_dilation() < 1) {
1392       llvm::errs() << "Window " << window.DebugString()
1393                    << " has a non-positive base area dilation factor.\n";
1394       return llvm::None;
1395     }
1396     if (dim.window_dilation() < 1) {
1397       llvm::errs() << "Window " << window.DebugString()
1398                    << " has a non-positive window dilation factor.\n";
1399       return llvm::None;
1400     }
1401 
1402     if (base_shape.isDynamicDim(i)) {
1403       output_dimensions[i] = ShapedType::kDynamicSize;
1404     } else {
1405       const int64_t dilated_base = xla::window_util::DilatedBound(
1406           base_shape.getDimSize(i), dim.base_dilation());
1407       const int64_t padded_dilated_base =
1408           dim.padding_low() + dilated_base + dim.padding_high();
1409       const int64_t dilated_window =
1410           xla::window_util::DilatedBound(dim.size(), dim.window_dilation());
1411 
1412       output_dimensions[i] = xla::window_util::StridedBound(
1413           padded_dilated_base, dilated_window, dim.stride());
1414     }
1415   }
1416 
1417   return RankedTensorType::get(output_dimensions, element_type);
1418 }
1419 
InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op)1420 bool ShapeInference::InferShapeForXlaReduceWindowOp(XlaReduceWindowOp op) {
1421   DCOMMENT_OP(op, "Inferring shape for XlaReduceWindowOp");
1422 
1423   bool changed = false;
1424 
1425   auto input_ty = op.input().getType().cast<ShapedType>();
1426   DenseElementsAttr window_dimensions, window_strides, base_dilations,
1427       window_dilations, padding;
1428   if (input_ty.hasStaticShape() &&
1429       matchPattern(op.window_dimensions(), m_Constant(&window_dimensions)) &&
1430       matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
1431       matchPattern(op.base_dilations(), m_Constant(&base_dilations)) &&
1432       matchPattern(op.window_dilations(), m_Constant(&window_dilations)) &&
1433       matchPattern(op.padding(), m_Constant(&padding))) {
1434     llvm::SmallVector<int64_t> window_dimensions_vec, window_strides_vec,
1435         base_dilations_vec, window_dilations_vec;
1436     llvm::SmallVector<std::pair<int64_t, int64_t>> padding_pairs(
1437         padding.getNumElements() / 2);
1438 
1439     for (auto i = 0; i < window_dimensions.size(); ++i) {
1440       window_dimensions_vec.push_back(
1441           window_dimensions.getValues<IntegerAttr>()[i].getInt());
1442     }
1443 
1444     for (auto i = 0; i < window_strides.size(); ++i) {
1445       window_strides_vec.push_back(
1446           window_strides.getValues<IntegerAttr>()[i].getInt());
1447     }
1448 
1449     for (auto i = 0; i < base_dilations.size(); ++i) {
1450       base_dilations_vec.push_back(
1451           base_dilations.getValues<IntegerAttr>()[i].getInt());
1452     }
1453 
1454     for (auto i = 0; i < window_dilations.size(); ++i) {
1455       window_dilations_vec.push_back(
1456           window_dilations.getValues<IntegerAttr>()[i].getInt());
1457     }
1458 
1459     for (auto i = 0; i < padding_pairs.size(); ++i) {
1460       padding_pairs[i] = {padding.getValues<IntegerAttr>()[i * 2].getInt(),
1461                           padding.getValues<IntegerAttr>()[i * 2 + 1].getInt()};
1462     }
1463 
1464     auto window = InferWindowFromDimensions(
1465         window_dimensions_vec, window_strides_vec, padding_pairs,
1466         base_dilations_vec, window_dilations_vec);
1467     if (!window) {
1468       op->emitOpError("failed to create window");
1469     }
1470     auto output_shape = InferWindowOutputShape(
1471         input_ty, window.getValue(),
1472         op.init_value().getType().cast<ShapedType>().getElementType());
1473 
1474     if (!output_shape) {
1475       op->emitOpError("failed to infer output shape");
1476     }
1477 
1478     changed = RefineResultType(op.getOperation(), op.getResult(),
1479                                output_shape.getValue());
1480   }
1481 
1482   return changed;
1483 }
1484 
InferShapeForXlaSelectAndScatterOp(XlaSelectAndScatterOp op)1485 bool ShapeInference::InferShapeForXlaSelectAndScatterOp(
1486     XlaSelectAndScatterOp op) {
1487   DCOMMENT_OP(op, "Inferring shape for XlaSelectAndScatterOp");
1488 
1489   auto operand_shape = op.operand().getType().cast<ShapedType>();
1490   auto source_shape = op.source().getType().cast<ShapedType>();
1491   DenseElementsAttr window_dimensions, window_strides, padding;
1492   if (operand_shape.hasRank() && source_shape.hasRank() &&
1493       matchPattern(op.window_dimensions(), m_Constant(&window_dimensions)) &&
1494       matchPattern(op.window_strides(), m_Constant(&window_strides)) &&
1495       matchPattern(op.padding(), m_Constant(&padding))) {
1496     llvm::SmallVector<int64_t> window_dimensions_vec, window_strides_vec,
1497         base_dilations_vec, window_dilations_vec;
1498     llvm::SmallVector<std::pair<int64_t, int64_t>> padding_pairs(
1499         padding.getNumElements() / 2);
1500 
1501     for (auto i = 0; i < window_dimensions.size(); ++i) {
1502       window_dimensions_vec.push_back(
1503           window_dimensions.getValues<IntegerAttr>()[i].getInt());
1504     }
1505 
1506     for (auto i = 0; i < window_strides.size(); ++i) {
1507       window_strides_vec.push_back(
1508           window_strides.getValues<IntegerAttr>()[i].getInt());
1509     }
1510 
1511     for (auto i = 0; i < padding_pairs.size(); ++i) {
1512       padding_pairs[i] = {padding.getValues<IntegerAttr>()[i * 2].getInt(),
1513                           padding.getValues<IntegerAttr>()[i * 2 + 1].getInt()};
1514     }
1515 
1516     auto window = InferWindowFromDimensions(
1517         window_dimensions_vec, window_strides_vec, padding_pairs, {}, {});
1518     if (!window) {
1519       op->emitOpError("failed to create window");
1520     }
1521     auto window_result_shape = InferWindowOutputShape(
1522         operand_shape, window.getValue(), operand_shape.getElementType());
1523 
1524     if (!window_result_shape) {
1525       op->emitOpError("failed to infer window result shape");
1526     }
1527 
1528     if (window_result_shape.getValue() != source_shape) {
1529       op->emitOpError(
1530           "Source shape does not match the shape of window-reduced operand.");
1531     }
1532   }
1533 
1534   return RefineResultType(op.getOperation(), op.getResult(),
1535                           op.operand().getType());
1536 }
1537 
InferXlaConvOutputShape(llvm::SmallVector<int64_t> input_tensor_dims,llvm::SmallVector<int64_t> kernel_tensor_dims,llvm::SmallVector<int64_t> window_strides,llvm::SmallVector<std::pair<int64_t,int64_t>> paddings,llvm::SmallVector<int64_t> lhs_dilations,llvm::SmallVector<int64_t> rhs_dilations,int64_t batch_group_count,xla::ConvolutionDimensionNumbers dnums,Type element_type)1538 llvm::Optional<RankedTensorType> InferXlaConvOutputShape(
1539     llvm::SmallVector<int64_t> input_tensor_dims,
1540     llvm::SmallVector<int64_t> kernel_tensor_dims,
1541     llvm::SmallVector<int64_t> window_strides,
1542     llvm::SmallVector<std::pair<int64_t, int64_t>> paddings,
1543     llvm::SmallVector<int64_t> lhs_dilations,
1544     llvm::SmallVector<int64_t> rhs_dilations, int64_t batch_group_count,
1545     xla::ConvolutionDimensionNumbers dnums, Type element_type) {
1546   auto num_spatial_dims = input_tensor_dims.size() - 2;
1547   std::vector<int64_t> output_dims(input_tensor_dims.size());
1548 
1549   auto input_batch = input_tensor_dims[dnums.input_batch_dimension()];
1550   auto kernel_output_feature =
1551       kernel_tensor_dims[dnums.kernel_output_feature_dimension()];
1552   output_dims[dnums.output_batch_dimension()] = input_batch / batch_group_count;
1553   DCOMMENT("inferrd output batch dimension is "
1554            << output_dims[dnums.output_batch_dimension()]);
1555   output_dims[dnums.output_feature_dimension()] = kernel_output_feature;
1556   DCOMMENT("inferrd output output_feature_dimension is "
1557            << output_dims[dnums.output_feature_dimension()]);
1558 
1559   std::vector<int64_t> input_spatial_dims;
1560   llvm::SmallVector<int64_t> window_spatial_dims;
1561   for (auto i = 0; i < num_spatial_dims; ++i) {
1562     input_spatial_dims.push_back(
1563         input_tensor_dims[dnums.input_spatial_dimensions(i)]);
1564     window_spatial_dims.push_back(
1565         kernel_tensor_dims[dnums.kernel_spatial_dimensions(i)]);
1566   }
1567 
1568   ShapedType base_shape =
1569       RankedTensorType::get(input_spatial_dims, element_type);
1570 
1571   auto window =
1572       InferWindowFromDimensions(window_spatial_dims, window_strides, paddings,
1573                                 lhs_dilations, rhs_dilations);
1574 
1575   auto output_shape =
1576       InferWindowOutputShape(base_shape, window.getValue(), element_type);
1577 
1578   for (auto i = 0; i < num_spatial_dims; ++i) {
1579     output_dims[dnums.output_spatial_dimensions(i)] =
1580         output_shape.getValue().getShape()[i];
1581     DCOMMENT("inferrd output spatial dimension "
1582              << i << " at dimension numebr "
1583              << dnums.output_spatial_dimensions(i) << " is "
1584              << output_dims[dnums.output_spatial_dimensions(i)]);
1585   }
1586   return RankedTensorType::get(output_dims, element_type);
1587 }
1588 
1589 // TODO(hanxiongwang): The logic in this function need move to Op Verify method
1590 // when dependecy issue of adding header file
1591 // "third_party/tensorflow/compiler/xla/xla_data.pb.h" into
1592 // "third_party/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc" is
1593 // resolved
PrecheckForXlaConvV2Op(XlaConvV2Op op)1594 LogicalResult PrecheckForXlaConvV2Op(XlaConvV2Op op) {
1595   auto input_tensor = op.lhs();
1596   auto kernel_tensor = op.rhs();
1597   auto window_strides = op.window_strides();
1598   auto padding = op.padding();
1599   auto lhs_dilation = op.lhs_dilation();
1600   auto rhs_dilation = op.rhs_dilation();
1601   auto feature_group_count = op.feature_group_count();
1602   int64_t batch_group_count = op.batch_group_count();
1603 
1604   auto input_args_have_static_shape = [&]() -> bool {
1605     return input_tensor.getType().cast<TensorType>().hasStaticShape() &&
1606            kernel_tensor.getType().cast<TensorType>().hasStaticShape() &&
1607            window_strides.getType().cast<TensorType>().hasStaticShape() &&
1608            padding.getType().cast<TensorType>().hasStaticShape() &&
1609            lhs_dilation.getType().cast<TensorType>().hasStaticShape() &&
1610            rhs_dilation.getType().cast<TensorType>().hasStaticShape() &&
1611            feature_group_count.getType().cast<TensorType>().hasStaticShape();
1612   };
1613 
1614   // Return failure when one of the input args has not a static shape
1615   if (!input_args_have_static_shape()) {
1616     return failure();
1617   }
1618 
1619   auto input_tensor_shape =
1620       input_tensor.getType().cast<RankedTensorType>().getShape();
1621   auto kernel_tensor_shape =
1622       kernel_tensor.getType().cast<RankedTensorType>().getShape();
1623 
1624   if (input_tensor_shape.size() <= 2) {
1625     return op.emitOpError()
1626            << "input tensor argument is " << input_tensor_shape.size()
1627            << " which is invalid, since input tensor argument must has a "
1628            << "rank greater than 2.\n";
1629   }
1630 
1631   if (kernel_tensor_shape.size() <= 2) {
1632     return op.emitOpError()
1633            << "kernel tensor argument is " << kernel_tensor_shape.size()
1634            << " which is invalid, since kernel tensor argument must has a "
1635            << "rank greater than 2.\n";
1636   }
1637 
1638   if (input_tensor_shape.size() != kernel_tensor_shape.size()) {
1639     return op.emitOpError() << "both input tensor and kernel tensor must "
1640                             << "have same number of dimensions.\n";
1641   }
1642 
1643   DenseElementsAttr feature_group_count_attr;
1644   xla::ConvolutionDimensionNumbers dnums;
1645   dnums.ParseFromString(op.dimension_numbersAttr().getValue().str());
1646   if (dnums.input_spatial_dimensions_size() !=
1647       dnums.kernel_spatial_dimensions_size()) {
1648     return op.emitOpError() << "Both arguments to convolution must have "
1649                             << "same number of dimensions.\n";
1650   }
1651 
1652   if (dnums.input_spatial_dimensions_size() !=
1653       dnums.output_spatial_dimensions_size()) {
1654     return op.emitOpError() << "Both input and output of convolution must have "
1655                             << "same number of dimensions.\n";
1656   }
1657   if (!matchPattern(feature_group_count,
1658                     m_Constant(&feature_group_count_attr))) {
1659     return success();
1660   }
1661 
1662   auto feature_group_count_val =
1663       feature_group_count_attr.getValues<IntegerAttr>()[0].getInt();
1664   auto input_features = input_tensor_shape[dnums.input_feature_dimension()];
1665   auto input_batch = input_tensor_shape[dnums.input_batch_dimension()];
1666   auto kernel_input_features =
1667       kernel_tensor_shape[dnums.kernel_input_feature_dimension()];
1668   auto kernel_output_features =
1669       kernel_tensor_shape[dnums.kernel_output_feature_dimension()];
1670 
1671   if (feature_group_count_val <= 0) {
1672     return op.emitOpError()
1673            << "feature_group_count must be a positive number, got "
1674            << feature_group_count_val;
1675   }
1676 
1677   if (batch_group_count <= 0) {
1678     return op.emitOpError()
1679            << "batch_group_count must be a positive number, got "
1680            << batch_group_count;
1681   }
1682   if (batch_group_count > 1 && feature_group_count_val > 1) {
1683     return op.emitOpError()
1684            << "both batch_group_count " << batch_group_count
1685            << "and feature_group_count " << feature_group_count_val
1686            << " cannot be greater than 1";
1687   }
1688   if (kernel_output_features % batch_group_count != 0) {
1689     return op.emitOpError()
1690            << "Expected output feature dimension size (value "
1691            << kernel_output_features
1692            << ") to be a multiple of batch group count " << batch_group_count;
1693   }
1694   if (input_features % feature_group_count_val != 0 ||
1695       input_features / feature_group_count_val != kernel_input_features) {
1696     return op.emitOpError()
1697            << "Expected the size of kernel_input_features (value "
1698            << kernel_input_features
1699            << ") in rhs times feature_group_count (value "
1700            << feature_group_count_val
1701            << ") in lhs should equal the size of the z dimension (value "
1702            << input_features << ") in lhs.\n";
1703   }
1704   if (kernel_output_features % feature_group_count_val > 0) {
1705     return op.emitOpError() << "Expected output feature dimension (value "
1706                             << kernel_output_features << ") to be divisible by "
1707                             << "feature_group_count (value "
1708                             << feature_group_count_val << ").\n";
1709   }
1710   if (input_batch % batch_group_count != 0) {
1711     return op.emitOpError()
1712            << "Expected input batch dimension (value " << input_batch
1713            << " ) to be divisible by batch_group_count (value "
1714            << batch_group_count << "); ";
1715   }
1716   return success();
1717 }
1718 
InferShapeForXlaConvV2Op(XlaConvV2Op op)1719 bool ShapeInference::InferShapeForXlaConvV2Op(XlaConvV2Op op) {
1720   DCOMMENT_OP(op, "Inferring shape for XlaConvV2Op");
1721 
1722   bool changed = false;
1723 
1724   if (PrecheckForXlaConvV2Op(op).failed()) {
1725     return changed;
1726   }
1727 
1728   auto input_tensor = op.lhs();
1729   auto kernel_tensor = op.rhs();
1730   auto window_strides = op.window_strides();
1731   auto padding = op.padding();
1732   auto lhs_dilation = op.lhs_dilation();
1733   auto rhs_dilation = op.rhs_dilation();
1734   int64_t batch_group_count = op.batch_group_count();
1735 
1736   DenseIntElementsAttr window_strides_attr, padding_attr, lhs_dilation_attr,
1737       rhs_dilation_attr;
1738   if (matchPattern(window_strides, m_Constant(&window_strides_attr)) &&
1739       matchPattern(padding, m_Constant(&padding_attr)) &&
1740       matchPattern(lhs_dilation, m_Constant(&lhs_dilation_attr)) &&
1741       matchPattern(rhs_dilation, m_Constant(&rhs_dilation_attr))) {
1742     llvm::SmallVector<int64_t> input_tensor_dims_vec, kernel_tensor_dims_vec,
1743         window_strides_vec, lhs_dilations_vec, rhs_dilations_vec;
1744     llvm::SmallVector<std::pair<int64_t, int64_t>> padding_pairs(
1745         padding_attr.getNumElements() / 2);
1746     xla::ConvolutionDimensionNumbers dnums;
1747     dnums.ParseFromString(op.dimension_numbersAttr().getValue().str());
1748 
1749     auto input_tensor_shape = input_tensor.getType().cast<RankedTensorType>();
1750     for (auto i = 0; i < input_tensor_shape.getShape().size(); ++i) {
1751       DCOMMENT("Input Tensor Shape " << i << "th is "
1752                                      << input_tensor_shape.getShape()[i]);
1753       input_tensor_dims_vec.push_back(input_tensor_shape.getShape()[i]);
1754     }
1755 
1756     auto kernel_tensor_shape = kernel_tensor.getType().cast<RankedTensorType>();
1757     for (auto i = 0; i < kernel_tensor_shape.getShape().size(); ++i) {
1758       DCOMMENT("Kernel tensor Shape" << i << "th is "
1759                                      << kernel_tensor_shape.getShape()[i]);
1760       kernel_tensor_dims_vec.push_back(kernel_tensor_shape.getShape()[i]);
1761     }
1762 
1763     for (const llvm::APInt& i : window_strides_attr) {
1764       window_strides_vec.push_back(i.getSExtValue());
1765     }
1766 
1767     for (auto i = 0; i < padding_pairs.size(); ++i) {
1768       padding_pairs[i] = {
1769           padding_attr.getValues<IntegerAttr>()[i * 2].getInt(),
1770           padding_attr.getValues<IntegerAttr>()[i * 2 + 1].getInt()};
1771     }
1772 
1773     for (const llvm::APInt& i : lhs_dilation_attr) {
1774       lhs_dilations_vec.push_back(i.getSExtValue());
1775     }
1776 
1777     for (const llvm::APInt& i : rhs_dilation_attr) {
1778       rhs_dilations_vec.push_back(i.getSExtValue());
1779     }
1780 
1781     Type input_tensor_element_type = input_tensor_shape.getElementType();
1782     Type result_element_type = op.getType().getElementType();
1783     Type element_type = input_tensor_element_type.getIntOrFloatBitWidth() >=
1784                                 result_element_type.getIntOrFloatBitWidth()
1785                             ? input_tensor_element_type
1786                             : result_element_type;
1787     auto output_shape = InferXlaConvOutputShape(
1788         input_tensor_dims_vec, kernel_tensor_dims_vec, window_strides_vec,
1789         padding_pairs, lhs_dilations_vec, rhs_dilations_vec, batch_group_count,
1790         dnums, element_type);
1791 
1792     if (output_shape.getValue()) {
1793       changed = RefineResultType(op.getOperation(), op.getResult(),
1794                                  output_shape.getValue());
1795       return changed;
1796     }
1797   }
1798   return changed;
1799 }
1800 
RefineWithInferTypeOpInterface(InferTypeOpInterface infer_ti)1801 bool ShapeInference::RefineWithInferTypeOpInterface(
1802     InferTypeOpInterface infer_ti) {
1803   Operation* op = infer_ti.getOperation();
1804   SmallVector<Type, 4> inferred;
1805   LogicalResult res = infer_ti.inferReturnTypes(
1806       op->getContext(), op->getLoc(), op->getOperands(),
1807       op->getAttrDictionary(), op->getRegions(), inferred);
1808   if (failed(res)) {
1809     op->emitOpError("failed to refine type as inference failed");
1810     return false;
1811   }
1812 
1813   if (inferred == op->getResultTypes()) return false;
1814 
1815   // Map each of the results of the call to the returned type of the
1816   // function.
1817   bool changed = false;
1818   for (auto result : zip(op->getResults(), inferred)) {
1819     if (std::get<0>(result).getType() == std::get<1>(result)) continue;
1820 
1821     if (!UpdateTypeAndInsertIncompatibleUseCasts(std::get<1>(result),
1822                                                  std::get<0>(result)))
1823       continue;
1824     changed = true;
1825   }
1826   return changed;
1827 }
1828 
ComputeOutputAsShape(OpResult result,InferenceContext * ic)1829 ShapeHandle ShapeInference::ComputeOutputAsShape(OpResult result,
1830                                                  InferenceContext* ic) {
1831   LLVM_DEBUG(result.print(llvm::dbgs() << "\nEvaluate partially "));
1832   auto rt = result.getType().dyn_cast<RankedTensorType>();
1833   if (!rt || !rt.hasStaticShape() || rt.getRank() != 1) return {};
1834   int dim_size = rt.getDimSize(0);
1835 
1836   // Worklist to direct partial evaluation.
1837   SmallVector<ValuePort, 4> worklist;
1838 
1839   // Simple evaluator that attempts to partially evaluate the input value even
1840   // if unable to evaluate the complete output. Below follows a simple stack
1841   // based evaluation where it queries what operands/part of operands need to
1842   // be evaluated and attempting to partially evaluate those operands. It does
1843   // so by pushing the operands that need to be required on to the worklist
1844   // before enqueuing the operation requiering those values.
1845   std::vector<DimensionHandle> dims(dim_size, ic->UnknownDim());
1846   for (unsigned int i = 0, e = dims.size(); i != e; ++i) {
1847     LLVM_DEBUG(llvm::dbgs() << "\nConsidering output dim " << i << "\n");
1848 
1849     worklist.push_back(
1850         ValuePort{result.getOwner(), {result.getResultNumber(), i}});
1851     while (!worklist.empty()) {
1852       auto front = worklist.pop_back_val();
1853       LLVM_DEBUG(front.print(llvm::dbgs() << "\nWorklist front "));
1854 
1855       SmallVector<ValuePort, 4> inputs;
1856       auto res = ComputeInputsRequiredForOutput(front, &inputs);
1857       if (failed(res)) {
1858         // Abort if unable to find which required inputs need to be computed.
1859         worklist.clear();
1860         break;
1861       }
1862 
1863       if (!inputs.empty()) {
1864         // Enqueue required computation followed by its required operands in
1865         // stack.
1866         worklist.push_back(std::move(front));
1867         for (auto& it : inputs) worklist.push_back(std::move(it));
1868         continue;
1869       }
1870 
1871       auto ret = ComputeOutputComponent(front);
1872       if (!ret) continue;
1873 
1874       LLVM_DEBUG(ret.print(llvm::dbgs() << "\ncomputed result = "));
1875 
1876       // If worklist is empty, then this is the root query op.
1877       if (worklist.empty()) {
1878         LLVM_DEBUG(llvm::dbgs() << "[root node]\n");
1879         if (auto dea = ret.dyn_cast<DenseIntElementsAttr>()) {
1880           if (dea.getNumElements() != 1) {
1881             LLVM_DEBUG(llvm::dbgs() << "Unexpected number of elements\n");
1882             return {};
1883           }
1884           int64_t val = (*dea.getValues<APInt>().begin()).getSExtValue();
1885           dims[i] = ic->MakeDim(val);
1886         }
1887       }
1888     }
1889   }
1890   return ic->MakeShape(dims);
1891 }
1892 
RefineTypeForPassThroughOperands(Operation * op,OperandRange operands,ResultRange results)1893 bool ShapeInference::RefineTypeForPassThroughOperands(Operation* op,
1894                                                       OperandRange operands,
1895                                                       ResultRange results) {
1896   bool changed = false;
1897   for (auto entry : llvm::zip(operands, results)) {
1898     Type operand_type = std::get<0>(entry).getType();
1899     Value result = std::get<1>(entry);
1900     TensorType result_type = result.getType().cast<TensorType>();
1901     Type inferred_type = TypeMeet(result_type, operand_type);
1902     if (result_type == inferred_type) continue;
1903 
1904     if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1905       continue;
1906     changed = true;
1907   }
1908   return changed;
1909 }
1910 
RefineShapeForPassThroughOps(Operation * op)1911 bool ShapeInference::RefineShapeForPassThroughOps(Operation* op) {
1912   DCOMMENT_OP(op, "Pass through op");
1913   bool changed = false;
1914   for (auto entry : llvm::zip(op->getOperands(), op->getResults())) {
1915     Value operand = std::get<0>(entry);
1916     Value result = std::get<1>(entry);
1917     Type inferred_type = TypeMeet(result.getType(), operand.getType());
1918     if (result.getType() == inferred_type) continue;
1919     if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, result))
1920       continue;
1921     changed = true;
1922   }
1923   return changed;
1924 }
1925 
InferShapeForNonTFDialectOperation(Operation * op)1926 bool ShapeInference::InferShapeForNonTFDialectOperation(Operation* op) {
1927   if (auto graph_op = dyn_cast<tf_executor::GraphOp>(op)) {
1928     return RefineTypeForPassThroughOperands(
1929         graph_op.GetFetch(), graph_op.GetFetch().fetches(), op->getResults());
1930   }
1931   if (auto island_op = dyn_cast<tf_executor::IslandOp>(op)) {
1932     return RefineTypeForPassThroughOperands(
1933         island_op.GetYield(), island_op.GetYield().fetches(), op->getResults());
1934   }
1935   if (auto iter_sink = dyn_cast<tf_executor::NextIterationSinkOp>(op)) {
1936     auto iter_source = cast<tf_executor::NextIterationSourceOp>(
1937         iter_sink.token().getDefiningOp());
1938     return RefineTypeForPassThroughOperands(
1939         op, iter_sink.getOperands().drop_front().take_front(),
1940         iter_source.getResults());
1941   }
1942   if (auto launch_op = dyn_cast<tf_device::LaunchOp>(op)) {
1943     auto terminator = launch_op.GetBody().getTerminator();
1944     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1945                                             op->getResults());
1946   }
1947   if (auto cluster_op = dyn_cast<tf_device::ClusterOp>(op)) {
1948     auto terminator = cluster_op.GetBody().getTerminator();
1949     return RefineTypeForPassThroughOperands(op, terminator->getOperands(),
1950                                             op->getResults());
1951   }
1952   if (op->hasTrait<OpTrait::SameOperandsAndResultShape>())
1953     return RefineShapeForPassThroughOps(op);
1954   if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
1955   if (isa<tensor::CastOp>(op)) return InferShapeForCast(op);
1956   return false;
1957 }
1958 
1959 // Finds element type to be used for result from operand, with special handling
1960 // for handle types.
GetElementTypeFromOperand(TensorType operand_type,TensorType result_type)1961 Type GetElementTypeFromOperand(TensorType operand_type,
1962                                TensorType result_type) {
1963   auto operand_handle_type =
1964       operand_type.getElementType().dyn_cast<TensorFlowTypeWithSubtype>();
1965   if (!operand_handle_type) return result_type.getElementType();
1966   auto result_handle_type =
1967       result_type.getElementType().cast<TensorFlowTypeWithSubtype>();
1968   if (operand_handle_type.GetSubtypes().empty() ||
1969       !result_handle_type.GetSubtypes().empty())
1970     return result_type.getElementType();
1971   return operand_handle_type;
1972 }
1973 
1974 // Checks if one tensor type can refine another type for tf.While/
1975 // tf.WhileRegion. If rank differs or static dimensions can be lost, the other
1976 // type cannot be used for refinement.
CanWhileTypeBeRefinedWith(TensorType current_type,TensorType potential_refined_type)1977 bool CanWhileTypeBeRefinedWith(TensorType current_type,
1978                                TensorType potential_refined_type) {
1979   if (!current_type.hasRank()) return true;
1980   if (!potential_refined_type.hasRank()) return false;
1981   if (current_type.getRank() != potential_refined_type.getRank()) return false;
1982   for (auto dim :
1983        llvm::zip(current_type.getShape(), potential_refined_type.getShape())) {
1984     int64_t current_dim = std::get<0>(dim);
1985     int64_t potential_refined_dim = std::get<1>(dim);
1986     if (current_dim != potential_refined_dim &&
1987         current_dim != ShapedType::kDynamicSize)
1988       return false;
1989   }
1990   return true;
1991 }
1992 
1993 template <typename WhileOpTy>
InferShapeForWhile(WhileOpTy op,TypeRange body_result_types)1994 bool ShapeInference::InferShapeForWhile(WhileOpTy op,
1995                                         TypeRange body_result_types) {
1996   if (!op.shape_invariant())
1997     return RefineTypeForPassThroughOperands(op, op.input(), op.output());
1998 
1999   bool changed = false;
2000   for (auto entry :
2001        zip(op.input().getTypes(), op.output(), body_result_types)) {
2002     Value result = std::get<1>(entry);
2003     TensorType body_result_type =
2004         std::get<2>(entry).template cast<TensorType>();
2005     auto result_type = result.getType().cast<TensorType>();
2006 
2007     Type potential_refined_type;
2008     if (CanWhileTypeBeRefinedWith(result_type, body_result_type)) {
2009       Type element_type =
2010           GetElementTypeFromOperand(body_result_type, result_type);
2011       potential_refined_type = CreateTensorType(
2012           body_result_type.hasRank() ? body_result_type.getShape()
2013                                      : llvm::Optional<ArrayRef<int64_t>>(),
2014           element_type);
2015     } else {
2016       TensorType operand_type = std::get<0>(entry).template cast<TensorType>();
2017       Type element_type = GetElementTypeFromOperand(operand_type, result_type);
2018       potential_refined_type = CreateTensorType(
2019           result_type.hasRank() ? result_type.getShape()
2020                                 : llvm::Optional<ArrayRef<int64_t>>(),
2021           element_type);
2022     }
2023     changed |= RefineResultType(op, result, potential_refined_type);
2024   }
2025   return changed;
2026 }
2027 
InferShapeForSingleOperation(Operation * op,int64_t max_iterations)2028 bool ShapeInference::InferShapeForSingleOperation(Operation* op,
2029                                                   int64_t max_iterations) {
2030   LLVM_DEBUG(op->print(llvm::dbgs() << "InferShapeForSingleOperation for ");
2031              llvm::dbgs() << "\n");
2032   assert(tf_dialect_ == op->getDialect());
2033   // The shape function of these ops sometimes does not propagate subtypes
2034   // (handle shapes) for resource and variant types. We use a simple passthrough
2035   // to make sure they are preserved in the output.
2036   if (isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp, TF::ZerosLikeOp>(
2037           op)) {
2038     return RefineTypeForPassThroughOperands(op, op->getOperands(),
2039                                             op->getResults());
2040   }
2041 
2042   // The shape inference function for `ReduceDatasetOp` should always be
2043   // executed regardless of whether the result type can be refined.
2044   if (auto reduce_dataset_op = dyn_cast<ReduceDatasetOp>(op)) {
2045     // TODO(jpienaar): The output type of these ops need to be refined.
2046     return InferShapeForReduceDataset(reduce_dataset_op, max_iterations);
2047   }
2048 
2049   // If no result for this op needs shape inference, we have a fast-path return.
2050   // But if the type is a resource/variant, we do not skip it because we might
2051   // not have the handle shapes.
2052   if (none_of(op->getResultTypes(), CanBeRefined)) {
2053     LLVM_DEBUG(llvm::dbgs() << "Skipping inference for statically shaped op '"
2054                             << op->getName() << "'.\n");
2055     return false;
2056   }
2057 
2058   if (isa<TF::RestoreOp, TF::RestoreV2Op>(op)) return InferShapeForRestore(op);
2059 
2060   // Handle call operations by looking up callee and inferring return shape as
2061   // needed.
2062   if (auto call = dyn_cast<CallOpInterface>(op)) return InferShapeForCall(call);
2063 
2064   // tf.Cast is only inferred if it has at least one user in the TF dialect or
2065   // feeding into the function return. This is necessary to avoid inserting
2066   // casts which cannot be refined.
2067   if (isa<CastOp>(op)) return InferShapeForCast(op);
2068 
2069   // Handle IfOp here by inferring the shape from the else/then function
2070   // results. Since `output_shapes` is a derived attribute, avoid going down the
2071   // TF InferenceContext path as IfOp shape inference is implemented as just
2072   // a lookup of the output_shapes attribute.
2073   if (auto if_op = dyn_cast<IfOp>(op)) return InferShapeForIf(if_op);
2074 
2075   // Handle IfRegion operations by inferring return shape from the then and else
2076   // branches.
2077   if (auto if_region = dyn_cast<IfRegionOp>(op))
2078     return InferShapeForIfRegion(if_region);
2079 
2080   if (auto while_op = dyn_cast<WhileOp>(op))
2081     return InferShapeForWhile(
2082         while_op, while_op.body_function().getFunctionType().getResults());
2083 
2084   if (auto while_region = dyn_cast<WhileRegionOp>(op))
2085     return InferShapeForWhile(
2086         while_region,
2087         while_region.body().front().getTerminator()->getOperandTypes());
2088 
2089   if (auto host_compute_op = dyn_cast<_XlaHostComputeMlirOp>(op)) {
2090     return InferShapeForXlaHostComputeMlir(host_compute_op);
2091   }
2092 
2093   // TODO(jpienaar): Extract function input arg constraint interface.
2094   // TODO(jpienaar): Unify the shape propagation to functions using interface.
2095   if (auto map_dataset_op = dyn_cast<MapDatasetOp>(op)) {
2096     // TODO(jpienaar): The output type of these ops need to be refined.
2097     return InferShapeForMapDataset(map_dataset_op, max_iterations);
2098   }
2099 
2100   if (auto takewhile_dataset_op = dyn_cast<TakeWhileDatasetOp>(op)) {
2101     // TODO(jpienaar): The output type of these ops need to be refined.
2102     return InferShapeForTakeWhileDataset(takewhile_dataset_op, max_iterations);
2103   }
2104 
2105   // Handle TensorList init operations by inferring shape from TensorList write
2106   // operations. If we are unable to refine element shape here, proceed to use
2107   // the InferenceContext below to get more precise shapes.
2108   if (IsTensorListInitOp(op) && InferShapeForTensorListInitOps(op)) return true;
2109 
2110   if (auto var_handle_op = dyn_cast<VarHandleOp>(op)) {
2111     return InferShapeForVarHandleOp(var_handle_op);
2112   }
2113 
2114   if (auto xla_reduce_window_op = dyn_cast<XlaReduceWindowOp>(op)) {
2115     return InferShapeForXlaReduceWindowOp(xla_reduce_window_op);
2116   }
2117 
2118   if (auto xla_select_and_scatter_op = dyn_cast<XlaSelectAndScatterOp>(op)) {
2119     return InferShapeForXlaSelectAndScatterOp(xla_select_and_scatter_op);
2120   }
2121 
2122   if (auto xla_conv_v2_op = dyn_cast<XlaConvV2Op>(op)) {
2123     return InferShapeForXlaConvV2Op(xla_conv_v2_op);
2124   }
2125 
2126   // Return operand as a constant attribute.
2127   auto operand_as_constant_fn = [&](Value operand) {
2128     ValuePort vp(operand);
2129     Attribute attr = ComputeOutputComponent(vp);
2130     if (!attr && matchPattern(operand, m_Constant(&attr)))
2131       RecordValue(vp, attr);
2132     return attr;
2133   };
2134 
2135   // Return op result as a shape.
2136   auto op_result_as_shape_fn = [&](InferenceContext& context,
2137                                    OpResult op_result) {
2138     return ComputeOutputAsShape(op_result, &context);
2139   };
2140 
2141   // Return result element type at `index`.
2142   auto result_element_type_fn = [&](int index) {
2143     return op->getResult(index).getType().cast<TensorType>().getElementType();
2144   };
2145 
2146   llvm::SmallVector<ShapedTypeComponents, 4> inferred_return_shapes;
2147   if (failed(InferReturnTypeComponentsForTFOp(
2148           /*location=*/None, op, graph_version_, operand_as_constant_fn,
2149           op_result_as_shape_fn, result_element_type_fn,
2150           inferred_return_shapes)))
2151     return false;
2152 
2153   // Update the shape for each of the operation result if the InferenceContext
2154   // has more precise shapes recorded.
2155   bool changed = false;
2156   for (auto result : llvm::zip(op->getResults(), inferred_return_shapes)) {
2157     Value op_result = std::get<0>(result);
2158     if (!CanBeRefined(op_result.getType())) continue;
2159 
2160     ShapedTypeComponents inferred = std::get<1>(result);
2161     TensorType inferred_type;
2162     if (inferred.hasRank())
2163       inferred_type =
2164           RankedTensorType::get(inferred.getDims(), inferred.getElementType());
2165     else
2166       inferred_type = UnrankedTensorType::get(inferred.getElementType());
2167 
2168     inferred_type =
2169         TypeMeet(op_result.getType(), inferred_type).cast<TensorType>();
2170     if (op_result.getType() == inferred_type) continue;
2171     if (!UpdateTypeAndInsertIncompatibleUseCasts(inferred_type, op_result))
2172       continue;
2173     changed = true;
2174   }
2175 
2176   if (changed) DCOMMENT_OP(op, "Modified after shape inference:");
2177   return changed;
2178 }
2179 
PropagateShapeToFunctions(ModuleOp module,TypeRange input_types,ArrayRef<func::FuncOp> functions,int64_t max_iterations)2180 FailureOr<bool> ShapeInference::PropagateShapeToFunctions(
2181     ModuleOp module, TypeRange input_types, ArrayRef<func::FuncOp> functions,
2182     int64_t max_iterations) {
2183   bool any_failure = false;
2184   bool any_nonconvergence = false;
2185   // If shape propagation fails for one function, return failure, but do not
2186   // early exit and attempt to propagate shapes for all provided functions to
2187   // have a best-effort propagation.
2188   for (func::FuncOp func : functions) {
2189     DCOMMENT("Propating shape to " << func.getName());
2190     ArrayRef<Operation*> callers = GetCallers(func);
2191     if (!llvm::hasSingleElement(callers) &&
2192         !llvm::all_of(callers.drop_front(), [&](Operation* caller) {
2193           /// TODO(aminim): this is overly conservative as some operations
2194           /// (like TPUPartitionedCallOp) may have extra operands that aren't
2195           /// propagated to the callee.
2196           return isa<CallOpInterface>(caller) &&
2197                  std::equal(caller->getOperandTypes().begin(),
2198                             caller->getOperandTypes().end(),
2199                             callers.front()->getOperandTypes().begin());
2200         })) {
2201       if (llvm::any_of(callers, [](Operation* op) {
2202             return isa<IfOp, WhileOp, CaseOp>(op);
2203           }))
2204         func.emitWarning(formatv(
2205             "expected control flow function @{0} to have exactly 1 use, "
2206             "found {1}.",
2207             func.getName(), callers.size()));
2208 
2209       continue;
2210     }
2211     FunctionType func_type = func.getFunctionType();
2212     func.setType(FunctionType::get(func.getContext(), input_types,
2213                                    func_type.getResults()));
2214 
2215     FailureOr<bool> failure_or_converged =
2216         PropagateShapeToRegions(input_types, {&func.getBody()}, max_iterations);
2217     if (failed(failure_or_converged)) {
2218       any_failure = true;
2219       continue;
2220     }
2221     any_nonconvergence = any_nonconvergence || !failure_or_converged.getValue();
2222     if (failed(InferShapeForFunctionReturnType(func))) any_failure = true;
2223   }
2224   if (any_failure) return failure();
2225   return any_nonconvergence;
2226 }
2227 
PropagateShapeToRegions(TypeRange input_types,ArrayRef<Region * > regions,int64_t max_iterations)2228 FailureOr<bool> ShapeInference::PropagateShapeToRegions(
2229     TypeRange input_types, ArrayRef<Region*> regions, int64_t max_iterations) {
2230   DCOMMENT("\tPropagating shapes to regions");
2231   bool any_failure = false;
2232   bool any_nonconvergence = false;
2233   // If shape propagation fails for one region, return failure, but do not
2234   // early exit and attempt to propagate shapes for all provided regions to
2235   // have a best-effort propagation.
2236   for (auto region : regions) {
2237     // Refine region arguments.
2238     Block& entry = region->front();
2239     assert(llvm::size(input_types) == entry.getNumArguments());
2240     for (auto it : llvm::zip(entry.getArguments(), input_types)) {
2241       BlockArgument arg = std::get<0>(it);
2242       Type type = std::get<1>(it);
2243       arg.setType(type);
2244     }
2245 
2246     // Propagate shapes into the region.
2247     FailureOr<bool> failure_or_converged =
2248         InferShapeUntilFixPoint(region, max_iterations);
2249     if (failed(failure_or_converged))
2250       any_failure = true;
2251     else if (!failure_or_converged.getValue())
2252       any_nonconvergence = true;
2253   }
2254   if (any_failure) return failure();
2255   return any_nonconvergence;
2256 }
2257 
PropagateConstantToCallee(CallOpInterface call_op,func::FuncOp func,ModuleOp module)2258 void ShapeInference::PropagateConstantToCallee(CallOpInterface call_op,
2259                                                func::FuncOp func,
2260                                                ModuleOp module) {
2261   auto callers = GetCallers(func);
2262   if (!llvm::hasSingleElement(callers)) return;
2263 
2264   OpBuilder builder(&func.front().front());
2265   Operation* op = call_op.getOperation();
2266   // If this is the only caller, and an operand is a constant, propagate
2267   // the constant value inside the function.
2268   for (auto arg : func.getArguments()) {
2269     auto operand = op->getOperand(arg.getArgNumber());
2270     if (propagate_caller_callee_constants_) {
2271       if (isa_and_nonnull<TF::ConstOp>(operand.getDefiningOp())) {
2272         arg.replaceAllUsesWith(
2273             builder.clone(*operand.getDefiningOp())->getResult(0));
2274       }
2275       continue;
2276     }
2277 
2278     auto known_constant = ComputeOutputComponent(ValuePort(operand));
2279     if (!known_constant) continue;
2280     LLVM_DEBUG(call_op.print(llvm::dbgs() << "Propagate to calee: ");
2281                known_constant.print(llvm::dbgs() << " constant ");
2282                llvm::dbgs() << "\n");
2283     RecordValue(ValuePort(arg), known_constant);
2284   }
2285 }
2286 
PropagateConstantFromCallee(CallOpInterface call_op,func::FuncOp func,ModuleOp module)2287 void ShapeInference::PropagateConstantFromCallee(CallOpInterface call_op,
2288                                                  func::FuncOp func,
2289                                                  ModuleOp module) {
2290   // If the return value is a constant, use the constant as the value of
2291   // the call return.
2292   Operation* op = call_op.getOperation();
2293   OpBuilder builder(op);
2294   builder.setInsertionPointAfter(op);
2295   for (auto retval :
2296        llvm::enumerate(func.front().getTerminator()->getOperands())) {
2297     if (propagate_caller_callee_constants_) {
2298       auto retval_op = retval.value().getDefiningOp();
2299       if (isa_and_nonnull<TF::ConstOp>(retval_op)) {
2300         op->getResult(retval.index())
2301             .replaceAllUsesWith(builder.clone(*retval_op)->getResult(0));
2302       }
2303       continue;
2304     }
2305 
2306     ValuePort vp(retval.value());
2307     if (auto known_constant = ComputeOutputComponent(vp)) {
2308       LLVM_DEBUG(known_constant.print(llvm::dbgs() << "Propagate constant ");
2309                  call_op.print(llvm::dbgs() << "from "); llvm::dbgs() << "\n");
2310       RecordValue(ValuePort(op->getResult(retval.index())), known_constant);
2311     }
2312   }
2313 }
2314 
RankedAndSameRank(TensorType lhs,TensorType rhs)2315 bool RankedAndSameRank(TensorType lhs, TensorType rhs) {
2316   return lhs.hasRank() && rhs.hasRank() && lhs.getRank() == rhs.getRank();
2317 }
2318 
2319 // Creates a compatible RankedTensorType where mismatched dimensions are
2320 // replaced with dynamic sizes.
GetCompatibleRankedTensorType(RankedTensorType lhs,RankedTensorType rhs)2321 RankedTensorType GetCompatibleRankedTensorType(RankedTensorType lhs,
2322                                                RankedTensorType rhs) {
2323   assert(lhs.getRank() == rhs.getRank());
2324   llvm::SmallVector<int64_t, 4> dims;
2325   dims.reserve(lhs.getRank());
2326   for (auto dim : llvm::zip(lhs.getShape(), rhs.getShape())) {
2327     int64_t lhs_dim = std::get<0>(dim);
2328     if (lhs_dim == std::get<1>(dim)) {
2329       dims.push_back(lhs_dim);
2330     } else {
2331       dims.push_back(ShapedType::kDynamicSize);
2332     }
2333   }
2334   return RankedTensorType::get(dims, GetElementTypeFromOperand(lhs, rhs));
2335 }
2336 
2337 // Finds compatible types to propagate into functions/regions of a shape
2338 // invariant tf.While/tf.WhileRegion. If operand and result types are the same,
2339 // that type is returned. If operand and result types are of the same rank, a
2340 // compatible type with matching dimensions is used. Otherwise functions/regions
2341 // arguments are returned but with the handle type from the operand type.
GetWhileCompatibleTypes(TypeRange operand_types,TypeRange result_types,TypeRange region_argument_types)2342 llvm::SmallVector<Type, 4> GetWhileCompatibleTypes(
2343     TypeRange operand_types, TypeRange result_types,
2344     TypeRange region_argument_types) {
2345   llvm::SmallVector<Type, 4> types;
2346   types.reserve(operand_types.size());
2347   for (auto entry :
2348        llvm::zip(operand_types, result_types, region_argument_types)) {
2349     auto operand_type = std::get<0>(entry).cast<TensorType>();
2350     auto result_type = std::get<1>(entry).cast<TensorType>();
2351     if (operand_type == result_type) {
2352       types.push_back(operand_type);
2353     } else if (RankedAndSameRank(operand_type, result_type)) {
2354       auto potential_refined_type =
2355           GetCompatibleRankedTensorType(operand_type.cast<RankedTensorType>(),
2356                                         result_type.cast<RankedTensorType>());
2357       types.push_back(potential_refined_type);
2358     } else {
2359       auto region_argument_type = std::get<2>(entry).cast<TensorType>();
2360       Type element_type = GetElementTypeFromOperand(
2361           operand_type.cast<TensorType>(), region_argument_type);
2362       Type potential_refined_type = CreateTensorType(
2363           region_argument_type.hasRank() ? region_argument_type.getShape()
2364                                          : llvm::Optional<ArrayRef<int64_t>>(),
2365           element_type);
2366       types.push_back(potential_refined_type);
2367     }
2368   }
2369   return types;
2370 }
2371 
PropagateShapeIntoAttachedFunctions(Operation * op,int64_t max_iterations)2372 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedFunctions(
2373     Operation* op, int64_t max_iterations) {
2374   ModuleOp module = op->getParentOfType<ModuleOp>();
2375   if (auto if_op = dyn_cast<TF::IfOp>(op)) {
2376     DCOMMENT("Propagating shapes into If");
2377     return PropagateShapeToFunctions(
2378         module, if_op.input().getTypes(),
2379         {if_op.ResolveThenFunction(&symbol_table_),
2380          if_op.ResolveElseFunction(&symbol_table_)},
2381         max_iterations);
2382   } else if (auto case_op = dyn_cast<TF::CaseOp>(op)) {
2383     SmallVector<func::FuncOp, 4> branches;
2384     case_op.get_branch_functions(branches);
2385     return PropagateShapeToFunctions(module, case_op.input().getTypes(),
2386                                      branches, max_iterations);
2387   } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) {
2388     // If `shape_invariant` is set, operand shapes cannot be simply propagated
2389     // to result shapes as the op may have different intermediate shapes (such
2390     // While ops can have different result shapes from operand shapes).
2391     // Compatible shapes must be determined before propagating them.
2392     if (while_op.shape_invariant()) {
2393       auto compatible_types = GetWhileCompatibleTypes(
2394           while_op.input().getTypes(), while_op.output().getTypes(),
2395           while_op.ResolveBodyFunction(&symbol_table_)
2396               .getFunctionType()
2397               .getInputs());
2398       return PropagateShapeToFunctions(
2399           module, compatible_types,
2400           {while_op.ResolveCondFunction(&symbol_table_),
2401            while_op.ResolveBodyFunction(&symbol_table_)},
2402           max_iterations);
2403     }
2404     return PropagateShapeToFunctions(
2405         module, while_op.input().getTypes(),
2406         {while_op.ResolveCondFunction(&symbol_table_),
2407          while_op.ResolveBodyFunction(&symbol_table_)},
2408         max_iterations);
2409   } else if (auto call_op = dyn_cast<CallOpInterface>(op)) {
2410     if (auto func =
2411             dyn_cast<func::FuncOp>(call_op.resolveCallable(&symbol_table_))) {
2412       PropagateConstantToCallee(call_op, func, module);
2413       FailureOr<bool> failure_or_converged = PropagateShapeToFunctions(
2414           module, call_op.getArgOperands().getTypes(), {func}, max_iterations);
2415       if (failed(failure_or_converged)) return failure();
2416       PropagateConstantFromCallee(call_op, func, module);
2417       return failure_or_converged;
2418     }
2419   } else if (isa<TF::XlaReduceWindowOp>(op) ||
2420              isa<TF::XlaSelectAndScatterOp>(op) ||
2421              isa<TF::XlaVariadicReduceV2Op>(op) ||
2422              isa<TF::XlaVariadicSortOp>(op)) {
2423     auto propagate_shape_to = [&](mlir::SymbolRefAttr func_sym) {
2424       auto func = llvm::cast<mlir::func::FuncOp>(
2425           mlir::SymbolTable::lookupSymbolIn(module, func_sym));
2426       mlir::SmallVector<mlir::Type, 2> types;
2427       for (auto type : func.getFunctionType().getInputs()) {
2428         types.push_back(RankedTensorType::get({}, getElementTypeOrSelf(type)));
2429       }
2430       return PropagateShapeToFunctions(module, types, {func}, max_iterations);
2431     };
2432 
2433     if (auto xla_reduce_window_op = dyn_cast<TF::XlaReduceWindowOp>(op)) {
2434       return propagate_shape_to(xla_reduce_window_op.computation());
2435     }
2436     if (auto xla_select_and_scatter_op =
2437             dyn_cast<TF::XlaSelectAndScatterOp>(op)) {
2438       return propagate_shape_to(xla_select_and_scatter_op.select())
2439                  .getValue() &&
2440              propagate_shape_to(xla_select_and_scatter_op.scatter()).getValue();
2441     } else if (auto xla_variadic_reduce_v2_op =
2442                    dyn_cast<TF::XlaVariadicReduceV2Op>(op)) {
2443       return propagate_shape_to(xla_variadic_reduce_v2_op.reducer());
2444     } else if (auto xla_variadic_sort_op =
2445                    dyn_cast<TF::XlaVariadicSortOp>(op)) {
2446       return propagate_shape_to(xla_variadic_sort_op.comparator());
2447     }
2448   }
2449 
2450   // TODO(ycao): Implement support for Call op, including function reuse.
2451 
2452   return true;
2453 }
2454 
PropagateShapeIntoAttachedRegions(Operation * op,int64_t max_iterations)2455 FailureOr<bool> ShapeInference::PropagateShapeIntoAttachedRegions(
2456     Operation* op, int64_t max_iterations) {
2457   if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) {
2458     // If `shape_invariant` is set, operand shapes cannot be simply propagated
2459     // to result shapes as the op may have different intermediate shapes (such
2460     // While ops can have different result shapes from operand shapes).
2461     // Compatible shapes must be determined before propagating them.
2462     if (while_op.shape_invariant()) {
2463       auto compatible_types = GetWhileCompatibleTypes(
2464           while_op.input().getTypes(), while_op.output().getTypes(),
2465           while_op.body().getArgumentTypes());
2466       return PropagateShapeToRegions(compatible_types,
2467                                      {&while_op.cond(), &while_op.body()},
2468                                      max_iterations);
2469     }
2470     return PropagateShapeToRegions(while_op.input().getTypes(),
2471                                    {&while_op.cond(), &while_op.body()},
2472                                    max_iterations);
2473   }
2474   return true;
2475 }
2476 
TryToFold(Operation * op)2477 LogicalResult ShapeInference::TryToFold(Operation* op) {
2478   LLVM_DEBUG(op->print(llvm::dbgs() << "TryToFold "); llvm::dbgs() << "\n");
2479   // If any output result is known, then the op probably has been computed
2480   // before.
2481   if (op->getNumResults() > 0 && results_[ValuePort(op->getResult(0))])
2482     return success();
2483 
2484   SmallVector<Attribute, 8> constant_operands(op->getNumOperands());
2485   SmallVector<OpFoldResult, 8> fold_results;
2486 
2487   // Check to see if any operands to the operation is constant and whether
2488   // the operation knows how to constant fold itself.
2489   bool some_unknown = false;
2490   for (int i = 0, e = op->getNumOperands(); i != e; ++i) {
2491     if (!(constant_operands[i] =
2492               ComputeOutputComponent(ValuePort(op->getOperand(i)))))
2493       some_unknown = true;
2494   }
2495 
2496   // Attempt to constant fold the operation.
2497   auto abstract_op = op->getRegisteredInfo();
2498   LogicalResult folded = failure();
2499   if (abstract_op) {
2500     folded = abstract_op->foldHook(op, constant_operands, fold_results);
2501   }
2502   // Attempt dialect fallback if op's fold hook failed.
2503   if (failed(folded)) {
2504     Dialect* dialect = op->getDialect();
2505     if (!dialect) return failure();
2506     // Only attempt TF dialect fallback if there are no unknown operands.
2507     if (some_unknown && dialect == tf_dialect_) return failure();
2508     auto* interface = dialect->getRegisteredInterface<DialectFoldInterface>();
2509     if (!interface) return failure();
2510 
2511     if (failed(interface->fold(op, constant_operands, fold_results)))
2512       return failure();
2513   }
2514 
2515   for (auto result : zip(op->getResults(), fold_results)) {
2516     auto fold_result = std::get<1>(result);
2517     Attribute attr = nullptr;
2518     if ((attr = fold_result.dyn_cast<Attribute>())) {
2519       RecordValue(ValuePort(std::get<0>(result)), attr);
2520     } else {
2521       auto value = fold_result.get<Value>();
2522       if ((attr = ComputeOutputComponent(ValuePort(value)))) {
2523         DCOMMENT("\t\tValue Result mapped to " << attr);
2524         RecordValue(ValuePort(std::get<0>(result)), attr);
2525       } else {
2526         DCOMMENT("\t\tValue result unmapped, consider value type:" << value);
2527         RefineResultType(op, std::get<0>(result), value.getType());
2528       }
2529     }
2530 
2531     if (ElementsAttr eattr = attr.dyn_cast_or_null<ElementsAttr>()) {
2532       if (std::get<0>(result).getType() == eattr.getType()) continue;
2533 
2534       (void)UpdateTypeAndInsertIncompatibleUseCasts(eattr.getType(),
2535                                                     std::get<0>(result));
2536     }
2537   }
2538 
2539   return success();
2540 }
2541 
InferShapeForFunctionReturnType(func::FuncOp func)2542 LogicalResult ShapeInference::InferShapeForFunctionReturnType(
2543     func::FuncOp func) {
2544   LLVM_DEBUG(llvm::dbgs() << "Inferring return type for: " << func.getName()
2545                           << "\n");
2546 
2547   // Find any return ops.
2548   SmallVector<func::ReturnOp, 4> return_ops;
2549   for (Block& block : func) {
2550     if (auto return_op = dyn_cast<func::ReturnOp>(block.getTerminator())) {
2551       return_ops.push_back(return_op);
2552     }
2553   }
2554 
2555   // Skip functions without a return, but don't flag as failure here.
2556   if (return_ops.empty()) return success();
2557 
2558   // Right now we only handle the case of a single return op.
2559   // To handle multiple return ops, we would need to look at all their shapes
2560   // and come up with a common shape and insert appropriate casts.
2561   if (return_ops.size() != 1) return failure();
2562 
2563   // Find the return type.
2564   auto return_op = return_ops.front();
2565 
2566   // Manually fold tf.Cast that precedes the return instruction and only differs
2567   // in shape refinement level.
2568   bool changed = false;
2569   for (OpOperand& arg_op : return_op.getOperation()->getOpOperands()) {
2570     Operation* arg_defining_op = arg_op.get().getDefiningOp();
2571     if (isa_and_nonnull<CastOp, tensor::CastOp>(arg_defining_op)) {
2572       Value input = arg_defining_op->getOperand(0);
2573       Value result = arg_defining_op->getResult(0);
2574       Type meet = TypeMeet(result.getType(), input.getType());
2575       if (meet == result.getType()) continue;
2576 
2577       LLVM_DEBUG({
2578         llvm::errs() << "\tfolding & updating return type ";
2579         result.getType().print(llvm::errs());
2580         input.getType().print(llvm::errs() << " to ");
2581         llvm::errs() << "\n";
2582       });
2583 
2584       // Shape inference should not change the element type.
2585       if (HasCompatibleElementTypes(input.getType(), result.getType()) &&
2586           meet == input.getType()) {
2587         arg_op.set(input);
2588       } else {
2589         OpBuilder b(return_op.getOperation());
2590         auto new_cast_op = InsertCast(b, return_op.getLoc(), meet, input);
2591         if (!new_cast_op) return failure();
2592         arg_op.set(new_cast_op->getResult(0));
2593       }
2594       if (result.use_empty()) arg_defining_op->erase();
2595       changed = true;
2596     }
2597   }
2598 
2599   DCOMMENT("Updating function type");
2600   func.setType(FunctionType::get(func.getContext(), func.getArgumentTypes(),
2601                                  return_op.getOperandTypes()));
2602 
2603   if (changed) EnqueueCallers(func);
2604   return success();
2605 }
2606 
InferShapeUntilFixPoint(Region * region,int64_t max_iterations)2607 FailureOr<bool> ShapeInference::InferShapeUntilFixPoint(
2608     Region* region, int64_t max_iterations) {
2609   bool changed = true;
2610 
2611   // TODO(aminim): we could have a more efficient traversal by guiding the
2612   // traversal with a worklist and reconsider only the nodes for which an
2613   // operand type was inferred. This would need to be careful if working on a
2614   // region that would not be isolated.
2615   for (int iteration = 0; iteration < max_iterations && changed; ++iteration) {
2616     changed = false;
2617     LLVM_DEBUG(llvm::dbgs()
2618                << "Shape inference, iteration " << iteration << "\n");
2619     auto res = region->walk([&](Operation* op) {
2620       DCOMMENT_OP(op, "Inferring for");
2621       if (auto infer_ti = dyn_cast<InferTypeOpInterface>(op)) {
2622         DCOMMENT("\tRefinining with type op interface");
2623         changed |= RefineWithInferTypeOpInterface(infer_ti);
2624         return WalkResult::advance();
2625       }
2626 
2627       if (op->getDialect() != tf_dialect_) {
2628         DCOMMENT("\tInfer non-TF dialect");
2629         changed |= InferShapeForNonTFDialectOperation(op);
2630         return WalkResult::advance();
2631       }
2632 
2633       // Before attempting inference, just try to compute the folded
2634       // value/shape.
2635       if (succeeded(TryToFold(op)) &&
2636           // Folding can "succeed" and yet not all types be refined. In such
2637           // cases we still want to give a try at `InferShapeForSingleOperation`
2638           none_of(op->getResultTypes(), CanBeRefined))
2639         return WalkResult::advance();
2640 
2641       // Best-effort shape inference in attached functions. Do not return
2642       // failure even if it doesn't get to fixed point, but propagate "real"
2643       // failure.
2644       if (failed(PropagateShapeIntoAttachedFunctions(op, max_iterations))) {
2645         op->emitWarning() << "unable to refine shape of attached function "
2646                              "arguments and bodies";
2647         return WalkResult::interrupt();
2648       }
2649 
2650       if (failed(PropagateShapeIntoAttachedRegions(op, max_iterations))) {
2651         op->emitWarning() << "unable to refine shape of attached region "
2652                              "arguments and bodies";
2653         return WalkResult::interrupt();
2654       }
2655 
2656       changed |= InferShapeForSingleOperation(op, max_iterations);
2657       return WalkResult::advance();
2658     });
2659     if (res.wasInterrupted()) return failure();
2660   }
2661 
2662   if (changed) {
2663     region->getParentOp()->emitWarning()
2664         << "shape inference did not reach stable state after " << max_iterations
2665         << " iterations";
2666   }
2667   return !changed;
2668 }
2669 
InferShapeForFunction(ShapeInference & context,func::FuncOp func,int64_t max_iterations)2670 static FailureOr<bool> InferShapeForFunction(ShapeInference& context,
2671                                              func::FuncOp func,
2672                                              int64_t max_iterations) {
2673   FailureOr<bool> failure_or_converged =
2674       context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
2675   if (failed(failure_or_converged) || !failure_or_converged.getValue())
2676     return failure_or_converged;
2677   // TODO(b/156276510): Verify that it is always fine to refine a function's
2678   // return type, as long as we do not change the argument shapes.
2679   if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
2680   return true;
2681 }
2682 
InferShapeForFunction(func::FuncOp func,ArrayRef<ArrayRef<int64_t>> arg_shapes,int64_t graph_version,int64_t max_iterations)2683 FailureOr<bool> InferShapeForFunction(func::FuncOp func,
2684                                       ArrayRef<ArrayRef<int64_t>> arg_shapes,
2685                                       int64_t graph_version,
2686                                       int64_t max_iterations) {
2687   ShapeInference context(graph_version, func->getParentOfType<ModuleOp>(),
2688                          /*propagate_caller_callee_constants=*/true);
2689   if (arg_shapes.empty()) {
2690     return InferShapeForFunction(context, func, max_iterations);
2691   }
2692 
2693   FunctionType func_type = func.getFunctionType();
2694   bool needs_refinement = false;
2695   SmallVector<Type, 4> new_arg_types;
2696   new_arg_types.reserve(func_type.getNumInputs());
2697 
2698   // Update argument types in-place using the provided arg_shapes.
2699   for (size_t i = 0; i < func_type.getNumInputs(); ++i) {
2700     ArrayRef<int64_t> shape = arg_shapes[i];
2701     Type element_type;
2702     if (auto input_ty = func_type.getInput(i).dyn_cast<RankedTensorType>()) {
2703       if (input_ty.getRank() != shape.size()) {
2704         return failure();
2705       }
2706       element_type = input_ty.getElementType();
2707     } else {
2708       auto unranked_input_ty = func_type.getInput(i).dyn_cast<TensorType>();
2709       if (!unranked_input_ty) {
2710         return failure();
2711       }
2712       element_type = unranked_input_ty.getElementType();
2713     }
2714 
2715     auto new_arg_type = RankedTensorType::get(shape, element_type);
2716     if (new_arg_type != func_type.getInput(i)) {
2717       // If the new type is more detailed, trigger shape inference.
2718       func.getArgument(i).setType(new_arg_type);
2719       needs_refinement = true;
2720     }
2721     new_arg_types.push_back(new_arg_type);
2722   }
2723 
2724   if (!needs_refinement) return true;
2725 
2726   FailureOr<bool> failure_or_converged =
2727       context.InferShapeUntilFixPoint(&func.getBody(), max_iterations);
2728   if (failed(failure_or_converged) || !failure_or_converged.getValue())
2729     return failure_or_converged;
2730 
2731   if (failed(context.InferShapeForFunctionReturnType(func))) return failure();
2732   func.setType(FunctionType::get(func.getContext(), new_arg_types,
2733                                  func.getFunctionType().getResults()));
2734 
2735   return true;
2736 }
2737 
InferModuleShape(ModuleOp module,int64_t max_iterations)2738 FailureOr<bool> InferModuleShape(ModuleOp module, int64_t max_iterations) {
2739   auto producer_or = tensorflow::GetTfGraphProducerVersion(module);
2740   if (!producer_or.ok()) {
2741     // TODO(jpienaar): Keeping the existing behavior for now but this could
2742     // be relaxed.
2743     LLVM_DEBUG(llvm::dbgs()
2744                << "Skipping inference; " << producer_or.status().ToString());
2745     return true;
2746   }
2747   int64_t producer = producer_or.ValueOrDie();
2748   // TODO(jpienaar): Clean up propagate_NextIterationSinkOp_callee_constants if
2749   // it is no longer needed.
2750   ShapeInference context(producer, module,
2751                          /*propagate_caller_callee_constants=*/false);
2752   if (auto main = module.lookupSymbol<mlir::func::FuncOp>("main"))
2753     context.enqueue(main);
2754   for (auto func : module.getOps<func::FuncOp>()) context.enqueue(func);
2755   // Arbitrarily upper bound the maximum number of functions that get processed
2756   // just to avoid pathological cases.
2757   auto max_iteration = context.QueueSize() * 4;
2758   while (!context.EmptyQueue()) {
2759     func::FuncOp func = context.front();
2760     FailureOr<bool> failure_or_converged =
2761         InferShapeForFunction(context, func, max_iterations);
2762     if (failed(failure_or_converged) || !failure_or_converged.getValue())
2763       return failure_or_converged;
2764     context.pop_front();
2765 
2766     if ((--max_iteration) == 0) {
2767       emitWarning(UnknownLoc::get(module.getContext()))
2768           << "shape inference did not reach stable state after "
2769           << max_iteration << " iterations";
2770       return false;
2771     }
2772   }
2773   return true;
2774 }
2775 
2776 }  // namespace TF
2777 }  // namespace mlir
2778