xref: /aosp_15_r20/external/tensorflow/tensorflow/core/transforms/constant_folding/pass.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/transforms/constant_folding/pass.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <numeric>
21 #include <string>
22 #include <tuple>
23 #include <type_traits>
24 #include <utility>
25 
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/Twine.h"
31 #include "mlir/Dialect/Traits.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinAttributeInterfaces.h"  // from @llvm-project
33 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
34 #include "mlir/Support/LLVM.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
37 #include "tensorflow/core/framework/resource_mgr.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/ir/dialect.h"
40 #include "tensorflow/core/ir/importexport/convert_types.h"
41 #include "tensorflow/core/ir/utility.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/transforms/pass_detail.h"
44 #include "tensorflow/core/transforms/utils/eval_utils.h"
45 #include "tensorflow/core/transforms/utils/op_cat_helper.h"
46 #include "tensorflow/core/transforms/utils/utils.h"
47 #include "tensorflow/core/util/bcast.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49 
50 namespace mlir {
51 namespace tfg {
52 
53 template <typename T>
54 static std::enable_if_t<std::is_integral<T>::value, ElementsAttr>
CreateElementsAttrOfTypeValues(Type element_type,ArrayRef<int64_t> shape,ArrayRef<T> values)55 CreateElementsAttrOfTypeValues(Type element_type, ArrayRef<int64_t> shape,
56                                ArrayRef<T> values) {
57   auto tensor_shape = RankedTensorType::get(shape, element_type);
58   SmallVector<APInt> elements;
59   for (T v : values)
60     elements.push_back(APInt(element_type.getIntOrFloatBitWidth(), v));
61   auto const_attr = DenseElementsAttr::get(tensor_shape, elements);
62   return const_attr;
63 }
64 
65 template <typename T>
66 static std::enable_if_t<std::is_floating_point<T>::value, ElementsAttr>
CreateElementsAttrOfTypeValues(Type element_type,ArrayRef<int64_t> shape,ArrayRef<T> values)67 CreateElementsAttrOfTypeValues(Type element_type, ArrayRef<int64_t> shape,
68                                ArrayRef<T> values) {
69   auto tensor_shape = RankedTensorType::get(shape, element_type);
70   SmallVector<APFloat> elements;
71   if (element_type.getIntOrFloatBitWidth() == 32)
72     llvm::for_each(values, [&](float v) { elements.push_back(APFloat(v)); });
73   else
74     llvm::for_each(values, [&](double v) { elements.push_back(APFloat(v)); });
75   auto const_attr = DenseElementsAttr::get(tensor_shape, elements);
76   return const_attr;
77 }
78 
CreateElementsAttrOfTypeValues(Type element_type,ArrayRef<int64_t> shape,ElementsAttr value_attr)79 static ElementsAttr CreateElementsAttrOfTypeValues(Type element_type,
80                                                    ArrayRef<int64_t> shape,
81                                                    ElementsAttr value_attr) {
82   auto tensor_shape = RankedTensorType::get(shape, element_type);
83   DenseElementsAttr const_attr;
84   if (element_type.isIntOrIndex()) {
85     const_attr = DenseElementsAttr::get(
86         tensor_shape, llvm::to_vector(value_attr.getValues<APInt>()));
87   } else {
88     const_attr = DenseElementsAttr::get(
89         tensor_shape, llvm::to_vector(value_attr.getValues<APFloat>()));
90   }
91   return const_attr;
92 }
93 
ConvertShapeToAttr(ShapedType shape)94 static ElementsAttr ConvertShapeToAttr(ShapedType shape) {
95   return CreateElementsAttrOfTypeValues(
96       IntegerType::get(shape.getContext(), 32), {shape.getRank()},
97       shape.getShape());
98 }
99 
GetDataTypeFromOp(OpBuilder & builder,Operation * op)100 static Type GetDataTypeFromOp(OpBuilder &builder, Operation *op) {
101   if (auto t_attr = op->getAttrOfType<TypeAttr>("T")) {
102     return t_attr.getValue();
103   } else if (auto dtype_attr = op->getAttrOfType<TypeAttr>("dtype")) {
104     return dtype_attr.getValue();
105   } else if (op->getName().stripDialect() == "LogicalOr" ||
106              op->getName().stripDialect() == "LogicalAnd") {
107     return builder.getI1Type();
108   }
109   return *(op->result_type_begin());
110 }
111 
CreateConstantTensorOp(OpBuilder & builder,Location loc,StringRef name_prefix,Type type,ValueRange control_operands,TypedAttr tensor_value,ArrayRef<NamedAttribute> other_attrs=llvm::None)112 static FailureOr<TFOp> CreateConstantTensorOp(
113     OpBuilder &builder, Location loc, StringRef name_prefix, Type type,
114     ValueRange control_operands, TypedAttr tensor_value,
115     ArrayRef<NamedAttribute> other_attrs = llvm::None) {
116   if (type.isa<VariantType>()) return failure();
117   // TODO(chiahungduan): Reuse ConstOp Like
118   // OperationFolder::tryGetOrCreateConstant.
119   OperationState state(loc, "tfg.Const");
120   state.addTypes({type, ControlType::get(builder.getContext())});
121 
122   state.attributes = other_attrs;
123   util::EraseRegularNodeAttributes(state.attributes);
124   state.attributes.set(
125       "dtype", TypeAttr::get(
126                    tensor_value.getType().cast<ShapedType>().getElementType()));
127   state.attributes.set("value", tensor_value);
128   if (!name_prefix.empty()) {
129     state.attributes.set(
130         TFGraphDialect::getNameAttrKey(),
131         builder.getStringAttr(Twine(name_prefix, "/const_folded")));
132   }
133 
134   state.addOperands(control_operands);
135   return TFOp(builder.create(state));
136 }
137 
IsControlAnchor(TFOp op,TFGraphDialect const * const dialect)138 static bool IsControlAnchor(TFOp op, TFGraphDialect const *const dialect) {
139   return (dialect->IsIdentity(op) || dialect->IsIdentityNSingleInput(op)) &&
140          op->getResults().drop_back().use_empty();
141 }
142 
143 // We can't anchor control dependencies directly on the switch node: unlike
144 // other nodes only one of the outputs of the switch node will be generated
145 // when the switch node is executed, and we need to make sure the control
146 // dependency is only triggered when the corresponding output is triggered.
147 // We start by looking for an identity node connected to the output of the
148 // switch node, and use it to anchor the control dependency.
149 // @param builder Builder, used for creating the anchor if necessary
150 // @param value   Output of a switch operation to be replaced
151 // @param dialect TFG dialect (passed in to avoid cost of looking it up)
GetControlAnchorForSwitchResult(OpBuilder & builder,OpResult value,TFGraphDialect const * const dialect)152 static TFOp GetControlAnchorForSwitchResult(
153     OpBuilder &builder, OpResult value, TFGraphDialect const *const dialect) {
154   assert(builder.getContext()->getLoadedDialect<TFGraphDialect>() == dialect);
155   TFOp switch_op = value.getDefiningOp();
156   assert(dialect->IsSwitch(switch_op));
157   // We cannot get the control edge from the parent op. We instead create a
158   // control anchor i.e. an Identity op without non-control uses and get the
159   // edge from there.
160 
161   // Try to find an existing control anchor
162   if (auto it = llvm::find_if(
163           value.getUsers(),
164           [&](Operation *op) { return IsControlAnchor(op, dialect); });
165       it != value.getUsers().end())
166     return TFOp(*it);
167 
168   // If it doesn't exist, create a new control anchor.
169   OperationState identity_op_state(value.getLoc(), "tfg.Identity");
170   identity_op_state.addOperands(value);
171   identity_op_state.addTypes(
172       {value.getType(), ControlType::get(builder.getContext())});
173   assert(switch_op->hasAttr("T"));
174   identity_op_state.addAttribute("T", switch_op->getAttr("T"));
175   TFOp identity_op = builder.create(identity_op_state);
176   if (StringAttr device_attr = switch_op.deviceAttr())
177     identity_op.setRequestedDevice(device_attr);
178   identity_op.setName(Twine(switch_op.name(), "/ControlDependencyCtrl_") +
179                       Twine(value.cast<OpResult>().getResultNumber()));
180   return identity_op;
181 }
182 
183 // Same as LookupControlDependency, except when value originates from a switch
184 // op. In such cases, we cannot add a control dependency to the parent op since
185 // the output does not necessarily activate when the switch op activates. We
186 // add a "control anchor" in the form of an identity op instead.
GetControlDependency(OpBuilder & builder,Value value)187 static Value GetControlDependency(OpBuilder &builder, Value value) {
188   if (value.getType().isa<ControlType>()) return value;
189 
190   TFGraphDialect *dialect =
191       builder.getContext()->getLoadedDialect<TFGraphDialect>();
192   assert(dialect);
193   if (OpResult result = value.dyn_cast<OpResult>();
194       result && dialect->IsSwitch(result.getOwner())) {
195     return GetControlAnchorForSwitchResult(builder, result, dialect)
196         .controlRet();
197   } else {
198     return LookupControlDependency(value);
199   }
200 }
201 
202 // Add control operand to `op` if it doesn't exist.
AddControlOperand(Operation * op,Value control,PatternRewriter & rewriter)203 static void AddControlOperand(Operation *op, Value control,
204                               PatternRewriter &rewriter) {
205   assert(control.getType().isa<ControlType>());
206   if (llvm::is_contained(op->getOperands(), control)) return;
207   rewriter.startRootUpdate(op);
208   op->insertOperands(op->getNumOperands(), control);
209   rewriter.finalizeRootUpdate(op);
210 }
211 
ReplaceOpWithConstantTensor(OpBuilder & builder,TFOp op,ElementsAttr value,ArrayRef<StringRef> exclude_attrs=llvm::None)212 static FailureOr<TFOp> ReplaceOpWithConstantTensor(
213     OpBuilder &builder, TFOp op, ElementsAttr value,
214     ArrayRef<StringRef> exclude_attrs = llvm::None) {
215   // New const op has the control dependency with op's non-control operands.
216   SmallVector<Value> operands_controls;
217   llvm::append_range(operands_controls,
218                      OperandControlRetRange(op.getNonControlOperands()));
219 
220   NamedAttrList attr_list;
221   for (NamedAttribute attr : op->getAttrs()) {
222     if (llvm::find_if(exclude_attrs,
223                       [&](StringRef name) { return name == attr.getName(); }))
224       continue;
225     attr_list.append(attr);
226   }
227   FailureOr<TFOp> const_op = CreateConstantTensorOp(
228       builder, op->getLoc(), /*name_prefix=*/"", value.getType(),
229       operands_controls, value, attr_list);
230   (*const_op).setName(op.nameAttr());
231   if (!op.device().empty()) (*const_op).setRequestedDevice(op.deviceAttr());
232   return *const_op;
233 }
234 
ReplaceOpWithIdentity(OpBuilder & builder,TFOp owner,unsigned idx)235 static FailureOr<TFOp> ReplaceOpWithIdentity(OpBuilder &builder, TFOp owner,
236                                              unsigned idx) {
237   OperationState state(owner->getLoc(), "tfg.Identity");
238   state.addTypes({owner->getOperand(idx).getType(),
239                   ControlType::get(builder.getContext())});
240   state.addAttribute(
241       "T", TypeAttr::get(GetDataTypeFromOp(builder, owner.getOperation())));
242 
243   Value kept_value = owner->getOperand(idx);
244   state.addOperands(kept_value);
245   auto [non_control_operands, control_operands] = owner.splitOperands();
246   for (Value value : non_control_operands) {
247     if (value != kept_value)
248       state.addOperands(GetControlDependency(builder, value));
249   }
250   state.addOperands(control_operands);
251 
252   Operation *identity_op = builder.create(state);
253   TFOp(identity_op).setName(owner.nameAttr());
254   if (!owner.device().empty())
255     TFOp(identity_op).setRequestedDevice(owner.deviceAttr());
256   return TFOp(identity_op);
257 }
258 
ReplaceOperationWithConstant(OpBuilder & builder,Operation * op,double constant_value)259 static FailureOr<TFOp> ReplaceOperationWithConstant(OpBuilder &builder,
260                                                     Operation *op,
261                                                     double constant_value) {
262   auto res = (*op->result_type_begin()).cast<ShapedType>();
263   Type dtype = GetDataTypeFromOp(builder, op);
264   Attribute value_attr;
265   if (dtype.isIntOrIndex())
266     value_attr = builder.getIntegerAttr(dtype, constant_value);
267   else
268     value_attr = builder.getFloatAttr(dtype, constant_value);
269 
270   auto const_attr = SplatElementsAttr::get(
271       RankedTensorType::get(res.getShape(), dtype), value_attr);
272   return ReplaceOpWithConstantTensor(builder, op, const_attr);
273 }
274 
ReplaceOperationWithSnapshot(OpBuilder & builder,TFOp op,int idx)275 static FailureOr<TFOp> ReplaceOperationWithSnapshot(OpBuilder &builder, TFOp op,
276                                                     int idx) {
277   // TODO(chiahungduan): If the graph contains no ops that mutate their
278   // inputs, we can use Identity instead of Snapshot.
279   // if (!graph_contains_assign_or_inplace_op_)
280   auto [non_control_operands, control_operands] = op.splitOperands();
281 
282   Value replace_value = op->getOperand(idx);
283   OperationState state(op->getLoc(), "tfg.Snapshot");
284   state.attributes = op->getAttrDictionary();
285   util::EraseRegularNodeAttributes(state.attributes);
286   state.addAttribute(
287       "T", TypeAttr::get(GetDataTypeFromOp(builder, op.getOperation())));
288   // Propagate the designated input through the Snapshot.
289   state.addOperands(replace_value);
290   // Add all other inputs as control dependencies.
291   llvm::append_range(state.operands,
292                      OperandControlRetRange(non_control_operands));
293   // Append the control operands
294   state.addOperands(control_operands);
295   state.addTypes(op->getResultTypes());
296 
297   Operation *snapshot_op = builder.create(state);
298   TFOp(snapshot_op).setName(op.nameAttr());
299   if (!op.device().empty())
300     TFOp(snapshot_op).setRequestedDevice(op.deviceAttr());
301   return TFOp(snapshot_op);
302 }
303 
ReplaceOperationWithBroadcastTo(OpBuilder & builder,TFOp op,int idx_to_replace)304 static FailureOr<TFOp> ReplaceOperationWithBroadcastTo(OpBuilder &builder,
305                                                        TFOp op,
306                                                        int idx_to_replace) {
307   ShapedType tensor_type = (*op->result_type_begin()).cast<ShapedType>();
308   if (!tensor_type.hasStaticShape()) return failure();
309   ElementsAttr const_attr = ConvertShapeToAttr(tensor_type);
310 
311   // Create a vector of control operands. We should not fail beyond this point
312   // since GetControlDependency may create a control anchor (a new op).
313   SmallVector<Value> control_operands;
314   for (auto &it : llvm::enumerate(op.getNonControlOperands())) {
315     int idx = it.index();
316     Value v = it.value();
317     if (idx == idx_to_replace) continue;
318     if (llvm::is_contained(control_operands, v)) continue;
319     control_operands.push_back(GetControlDependency(builder, v));
320   }
321   // CreateConstantTensorOp cannot fail; it only fails for variant types and
322   // const_attr is a tensor of i32.
323   TFOp const_op = *CreateConstantTensorOp(
324       builder, op->getLoc(),
325       (Twine(op.name(), "/broadcastto_shape_") + std::to_string(idx_to_replace))
326           .str(),
327       const_attr.getType(), control_operands, const_attr);
328   if (!op.device().empty()) const_op.setRequestedDevice(op.device());
329 
330   OperationState state(op->getLoc(), "tfg.BroadcastTo");
331 
332   state.attributes = op->getAttrDictionary();
333   util::EraseRegularNodeAttributes(state.attributes);
334   state.addAttribute(
335       "T", TypeAttr::get(GetDataTypeFromOp(builder, op.getOperation())));
336   state.addAttribute("Tidx", TypeAttr::get(builder.getI32Type()));
337 
338   state.addOperands({op->getOperand(idx_to_replace), const_op->getResult(0)});
339   state.addOperands(control_operands);
340   state.addTypes(op->getResultTypes());
341 
342   Operation *broadcast_to_op = builder.create(state);
343   TFOp(broadcast_to_op).setName(op.nameAttr());
344   if (!op.device().empty())
345     TFOp(broadcast_to_op).setRequestedDevice(op.deviceAttr());
346   return TFOp(broadcast_to_op);
347 }
348 
349 namespace {
350 // A helper class to see if an operation falls into certain category or has
351 // certain non-trivial properties.
352 class OpPropertyHelper : public OpCatHelper {
353  public:
OpPropertyHelper(TFGraphDialect * dialect,bool disable_compressed_tensor_optimization)354   OpPropertyHelper(TFGraphDialect *dialect,
355                    bool disable_compressed_tensor_optimization)
356       : OpCatHelper(dialect),
357         dialect_(dialect),
358         disable_compressed_tensor_optimization_(
359             disable_compressed_tensor_optimization) {}
360 
361   // Return true if the operation modifies the input in-place.
362   bool ModifiesInputsInPlace(TFOp op);
363 
364   // Return true if this operation doesn't have any side effect.
365   bool IsFreeOfSideEffect(TFOp op);
366 
367   // Return true if an operation may modify the frame info.
ModifiesFrameInfo(TFOp op)368   bool ModifiesFrameInfo(TFOp op) {
369     return dialect_->IsEnter(op) || dialect_->IsExit(op) ||
370            dialect_->IsNextIteration(op);
371   }
372 
373   // This combines the results of both MaybeFoldable() and IsFoldableUncached()
374   bool IsFoldable(TFOp op);
375 
376   // Return if this is a preserved op. It checks the `name` attr.
377   bool ShouldPreserveOp(TFOp op);
378 
379   // Disable compressed tensor optimization.
380   bool DisableCompressedTensorOptimization();
381 
382   // Get the TFG dialect instance.
getDialect()383   TFGraphDialect *getDialect() { return dialect_; }
384 
385  private:
386   // Return true if this operation is safe to be folded. This filter the ops by
387   // name.
388   bool MaybeFoldable(TFOp op);
389 
390   // Return true if this operation is safe to be folded. This filter the ops by
391   // the operation property like, it'll check the operands, attributes, .etc.
392   bool IsFoldableUncached(TFOp op);
393 
394   // A reference to the TFG dialect.
395   TFGraphDialect *dialect_;
396 
397   // Indicate that if we've disabled compressed tensor optimization.
398   bool disable_compressed_tensor_optimization_;
399 
400   // We only fold/materialize constants smaller than 100kB.
401   static constexpr int64_t kMaxConstantSize = 100 * 1024;
402 };
403 }  // namespace
404 
ModifiesInputsInPlace(TFOp op)405 bool OpPropertyHelper::ModifiesInputsInPlace(TFOp op) {
406   StringRef op_name = op->getName().stripDialect();
407 
408   // Ops that modify resource variables effectively modify one of their inputs.
409   if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
410       op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
411       op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
412       op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
413       op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
414     return false;
415   }
416 
417   std::string lower_op_name = op_name.str();
418   std::transform(lower_op_name.begin(), lower_op_name.end(),
419                  lower_op_name.begin(), ::tolower);
420   if (absl::StrContains(lower_op_name, "inplace")) return true;
421 
422   return op->hasAttr("in_place") || op->hasAttr("inplace");
423 }
424 
IsFreeOfSideEffect(TFOp op)425 bool OpPropertyHelper::IsFreeOfSideEffect(TFOp op) {
426   tensorflow::OpRegistry *op_registry = tensorflow::OpRegistry::Global();
427   const tensorflow::OpDef *op_def;
428   tensorflow::Status status =
429       op_registry->LookUpOpDef(op->getName().stripDialect().str(), &op_def);
430   if (!status.ok()) return false;
431 
432   if (op_def->is_stateful()) return false;
433 
434   for (const auto &input : op_def->input_arg())
435     if (input.is_ref()) return false;
436 
437   if (dialect_->IsQueue(op)) return false;
438 
439   if (dialect_->IsSend(op)) return false;
440 
441   return !ModifiesInputsInPlace(op);
442 }
443 
444 // To determine if we want to evalue the value of the operation. There several
445 // kinds operation we don't want to evalute with the eager runtime. Those
446 // operations may not safe for evaluation or not worth for evaluating because of
447 // the evaluation cost. For example, Const op already has the constant value
448 // attached as attribute.
MaybeFoldable(TFOp op)449 bool OpPropertyHelper::MaybeFoldable(TFOp op) {
450   StringRef op_name = op->getName().stripDialect();
451 
452   if (dialect_->IsConstant(op)) return false;
453 
454   // Don't fold stateful ops such as TruncatedNormal.
455   if (!IsFreeOfSideEffect(op)) return false;
456 
457   // TODO(chiahungduan): Handle preserve nodes
458 
459   // Skips ops that don't benefit from folding.
460   if (dialect_->IsPlaceholder(op)) return false;
461 
462   if (dialect_->IsFakeParam(op)) return false;
463 
464   // Skip certain control flow nodes, they can't be folded.
465   if (ModifiesFrameInfo(op)) return false;
466 
467   if (op_name == "AccumulateNV2") return false;
468 
469   // Removing LoopCond nodes can screw up the partitioner.
470   if (op_name == "LoopCond") return false;
471 
472   // TODO(chiahungduan): add fold_quantization_emulation arg.
473   // if (!fold_quantization_emulation && IsQuantizationEmulation(op)) return
474   // false;
475 
476   if (dialect_->IsRestore(op) || op_name.contains("Save") ||
477       op_name.contains("Reader"))
478     return false;
479 
480   if (op_name.contains("Quantized") ||
481       absl::StartsWith(op_name.data(), "Sparse"))
482     return false;
483 
484   // Don't fold nodes that contain TPU attributes.
485   // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
486   // properly forward custom attributes, b/119051778.
487   for (NamedAttribute attr : op->getAttrs())
488     if (attr.getName().strref().find("_tpu_") != StringRef::npos) return false;
489 
490   // Don't fold ops without outputs. Note that almost all tfg op has additional
491   // control output value.
492   if (op->getNumResults() <= 1) return false;
493 
494   const tensorflow::OpDef *op_def = nullptr;
495   tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef(
496       op->getName().stripDialect().str(), &op_def);
497   if (!status.ok()) {
498     return false;
499   }
500   // Don't fold ops without outputs.
501   if (op_def->output_arg_size() == 0) {
502     return false;
503   }
504 
505   // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
506   // TODO(rmlarsen): Only do this for XLA_* devices.
507   for (const tensorflow::OpDef::ArgDef &output_arg : op_def->output_arg()) {
508     if (output_arg.type() == tensorflow::DT_VARIANT) {
509       return false;
510     }
511   }
512 
513   // Don't fold nodes that have no outgoing edges except allowlisted nodes.
514   // Such nodes could be introduced by an earlier constant folding pass and are
515   // preserved in case users want to fetch their values; re-processing them
516   // would lead to an error of adding a duplicated node to graph.
517   // TODO(chiahungduan): Op has no users and doesn't in nodes_allowlist_ can't
518   // be folded.
519   return true;
520 }
521 
IsFoldableUncached(TFOp op)522 bool OpPropertyHelper::IsFoldableUncached(TFOp op) {
523   ValueRange operands = op.getNonControlOperands();
524   if (operands.empty()) return false;
525 
526   // We can only fold nodes if all their inputs are known statically, except in
527   // the case of a merge node that propagate the first inputs that becomes
528   // available, and therefore only requires a single constant input to be
529   // foldable.
530   bool merge_has_constant_input = false;
531   bool is_merge = dialect_->IsMerge(op);
532   for (Value operand : operands) {
533     TFOp operand_op = operand.getDefiningOp();
534     if (operand_op && dialect_->IsConstant(operand_op)) {
535       auto dtype = operand_op->getAttrOfType<TypeAttr>("dtype");
536       if (!dtype || dtype.getValue().isa<tf_type::StringType>()) return false;
537 
538       // Special case: If a Merge node has at least one constant input that
539       // does not depend on a control input, we can fold it.
540       merge_has_constant_input |= operand_op.getControlOperands().empty();
541     } else if (!is_merge) {
542       return false;
543     }
544   }
545 
546   if (is_merge && !merge_has_constant_input) return false;
547   if (DisableCompressedTensorOptimization() &&
548       (dialect_->IsFill(op) || dialect_->IsZerosLike(op) ||
549        dialect_->IsOnesLike(op))) {
550     return false;
551   }
552 
553   // If we know the output shapes, make sure that the outputs are small enough
554   // to materialize.
555   int64_t input_size_bytes = 0;
556   for (Value operand : operands) {
557     auto shape = operand.getType().dyn_cast<ShapedType>();
558     if (!shape || !shape.hasStaticShape()) continue;
559     auto element_type = shape.getElementType();
560 
561     tensorflow::DataType dtype;
562     if (!ConvertScalarTypeToDataType(element_type, &dtype).ok()) return false;
563     input_size_bytes += shape.getNumElements() * DataTypeSize(dtype);
564   }
565   for (Value res : op->getResults().drop_back()) {
566     auto shape = res.getType().dyn_cast<ShapedType>();
567     if (!shape || !shape.hasStaticShape()) continue;
568     auto element_type = shape.getElementType();
569 
570     tensorflow::DataType dtype;
571     if (!ConvertScalarTypeToDataType(element_type, &dtype).ok()) return false;
572     int64_t num_bytes = shape.getNumElements() * DataTypeSize(dtype);
573     if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize)
574       return false;
575   }
576 
577   return true;
578 }
579 
IsFoldable(TFOp op)580 bool OpPropertyHelper::IsFoldable(TFOp op) {
581   // TODO(chiahungduan): Cache foldable ops
582   if (!MaybeFoldable(op)) return false;
583   return IsFoldableUncached(op);
584 }
585 
ShouldPreserveOp(TFOp op)586 bool OpPropertyHelper::ShouldPreserveOp(TFOp op) {
587   // TODO(tlongeri): Find a better way to identify preserved ops. A node has its
588   // control output returned if it is a node-to-be-preserved (in
589   // LiftGraphToFunc) - *not* iff, so the following check is overly broad:
590   return llvm::any_of(op.controlRet().getUsers(), [&](TFOp child_op) {
591     return dialect_->IsReturn(child_op);
592   });
593 }
594 
DisableCompressedTensorOptimization()595 bool OpPropertyHelper::DisableCompressedTensorOptimization() {
596   return disable_compressed_tensor_optimization_;
597 }
598 
IsValidConstShapeForMulConvPushDown(StringAttr data_format,ShapedType filter_shape,ShapedType const_shape)599 static bool IsValidConstShapeForMulConvPushDown(StringAttr data_format,
600                                                 ShapedType filter_shape,
601                                                 ShapedType const_shape) {
602   if (!filter_shape.hasStaticShape() || !const_shape.hasStaticShape())
603     return false;
604   if (const_shape.getRank() <= data_format.size() &&
605       const_shape.getNumElements() == 1) {
606     return true;
607   }
608   if (data_format == "NHWC" || data_format == "NDHWC") {
609     SmallVector<int64_t> broadcast_shape;
610     if (!OpTrait::util::getBroadcastedShape(
611             filter_shape.getShape(), const_shape.getShape(), broadcast_shape)) {
612       return false;
613     }
614 
615     // TODO(chiahungduan): Symbolic shape equivalence is acceptable.
616     if (filter_shape.getShape() != llvm::makeArrayRef(broadcast_shape))
617       return false;
618 
619     // Only the last dimension could be larger than one, since broadcasting over
620     // the last dimension (the output channel) will result in invalid filter.
621     for (int dim_size : const_shape.getShape())
622       if (dim_size > 1) return false;
623     return true;
624   } else if (data_format == "NCHW" || data_format == "NCDHW") {
625     // TODO(laigd): support NCHW and NCDHW (b/111214513).
626     return false;
627   }
628   return false;
629 }
630 
631 namespace {
632 template <typename ConcreteType, template <typename> class... Traits>
633 class ConstantPatternBase : public RewritePattern,
634                             public Traits<ConcreteType>... {
635  public:
636   using RewritePattern::RewritePattern;
637 
ConstantPatternBase(StringRef opName,OpPropertyHelper & helper)638   ConstantPatternBase(StringRef opName, OpPropertyHelper &helper)
639       : RewritePattern(opName, PatternBenefit(1),
640                        helper.getDialect()->getContext()),
641         helper_(helper),
642         dialect_(helper.getDialect()) {}
ConstantPatternBase(MatchAnyOpTypeTag tag,OpPropertyHelper & helper)643   ConstantPatternBase(MatchAnyOpTypeTag tag, OpPropertyHelper &helper)
644       : RewritePattern(tag, PatternBenefit(1),
645                        helper.getDialect()->getContext()),
646         helper_(helper),
647         dialect_(helper.getDialect()) {}
648 
649  protected:
650   OpPropertyHelper &helper_;
651   TFGraphDialect *dialect_;
652 };
653 
654 // A base trait which can help with classifying patterns and filter patterns
655 // according to the classification.
656 template <typename ConcreteType>
657 struct TraitBase {
getPatternmlir::tfg::__anon36974c670711::TraitBase658   ConcreteType *getPattern() { return static_cast<ConcreteType *>(this); }
659 };
660 
661 // A trait indicates that the pattern will fold the root operation into a
662 // another operation like a constant op.
663 template <typename ConcreteType>
664 struct FolderTrait : public TraitBase<ConcreteType> {};
665 
666 // A trait indicates that the pattern may propagate the constant operands to its
667 // users.
668 template <typename ConcreteType>
669 struct PropagationTrait : public TraitBase<ConcreteType> {};
670 
671 template <typename ConcreteType>
672 using FolderPatternBase = ConstantPatternBase<ConcreteType, FolderTrait>;
673 
674 template <typename ConcreteType>
675 using PropagationPatternBase =
676     ConstantPatternBase<ConcreteType, PropagationTrait>;
677 }  // namespace
678 
679 // EvaluateConstant maps the implementation of FoldGraph in
680 // ConstantFolding::FoldGraph in grappler/optimizers/constant_folding.cc
681 class EvaluateConstant : public FolderPatternBase<EvaluateConstant> {
682  public:
EvaluateConstant(OpPropertyHelper & helper)683   explicit EvaluateConstant(OpPropertyHelper &helper)
684       : FolderPatternBase<EvaluateConstant>(MatchAnyOpTypeTag(), helper),
685         has_folded_(BoolAttr::get(helper.getDialect()->getContext(), true)),
686         folded_attr_name_(
687             StringAttr::get(helper.getDialect()->getContext(), "has_folded")),
688         cpu_device_(std::make_unique<util::SimpleDevice>()),
689         resource_mgr_(std::make_unique<tensorflow::ResourceMgr>()) {}
690 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const691   LogicalResult matchAndRewrite(Operation *op,
692                                 PatternRewriter &rewriter) const override {
693     if (!helper_.IsFoldable(op)) return failure();
694 
695     // TODO(chiahungduan): Switch folding needs to delete dead values.
696     if (dialect_->IsSwitch(op)) return failure();
697 
698     // The op has been folded but it has multiple results which we can just
699     // replace it with a constant op and it also has control edges which prevent
700     // it from removing. Use the attr to avoid evaluating them again.
701     if (op->hasAttr(folded_attr_name_)) return failure();
702 
703     // If the op has no users, don't invoke the eager runtime.
704     if (op->getNumResults() > 2 &&
705         llvm::all_of(op->getResults().drop_back(),
706                      [](Value v) { return v.use_empty(); })) {
707       return failure();
708     }
709 
710     SmallVector<ElementsAttr> const_operands;
711     for (Value operand : TFOp(op).getNonControlOperands()) {
712       Operation *defining_op = operand.getDefiningOp();
713       if (defining_op && dialect_->IsConstant(defining_op)) {
714         const_operands.push_back(
715             defining_op->getAttrOfType<ElementsAttr>("value"));
716       } else {
717         return failure();
718       }
719     }
720 
721     SmallVector<TypedAttr> result;
722     if (failed(util::EvaluateOperation(cpu_device_.get(), resource_mgr_.get(),
723                                        op, const_operands, result))) {
724       return failure();
725     }
726 
727     StringAttr name_attr = static_cast<TFGraphDialect *>(op->getDialect())
728                                ->getNameAttrIdentifier();
729     SmallVector<Value> control_operands(
730         OperandControlRetRange(op->getOperands()));
731 
732     StringAttr device_attr = TFOp(op).deviceAttr();
733     SmallVector<TFOp> const_ops;
734     for (auto &it : llvm::enumerate(result)) {
735       TypedAttr attr = it.value();
736       FailureOr<TFOp> const_op = CreateConstantTensorOp(
737           rewriter, op->getLoc(),
738           (Twine(TFOp(op).name(), "/eval_") + Twine(it.index())).str(),
739           attr.getType().cast<ShapedType>(), control_operands, attr,
740           NamedAttribute(name_attr, TFOp(op).nameAttr()));
741       if (failed(const_op)) return failure();
742       if (device_attr) (*const_op).setRequestedDevice(device_attr);
743       const_ops.emplace_back(*const_op);
744     }
745 
746     // If this is single output, just replace the op.
747     if (const_ops.size() == 1) {
748       // Use the same node name for the replacement. Note that even this is not
749       // in nodes_to_preserve, certain cases may still expect the op has the
750       // same name after folding.
751       const_ops[0].setName(TFOp(op).nameAttr());
752       rewriter.replaceOp(op, const_ops[0]->getResults());
753     } else {
754       for (auto &it : llvm::enumerate(const_ops)) {
755         for (OpOperand &user :
756              llvm::make_early_inc_range(op->getResult(it.index()).getUses())) {
757           rewriter.startRootUpdate(user.getOwner());
758           user.set(it.value()->getResult(0));
759           rewriter.finalizeRootUpdate(user.getOwner());
760         }
761       }
762 
763       // Now all the non-control operands are replaced with constant ops, remove
764       // the op if it doesn't have control operand either.
765       if (TFOp(op).controlRet().use_empty()) {
766         rewriter.eraseOp(op);
767       } else {
768         // We can't remove it directly. To avoid folding it again, add an attr
769         // to identity these ops. This will be removed in the end of constant
770         // folding pass.
771         op->setAttr(folded_attr_name_, has_folded_);
772       }
773     }
774 
775     return success();
776   }
777 
778  private:
779   BoolAttr has_folded_;
780   StringAttr folded_attr_name_;
781   std::unique_ptr<util::SimpleDevice> cpu_device_;
782   std::unique_ptr<tensorflow::ResourceMgr> resource_mgr_;
783 };
784 
785 // This implementation is mapped to the ShapeOp materialization in
786 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
787 class MaterializeShapeOp : public FolderPatternBase<MaterializeShapeOp> {
788  public:
MaterializeShapeOp(OpPropertyHelper & helper)789   explicit MaterializeShapeOp(OpPropertyHelper &helper)
790       : FolderPatternBase<MaterializeShapeOp>("tfg.Shape", helper) {}
791 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const792   LogicalResult matchAndRewrite(Operation *op,
793                                 PatternRewriter &rewriter) const override {
794     Value input = op->getOperand(0);
795 
796     auto input_shape = input.getType().cast<ShapedType>();
797     if (!input_shape.hasStaticShape()) return failure();
798 
799     // TODO(rmlarsen): Remove this workaround for b/150861569
800     // The bug involves an expression of the form Shape(ExpandDims(x)
801     // with an incorrectly inferred zero-size first dimension.
802     if (!input_shape.getShape().empty() && input_shape.getShape()[0] == 0)
803       return failure();
804 
805     Type output_dtype =
806         op->getResult(0).getType().cast<ShapedType>().getElementType();
807     ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
808         output_dtype, {input_shape.getRank()}, input_shape.getShape());
809 
810     // Add the control edge to `input` to ensure that the constant value will
811     // only be run in the cases where Shape would have been run in the original
812     // graph.
813     TFOp const_op = *CreateConstantTensorOp(
814         rewriter, op->getLoc(), /*name_prefix=*/"", const_attr.getType(),
815         GetControlDependency(rewriter, input), const_attr, op->getAttrs());
816     const_op.setName(TFOp(op).nameAttr());
817 
818     rewriter.replaceOp(op, const_op->getResults());
819 
820     return success();
821   }
822 };
823 
824 // This implementation is mapped to the SizeOp materialization in
825 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
826 class MaterializeSizeOp : public FolderPatternBase<MaterializeSizeOp> {
827  public:
MaterializeSizeOp(OpPropertyHelper & helper)828   explicit MaterializeSizeOp(OpPropertyHelper &helper)
829       : FolderPatternBase<MaterializeSizeOp>("tfg.Size", helper) {}
830 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const831   LogicalResult matchAndRewrite(Operation *op,
832                                 PatternRewriter &rewriter) const override {
833     Value input = op->getOperand(0);
834 
835     auto input_shape = input.getType().cast<ShapedType>();
836     if (!input_shape.hasStaticShape()) return failure();
837 
838     ShapedType result_type = (*op->result_type_begin()).cast<ShapedType>();
839     if (!result_type.getElementType().isIntOrIndexOrFloat()) return failure();
840 
841     ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
842         result_type.getElementType(), {},
843         ArrayRef<int64_t>(input_shape.getNumElements()));
844 
845     // Add the control edge to `input` to ensure that the constant value will
846     // only be run in the cases where Size would have been run in the original
847     // graph.
848     TFOp const_op = *CreateConstantTensorOp(
849         rewriter, op->getLoc(), /*name_prefix=*/"", const_attr.getType(),
850         GetControlDependency(rewriter, input), const_attr, op->getAttrs());
851     const_op.setName(TFOp(op).nameAttr());
852 
853     rewriter.replaceOp(op, const_op->getResults());
854 
855     return success();
856   }
857 };
858 
859 // This implementation is mapped to the RankOp materialization in
860 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
861 class MaterializeRankOp : public FolderPatternBase<MaterializeRankOp> {
862  public:
MaterializeRankOp(OpPropertyHelper & helper)863   explicit MaterializeRankOp(OpPropertyHelper &helper)
864       : FolderPatternBase<MaterializeRankOp>("tfg.Rank", helper) {}
865 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const866   LogicalResult matchAndRewrite(Operation *op,
867                                 PatternRewriter &rewriter) const override {
868     Value input = op->getOperand(0);
869 
870     auto input_shape = input.getType().cast<ShapedType>();
871     if (!input_shape.hasRank()) return failure();
872 
873     ShapedType result_type = (*op->result_type_begin()).cast<ShapedType>();
874     if (!result_type.getElementType().isIntOrIndexOrFloat()) return failure();
875 
876     ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
877         result_type.getElementType(), {}, ArrayRef<int>(input_shape.getRank()));
878 
879     // Add the control edge to `input` to ensure that the constant value will
880     // only be run in the cases where Rank would have been run in the original
881     // graph.
882     TFOp const_op = *CreateConstantTensorOp(
883         rewriter, op->getLoc(), /*name_prefix=*/"", const_attr.getType(),
884         GetControlDependency(rewriter, input), const_attr, op->getAttrs());
885     const_op.setName(TFOp(op).nameAttr());
886 
887     rewriter.replaceOp(op, const_op->getResults());
888 
889     return success();
890   }
891 };
892 
893 // This implementation is mapped to the TensorArraySizeV3 materialization in
894 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
895 class MaterializeTensorArraySizeV3Op
896     : public FolderPatternBase<MaterializeTensorArraySizeV3Op> {
897  public:
MaterializeTensorArraySizeV3Op(OpPropertyHelper & helper)898   explicit MaterializeTensorArraySizeV3Op(OpPropertyHelper &helper)
899       : FolderPatternBase<MaterializeTensorArraySizeV3Op>(
900             "tfg.TensorArraySizeV3", helper) {}
901 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const902   LogicalResult matchAndRewrite(Operation *op,
903                                 PatternRewriter &rewriter) const override {
904     Operation *handle_op = op->getOperand(0).getDefiningOp();
905     if (!handle_op || handle_op->getNumOperands() == 0) return failure();
906 
907     auto dynamic_size = handle_op->getAttrOfType<BoolAttr>("dynamic_size");
908     if (dynamic_size && dynamic_size.getValue()) return failure();
909 
910     Operation *array_size = handle_op->getOperand(0).getDefiningOp();
911     if (!array_size || !dialect_->IsConstant(array_size)) return failure();
912 
913     // Don't materialize 0 sizes to avoid triggering incorrect static checks.
914     // A 0 sized array that can't grow isn't useful anyway.
915     auto size_attr = array_size->getAttrOfType<SplatElementsAttr>("value");
916     if (!size_attr || !size_attr.getElementType().isInteger(32))
917       return failure();
918     if (size_attr.getSplatValue<IntegerAttr>().getInt() == 0) return failure();
919 
920     SmallVector<Value> control_operands;
921     control_operands.push_back(TFOp(handle_op).controlRet());
922     control_operands.push_back(
923         GetControlDependency(rewriter, op->getOperand(1)));
924     // CreateConstantTensorOp cannot fail; its type is tensor of i32
925     TFOp const_op = *CreateConstantTensorOp(
926         rewriter, op->getLoc(), /*name_prefix=*/"", size_attr.getType(),
927         control_operands, size_attr, op->getAttrs());
928     const_op.setName(TFOp(op).nameAttr());
929 
930     rewriter.replaceOp(op, const_op->getResults());
931 
932     return success();
933   }
934 };
935 
936 // This implementation is mapped to the ShapeN materialization in
937 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
938 class MaterializeShapeNOp : public FolderPatternBase<MaterializeShapeNOp> {
939  public:
MaterializeShapeNOp(OpPropertyHelper & helper)940   explicit MaterializeShapeNOp(OpPropertyHelper &helper)
941       : FolderPatternBase<MaterializeShapeNOp>("tfg.ShapeN", helper) {}
942 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const943   LogicalResult matchAndRewrite(Operation *op,
944                                 PatternRewriter &rewriter) const override {
945     for (const auto &it : llvm::enumerate(TFOp(op).getNonControlOperands())) {
946       Value operand = op->getOperand(it.index());
947 
948       auto operand_shape = operand.getType().cast<ShapedType>();
949       if (!operand_shape.hasStaticShape()) continue;
950 
951       if (op->getResults()[it.index()].use_empty()) continue;
952 
953       ElementsAttr const_attr = ConvertShapeToAttr(operand_shape);
954 
955       FailureOr<TFOp> const_op = CreateConstantTensorOp(
956           rewriter, op->getLoc(), TFOp(op).name(), *(op->result_type_begin()),
957           TFOp(op).controlRet(), const_attr);
958       if (failed(const_op)) return failure();
959 
960       (*const_op).setName(Twine(TFOp(op).name(), "/matshapes_") +
961                           std::to_string(it.index()));
962       if (!TFOp(op).device().empty())
963         (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
964 
965       // TODO(chiahungduan): Do we need to handle `direct_edges_exist` in
966       // ConstantFolding::MaterializeShapes for ShapeN?
967 
968       for (OpOperand &user :
969            llvm::make_early_inc_range(op->getResult(it.index()).getUses())) {
970         rewriter.startRootUpdate(user.getOwner());
971         user.set((*const_op)->getResult(0));
972         rewriter.finalizeRootUpdate(user.getOwner());
973       }
974     }
975 
976     return success();
977   }
978 };
979 
980 // This implementation is mapped to the BroadcastGradientArgsOp materialization
981 // in ConstantFolding::MaterializeBroadcastGradientArgs in
982 // grappler/optimizers/constant_folding.cc
983 class MaterializeBroadcastGradientArgsOp
984     : public PropagationPatternBase<MaterializeBroadcastGradientArgsOp> {
985  public:
MaterializeBroadcastGradientArgsOp(OpPropertyHelper & helper)986   explicit MaterializeBroadcastGradientArgsOp(OpPropertyHelper &helper)
987       : PropagationPatternBase<MaterializeBroadcastGradientArgsOp>(
988             "tfg.BroadcastGradientArgs", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const989   LogicalResult matchAndRewrite(Operation *op,
990                                 PatternRewriter &rewriter) const override {
991     Operation *s0 = op->getOperand(0).getDefiningOp();
992     Operation *s1 = op->getOperand(1).getDefiningOp();
993     if (!s0 || !s1) return failure();
994 
995     if (!dialect_->IsShape(s0) && !dialect_->IsConstant(s0)) return failure();
996     if (!dialect_->IsShape(s1) && !dialect_->IsConstant(s1)) return failure();
997 
998     // This operation has been optimized.
999     if (op->getResult(0).use_empty() || op->getResult(1).use_empty())
1000       return failure();
1001 
1002     auto get_shape = [this](Operation *op,
1003                             SmallVector<int64_t> &shape) -> bool {
1004       if (dialect_->IsShape(op)) {
1005         auto type = op->getOperand(0).getType().cast<ShapedType>();
1006         if (!type.hasRank()) return false;
1007 
1008         llvm::append_range(shape, type.getShape());
1009       } else {
1010         auto attr = op->getAttrOfType<ElementsAttr>("value");
1011         if (!attr) return false;
1012 
1013         Type element_type = attr.getElementType();
1014         if (element_type.isInteger(32)) {
1015           llvm::append_range(shape, attr.getValues<int32_t>());
1016         } else if (element_type.isInteger(64)) {
1017           llvm::append_range(shape, attr.getValues<int64_t>());
1018         } else {
1019           return false;
1020         }
1021       }
1022       return true;
1023     };
1024 
1025     SmallVector<int64_t> s0_shape;
1026     SmallVector<int64_t> s1_shape;
1027     if (!get_shape(s0, s0_shape) || !get_shape(s1, s1_shape)) return failure();
1028 
1029     const int common_dims = std::min(s0_shape.size(), s1_shape.size());
1030     for (int i = 0; i < common_dims; ++i) {
1031       if (s0_shape[i] >= 0 && s1_shape[i] >= 0) continue;
1032 
1033       // TODO(chiahungduan): Check if two dims are symbolically equal. Grappler
1034       // stores the symbolic shape information with dim < -1 which is not a
1035       // convention in TFG. Use symbolic shape information instead.
1036 
1037       // Return failure if two dims are symbolically unequal.
1038       return failure();
1039     }
1040 
1041     for (int i = common_dims; i < s0_shape.size(); ++i)
1042       if (s0_shape[i] < 0) return failure();
1043     for (int i = common_dims; i < s1_shape.size(); ++i)
1044       if (s1_shape[i] < 0) return failure();
1045 
1046     tensorflow::BCast::Vec s0_vec(s0_shape.begin(), s0_shape.end());
1047     tensorflow::BCast::Vec s1_vec(s1_shape.begin(), s1_shape.end());
1048     tensorflow::BCast bcast(s0_vec, s1_vec);
1049     if (!bcast.IsValid()) return failure();
1050 
1051     tensorflow::BCast::Vec reduce_dims[2];
1052     reduce_dims[0] = bcast.grad_x_reduce_idx();
1053     reduce_dims[1] = bcast.grad_y_reduce_idx();
1054 
1055     auto type_attr = op->getAttrOfType<TypeAttr>("T");
1056     if (!type_attr) return failure();
1057     if (!type_attr.getValue().isIntOrIndexOrFloat()) return failure();
1058 
1059     SmallVector<Value, 2> const_values;
1060     for (int j = 0; j < 2; ++j) {
1061       int reduction_indices = reduce_dims[j].size();
1062       ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1063           type_attr.getValue(), {reduction_indices},
1064           llvm::makeArrayRef<int64_t>(reduce_dims[j].data(),
1065                                       reduction_indices));
1066       FailureOr<TFOp> const_op = CreateConstantTensorOp(
1067           rewriter, op->getLoc(), TFOp(op).name(), op->getResultTypes()[j],
1068           TFOp(op).controlRet(), const_attr);
1069       if (failed(const_op)) return failure();
1070 
1071       (*const_op).setName(Twine(TFOp(op).name(), "/bcastargs_") +
1072                           std::to_string(j));
1073       if (!TFOp(op).device().empty())
1074         (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
1075       const_values.push_back((*const_op)->getResult(0));
1076     }
1077 
1078     for (OpOperand &user :
1079          llvm::make_early_inc_range(op->getResult(0).getUses())) {
1080       rewriter.startRootUpdate(user.getOwner());
1081       user.set(const_values[0]);
1082       rewriter.finalizeRootUpdate(user.getOwner());
1083     }
1084     for (OpOperand &user :
1085          llvm::make_early_inc_range(op->getResult(1).getUses())) {
1086       rewriter.startRootUpdate(user.getOwner());
1087       user.set(const_values[1]);
1088       rewriter.finalizeRootUpdate(user.getOwner());
1089     }
1090 
1091     return success();
1092   }
1093 };
1094 
1095 // This implementation is mapped to the indices of reduction ops materialization
1096 // in ConstantFolding::MaterializeReductionIndices in
1097 // grappler/optimizers/constant_folding.cc
1098 class MaterializeReductionIndices
1099     : public PropagationPatternBase<MaterializeReductionIndices> {
1100  public:
MaterializeReductionIndices(OpPropertyHelper & helper)1101   explicit MaterializeReductionIndices(OpPropertyHelper &helper)
1102       : PropagationPatternBase<MaterializeReductionIndices>(MatchAnyOpTypeTag(),
1103                                                             helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1104   LogicalResult matchAndRewrite(Operation *op,
1105                                 PatternRewriter &rewriter) const override {
1106     if (!dialect_->IsReduction(op)) return failure();
1107 
1108     Operation *indices = op->getOperand(1).getDefiningOp();
1109     // The reduction indices are already constant, there's nothing to do.
1110     if (!indices || dialect_->IsConstant(indices)) return failure();
1111 
1112     auto indices_shape = indices->getResult(0).getType().cast<ShapedType>();
1113     if (!indices_shape.hasRank()) return failure();
1114     if (!indices_shape.getElementType().isInteger(32) &&
1115         !indices_shape.getElementType().isInteger(64)) {
1116       return failure();
1117     }
1118 
1119     auto input_shape = op->getOperand(0).getType().cast<ShapedType>();
1120     // Unexpected graph, don't try to change it.
1121     if (!input_shape.hasRank() || input_shape.getRank() < 1) return failure();
1122 
1123     auto output_shape = op->getResult(0).getType().cast<ShapedType>();
1124     const int output_rank =
1125         output_shape.hasRank() ? output_shape.getRank() : -1;
1126 
1127     bool full_reduction = output_rank == 0 || (indices_shape.hasStaticShape() &&
1128                                                indices_shape.getNumElements() ==
1129                                                    input_shape.getRank());
1130 
1131     if (!full_reduction) {
1132       // A full reduction will generate a tensor of one of the shapes
1133       // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
1134       // elements in the output of the reduction, we may deduce it from reshape
1135       // nodes following it.
1136       for (Operation *user : op->getResult(0).getUsers()) {
1137         full_reduction = false;
1138         if (!dialect_->IsReshape(user)) return failure();
1139 
1140         auto shape = user->getResult(0).getType().cast<ShapedType>();
1141         if (!shape.hasStaticShape() || shape.getNumElements() != 1)
1142           return failure();
1143         else
1144           full_reduction = true;
1145       }
1146       if (!full_reduction) return failure();
1147     }
1148 
1149     // We know it's a full reduction. We can generate the full set of indices
1150     // to reduce as a constant node.
1151     SmallVector<int> elements(input_shape.getRank());
1152     std::iota(elements.begin(), elements.end(), 0);
1153 
1154     ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1155         indices_shape.getElementType(), {input_shape.getRank()},
1156         llvm::makeArrayRef(elements));
1157 
1158     FailureOr<TFOp> const_op = CreateConstantTensorOp(
1159         rewriter, indices->getLoc(), Twine(TFOp(op).name(), "/indices").str(),
1160         const_attr.getType(), TFOp(indices).controlRet(), const_attr);
1161     if (failed(const_op)) return failure();
1162 
1163     if (TFOp(op).deviceAttr())
1164       (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
1165 
1166     rewriter.startRootUpdate(op);
1167     op->setOperand(1, (*const_op)->getResults()[0]);
1168     rewriter.finalizeRootUpdate(op);
1169 
1170     return success();
1171   }
1172 };
1173 
1174 // This implementation is mapped to the constant value materialization in
1175 // ConstantFolding::MaterializeConstantValuedNode in
1176 // grappler/optimizers/constant_folding.cc
1177 class MaterializeFillNode : public FolderPatternBase<MaterializeFillNode> {
1178  public:
MaterializeFillNode(OpPropertyHelper & helper)1179   explicit MaterializeFillNode(OpPropertyHelper &helper)
1180       : FolderPatternBase<MaterializeFillNode>("tfg.Fill", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1181   LogicalResult matchAndRewrite(Operation *op,
1182                                 PatternRewriter &rewriter) const override {
1183     if (helper_.DisableCompressedTensorOptimization()) return failure();
1184     // Only handles single result op. Note that another result is control ret.
1185     if (op->getNumResults() != 2) return failure();
1186 
1187     auto output_type = op->getResult(0).getType().cast<ShapedType>();
1188     if (!output_type.hasStaticShape()) return failure();
1189     if (!output_type.isIntOrIndexOrFloat()) return failure();
1190 
1191     Operation *dim = op->getOperand(0).getDefiningOp();
1192     Operation *value = op->getOperand(1).getDefiningOp();
1193     if (!dim || !value) return failure();
1194     // In grappler's constant folding, they also check if `dim` is constant.
1195     // Which is redundant because it's constant property is never used.
1196     if (!dialect_->IsConstant(value)) return failure();
1197 
1198     ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1199         output_type.getElementType(), output_type.getShape(),
1200         {value->getAttrOfType<ElementsAttr>("value")});
1201 
1202     FailureOr<TFOp> const_op = ReplaceOpWithConstantTensor(
1203         rewriter, op, const_attr,
1204         /*exclude_attrs=*/ArrayRef<StringRef>({"T", "index_type"}));
1205     if (failed(const_op)) return failure();
1206 
1207     rewriter.replaceOp(op, (*const_op)->getResults());
1208 
1209     return success();
1210   }
1211 };
1212 
1213 // This implementation is mapped to the constant value materialization in
1214 // ConstantFolding::MaterializeConstantValuedNode in
1215 // grappler/optimizers/constant_folding.cc
1216 class MaterializeConstantValuedNode
1217     : public FolderPatternBase<MaterializeConstantValuedNode> {
1218  public:
MaterializeConstantValuedNode(OpPropertyHelper & helper)1219   explicit MaterializeConstantValuedNode(OpPropertyHelper &helper)
1220       : FolderPatternBase<MaterializeConstantValuedNode>(MatchAnyOpTypeTag(),
1221                                                          helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1222   LogicalResult matchAndRewrite(Operation *op,
1223                                 PatternRewriter &rewriter) const override {
1224     if (helper_.DisableCompressedTensorOptimization()) return failure();
1225     // Only handles single result op. Note that another result is control ret.
1226     if (op->getNumResults() != 2) return failure();
1227 
1228     // FillOp is handled in MaterializeFillNode pattern.
1229     if (dialect_->IsFill(op)) return failure();
1230     if (!dialect_->IsZerosLike(op) && !dialect_->IsOnesLike(op))
1231       return failure();
1232 
1233     // TODO(chiahungduan): If op->getOperand(0) has static shape, can we use
1234     // that to materialize?
1235     auto output_type = op->getResult(0).getType().cast<ShapedType>();
1236     if (!output_type.hasStaticShape()) return failure();
1237 
1238     int value =
1239         dialect_->IsZerosLike(op) ? 0 : (dialect_->IsOnesLike(op) ? 1 : -1);
1240     if (value < 0) return failure();
1241 
1242     if (!output_type.getElementType().isIntOrIndexOrFloat()) return failure();
1243 
1244     ElementsAttr const_attr;
1245     if (output_type.getElementType().isIntOrIndex()) {
1246       const_attr = CreateElementsAttrOfTypeValues(output_type.getElementType(),
1247                                                   output_type.getShape(),
1248                                                   ArrayRef<int>(value));
1249     } else {
1250       const_attr = CreateElementsAttrOfTypeValues(output_type.getElementType(),
1251                                                   output_type.getShape(),
1252                                                   ArrayRef<double>(value));
1253     }
1254 
1255     FailureOr<TFOp> const_op =
1256         ReplaceOpWithConstantTensor(rewriter, op, const_attr);
1257     if (failed(const_op)) return failure();
1258 
1259     rewriter.replaceOp(op, (*const_op)->getResults());
1260     return success();
1261   }
1262 };
1263 
1264 // This implementation is mapped to the output value materialization in
1265 // ConstantFolding::MaterializeOutputValues in
1266 // grappler/optimizers/constant_folding.cc
1267 class MaterializeOutputValue
1268     : public PropagationPatternBase<MaterializeOutputValue> {
1269  public:
MaterializeOutputValue(OpPropertyHelper & helper)1270   explicit MaterializeOutputValue(OpPropertyHelper &helper)
1271       : PropagationPatternBase<MaterializeOutputValue>(MatchAnyOpTypeTag(),
1272                                                        helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1273   LogicalResult matchAndRewrite(Operation *op,
1274                                 PatternRewriter &rewriter) const override {
1275     // In grappler, the shape information is stored in a separate structure and
1276     // this pass is used to materialize the shape inference information to the
1277     // node. But in MLIR, the shape inference information is stored in the
1278     // operation.
1279     return failure();
1280   }
1281 };
1282 
1283 // This implementation is mapped to the merge node folding in
1284 // ConstantFolding::FoldMergeNode in
1285 // grappler/optimizers/constant_folding.cc
1286 template <typename ConcreteType>
1287 class MergeNodeFoldingBase : public PropagationPatternBase<ConcreteType> {
1288  protected:
MergeNodeFoldingBase(StringRef op_name,OpPropertyHelper & helper)1289   MergeNodeFoldingBase(StringRef op_name, OpPropertyHelper &helper)
1290       : PropagationPatternBase<ConcreteType>(op_name, helper),
1291         zero_dim_i32_tensor_type_(RankedTensorType::get(
1292             llvm::None,
1293             IntegerType::get(helper.getDialect()->getContext(), 32))) {}
1294 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1295   LogicalResult matchAndRewrite(Operation *op,
1296                                 PatternRewriter &rewriter) const override {
1297     // Merge nodes are special, in the sense that they execute as soon as one of
1298     // their input is ready. We can therefore fold a merge node iff it has at
1299     // least one constant input without control dependency.
1300     // We still need to ensure that the nodes in the fanin of the merge node are
1301     // scheduled. We'll therefore add a control dependency from the merge node
1302     // to the folded constant. We end up with:
1303     //  * the merge node and its inputs are preserved as is
1304     //  * a new constant node C1, driven by the merge node through a control
1305     //  dependency, initialized to the value of the folded input
1306     //  * a new constant node C2, driven by the merge node through a control
1307     //  dependency, initialized to the index of the folded input
1308     //  * the fanout of the merge nodes is rewired to be driven by either C1 or
1309     //  C2.
1310 
1311     // The node may have been optimized.
1312     if (llvm::all_of(op->getResults().drop_back(),
1313                      [](Value v) { return v.use_empty(); })) {
1314       return failure();
1315     }
1316 
1317     int idx = 0;
1318     for (Value operand : TFOp(op).getNonControlOperands()) {
1319       Operation *operand_op = operand.getDefiningOp();
1320       if (!operand_op) continue;
1321       if (!this->dialect_->IsConstant(operand_op)) continue;
1322       if (!TFOp(operand_op).getControlOperands().empty()) continue;
1323 
1324       FailureOr<TFOp> const_out = CreateConstantTensorOp(
1325           rewriter, op->getLoc(), TFOp(op).name(),
1326           *(operand_op->result_type_begin()), TFOp(op).controlRet(),
1327           operand_op->getAttrOfType<ElementsAttr>("value"), op->getAttrs());
1328       if (failed(const_out)) return failure();
1329       (*const_out).setName(Twine(TFOp(op).name(), "/const"));
1330       if (!TFOp(op).device().empty())
1331         (*const_out).setRequestedDevice(TFOp(op).device());
1332 
1333       FailureOr<TFOp> const_index = CreateConstantTensorOp(
1334           rewriter, op->getLoc(), TFOp(op).name(), rewriter.getIntegerType(32),
1335           TFOp(op).controlRet(),
1336           DenseElementsAttr::get(zero_dim_i32_tensor_type_, idx++));
1337       if (failed(const_index)) return failure();
1338 
1339       (*const_index).setName(Twine(TFOp(op).name(), "/index"));
1340       if (!TFOp(op).device().empty())
1341         (*const_index).setRequestedDevice(TFOp(op).device());
1342 
1343       for (OpOperand &user :
1344            llvm::make_early_inc_range(op->getResults()[0].getUses())) {
1345         rewriter.startRootUpdate(user.getOwner());
1346         user.set((*const_out)->getResult(0));
1347         rewriter.finalizeRootUpdate(user.getOwner());
1348       }
1349       for (OpOperand &user :
1350            llvm::make_early_inc_range(op->getResults()[1].getUses())) {
1351         rewriter.startRootUpdate(user.getOwner());
1352         user.set((*const_index)->getResult(0));
1353         rewriter.finalizeRootUpdate(user.getOwner());
1354       }
1355 
1356       // Already found an avaiable input.
1357       return success();
1358     }
1359     return failure();
1360   }
1361 
1362   RankedTensorType zero_dim_i32_tensor_type_;
1363 };
1364 
1365 class MergeNodeFolding : public MergeNodeFoldingBase<MergeNodeFolding> {
1366  public:
MergeNodeFolding(OpPropertyHelper & helper)1367   explicit MergeNodeFolding(OpPropertyHelper &helper)
1368       : MergeNodeFoldingBase("tfg.Merge", helper) {}
1369 };
1370 
1371 class RefMergeNodeFolding : public MergeNodeFoldingBase<RefMergeNodeFolding> {
1372  public:
RefMergeNodeFolding(OpPropertyHelper & helper)1373   explicit RefMergeNodeFolding(OpPropertyHelper &helper)
1374       : MergeNodeFoldingBase("tfg.RefMerge", helper) {}
1375 };
1376 
1377 class XlaMergeNodeFolding : public MergeNodeFoldingBase<XlaMergeNodeFolding> {
1378  public:
XlaMergeNodeFolding(OpPropertyHelper & helper)1379   explicit XlaMergeNodeFolding(OpPropertyHelper &helper)
1380       : MergeNodeFoldingBase("tfg.XlaMerge", helper) {}
1381 };
1382 
1383 // This implementation is mapped with ConstantFolding::RemoveSplitOrSplitVin in
1384 // grappler/optimizers/constant_folding.cc
1385 class RemoveSplitOp : public FolderPatternBase<RemoveSplitOp> {
1386  public:
RemoveSplitOp(OpPropertyHelper & helper)1387   explicit RemoveSplitOp(OpPropertyHelper &helper)
1388       : FolderPatternBase<RemoveSplitOp>("tfg.Split", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1389   LogicalResult matchAndRewrite(Operation *op,
1390                                 PatternRewriter &rewriter) const override {
1391     auto num_split_attr = op->getAttrOfType<IntegerAttr>("num_split");
1392     if (!num_split_attr || num_split_attr.getInt() != 1) return failure();
1393     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 1);
1394     if (failed(identity)) return failure();
1395     rewriter.replaceOp(op, (*identity)->getResults());
1396     return success();
1397   }
1398 };
1399 
1400 // This implementation is mapped with ConstantFolding::RemoveSplitOrSplitVin in
1401 // grappler/optimizers/constant_folding.cc
1402 class RemoveSplitVOp : public FolderPatternBase<RemoveSplitVOp> {
1403  public:
RemoveSplitVOp(OpPropertyHelper & helper)1404   explicit RemoveSplitVOp(OpPropertyHelper &helper)
1405       : FolderPatternBase<RemoveSplitVOp>("tfg.SplitV", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1406   LogicalResult matchAndRewrite(Operation *op,
1407                                 PatternRewriter &rewriter) const override {
1408     auto num_split_attr = op->getAttrOfType<IntegerAttr>("num_split");
1409     if (!num_split_attr || num_split_attr.getInt() != 1) return failure();
1410     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1411     if (failed(identity)) return failure();
1412     rewriter.replaceOp(op, (*identity)->getResults());
1413     return success();
1414   }
1415 };
1416 
1417 // TODO(chiahungduan): Do we still have "Shuffle" op?
1418 // This implementation is mapped with ConstantFolding::RemoveShuffleOrTranspose
1419 // in grappler/optimizers/constant_folding.cc
1420 class RemoveShuffleOp : public FolderPatternBase<RemoveShuffleOp> {
1421  public:
RemoveShuffleOp(OpPropertyHelper & helper)1422   explicit RemoveShuffleOp(OpPropertyHelper &helper)
1423       : FolderPatternBase<RemoveShuffleOp>("tfg.Shuffle", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1424   LogicalResult matchAndRewrite(Operation *op,
1425                                 PatternRewriter &rewriter) const override {
1426     Operation *perm_op = op->getOperand(1).getDefiningOp();
1427     if (!perm_op || !dialect_->IsConstant(perm_op)) return failure();
1428     ElementsAttr perm_tensor = perm_op->getAttrOfType<ElementsAttr>("value");
1429     if (!perm_tensor) return failure();
1430 
1431     ShapedType x_shape = op->getOperand(0).getType().cast<ShapedType>();
1432     if (!x_shape.hasRank()) return failure();
1433     if (perm_tensor.getNumElements() != x_shape.getRank()) return failure();
1434 
1435     for (unsigned i = 0; i < x_shape.getRank(); ++i) {
1436       int64_t value = perm_tensor.getElementType().isInteger(32)
1437                           ? perm_tensor.getValues<int32_t>()[i]
1438                           : perm_tensor.getValues<int64_t>()[i];
1439       if (value != i && x_shape.getShape()[i] != 1) return failure();
1440     }
1441 
1442     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1443     if (failed(identity)) return failure();
1444     rewriter.replaceOp(op, (*identity)->getResults());
1445 
1446     return success();
1447   }
1448 };
1449 
1450 // This implementation is mapped with ConstantFolding::RemoveShuffleOrTranspose
1451 // in grappler/optimizers/constant_folding.cc
1452 class RemoveTransposeOp : public FolderPatternBase<RemoveTransposeOp> {
1453  public:
RemoveTransposeOp(OpPropertyHelper & helper)1454   explicit RemoveTransposeOp(OpPropertyHelper &helper)
1455       : FolderPatternBase<RemoveTransposeOp>("tfg.Transpose", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1456   LogicalResult matchAndRewrite(Operation *op,
1457                                 PatternRewriter &rewriter) const override {
1458     Operation *perm_op = op->getOperand(1).getDefiningOp();
1459     if (!perm_op || !dialect_->IsConstant(perm_op)) return failure();
1460     ElementsAttr perm_tensor = perm_op->getAttrOfType<ElementsAttr>("value");
1461     if (!perm_tensor) return failure();
1462 
1463     ShapedType x_shape = op->getOperand(0).getType().cast<ShapedType>();
1464     if (!x_shape.hasRank()) return failure();
1465     if (perm_tensor.getNumElements() != x_shape.getRank()) return failure();
1466 
1467     for (unsigned i = 0; i < x_shape.getRank(); ++i) {
1468       int64_t value = perm_tensor.getElementType().isInteger(32)
1469                           ? perm_tensor.getValues<int32_t>()[i]
1470                           : perm_tensor.getValues<int64_t>()[i];
1471       if (value != i && x_shape.getShape()[i] != 1) return failure();
1472     }
1473 
1474     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1475     if (failed(identity)) return failure();
1476     rewriter.replaceOp(op, (*identity)->getResults());
1477 
1478     return success();
1479   }
1480 };
1481 
1482 // This implementation is mapped with ConstantFolding::RemoveRandomShuffle
1483 // in grappler/optimizers/constant_folding.cc
1484 class RemoveRandomShuffleOp : public FolderPatternBase<RemoveRandomShuffleOp> {
1485  public:
RemoveRandomShuffleOp(OpPropertyHelper & helper)1486   explicit RemoveRandomShuffleOp(OpPropertyHelper &helper)
1487       : FolderPatternBase<RemoveRandomShuffleOp>("tfg.RandomShuffle", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1488   LogicalResult matchAndRewrite(Operation *op,
1489                                 PatternRewriter &rewriter) const override {
1490     auto shape = op->getOperand(0).getType().cast<ShapedType>();
1491     if (!shape.hasRank()) return failure();
1492     if (shape.getRank() != 0 && shape.getShape()[0] != 1) return failure();
1493 
1494     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1495     if (failed(identity)) return failure();
1496     rewriter.replaceOp(op, (*identity)->getResults());
1497 
1498     return success();
1499   }
1500 };
1501 
1502 // This implementation is mapped with ConstantFolding::RemoveReverse
1503 // in grappler/optimizers/constant_folding.cc
1504 class RemoveReverse : public FolderPatternBase<RemoveReverse> {
1505  public:
RemoveReverse(OpPropertyHelper & helper)1506   explicit RemoveReverse(OpPropertyHelper &helper)
1507       : FolderPatternBase<RemoveReverse>("tfg.ReverseV2", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1508   LogicalResult matchAndRewrite(Operation *op,
1509                                 PatternRewriter &rewriter) const override {
1510     ShapedType tensor_type = op->getOperand(0).getType().cast<ShapedType>();
1511     if (!tensor_type.hasRank()) return failure();
1512 
1513     Operation *dim_op = op->getOperand(1).getDefiningOp();
1514     if (!dim_op || !dialect_->IsConstant(dim_op)) return failure();
1515 
1516     auto dim_attr = dim_op->getAttrOfType<ElementsAttr>("value");
1517     DenseSet<int> target_axis;
1518     for (unsigned i = 0; i < dim_attr.getNumElements(); ++i) {
1519       // Value of axis can be negative.
1520       if (dim_attr.getElementType().isInteger(32)) {
1521         target_axis.insert(
1522             (dim_attr.getValues<int32_t>()[i] + tensor_type.getRank()) %
1523             tensor_type.getRank());
1524       } else {
1525         target_axis.insert(
1526             (dim_attr.getValues<int64_t>()[i] + tensor_type.getRank()) %
1527             tensor_type.getRank());
1528       }
1529     }
1530 
1531     for (unsigned i = 0; i < tensor_type.getRank(); ++i) {
1532       if (tensor_type.getShape()[i] != 1 && target_axis.contains(i))
1533         return failure();
1534     }
1535 
1536     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1537     if (failed(identity)) return failure();
1538     rewriter.replaceOp(op, (*identity)->getResults());
1539 
1540     return success();
1541   }
1542 };
1543 
1544 // This implementation is mapped with ConstantFolding::SimplifySlice
1545 // in grappler/optimizers/constant_folding.cc
1546 class SimplifySliceOp : public FolderPatternBase<SimplifySliceOp> {
1547  public:
SimplifySliceOp(OpPropertyHelper & helper)1548   explicit SimplifySliceOp(OpPropertyHelper &helper)
1549       : FolderPatternBase<SimplifySliceOp>("tfg.Slice", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1550   LogicalResult matchAndRewrite(Operation *op,
1551                                 PatternRewriter &rewriter) const override {
1552     Operation *begin_op = op->getOperand(1).getDefiningOp();
1553     Operation *size_op = op->getOperand(2).getDefiningOp();
1554     if (!begin_op || !size_op) return failure();
1555 
1556     if (!dialect_->IsConstant(begin_op) || !dialect_->IsConstant(size_op))
1557       return failure();
1558 
1559     auto begin_attr = begin_op->getAttrOfType<ElementsAttr>("value");
1560     auto size_attr = size_op->getAttrOfType<ElementsAttr>("value");
1561 
1562     ShapedType input_type = op->getOperand(0).getType().cast<ShapedType>();
1563     if (!input_type.hasRank()) return failure();
1564 
1565     for (unsigned i = 0; i < input_type.getRank(); ++i) {
1566       if (begin_attr.getElementType().isInteger(32)) {
1567         if (begin_attr.getValues<int32_t>()[i] != 0) return failure();
1568       } else {
1569         if (begin_attr.getValues<int64_t>()[i] != 0) return failure();
1570       }
1571 
1572       if (size_attr.getElementType().isInteger(32)) {
1573         if (size_attr.getValues<int32_t>()[i] != -1 &&
1574             size_attr.getValues<int32_t>()[i] != input_type.getShape()[i])
1575           return failure();
1576       } else {
1577         if (size_attr.getValues<int64_t>()[i] != -1 &&
1578             size_attr.getValues<int64_t>()[i] != input_type.getShape()[i])
1579           return failure();
1580       }
1581     }
1582 
1583     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1584     if (failed(identity)) return failure();
1585     rewriter.replaceOp(op, (*identity)->getResults());
1586 
1587     return success();
1588   }
1589 };
1590 
1591 // This implementation is mapped with ConstantFolding::SimplifyStridedSlice
1592 // in grappler/optimizers/constant_folding.cc
1593 class SimplifyStridedSlice : public FolderPatternBase<SimplifyStridedSlice> {
1594  public:
SimplifyStridedSlice(OpPropertyHelper & helper)1595   explicit SimplifyStridedSlice(OpPropertyHelper &helper)
1596       : FolderPatternBase<SimplifyStridedSlice>("tfg.StridedSlice", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1597   LogicalResult matchAndRewrite(Operation *op,
1598                                 PatternRewriter &rewriter) const override {
1599     // Skip ops with new/shrink axis mask, since they involve dimension changes.
1600     if (auto attr = op->getAttrOfType<IntegerAttr>("new_axis_mask")) {
1601       if (attr.getInt() != 0) return failure();
1602     } else {
1603       return failure();
1604     }
1605     if (auto attr = op->getAttrOfType<IntegerAttr>("shrink_axis_mask")) {
1606       if (attr.getInt() != 0) return failure();
1607     } else {
1608       return failure();
1609     }
1610 
1611     auto begin_mask_attr = op->getAttrOfType<IntegerAttr>("begin_mask");
1612     auto end_mask_attr = op->getAttrOfType<IntegerAttr>("end_mask");
1613     auto ellipsis_mask_attr = op->getAttrOfType<IntegerAttr>("ellipsis_mask");
1614     if (!begin_mask_attr || !end_mask_attr || !ellipsis_mask_attr)
1615       return failure();
1616 
1617     ShapedType input_type = op->getOperand(0).getType().cast<ShapedType>();
1618     if (!input_type.hasStaticShape()) return failure();
1619 
1620     Operation *begin_op = op->getOperand(1).getDefiningOp();
1621     Operation *end_op = op->getOperand(2).getDefiningOp();
1622     Operation *strides_op = op->getOperand(3).getDefiningOp();
1623     if (!begin_op || !end_op || !strides_op) return failure();
1624 
1625     if (!dialect_->IsConstant(begin_op) || !dialect_->IsConstant(end_op) ||
1626         !dialect_->IsConstant(strides_op))
1627       return failure();
1628 
1629     ElementsAttr begin_attr = begin_op->getAttrOfType<ElementsAttr>("value");
1630     ElementsAttr end_attr = end_op->getAttrOfType<ElementsAttr>("value");
1631     ElementsAttr strides_attr =
1632         strides_op->getAttrOfType<ElementsAttr>("value");
1633 
1634     const int64_t begin_mask = begin_mask_attr.getInt();
1635     const int64_t end_mask = end_mask_attr.getInt();
1636     const int64_t ellipsis_mask = ellipsis_mask_attr.getInt();
1637     const int64_t num_strides_elements = strides_attr.getNumElements();
1638 
1639     DenseSet<int> expanded_ellipsis_indices;
1640     int ellipsis_index = -1;
1641 
1642     for (unsigned i = 0; i < input_type.getRank(); ++i) {
1643       if (ellipsis_mask & 1 << i ||
1644           (ellipsis_index == -1 && i >= num_strides_elements)) {
1645         ellipsis_index = i;
1646       }
1647       if (ellipsis_index != -1 &&
1648           input_type.getRank() > num_strides_elements + i - ellipsis_index) {
1649         expanded_ellipsis_indices.insert(i);
1650       }
1651     }
1652 
1653     for (unsigned i = 0; i < input_type.getRank(); ++i) {
1654       if (expanded_ellipsis_indices.contains(i)) {
1655         // ellipsis_mask is effective on current dimension.
1656         continue;
1657       }
1658 
1659       int j = i;
1660       int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
1661       if (ellipsis_index != -1 &&
1662           i >= ellipsis_index + expanded_ellipsis_indices_size) {
1663         j = i - expanded_ellipsis_indices_size;
1664       }
1665       int b = begin_attr.getElementType().isInteger(32)
1666                   ? begin_attr.getValues<int32_t>()[j]
1667                   : begin_attr.getValues<int64_t>()[j];
1668       int e = end_attr.getElementType().isInteger(32)
1669                   ? end_attr.getValues<int32_t>()[j]
1670                   : end_attr.getValues<int64_t>()[j];
1671       int s = strides_attr.getElementType().isInteger(32)
1672                   ? strides_attr.getValues<int32_t>()[j]
1673                   : strides_attr.getValues<int64_t>()[j];
1674 
1675       if (!(begin_mask & 1 << j || b == 0) ||
1676           !(end_mask & 1 << j || e == input_type.getShape()[i]) || s != 1) {
1677         return failure();
1678       }
1679     }
1680 
1681     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1682     if (failed(identity)) return failure();
1683     rewriter.replaceOp(op, (*identity)->getResults());
1684 
1685     return success();
1686   }
1687 };
1688 
1689 // This implementation is mapped with ConstantFolding::SimplifyTile
1690 // in grappler/optimizers/constant_folding.cc
1691 class SimplifyTileOp : public FolderPatternBase<SimplifyTileOp> {
1692  public:
SimplifyTileOp(OpPropertyHelper & helper)1693   explicit SimplifyTileOp(OpPropertyHelper &helper)
1694       : FolderPatternBase<SimplifyTileOp>("tfg.Tile", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1695   LogicalResult matchAndRewrite(Operation *op,
1696                                 PatternRewriter &rewriter) const override {
1697     Operation *multiples_op = op->getOperand(1).getDefiningOp();
1698     if (!multiples_op || !dialect_->IsConstant(multiples_op)) return failure();
1699 
1700     ElementsAttr multiples_attr =
1701         multiples_op->getAttrOfType<ElementsAttr>("value");
1702     if (multiples_attr.getElementType().isInteger(32)) {
1703       if (llvm::any_of(multiples_attr.getValues<int32_t>(),
1704                        [](int v) { return v != 1; })) {
1705         return failure();
1706       }
1707     } else {
1708       if (llvm::any_of(multiples_attr.getValues<int64_t>(),
1709                        [](int64_t v) { return v != 1; })) {
1710         return failure();
1711       }
1712     }
1713 
1714     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1715     if (failed(identity)) return failure();
1716     rewriter.replaceOp(op, (*identity)->getResults());
1717 
1718     return success();
1719   }
1720 };
1721 
1722 // This implementation is mapped with ConstantFolding::SimplifyPad
1723 // in grappler/optimizers/constant_folding.cc
1724 template <typename ConcreteType>
1725 class SimplifyPadOpBase : public FolderPatternBase<ConcreteType> {
1726  protected:
SimplifyPadOpBase(StringRef op_name,OpPropertyHelper & helper)1727   SimplifyPadOpBase(StringRef op_name, OpPropertyHelper &helper)
1728       : FolderPatternBase<ConcreteType>(op_name, helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1729   LogicalResult matchAndRewrite(Operation *op,
1730                                 PatternRewriter &rewriter) const override {
1731     Operation *paddings = op->getOperand(1).getDefiningOp();
1732     if (!paddings || !this->dialect_->IsConstant(paddings)) return failure();
1733 
1734     ElementsAttr paddings_attr = paddings->getAttrOfType<ElementsAttr>("value");
1735     if (paddings_attr.getElementType().isInteger(32)) {
1736       if (llvm::any_of(paddings_attr.getValues<int32_t>(),
1737                        [](int v) { return v != 0; })) {
1738         return failure();
1739       }
1740     } else {
1741       if (llvm::any_of(paddings_attr.getValues<int64_t>(),
1742                        [](int64_t v) { return v != 0; })) {
1743         return failure();
1744       }
1745     }
1746 
1747     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1748     if (failed(identity)) return failure();
1749     rewriter.replaceOp(op, (*identity)->getResults());
1750 
1751     return success();
1752   }
1753 };
1754 
1755 // This implementation is mapped with ConstantFolding::SimplifyPad
1756 // in grappler/optimizers/constant_folding.cc
1757 class SimplifyPadOp : public SimplifyPadOpBase<SimplifyPadOp> {
1758  public:
SimplifyPadOp(OpPropertyHelper & helper)1759   explicit SimplifyPadOp(OpPropertyHelper &helper)
1760       : SimplifyPadOpBase("tfg.Pad", helper) {}
1761 };
1762 
1763 // This implementation is mapped with ConstantFolding::SimplifyPad
1764 // in grappler/optimizers/constant_folding.cc
1765 class SimplifyPadV2Op : public SimplifyPadOpBase<SimplifyPadV2Op> {
1766  public:
SimplifyPadV2Op(OpPropertyHelper & helper)1767   explicit SimplifyPadV2Op(OpPropertyHelper &helper)
1768       : SimplifyPadOpBase("tfg.PadV2", helper) {}
1769 };
1770 
1771 // This implementation is mapped with ConstantFolding::SimplifySqueeze
1772 // in grappler/optimizers/constant_folding.cc
1773 class SimplifySqueezeOp : public FolderPatternBase<SimplifySqueezeOp> {
1774  public:
SimplifySqueezeOp(OpPropertyHelper & helper)1775   explicit SimplifySqueezeOp(OpPropertyHelper &helper)
1776       : FolderPatternBase<SimplifySqueezeOp>("tfg.Squeeze", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1777   LogicalResult matchAndRewrite(Operation *op,
1778                                 PatternRewriter &rewriter) const override {
1779     auto shape_type = op->getOperand(0).getType().cast<ShapedType>();
1780     if (!shape_type.hasRank()) return failure();
1781     if (llvm::any_of(shape_type.getShape(), [](int64_t s) { return s <= 1; }))
1782       return failure();
1783 
1784     FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1785     if (failed(identity)) return failure();
1786     rewriter.replaceOp(op, (*identity)->getResults());
1787 
1788     return success();
1789   }
1790 };
1791 
1792 // This implementation is mapped with ConstantFolding::SimplifyPack
1793 // in grappler/optimizers/constant_folding.cc
1794 class SimplifyPackOp : public FolderPatternBase<SimplifyPackOp> {
1795  public:
SimplifyPackOp(OpPropertyHelper & helper)1796   explicit SimplifyPackOp(OpPropertyHelper &helper)
1797       : FolderPatternBase<SimplifyPackOp>("tfg.Pack", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1798   LogicalResult matchAndRewrite(Operation *op,
1799                                 PatternRewriter &rewriter) const override {
1800     auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
1801     if (non_control_operands.size() != 1) return failure();
1802 
1803     // It's unsafe to add a control dependency on the feed node, because it
1804     // might have been never executed otherwiwise.
1805     if (non_control_operands[0].isa<BlockArgument>()) return failure();
1806 
1807     IntegerAttr axis = op->getAttrOfType<IntegerAttr>("axis");
1808     ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1809         rewriter.getIntegerType(32), /*shape=*/{},
1810         ArrayRef<int>(axis ? axis.getInt() : 0));
1811     // CreateConstantTensorOp cannot fail
1812     TFOp const_op = *CreateConstantTensorOp(
1813         rewriter, op->getLoc(), TFOp(op).name(), const_attr.getType(),
1814         GetControlDependency(rewriter, op->getOperand(0)), const_attr);
1815 
1816     const_op.setName(Twine(TFOp(op).name(), "/_const_axis"));
1817     if (!TFOp(op).device().empty())
1818       const_op.setRequestedDevice(TFOp(op).deviceAttr());
1819 
1820     OperationState state(op->getLoc(), "tfg.ExpandDims");
1821     state.addTypes(op->getResultTypes());
1822 
1823     state.attributes = op->getAttrDictionary();
1824     state.attributes.erase("axis");
1825     state.attributes.erase("N");
1826     state.addAttribute("Tdim", TypeAttr::get(rewriter.getI32Type()));
1827 
1828     state.addOperands({op->getOperand(0), const_op->getResult(0)});
1829     state.addOperands(control_operands);
1830     Operation *expand_dims_op = rewriter.create(state);
1831     rewriter.replaceOp(op, expand_dims_op->getResults());
1832     return success();
1833   }
1834 };
1835 
1836 // This implementation is mapped with ConstantFolding::MoveConstantsPastEnter
1837 // in grappler/optimizers/constant_folding.cc
1838 template <typename ConcreteType>
1839 class MoveConstantsPastEnterOpBase
1840     : public PropagationPatternBase<ConcreteType> {
1841  protected:
MoveConstantsPastEnterOpBase(StringRef op_name,OpPropertyHelper & helper)1842   MoveConstantsPastEnterOpBase(StringRef op_name, OpPropertyHelper &helper)
1843       : PropagationPatternBase<ConcreteType>(op_name, helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1844   LogicalResult matchAndRewrite(Operation *op,
1845                                 PatternRewriter &rewriter) const override {
1846     auto is_constant_attr = op->getAttrOfType<BoolAttr>("is_constant");
1847     if (!is_constant_attr || !is_constant_attr.getValue()) return failure();
1848 
1849     Operation *input = op->getOperand(0).getDefiningOp();
1850     if (!input || !this->dialect_->IsConstant(input)) return failure();
1851 
1852     // Find non-constant nodes that consume the outputs of Enter.
1853     if (op->getResults()[0].use_empty()) return failure();
1854 
1855     FailureOr<TFOp> cloned_const_op = CreateConstantTensorOp(
1856         rewriter, op->getLoc(), TFOp(op).name(), *(input->result_type_begin()),
1857         TFOp(op).controlRet(), input->getAttr("value"), input->getAttrs());
1858     if (failed(cloned_const_op)) return failure();
1859 
1860     (*cloned_const_op).setName(Twine(TFOp(op).name(), "/_enter"));
1861     if (!TFOp(op).device().empty())
1862       (*cloned_const_op).setRequestedDevice(TFOp(op).deviceAttr());
1863 
1864     rewriter.startRootUpdate(op);
1865     op->getResults()[0].replaceAllUsesWith((*cloned_const_op)->getResults()[0]);
1866     rewriter.finalizeRootUpdate(op);
1867     return success();
1868   }
1869 };
1870 
1871 // This implementation is mapped with ConstantFolding::MoveConstantsPastEnter
1872 // in grappler/optimizers/constant_folding.cc
1873 class MoveConstantsPastEnterOp
1874     : public MoveConstantsPastEnterOpBase<MoveConstantsPastEnterOp> {
1875  public:
MoveConstantsPastEnterOp(OpPropertyHelper & helper)1876   explicit MoveConstantsPastEnterOp(OpPropertyHelper &helper)
1877       : MoveConstantsPastEnterOpBase("tfg.Enter", helper) {}
1878 };
1879 
1880 // This implementation is mapped with ConstantFolding::MoveConstantsPastEnter
1881 // in grappler/optimizers/constant_folding.cc
1882 class MoveConstantsPastRefEnterOp
1883     : public MoveConstantsPastEnterOpBase<MoveConstantsPastRefEnterOp> {
1884  public:
MoveConstantsPastRefEnterOp(OpPropertyHelper & helper)1885   explicit MoveConstantsPastRefEnterOp(OpPropertyHelper &helper)
1886       : MoveConstantsPastEnterOpBase("tfg.RefEnter", helper) {}
1887 };
1888 
1889 // This implementation is mapped with ConstantFolding::SimplifySwitch
1890 // in grappler/optimizers/constant_folding.cc
1891 class SimplifySwitchOp : public PropagationPatternBase<SimplifySwitchOp> {
1892  public:
SimplifySwitchOp(OpPropertyHelper & helper)1893   explicit SimplifySwitchOp(OpPropertyHelper &helper)
1894       : PropagationPatternBase<SimplifySwitchOp>("tfg.Switch", helper),
1895         zero_dim_i1_tensor_type_(RankedTensorType::get(
1896             {}, IntegerType::get(helper.getDialect()->getContext(), 1))) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1897   LogicalResult matchAndRewrite(Operation *op,
1898                                 PatternRewriter &rewriter) const override {
1899     if (op->getOperand(0) != op->getOperand(1)) return failure();
1900 
1901     // If the optimization was already applied, the switch would have exactly
1902     // one Identity node consuming each of its outputs, each without any
1903     // non-control outputs.
1904     // TODO(tlongeri): This does not hold anymore as other patterns may need to
1905     // introduce an anchor. Fix this check, and handle both sides independently.
1906     if (llvm::any_of(op->getResults().drop_back(), [&](Value res) {
1907           return res.hasOneUse() &&
1908                  IsControlAnchor(*res.getUsers().begin(), dialect_);
1909         })) {
1910       return failure();
1911     }
1912 
1913     TFOp true_control_identity =
1914         GetControlAnchorForSwitchResult(rewriter, op->getResult(1), dialect_);
1915     TFOp false_control_identity =
1916         GetControlAnchorForSwitchResult(rewriter, op->getResult(0), dialect_);
1917 
1918     FailureOr<TFOp> true_op = CreateConstantTensorOp(
1919         rewriter, op->getLoc(), TFOp(op).name(), op->getResultTypes()[1],
1920         true_control_identity.controlRet(),
1921         DenseElementsAttr::get(zero_dim_i1_tensor_type_, true));
1922     if (failed(true_op)) return failure();
1923 
1924     (*true_op).setName(Twine(TFOp(op).name(), "/_const_true"));
1925     if (!TFOp(op).device().empty())
1926       (*true_op).setRequestedDevice(TFOp(op).device());
1927 
1928     FailureOr<TFOp> false_op = CreateConstantTensorOp(
1929         rewriter, op->getLoc(), TFOp(op).name(), op->getResultTypes()[0],
1930         false_control_identity.controlRet(),
1931         DenseElementsAttr::get(zero_dim_i1_tensor_type_, false));
1932     if (failed(false_op)) return failure();
1933 
1934     (*false_op).setName(Twine(TFOp(op).name(), "/_const_false"));
1935     if (!TFOp(op).device().empty())
1936       (*false_op).setRequestedDevice(TFOp(op).device().data());
1937 
1938     // Note that we can't use replaceAllUsesWith here because we don't want to
1939     // replace the user of control identity.
1940     for (OpOperand &user :
1941          llvm::make_early_inc_range(op->getResult(1).getUses())) {
1942       if (user.getOwner() == &(*true_control_identity)) continue;
1943 
1944       rewriter.startRootUpdate(user.getOwner());
1945       user.set((*true_op)->getResult(0));
1946       rewriter.finalizeRootUpdate(user.getOwner());
1947     }
1948     for (OpOperand &user :
1949          llvm::make_early_inc_range(op->getResult(0).getUses())) {
1950       if (user.getOwner() == &(*false_control_identity)) continue;
1951 
1952       rewriter.startRootUpdate(user.getOwner());
1953       user.set((*false_op)->getResult(0));
1954       rewriter.finalizeRootUpdate(user.getOwner());
1955     }
1956 
1957     return success();
1958   }
1959 
1960   RankedTensorType zero_dim_i1_tensor_type_;
1961 };
1962 
1963 // This implementation is mapped with ConstantFolding::SimplifyReduction
1964 // in grappler/optimizers/constant_folding.cc
1965 class SimplifyReductionOp : public FolderPatternBase<SimplifyReductionOp> {
1966  public:
SimplifyReductionOp(OpPropertyHelper & helper)1967   explicit SimplifyReductionOp(OpPropertyHelper &helper)
1968       : FolderPatternBase<SimplifyReductionOp>(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1969   LogicalResult matchAndRewrite(Operation *op,
1970                                 PatternRewriter &rewriter) const override {
1971     if (!dialect_->IsReduction(op)) return failure();
1972 
1973     Operation *reduction_indices = op->getOperand(1).getDefiningOp();
1974     if (!reduction_indices) return failure();
1975 
1976     ShapedType indices_type = *(reduction_indices->result_type_begin());
1977     if (indices_type.hasStaticShape() && indices_type.getNumElements() == 0) {
1978       Operation *identity_op = ReplaceReductionWithIdentity(rewriter, op);
1979       if (!identity_op) return failure();
1980 
1981       rewriter.replaceOp(op, identity_op->getResults());
1982       return success();
1983     }
1984 
1985     // Check `IsReductionCandidateForSimplification`
1986     auto input_type = op->getOperand(0).getType().cast<ShapedType>();
1987     auto op_type = (*op->result_type_begin()).cast<ShapedType>();
1988     if (!input_type.hasStaticShape() || !op_type.hasStaticShape())
1989       return failure();
1990 
1991     bool is_single_element_op =
1992         (input_type.getNumElements() == 1) &&
1993         (op_type.hasStaticShape() && op_type.getNumElements() == 1);
1994 
1995     bool keep_dims = false;
1996     if (auto attr = op->getAttrOfType<BoolAttr>("keep_dims")) {
1997       keep_dims = attr.getValue();
1998     }
1999     bool simplifiable_to_reshape =
2000         is_single_element_op && !keep_dims && op->hasAttr("T");
2001 
2002     bool simplifiable_to_identity = keep_dims;
2003     // In grappler, they call EvaluateNode() to try to get the constant value of
2004     // reduction indices. But if it is a constant, then the EvaluationConstant
2005     // will have folded it. So we don't need to evalute the node here.
2006     if (dialect_->IsConstant(reduction_indices)) {
2007       ElementsAttr reduction_indices_attr =
2008           reduction_indices->getAttrOfType<ElementsAttr>("value");
2009 
2010       if (reduction_indices_attr.getElementType().isInteger(32)) {
2011         for (int v : reduction_indices_attr.getValues<int32_t>()) {
2012           if (v < 0) v += input_type.getRank();
2013           if (v < 0 || v >= input_type.getRank() ||
2014               input_type.getShape()[v] != 1)
2015             simplifiable_to_identity = false;
2016         }
2017       } else {
2018         for (int64_t v : reduction_indices_attr.getValues<int64_t>()) {
2019           if (v < 0) v += input_type.getRank();
2020           if (v < 0 || v >= input_type.getRank() ||
2021               input_type.getShape()[v] != 1)
2022             simplifiable_to_identity = false;
2023         }
2024       }
2025     }
2026 
2027     if (simplifiable_to_reshape) {
2028       Operation *reshape_op =
2029           ReplaceReductionWithReshape(rewriter, op, reduction_indices);
2030       if (!reshape_op) return failure();
2031 
2032       rewriter.replaceOp(op, reshape_op->getResults());
2033     } else if (simplifiable_to_identity) {
2034       Operation *identity_op = ReplaceReductionWithIdentity(rewriter, op);
2035       if (!identity_op) return failure();
2036 
2037       rewriter.replaceOp(op, identity_op->getResults());
2038     } else {
2039       return failure();
2040     }
2041 
2042     return success();
2043   }
2044 
2045  private:
ReplaceReductionWithReshape(OpBuilder & builder,Operation * op,Operation * reduction_indices) const2046   Operation *ReplaceReductionWithReshape(OpBuilder &builder, Operation *op,
2047                                          Operation *reduction_indices) const {
2048     const int new_num_dimensions =
2049         (*op->result_type_begin()).cast<ShapedType>().getRank();
2050     SmallVector<int64_t> elements(new_num_dimensions);
2051     std::iota(elements.begin(), elements.end(), 1);
2052     ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
2053         builder.getIntegerType(32), {new_num_dimensions},
2054         llvm::makeArrayRef(elements));
2055     FailureOr<TFOp> const_op = CreateConstantTensorOp(
2056         builder, op->getLoc(), TFOp(op).name(),
2057         *(reduction_indices->result_type_begin()),
2058         TFOp(reduction_indices).controlRet(), const_attr);
2059     if (failed(const_op)) return nullptr;
2060 
2061     (*const_op).setName(Twine(TFOp(op).name(), "/_shape_const"));
2062     if (!TFOp(op).device().empty())
2063       (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
2064 
2065     OperationState state(op->getLoc(), "tfg.Reshape");
2066     state.attributes = op->getAttrDictionary();
2067     state.attributes.erase("keep_dims");
2068     state.attributes.erase("Tidx");
2069     state.addAttribute("Tshape", TypeAttr::get(builder.getI32Type()));
2070 
2071     state.addOperands(op->getOperands());
2072     state.operands[1] = (*const_op)->getResult(0);
2073     state.addTypes(op->getResultTypes());
2074 
2075     Operation *reshape_op = builder.create(state);
2076     TFOp(reshape_op).setName(TFOp(op).nameAttr());
2077     if (!TFOp(op).device().empty())
2078       TFOp(reshape_op).setRequestedDevice(TFOp(op).deviceAttr());
2079     return reshape_op;
2080   }
2081 
ReplaceReductionWithIdentity(OpBuilder & builder,Operation * op) const2082   Operation *ReplaceReductionWithIdentity(OpBuilder &builder,
2083                                           Operation *op) const {
2084     OperationState state(op->getLoc(), "tfg.Identity");
2085     Type t_attr_type;
2086     if (auto T_attr = op->getAttrOfType<TypeAttr>("T"))
2087       t_attr_type = T_attr.getValue();
2088     else if (dialect_->IsAny(op) || dialect_->IsAll(op))
2089       t_attr_type = builder.getI1Type();
2090     else
2091       return nullptr;
2092     state.attributes = op->getAttrDictionary();
2093     util::EraseRegularNodeAttributes(state.attributes);
2094     state.addAttribute("T", TypeAttr::get(t_attr_type));
2095     state.addTypes(op->getResultTypes());
2096     state.addOperands(
2097         {op->getOperand(0), GetControlDependency(builder, op->getOperand(1))});
2098 
2099     Operation *identity_op = builder.create(state);
2100     TFOp(identity_op).setName(TFOp(op).nameAttr());
2101     if (!TFOp(op).device().empty())
2102       TFOp(identity_op).setRequestedDevice(TFOp(op).deviceAttr());
2103     return identity_op;
2104   }
2105 };
2106 
2107 // This implementation is mapped with ConstantFolding::SimplifyReshapeOp
2108 // in grappler/optimizers/constant_folding.cc
2109 class SimplifyReshapeOp : public FolderPatternBase<SimplifyReshapeOp> {
2110  public:
SimplifyReshapeOp(OpPropertyHelper & helper)2111   explicit SimplifyReshapeOp(OpPropertyHelper &helper)
2112       : FolderPatternBase<SimplifyReshapeOp>(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2113   LogicalResult matchAndRewrite(Operation *op,
2114                                 PatternRewriter &rewriter) const override {
2115     if (!dialect_->IsReshape(op) || !op->hasAttr("T")) return failure();
2116 
2117     auto input_shape = op->getOperand(0).getType().cast<ShapedType>();
2118     if (!input_shape.hasStaticShape()) return failure();
2119 
2120     Operation *shape_op = op->getOperand(1).getDefiningOp();
2121     if (!shape_op || !dialect_->IsConstant(shape_op)) return failure();
2122 
2123     auto shape_attr = shape_op->getAttrOfType<ElementsAttr>("value");
2124     // TODO(tlongeri): only reason for SmallVector instead of range directly is
2125     // that llvm::zip implementation requires copy assignment (it shouldn't)
2126     SmallVector<APInt> new_shape(shape_attr.getValues<APInt>());
2127 
2128     if (input_shape.getRank() != new_shape.size()) return failure();
2129     for (const auto &it : llvm::zip(input_shape.getShape(), new_shape)) {
2130       int64_t dim_0 = std::get<0>(it);
2131       int64_t dim_1 = std::get<1>(it).getSExtValue();
2132       if (dim_0 >= 0 && dim_1 >= 0 && dim_0 != dim_1) return failure();
2133     }
2134 
2135     OperationState state(op->getLoc(), "tfg.Identity");
2136     state.addTypes(op->getResultTypes());
2137     state.addOperands(
2138         {op->getOperand(0), GetControlDependency(rewriter, op->getOperand(1))});
2139     state.addOperands(TFOp(op).getControlOperands());
2140 
2141     state.attributes = op->getAttrDictionary();
2142     util::EraseRegularNodeAttributes(state.attributes);
2143     state.addAttribute("T", op->getAttrOfType<TypeAttr>("T"));
2144 
2145     Operation *identity_op = rewriter.create(state);
2146     TFOp(identity_op).setName(TFOp(op).nameAttr());
2147     if (!TFOp(op).device().empty())
2148       TFOp(identity_op).setRequestedDevice(TFOp(op).deviceAttr());
2149     rewriter.replaceOp(op, identity_op->getResults());
2150 
2151     return success();
2152   }
2153 };
2154 
2155 // This implementation is mapped with
2156 // ConstantFolding::SimplifyArithmeticOperations in
2157 // grappler/optimizers/constant_folding.cc
2158 class SimplifyArithmeticOp
2159     : public ConstantPatternBase<SimplifyArithmeticOp, FolderTrait,
2160                                  PropagationTrait> {
2161  public:
SimplifyArithmeticOp(OpPropertyHelper & helper)2162   explicit SimplifyArithmeticOp(OpPropertyHelper &helper)
2163       : ConstantPatternBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2164   LogicalResult matchAndRewrite(Operation *op,
2165                                 PatternRewriter &rewriter) const override {
2166     const bool is_mul = dialect_->IsAnyMul(op) || dialect_->IsLogicalAnd(op);
2167     const bool is_matmul = dialect_->IsAnyMatMul(op);
2168     const bool is_add = dialect_->IsAdd(op) || dialect_->IsBiasAdd(op) ||
2169                         dialect_->IsLogicalOr(op);
2170     const bool is_sub = dialect_->IsSub(op);
2171     const bool is_any_div = dialect_->IsAnyDiv(op) && !dialect_->IsFloorDiv(op);
2172 
2173     if (!is_mul && !is_matmul && !is_add && !is_sub && !is_any_div)
2174       return failure();
2175 
2176     Operation *x = op->getOperand(0).getDefiningOp();
2177     Operation *y = op->getOperand(1).getDefiningOp();
2178     if (!x || !y) return failure();
2179 
2180     ShapedType op_type = (*op->result_type_begin()).cast<ShapedType>();
2181     ShapedType x_type = (*x->result_type_begin()).cast<ShapedType>();
2182     ShapedType y_type = (*y->result_type_begin()).cast<ShapedType>();
2183 
2184     const bool y_matches_output_shape = op_type == y_type;
2185     const bool x_matches_output_shape = op_type == x_type;
2186 
2187     const bool x_is_zero = helper_.IsZeros(x);
2188     const bool x_is_one = x_is_zero ? false : helper_.IsOnes(x);
2189 
2190     // TODO(chiahungduan): Check if the optimizations has been applied.
2191 
2192     if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
2193       // 1 * y = y or 0 + y = y.
2194       if (y_matches_output_shape) {
2195         FailureOr<TFOp> snapshot_op =
2196             ReplaceOperationWithSnapshot(rewriter, op, 1);
2197         if (failed(snapshot_op)) return failure();
2198         rewriter.replaceOp(op, (*snapshot_op)->getResults());
2199         return success();
2200       } else if (x_matches_output_shape) {
2201         FailureOr<TFOp> broadcast_to_op =
2202             ReplaceOperationWithBroadcastTo(rewriter, op, 1);
2203         rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2204         return success();
2205       }
2206       return failure();
2207     }
2208 
2209     if (y_matches_output_shape && (is_sub && x_is_zero)) {
2210       // Replace 0 - y with Neg(y).
2211       OperationState state(op->getLoc(), "tfg.Neg");
2212       state.addOperands({op->getOperand(1),
2213                          GetControlDependency(rewriter, op->getOperand(0))});
2214       state.addOperands(TFOp(op).getControlOperands());
2215       state.attributes = op->getAttrDictionary();
2216       state.addTypes(op->getResultTypes());
2217       Operation *neg = rewriter.create(state);
2218       rewriter.replaceOp(op, neg->getResults());
2219       return success();
2220     }
2221 
2222     // Replace 1 / y with Reciprocal op.
2223     if (y_matches_output_shape && is_any_div && x_is_one) {
2224       TypeAttr type_attr = op->getAttrOfType<TypeAttr>("T");
2225       if (!type_attr) return failure();
2226 
2227       if (type_attr.getValue().isa<FloatType>() ||
2228           type_attr.getValue().isa<ComplexType>()) {
2229         OperationState state(op->getLoc(), "tfg.Reciprocal");
2230         state.addOperands({op->getOperand(1),
2231                            GetControlDependency(rewriter, op->getOperand(0))});
2232         state.addOperands(TFOp(op).getControlOperands());
2233         state.attributes = op->getAttrDictionary();
2234         state.addTypes(op->getResultTypes());
2235         Operation *reciprocal_op = rewriter.create(state);
2236         rewriter.replaceOp(op, reciprocal_op->getResults());
2237         return success();
2238       }
2239     }
2240 
2241     const bool y_is_zero = helper_.IsZeros(y);
2242     const bool y_is_one = helper_.IsOnes(y);
2243 
2244     if (((is_mul || is_any_div) && y_is_one) ||
2245         ((is_add || is_sub) && y_is_zero)) {
2246       // x * 1 = x or x / 1 = x or x +/- 0 = x
2247       if (x_matches_output_shape) {
2248         FailureOr<TFOp> snapshot_op =
2249             ReplaceOperationWithSnapshot(rewriter, op, 0);
2250         if (failed(snapshot_op)) return failure();
2251         rewriter.replaceOp(op, (*snapshot_op)->getResults());
2252         return success();
2253       } else if (y_matches_output_shape) {
2254         FailureOr<TFOp> broadcast_to_op =
2255             ReplaceOperationWithBroadcastTo(rewriter, op, 0);
2256         if (failed(broadcast_to_op)) return failure();
2257         rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2258         return success();
2259       }
2260       return failure();
2261     }
2262 
2263     // x OR true = true OR y = true.
2264     if (op_type.hasStaticShape() && dialect_->IsLogicalOr(op) &&
2265         (y_is_one || x_is_one)) {
2266       FailureOr<TFOp> const_op = ReplaceOperationWithConstant(rewriter, op, 1);
2267       if (failed(const_op)) return failure();
2268       rewriter.replaceOp(op, (*const_op)->getResults());
2269       return success();
2270     }
2271 
2272     // TFG optimizer doesn't support aggrasive mode.
2273     const bool is_aggressive = false;
2274     // Note that this is always false because of `is_aggressive`. Keep it in
2275     // this form to alleviate the effort of comparing the logic with the same
2276     // logic in grappler.
2277     bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
2278     if ((x_is_zero || y_is_zero) &&
2279         (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
2280       if (op_type.hasStaticShape()) {
2281         bool is_quantized = dialect_->IsQuantizedMatMul(op);
2282         if (is_quantized) {
2283           // TODO(chiahungduan): AddQuantizedMatMulMinMaxOutConstNodes
2284           return failure();
2285         }
2286 
2287         FailureOr<TFOp> const_op =
2288             ReplaceOperationWithConstant(rewriter, op, 0);
2289         if (failed(const_op)) return failure();
2290 
2291         rewriter.replaceOp(op, (*const_op)->getResults());
2292         return success();
2293       }
2294 
2295       if ((is_mul || is_any_div) && x_is_zero) {
2296         if (x_matches_output_shape) {
2297           FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
2298           if (failed(identity)) return failure();
2299           rewriter.replaceOp(op, (*identity)->getResults());
2300           return success();
2301         } else if (y_matches_output_shape) {
2302           FailureOr<TFOp> broadcast_to_op =
2303               ReplaceOperationWithBroadcastTo(rewriter, op, 0);
2304           if (failed(broadcast_to_op)) return failure();
2305           rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2306           return success();
2307         }
2308       } else if (is_mul && y_is_zero) {
2309         if (y_matches_output_shape) {
2310           FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
2311           if (failed(identity)) return failure();
2312           rewriter.replaceOp(op, (*identity)->getResults());
2313           return success();
2314         } else if (x_matches_output_shape) {
2315           FailureOr<TFOp> broadcast_to_op =
2316               ReplaceOperationWithBroadcastTo(rewriter, op, 1);
2317           if (failed(broadcast_to_op)) return failure();
2318           rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2319           return success();
2320         }
2321       }
2322     }
2323 
2324     return failure();
2325   }
2326 };
2327 
2328 // This implementation is mapped with ConstantFolding::ReduceDivToReciprocalMul
2329 // in grappler/optimizers/constant_folding.cc
2330 class ReduceDivToReciprocalMul
2331     : public FolderPatternBase<ReduceDivToReciprocalMul> {
2332  public:
ReduceDivToReciprocalMul(OpPropertyHelper & helper)2333   explicit ReduceDivToReciprocalMul(OpPropertyHelper &helper)
2334       : FolderPatternBase<ReduceDivToReciprocalMul>(MatchAnyOpTypeTag(),
2335                                                     helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2336   LogicalResult matchAndRewrite(Operation *op,
2337                                 PatternRewriter &rewriter) const override {
2338     // Strength reduce floating point division by a constant Div(x, const) to
2339     // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
2340     // will be constant folded to Mul(x, 1.0/const).
2341     if (!dialect_->IsDiv(op) && !dialect_->IsRealDiv(op) &&
2342         !dialect_->IsXdivy(op)) {
2343       return failure();
2344     }
2345 
2346     Operation *y = op->getOperand(1).getDefiningOp();
2347     if (!y || !dialect_->IsConstant(y)) return failure();
2348 
2349     TypeAttr type_attr = op->getAttrOfType<TypeAttr>("T");
2350     if (!type_attr) return failure();
2351 
2352     // Skip integer division.
2353     if (dialect_->IsDiv(op) && !(type_attr.getValue().isa<FloatType>() ||
2354                                  type_attr.getValue().isa<ComplexType>())) {
2355       return failure();
2356     }
2357 
2358     OperationState state(op->getLoc(), "tfg.Reciprocal");
2359     state.addOperands(y->getResult(0));
2360     state.addTypes({*(y->result_type_begin()), ControlType::get(getContext())});
2361     state.addAttribute("T", type_attr);
2362     TFOp reciprocal_op = rewriter.create(state);
2363     reciprocal_op.setName(Twine(TFOp(op).name(), "/") +
2364                           Twine(TFOp(y).name(), "/_recip"));
2365     if (!TFOp(op).device().empty())
2366       reciprocal_op.setRequestedDevice(TFOp(op).deviceAttr());
2367 
2368     StringRef new_op_name = dialect_->IsXdivy(op) ? "tfg.MulNoNan" : "tfg.Mul";
2369     OperationState new_op_state(op->getLoc(), new_op_name);
2370 
2371     if (dialect_->IsXdivy(op)) {
2372       new_op_state.addOperands(
2373           {reciprocal_op->getResult(0), op->getOperand(0)});
2374     } else {
2375       new_op_state.addOperands(
2376           {op->getOperand(0), reciprocal_op->getResult(0)});
2377     }
2378     new_op_state.addOperands(TFOp(op).getControlOperands());
2379 
2380     new_op_state.attributes = op->getAttrDictionary();
2381     new_op_state.addTypes(op->getResultTypes());
2382 
2383     Operation *new_op = rewriter.create(new_op_state);
2384     rewriter.replaceOp(op, new_op->getResults());
2385 
2386     return success();
2387   }
2388 };
2389 
2390 namespace {
2391 template <typename ConcreteType>
2392 using Base = ConstantPatternBase<ConcreteType, FolderTrait, PropagationTrait>;
2393 
2394 template <typename ConcreteType>
2395 class ConstantPushDownBase : public Base<ConcreteType> {
2396  protected:
2397   using Base<ConcreteType>::Base;
2398 
IsOperandsSafeToMove(Operation * op_child,Operation * const_child) const2399   bool IsOperandsSafeToMove(Operation *op_child, Operation *const_child) const {
2400     // Don't rewrite the tree if it might create cycles.
2401     // TODO(chiahungduan): Remove the control dependency which may create
2402     // cycles.
2403     if (llvm::any_of(
2404             TFOp(const_child).getControlOperands(),
2405             [op_child](Value v) { return v.getDefiningOp() == op_child; })) {
2406       return false;
2407     }
2408 
2409     // Move operands may change the result shapes, only do it when there's one
2410     // user for each of non control return values.
2411     if (llvm::any_of(op_child->getResults().drop_back(),
2412                      [](Value v) { return !v.hasOneUse(); })) {
2413       return false;
2414     }
2415     return true;
2416   }
2417 };
2418 }  // namespace
2419 
2420 // Consider the transformation
2421 //
2422 //                      +                +       = parent
2423 //                     / \              / \
2424 //                    C   +    -- >    X   +     = children
2425 //                       / \              / \
2426 //                      X   Y            C   Y   = leaves
2427 //
2428 // where C is constant, X is non-constant, Y may be constant or non-constant,
2429 // and '+' denotes an associative and commutative operator like addition or
2430 // multiplication. This optimization pushes constants down in the tree to
2431 // canonicalize it. Moreover, in cases where the child node has a second
2432 // constant input Y we will create a leaf node that can be folded, e.g.
2433 //
2434 //    Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
2435 //
2436 // We also handle the non-commutative cases of subtraction and division
2437 // by rotating the tree locally, e.g.
2438 //    Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
2439 //    Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
2440 class ConstantPushDown : public ConstantPushDownBase<ConstantPushDown> {
2441  public:
ConstantPushDown(OpPropertyHelper & helper)2442   explicit ConstantPushDown(OpPropertyHelper &helper)
2443       : ConstantPushDownBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2444   LogicalResult matchAndRewrite(Operation *op,
2445                                 PatternRewriter &rewriter) const override {
2446     // Get parent op type.
2447     const bool is_add = dialect_->IsAdd(op);
2448     const bool is_mul = dialect_->IsMul(op);
2449     const bool is_sub = dialect_->IsSub(op);
2450     const bool is_div = dialect_->IsDiv(op);
2451     if (!(is_add || is_sub || is_mul || is_div)) return failure();
2452     const bool is_symmetric = is_add || is_mul;
2453 
2454     Operation *child_op = op->getOperand(0).getDefiningOp();
2455     Operation *const_op = op->getOperand(1).getDefiningOp();
2456     if (!child_op || !const_op) return failure();
2457 
2458     // Don't move nodes across devices.
2459     if (TFOp(op).deviceAttr() != TFOp(child_op).deviceAttr() ||
2460         TFOp(op).deviceAttr() != TFOp(const_op).deviceAttr()) {
2461       return failure();
2462     }
2463 
2464     const bool left_child_is_const = dialect_->IsConstant(child_op);
2465 
2466     // One of the child op has to be constant.
2467     if (!dialect_->IsConstant(const_op)) std::swap(child_op, const_op);
2468     if (!dialect_->IsConstant(const_op)) return failure();
2469     if (helper_.ShouldPreserveOp(child_op)) return failure();
2470 
2471     if (!IsOperandsSafeToMove(child_op, const_op)) return failure();
2472 
2473     // Get child op type.
2474     const bool is_child_add = dialect_->IsAdd(child_op);
2475     const bool is_child_mul = dialect_->IsMul(child_op);
2476     const bool is_child_sub = dialect_->IsSub(child_op);
2477     const bool is_child_div = dialect_->IsDiv(child_op);
2478     const bool is_add_sub =
2479         (is_add || is_sub) && (is_child_add || is_child_sub);
2480     const bool is_mul_div =
2481         (is_mul || is_div) && (is_child_mul || is_child_div);
2482     if (!is_add_sub && !is_mul_div) return failure();
2483 
2484     const bool is_child_symmetric = is_child_add || is_child_mul;
2485 
2486     TypeAttr t_attr = op->getAttrOfType<TypeAttr>("T");
2487     if (!t_attr) return failure();
2488 
2489     if (!(is_symmetric && is_child_symmetric) &&
2490         t_attr.getValue().isIntOrIndex()) {
2491       return failure();
2492     }
2493 
2494     Operation *left_leaf_op = child_op->getOperand(0).getDefiningOp();
2495     Operation *right_leaf_op = child_op->getOperand(1).getDefiningOp();
2496     if (!left_leaf_op || !right_leaf_op) return failure();
2497 
2498     // Don't move nodes across devices.
2499     if (TFOp(op).deviceAttr() != TFOp(left_leaf_op).deviceAttr() ||
2500         TFOp(op).deviceAttr() != TFOp(right_leaf_op).deviceAttr()) {
2501       return failure();
2502     }
2503 
2504     const bool left_leaf_is_const = dialect_->IsConstant(left_leaf_op);
2505     Operation *y_node = left_leaf_is_const ? left_leaf_op : right_leaf_op;
2506 
2507     if (!dialect_->IsConstant(y_node)) {
2508       // If we know the shapes of the nodes being swapped, make sure we don't
2509       // push down a larger node and create more work by broadcasting earlier
2510       // in the expressions tree.
2511       auto c_shape = op->getOperand((left_child_is_const ? 0 : 1))
2512                          .getType()
2513                          .cast<ShapedType>();
2514       auto x_shape = child_op->getOperand((left_leaf_is_const ? 0 : 1))
2515                          .getType()
2516                          .cast<ShapedType>();
2517 
2518       if (c_shape.hasStaticShape() && x_shape.hasStaticShape() &&
2519           c_shape.getNumElements() > x_shape.getNumElements()) {
2520         return failure();
2521       }
2522       if (c_shape.hasRank() && x_shape.hasRank() && c_shape.getRank() > 0) {
2523         for (auto it : llvm::zip(c_shape.getShape(), x_shape.getShape())) {
2524           int c_dim = std::get<0>(it);
2525           int x_dim = std::get<1>(it);
2526           if (x_dim >= 0 && c_dim > x_dim) return failure();
2527         }
2528       }
2529     }
2530 
2531     // Child input
2532     Operation *input_x = left_leaf_is_const
2533                              ? child_op->getOperand(1).getDefiningOp()
2534                              : child_op->getOperand(0).getDefiningOp();
2535     Operation *input_y = left_leaf_is_const
2536                              ? child_op->getOperand(0).getDefiningOp()
2537                              : child_op->getOperand(1).getDefiningOp();
2538     if (!input_x || !input_y) return failure();
2539 
2540     Operation *input_c = const_op;
2541     Operation *input_op = child_op;
2542 
2543     if (op->getOperand(0).getDefiningOp() == input_c)
2544       op->setOperand(0, input_x->getResult(0));
2545     else
2546       op->setOperand(1, input_x->getResult(0));
2547 
2548     if (is_symmetric && is_child_symmetric) {
2549       // Easy case (only commutative ops). We always write this as one of
2550       //   +
2551       //  / \
2552       // X   +
2553       //    / \
2554       //   C   Y
2555       rewriter.startRootUpdate(op);
2556       op->setOperand(0, input_x->getResult(0));
2557       op->setOperand(1, input_op->getResult(0));
2558       rewriter.finalizeRootUpdate(op);
2559       rewriter.startRootUpdate(child_op);
2560       child_op->setOperand(0, input_c->getResult(0));
2561       child_op->setOperand(1, input_y->getResult(0));
2562       rewriter.finalizeRootUpdate(child_op);
2563     } else {
2564       // More complicated case: When there are non-commutative operations like
2565       // subtractions or divisions involved, we may have to rotate the tree
2566       // and/or change op types. There are 6 non-trivial cases depending on
2567       // the effective generalized "sign" of each of the three terms C, Y, and
2568       // X. Here are the final trees we want to generate for those 6 cases:
2569       //
2570       // (CYX signs):   ++-      +--      -+-    --+     +-+      -++
2571       //
2572       //                 -        -        -      -       +        +
2573       //                / \      / \      / \    / \     / \      / \
2574       //               +   X    -   X    -   X  X   +   X   -    X   -
2575       //              / \      / \      / \        / \     / \      / \
2576       //             C   Y    C   Y    Y   C      Y   C   C   Y    Y   C
2577       //
2578 
2579       // First, let's determine the effective sign of each term in the original
2580       // expression
2581       auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
2582         bool leaf_negated = !is_child_symmetric && is_right_leaf;
2583         bool child_negated = !is_symmetric && left_child_is_const;
2584         return leaf_negated != child_negated;
2585       };
2586 
2587       StringRef symmetric_op = (is_add || is_sub) ? "tfg.Add" : "tfg.Mul";
2588       StringRef nonsymmetric_op = (is_add || is_sub) ? "tfg.Sub" : "tfg.Div";
2589       bool neg_c = !is_symmetric && !left_child_is_const;
2590       bool neg_x = is_leaf_negated(left_leaf_is_const);
2591       bool neg_y = is_leaf_negated(!left_leaf_is_const);
2592 
2593       StringRef op_name =
2594           (neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op;
2595       OperationState state(op->getLoc(), op_name);
2596       state.addOperands({input_op->getResult(0), input_x->getResult(0)});
2597       if (!neg_x) std::swap(state.operands[0], state.operands[1]);
2598       state.addOperands(TFOp(op).getControlOperands());
2599       state.attributes = op->getAttrDictionary();
2600       state.addTypes(op->getResultTypes());
2601       Operation *new_op = rewriter.create(state);
2602       rewriter.replaceOp(op, new_op->getResults());
2603 
2604       StringRef child_name = neg_c != neg_y ? nonsymmetric_op : symmetric_op;
2605       OperationState new_child_state(child_op->getLoc(), child_name);
2606       new_child_state.addOperands(
2607           {input_y->getResult(0), input_c->getResult(0)});
2608       if (!neg_c)
2609         std::swap(new_child_state.operands[0], new_child_state.operands[1]);
2610       new_child_state.addOperands(TFOp(child_op).getControlOperands());
2611       new_child_state.attributes = child_op->getAttrDictionary();
2612       new_child_state.addTypes(child_op->getResultTypes());
2613       rewriter.setInsertionPoint(child_op);
2614       Operation *new_child_op = rewriter.create(new_child_state);
2615       rewriter.replaceOp(child_op, new_child_op->getResults());
2616     }
2617     return success();
2618   }
2619 };
2620 
2621 // This implementation is mapped with
2622 // ConstantFolding::PartialConstPropThroughIdentityN in
2623 // grappler/optimizers/constant_folding.cc
2624 class PartialConstPropThroughIdentityN
2625     : public PropagationPatternBase<PartialConstPropThroughIdentityN> {
2626  public:
PartialConstPropThroughIdentityN(OpPropertyHelper & helper)2627   explicit PartialConstPropThroughIdentityN(OpPropertyHelper &helper)
2628       : PropagationPatternBase<PartialConstPropThroughIdentityN>(
2629             MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2630   LogicalResult matchAndRewrite(Operation *op,
2631                                 PatternRewriter &rewriter) const override {
2632     // In grappler's constant folding, it propagates the values from IdentityN.
2633     // At here, we check the operand which is defined by Identity/IdentityN.
2634 
2635     SmallVector<Value> control_operands;
2636     for (OpOperand &operand : op->getOpOperands()) {
2637       Value v = operand.get();
2638       if (v.getType().isa<ControlType>()) break;
2639 
2640       Operation *v_op = v.getDefiningOp();
2641       if (!v_op || !dialect_->IsIdentityN(v_op) ||
2642           dialect_->IsIdentityNSingleInput(v_op)) {
2643         continue;
2644       }
2645 
2646       int res_index = v.cast<OpResult>().getResultNumber();
2647       Value value_to_forward = v_op->getOperand(res_index);
2648       if (!value_to_forward.getDefiningOp() ||
2649           !dialect_->IsConstant(value_to_forward.getDefiningOp())) {
2650         continue;
2651       }
2652 
2653       rewriter.startRootUpdate(op);
2654       operand.set(value_to_forward);
2655       rewriter.finalizeRootUpdate(op);
2656 
2657       // Add the control dependency to the Identity/IdentityN. Note that it's
2658       // possible to have multiple operands defined by the same
2659       // Identity/IdentityN. Given the number is small and this propagation is
2660       // usually done on an operation one time, do a linear scan before
2661       // insertion.
2662       Value control = TFOp(v_op).controlRet();
2663       if (!llvm::is_contained(control_operands, control))
2664         control_operands.push_back(control);
2665     }
2666 
2667     // No new control operands implies that we didn't find constants that can be
2668     // propagated through Identity/IdentityN.
2669     if (control_operands.empty()) return failure();
2670 
2671     OperationState state(op->getLoc(), op->getName());
2672     state.attributes = op->getAttrDictionary();
2673     state.addOperands(op->getOperands());
2674     // Append the newly added control operands from Identity/IdentityN.
2675     state.addOperands(control_operands);
2676     state.addTypes(op->getResultTypes());
2677 
2678     Operation *new_op = rewriter.create(state);
2679     rewriter.replaceOp(op, new_op->getResults());
2680 
2681     return success();
2682   }
2683 };
2684 
2685 // This implementation is mapped with
2686 // ConstantFolding::PartialAssocOpConstFolding in
2687 // grappler/optimizers/constant_folding.cc
2688 class PartialAssocOpConstFolding
2689     : public FolderPatternBase<PartialAssocOpConstFolding> {
2690  public:
PartialAssocOpConstFolding(OpPropertyHelper & helper)2691   explicit PartialAssocOpConstFolding(OpPropertyHelper &helper)
2692       : FolderPatternBase<PartialAssocOpConstFolding>(MatchAnyOpTypeTag(),
2693                                                       helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2694   LogicalResult matchAndRewrite(Operation *op,
2695                                 PatternRewriter &rewriter) const override {
2696     // Partial constant folding for associative operators:
2697     // Split AddN/AccumulateNV2 to enable partial
2698     // folding of ops when more than one but not all inputs are constant.
2699     // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
2700     // addition is commutative.
2701     if (!helper_.IsAggregate(op) || !helper_.IsCommutative(op))
2702       return failure();
2703 
2704     SmallVector<Value> const_inputs;
2705     SmallVector<Value> non_const_inputs;
2706 
2707     auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
2708     int non_control_inputs_size = non_control_operands.size();
2709     if (non_control_inputs_size <= 2) return failure();
2710 
2711     if (llvm::any_of(non_control_operands, [](Value v) {
2712           Operation *v_op = v.getDefiningOp();
2713           return v_op &&
2714                  TFOp(v_op).name().rfind("_partial_split_") != StringRef::npos;
2715         })) {
2716       return failure();
2717     }
2718 
2719     for (Value operand : non_control_operands) {
2720       Operation *may_const_op = operand.getDefiningOp();
2721       if (may_const_op && dialect_->IsConstant(may_const_op))
2722         const_inputs.push_back(operand);
2723       else
2724         non_const_inputs.push_back(operand);
2725     }
2726 
2727     if (const_inputs.size() == non_control_inputs_size &&
2728         op->getName().stripDialect() == "AccumulateNV2") {
2729       OperationState state(op->getLoc(), "tfg.AddN");
2730       state.addTypes(op->getResultTypes());
2731       state.addOperands(op->getOperands());
2732       state.attributes = op->getAttrDictionary();
2733       state.attributes.erase("shape");
2734       Operation *add_n = rewriter.create(state);
2735       rewriter.replaceOp(op, add_n->getResults());
2736       return success();
2737     }
2738 
2739     if (const_inputs.size() <= 1) return failure();
2740 
2741     OperationState state(op->getLoc(), "tfg.AddN");
2742     state.addOperands(const_inputs);
2743     state.addTypes(op->getResultTypes());
2744     state.attributes = op->getAttrDictionary();
2745     state.attributes.erase("shape");
2746     state.attributes.set("N", IntegerAttr::get(rewriter.getIntegerType(32),
2747                                                const_inputs.size()));
2748     Operation *add_n = rewriter.create(state);
2749     TFOp(add_n).setName(Twine(TFOp(op).name(), "/_partial_split_") +
2750                         std::to_string(const_inputs.size()));
2751     // Op inherits all the attrs of op, don't need to update the device attr.
2752 
2753     OperationState new_op_state(op->getLoc(), op->getName());
2754     // Note that in grappler, it puts the AddOp at the position of the first
2755     // const operand. Here we always put the AddOp at begin.
2756     new_op_state.addOperands(add_n->getResult(0));
2757     new_op_state.addOperands(non_const_inputs);
2758     new_op_state.addOperands(control_operands);
2759     new_op_state.addTypes(op->getResultTypes());
2760     new_op_state.attributes = op->getAttrDictionary();
2761     new_op_state.attributes.set("N",
2762                                 IntegerAttr::get(rewriter.getIntegerType(32),
2763                                                  non_const_inputs.size() + 1));
2764 
2765     Operation *new_op = rewriter.create(new_op_state);
2766     rewriter.replaceOp(op, new_op->getResults());
2767 
2768     return success();
2769   }
2770 };
2771 
2772 // This implementation is mapped with ConstantFolding::MergeConcat in
2773 // grappler/optimizers/constant_folding.cc
2774 class MergeConcatOp : public FolderPatternBase<MergeConcatOp> {
2775  public:
MergeConcatOp(OpPropertyHelper & helper)2776   explicit MergeConcatOp(OpPropertyHelper &helper)
2777       : FolderPatternBase<MergeConcatOp>("tfg.ConcatV2", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2778   LogicalResult matchAndRewrite(Operation *op,
2779                                 PatternRewriter &rewriter) const override {
2780     if (helper_.ShouldPreserveOp(op)) return failure();
2781 
2782     auto getAxis = [&](Operation *axis_op) {
2783       ElementsAttr axis_attr = axis_op->getAttrOfType<ElementsAttr>("value");
2784       return axis_attr.getElementType().isInteger(64)
2785                  ? static_cast<int>(axis_attr.getSplatValue<int64_t>())
2786                  : axis_attr.getSplatValue<int>();
2787     };
2788 
2789     auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
2790     Operation *axis_op = non_control_operands.back().getDefiningOp();
2791     if (!axis_op || !dialect_->IsConstant(axis_op)) return failure();
2792     int axis = getAxis(axis_op);
2793 
2794     // In grappler, it checks the first user of the ConcatV2 to see if it's also
2795     // a ConcatV2. At here, we check the user's operand. Another difference is
2796     // that grappler only checks the first user and we check all the operands.
2797     Operation *concat_operand = nullptr;
2798     for (Value operand : non_control_operands) {
2799       Operation *defining_op = operand.getDefiningOp();
2800       if (defining_op && dialect_->IsConcatV2(defining_op)) {
2801         concat_operand = defining_op;
2802         break;
2803       }
2804     }
2805     if (!concat_operand) return failure();
2806 
2807     auto [concat_non_control_operands, concat_control_operands] =
2808         TFOp(concat_operand).splitOperands();
2809     Operation *concat_operand_axis_op =
2810         concat_non_control_operands.back().getDefiningOp();
2811     if (!concat_operand_axis_op ||
2812         !dialect_->IsConstant(concat_operand_axis_op)) {
2813       return failure();
2814     }
2815     if (axis != getAxis(concat_operand_axis_op)) return failure();
2816 
2817     // If all inputs are constant, don't merge and let EvaluateConstant take
2818     // case of it.
2819     if (llvm::all_of(concat_non_control_operands.drop_back(), [&](Value v) {
2820           return v.getDefiningOp() && dialect_->IsConstant(v.getDefiningOp());
2821         })) {
2822       return failure();
2823     }
2824 
2825     // Make a pass over the parent inputs to see if any of them have explicit
2826     // device() fields set, and if different inputs are on different tasks.  If
2827     // so, this concat of concats may have been carefully constructed to be a
2828     // two-stage concat, and we don't want to undo that here.
2829     std::string task, device;
2830     StringRef unique_input_tasks;
2831     for (Value v : non_control_operands) {
2832       Operation *v_op = v.getDefiningOp();
2833       if (!v_op || v_op == axis_op) continue;
2834       StringRef op_device = TFOp(v_op).device();
2835       if (!op_device.empty() && tensorflow::DeviceNameUtils::SplitDeviceName(
2836                                     op_device.str(), &task, &device)) {
2837         if (unique_input_tasks.empty())
2838           unique_input_tasks = task;
2839         else if (unique_input_tasks != task)
2840           return failure();
2841       }
2842     }
2843 
2844     OperationState state(op->getLoc(), "tfg.ConcatV2");
2845     for (Value operand : non_control_operands) {
2846       if (operand == concat_operand->getResult(0)) {
2847         // Inline the non-control operands of concat_operand.
2848         state.addOperands(ValueRange(concat_non_control_operands.drop_back()));
2849       } else {
2850         state.addOperands(operand);
2851       }
2852     }
2853     // Copy the control operands.
2854     state.addOperands(control_operands);
2855     state.addOperands(concat_control_operands);
2856     state.addTypes(op->getResultTypes());
2857     state.attributes = op->getAttrDictionary();
2858     state.attributes.set("N", IntegerAttr::get(rewriter.getIntegerType(32),
2859                                                state.operands.size() - 1));
2860     Operation *concat_op = rewriter.create(state);
2861     rewriter.replaceOp(op, concat_op->getResults());
2862 
2863     return success();
2864   }
2865 };
2866 
2867 // This implementation is mapped with ConstantFolding::MulConvPushDown
2868 // in grappler/optimizers/constant_folding.cc
2869 class MulConvPushDown : public ConstantPatternBase<MulConvPushDown, FolderTrait,
2870                                                    PropagationTrait> {
2871  public:
MulConvPushDown(OpPropertyHelper & helper)2872   explicit MulConvPushDown(OpPropertyHelper &helper)
2873       : ConstantPatternBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2874   LogicalResult matchAndRewrite(Operation *op,
2875                                 PatternRewriter &rewriter) const override {
2876     // Push down multiplication on ConvND.
2877     //                       *                  ConvND
2878     //                     /   \                /    \
2879     //                 ConvND  C2    -- >      X      *
2880     //                  / \                          / \
2881     //                 X  C1                       C1  C2
2882     //
2883     // where C1 and C2 are constants and X is non-constant.
2884     if (!dialect_->IsAnyMul(op)) return failure();
2885 
2886     Operation *mul_left_child = op->getOperand(0).getDefiningOp();
2887     Operation *mul_right_child = op->getOperand(1).getDefiningOp();
2888     if (!mul_left_child || !mul_right_child) return failure();
2889 
2890     const bool left_child_is_constant = dialect_->IsConstant(mul_left_child);
2891     const bool right_child_is_constant = dialect_->IsConstant(mul_right_child);
2892     // One child must be constant, and the second must be Conv op.
2893     if (!left_child_is_constant && !right_child_is_constant) return failure();
2894 
2895     Operation *conv_node =
2896         left_child_is_constant ? mul_right_child : mul_left_child;
2897     if (!dialect_->IsConv2D(conv_node) && !dialect_->IsConv3D(conv_node))
2898       return failure();
2899 
2900     // Make sure that it is safe to change the value of the convolution
2901     // output.
2902     if (helper_.ShouldPreserveOp(conv_node)) return failure();
2903 
2904     if (TFOp(op).deviceAttr() != TFOp(mul_left_child).deviceAttr() ||
2905         TFOp(op).deviceAttr() != TFOp(mul_right_child).deviceAttr()) {
2906       return failure();
2907     }
2908 
2909     // Identify the nodes to swap.
2910     Operation *conv_left_child = conv_node->getOperand(0).getDefiningOp();
2911     Operation *conv_right_child = conv_node->getOperand(1).getDefiningOp();
2912     const bool conv_left_is_constant =
2913         conv_left_child && dialect_->IsConstant(conv_left_child);
2914     const bool conv_right_is_constant =
2915         conv_right_child && dialect_->IsConstant(conv_right_child);
2916     if (!conv_left_is_constant && !conv_right_is_constant) {
2917       // At least one of the convolution inputs should be constant.
2918       return failure();
2919     }
2920 
2921     if (conv_left_is_constant && conv_right_is_constant) {
2922       // Operation evaluation will handle this.
2923       return failure();
2924     }
2925 
2926     ShapedType mul_shape = (*op->result_type_begin()).cast<ShapedType>();
2927     ShapedType conv_shape =
2928         (*conv_node->result_type_begin()).cast<ShapedType>();
2929     // TODO(chiahungduan): Symbolic shape equivalence is acceptable.
2930     if (!mul_shape.hasStaticShape() || !conv_shape.hasStaticShape() ||
2931         mul_shape != conv_shape) {
2932       return failure();
2933     }
2934 
2935     auto filter_shape = conv_node->getOperand(1).getType().cast<ShapedType>();
2936 
2937     Operation *const_node =
2938         left_child_is_constant ? mul_left_child : mul_right_child;
2939     auto const_node_shape =
2940         (*const_node->result_type_begin()).cast<ShapedType>();
2941     if (!IsValidConstShapeForMulConvPushDown(
2942             conv_node->getAttrOfType<StringAttr>("data_format"), filter_shape,
2943             const_node_shape)) {
2944       return failure();
2945     }
2946 
2947     Operation *conv_const_node =
2948         conv_left_is_constant ? conv_left_child : conv_right_child;
2949     // Make sure we don't introduce loops in the graph by removing control
2950     // dependencies from the conv2d node to c2.
2951     if (Operation *new_const_op =
2952             RemoveControlOperandIfExist(rewriter, const_node, conv_node)) {
2953       rewriter.replaceOp(const_node, new_const_op->getResults());
2954       const_node = new_const_op;
2955 
2956       // Add a control dep from c1 to c2 to ensure c2 is in the right frame
2957       AddControlOperand(const_node, TFOp(conv_const_node).controlRet(),
2958                         rewriter);
2959     }
2960 
2961     StringRef conv_node_name = TFOp(conv_node).name();
2962 
2963     rewriter.startRootUpdate(conv_node);
2964     TFOp(conv_node).setName(TFOp(op).nameAttr());
2965     if (conv_left_is_constant)
2966       conv_node->setOperand(0, op->getResult(0));
2967     else
2968       conv_node->setOperand(1, op->getResult(0));
2969     rewriter.finalizeRootUpdate(conv_node);
2970 
2971     rewriter.startRootUpdate(op);
2972     TFOp(op).setName(Twine(conv_node_name, "/merged_input"));
2973     if (left_child_is_constant)
2974       op->setOperand(1, conv_const_node->getResult(0));
2975     else
2976       op->setOperand(0, conv_const_node->getResult(0));
2977     rewriter.finalizeRootUpdate(op);
2978 
2979     return success();
2980   }
2981 
2982  private:
2983   // Remove the control dependency from `op` to `to_remove` if any.
RemoveControlOperandIfExist(OpBuilder & builder,Operation * op,Operation * to_remove) const2984   Operation *RemoveControlOperandIfExist(OpBuilder &builder, Operation *op,
2985                                          Operation *to_remove) const {
2986     auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
2987     Value control_to_remove = TFOp(to_remove).controlRet();
2988     SmallVector<Value> new_control_operands(control_operands);
2989     auto it = llvm::remove_if(
2990         new_control_operands,
2991         [control_to_remove](Value v) { return v == control_to_remove; });
2992     if (it == new_control_operands.end()) return nullptr;
2993     new_control_operands.erase(it, new_control_operands.end());
2994 
2995     OperationState state(op->getLoc(), op->getName());
2996     state.addOperands(non_control_operands);
2997     state.addOperands(new_control_operands);
2998     state.addAttributes(op->getAttrs());
2999     state.addTypes(op->getResultTypes());
3000 
3001     return builder.create(state);
3002   }
3003 };
3004 
3005 // This implementation is mapped with ConstantFolding::PartialConcatConstFolding
3006 // in grappler/optimizers/constant_folding.cc
3007 class PartialConcatConstFolding
3008     : public FolderPatternBase<PartialConcatConstFolding> {
3009  public:
PartialConcatConstFolding(OpPropertyHelper & helper)3010   explicit PartialConcatConstFolding(OpPropertyHelper &helper)
3011       : FolderPatternBase<PartialConcatConstFolding>(MatchAnyOpTypeTag(),
3012                                                      helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3013   LogicalResult matchAndRewrite(Operation *op,
3014                                 PatternRewriter &rewriter) const override {
3015     // Partial constant folding for Concat which is not commutative, so
3016     // we have to preserve order and can only push consecutive runs of constant
3017     // inputs into sub-nodes.
3018     if (!dialect_->IsConcat(op)) return failure();
3019     if (TFOp(op).name().rfind("_partial_split_") != StringRef::npos) {
3020       return failure();
3021     }
3022 
3023     auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
3024     const int num_non_control_inputs = non_control_operands.size();
3025     if (num_non_control_inputs <= 3) return failure();
3026 
3027     int axis_arg = -1;
3028     int begin = 0;
3029     int end = num_non_control_inputs;
3030     // Note that IsConcat includes both Concat and ConcatV2 so that we need to
3031     // check ConcatV2 first.
3032     if (dialect_->IsConcatV2(op)) {
3033       end = num_non_control_inputs - 1;
3034       axis_arg = num_non_control_inputs - 1;
3035     } else if (dialect_->IsConcat(op)) {
3036       begin = 1;
3037       axis_arg = 0;
3038     } else {
3039       return failure();
3040     }
3041 
3042     // We search for consecutive runs of constant inputs in the range
3043     // [begin:end] and push then down into child nodes.
3044     SmallVector<std::pair<int, int>> constant_input_runs;
3045     int first = begin;
3046     int last = begin;
3047     while (last < end) {
3048       while (first < end) {
3049         Operation *v_op = op->getOperand(first).getDefiningOp();
3050         if (v_op && dialect_->IsConstant(v_op)) break;
3051         ++first;
3052       }
3053 
3054       // Invariant: node[first] is constant || first >= end.
3055       last = first + 1;
3056       while (last < end) {
3057         Operation *v_op = op->getOperand(last).getDefiningOp();
3058         if (!v_op || !dialect_->IsConstant(v_op)) break;
3059         ++last;
3060       }
3061 
3062       // Invariant: node[last] is not constant || last >= end
3063       // Discard intervals shorter than 2 elements.
3064       if (first < end && (last - first) > 1)
3065         constant_input_runs.emplace_back(first, last);
3066       first = last;
3067     }
3068 
3069     // Skip if all inputs are constant, and let constant folding take over.
3070     if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3071                                         constant_input_runs[0].first == begin &&
3072                                         constant_input_runs[0].second == end)) {
3073       return failure();
3074     }
3075 
3076     // TODO(chiahungduan): The optimization is able to be applied multiple
3077     // times. Find a better way to name the new ops without having duplicate
3078     // name. Now we just optimize it once.
3079     if (llvm::any_of(non_control_operands, [](Value v) {
3080           Operation *v_op = v.getDefiningOp();
3081           return v_op &&
3082                  TFOp(v_op).name().rfind("_partial_split_") != StringRef::npos;
3083         })) {
3084       return failure();
3085     }
3086 
3087     DenseSet<int> inputs_to_delete;
3088     for (auto interval : constant_input_runs) {
3089       // Push the constant inputs in the interval to a child node than can be
3090       // constant folded.
3091       OperationState state(op->getLoc(), "tfg.ConcatV2");
3092       state.addOperands(op->getOperand(interval.first));
3093       for (auto i : llvm::seq<int>(interval.first + 1, interval.second)) {
3094         state.addOperands(op->getOperand(i));
3095         inputs_to_delete.insert(i);
3096       }
3097       state.addOperands(op->getOperand(axis_arg));
3098       state.attributes = op->getAttrDictionary();
3099       state.attributes.set("N",
3100                            IntegerAttr::get(rewriter.getI32Type(),
3101                                             interval.second - interval.first));
3102       state.addTypes(op->getResultTypes());
3103 
3104       Operation *new_op = rewriter.create(state);
3105       TFOp(new_op).setName(Twine(TFOp(op).name(), "/_partial_split_") +
3106                            std::to_string(interval.first));
3107       // Op inherits all the attrs of op, don't need to update the device attr.
3108 
3109       // Overwrite the first constant input with the result of the added
3110       // child node.
3111       rewriter.startRootUpdate(op);
3112       op->setOperand(interval.first, new_op->getResult(0));
3113       rewriter.finalizeRootUpdate(op);
3114     }
3115 
3116     if (!inputs_to_delete.empty()) {
3117       OperationState state(op->getLoc(), op->getName());
3118       for (auto &it : llvm::enumerate(non_control_operands)) {
3119         if (inputs_to_delete.contains(it.index())) continue;
3120         state.addOperands(it.value());
3121       }
3122       assert(state.operands.size() != non_control_operands.size());
3123       state.addOperands(control_operands);
3124 
3125       state.attributes = op->getAttrDictionary();
3126       state.attributes.set(
3127           "N", IntegerAttr::get(
3128                    rewriter.getI32Type(),
3129                    state.operands.size() - control_operands.size() - 1));
3130       state.addTypes(op->getResultTypes());
3131       Operation *new_op = rewriter.create(state);
3132       rewriter.replaceOp(op, new_op->getResults());
3133     }
3134 
3135     return success();
3136   }
3137 };
3138 
3139 // This implements constant push-down for BiasAdd. In the following "CV" is a
3140 // constant vector (tensor of rank 1), "V" is a (possibly) non-constant vector,
3141 // "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3142 // non-constant matrix, and "BA" is BiasAdd.
3143 // For a valid input graph, the following 4 rewrites are legal:
3144 //
3145 //  1)                  +                +
3146 //                     / \              / \
3147 //                    BA  CV    -- >   BA  V
3148 //                   / \              / \
3149 //                  M   V            M   CV
3150 //
3151 //  2)                  +                +
3152 //                     / \              / \
3153 //                    BA  CM    -- >   BA  M
3154 //                   / \              / \
3155 //                  M   V            CM  V
3156 //
3157 //  3)                  BA               BA
3158 //                     / \              / \
3159 //                    +  CV     -- >   +   V
3160 //                   / \              / \
3161 //                  M   V            M  CV
3162 //
3163 //  4)                  BA               BA      = parent
3164 //                     / \              / \
3165 //                    BA  CV    -- >   BA  V     = children
3166 //                   / \              / \
3167 //                  M   V            M  CV       = leaves
3168 //
3169 // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3170 class ConstantPushDownBiasAdd
3171     : public ConstantPushDownBase<ConstantPushDownBiasAdd> {
3172  public:
ConstantPushDownBiasAdd(OpPropertyHelper & helper)3173   explicit ConstantPushDownBiasAdd(OpPropertyHelper &helper)
3174       : ConstantPushDownBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3175   LogicalResult matchAndRewrite(Operation *op,
3176                                 PatternRewriter &rewriter) const override {
3177     if (!dialect_->IsBiasAdd(op)) return failure();
3178 
3179     Operation *add_child = op->getOperand(0).getDefiningOp();
3180     if (!add_child) return failure();
3181 
3182     Operation *const_child = op->getOperand(1).getDefiningOp();
3183     if (!const_child || !dialect_->IsConstant(const_child)) return failure();
3184 
3185     if (helper_.ShouldPreserveOp(add_child)) return failure();
3186 
3187     // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3188     // >= 2 and the leaves must be vectors, we cannot swap them.
3189     if (dialect_->IsConstant(add_child)) return failure();
3190     if (!dialect_->IsBiasAdd(add_child) && !dialect_->IsAdd(add_child))
3191       return failure();
3192 
3193     if (!IsOperandsSafeToMove(add_child, const_child)) return failure();
3194 
3195     auto hasRank = [&](Value value) {
3196       return value.getType().cast<ShapedType>().hasRank();
3197     };
3198 
3199     if (!hasRank(op->getOperand(0)) || !hasRank(op->getOperand(1)) ||
3200         !hasRank(add_child->getOperand(0)) ||
3201         !hasRank(add_child->getOperand(0))) {
3202       return failure();
3203     }
3204 
3205     // Now get the ranks and types of the 3 leaf nodes.
3206     const int left_leaf_rank =
3207         add_child->getOperand(0).getType().cast<ShapedType>().getRank();
3208     const int right_leaf_rank =
3209         add_child->getOperand(1).getType().cast<ShapedType>().getRank();
3210 
3211     // At least one leaf must be a vector.
3212     if (left_leaf_rank != 1 && right_leaf_rank != 1) return failure();
3213 
3214     const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3215     auto vector_type =
3216         add_child->getOperand(vector_idx).getType().cast<ShapedType>();
3217     Type vector_d_type = vector_type.getElementType();
3218 
3219     auto const_type = const_child->getResultTypes()[0].cast<ShapedType>();
3220     const int const_rank = const_type.getRank();
3221     Type const_d_type = const_type.getElementType();
3222 
3223     if (const_rank != 1 || const_d_type != vector_d_type) return failure();
3224 
3225     // This is case #1, #3, and #4:
3226     int input_to_swap = vector_idx;
3227 
3228     Value leaf_to_swap = add_child->getOperand(input_to_swap);
3229     if (leaf_to_swap.getDefiningOp() &&
3230         dialect_->IsConstant(leaf_to_swap.getDefiningOp())) {
3231       return failure();
3232     }
3233 
3234     rewriter.startRootUpdate(op);
3235     op->setOperand(1, leaf_to_swap);
3236     rewriter.finalizeRootUpdate(op);
3237     rewriter.startRootUpdate(add_child);
3238     add_child->setOperand(input_to_swap, const_child->getResult(0));
3239     rewriter.finalizeRootUpdate(add_child);
3240 
3241     return success();
3242   }
3243 };
3244 
3245 // This implements constant push-down for Add. In the following "CV" is a
3246 // constant vector (tensor of rank 1), "V" is a (possibly) non-constant vector,
3247 // "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3248 // non-constant matrix, and "BA" is BiasAdd.
3249 // For a valid input graph, the following 4 rewrites are legal:
3250 //
3251 //  1)                  +                +
3252 //                     / \              / \
3253 //                    BA  CV    -- >   BA  V
3254 //                   / \              / \
3255 //                  M   V            M   CV
3256 //
3257 //  2)                  +                +
3258 //                     / \              / \
3259 //                    BA  CM    -- >   BA  M
3260 //                   / \              / \
3261 //                  M   V            CM  V
3262 //
3263 //  3)                  BA               BA
3264 //                     / \              / \
3265 //                    +  CV     -- >   +   V
3266 //                   / \              / \
3267 //                  M   V            M  CV
3268 //
3269 //  4)                  BA               BA      = parent
3270 //                     / \              / \
3271 //                    BA  CV    -- >   BA  V     = children
3272 //                   / \              / \
3273 //                  M   V            M  CV       = leaves
3274 //
3275 // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3276 class ConstantPushDownAdd : public ConstantPushDownBase<ConstantPushDownAdd> {
3277  public:
ConstantPushDownAdd(OpPropertyHelper & helper)3278   explicit ConstantPushDownAdd(OpPropertyHelper &helper)
3279       : ConstantPushDownBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3280   LogicalResult matchAndRewrite(Operation *op,
3281                                 PatternRewriter &rewriter) const override {
3282     if (!dialect_->IsAdd(op)) return failure();
3283 
3284     Operation *add_child = op->getOperand(0).getDefiningOp();
3285     Operation *const_child = op->getOperand(1).getDefiningOp();
3286     if (!add_child || !const_child) return failure();
3287 
3288     if (!dialect_->IsConstant(const_child)) std::swap(add_child, const_child);
3289     if (!dialect_->IsConstant(const_child)) return failure();
3290 
3291     if (!IsOperandsSafeToMove(add_child, const_child)) return failure();
3292 
3293     bool child_is_bias_add = dialect_->IsBiasAdd(add_child);
3294     if (!child_is_bias_add && !dialect_->IsAdd(add_child)) return failure();
3295 
3296     auto hasRank = [&](Value value) {
3297       return value.getType().cast<ShapedType>().hasRank();
3298     };
3299 
3300     if (!hasRank(op->getOperand(0)) || !hasRank(op->getOperand(1)) ||
3301         !hasRank(add_child->getOperand(0)) ||
3302         !hasRank(add_child->getOperand(1))) {
3303       return failure();
3304     }
3305 
3306     // Now get the ranks and types of the 3 leaf nodes.
3307     const int left_leaf_rank =
3308         add_child->getOperand(0).getType().cast<ShapedType>().getRank();
3309     const int right_leaf_rank =
3310         add_child->getOperand(1).getType().cast<ShapedType>().getRank();
3311     // At least one leaf must be a vector.
3312     if (left_leaf_rank != 1 && right_leaf_rank != 1) return failure();
3313 
3314     const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3315     const int matrix_idx = 1 - vector_idx;
3316 
3317     ShapedType vector_type =
3318         add_child->getOperand(vector_idx).getType().cast<ShapedType>();
3319     Type vector_d_type = vector_type.getElementType();
3320 
3321     ShapedType matrix_type =
3322         add_child->getOperand(matrix_idx).getType().cast<ShapedType>();
3323     const int matrix_rank = matrix_type.getRank();
3324     Type matrix_d_type = matrix_type.getElementType();
3325 
3326     const int const_index =
3327         op->getOperand(0).getDefiningOp() == const_child ? 0 : 1;
3328     ShapedType const_type =
3329         const_child->getResult(0).getType().cast<ShapedType>();
3330     const int const_rank = const_type.getRank();
3331     Type const_d_type = const_type.getElementType();
3332 
3333     int input_to_swap = -1;
3334 
3335     if (child_is_bias_add && const_rank == matrix_rank &&
3336         const_d_type == matrix_d_type) {
3337       // Case 2:
3338       input_to_swap = matrix_idx;
3339     } else if (const_rank == 1 && const_d_type == vector_d_type) {
3340       // Case 1, 3, and, 4:
3341       input_to_swap = vector_idx;
3342     } else {
3343       return failure();
3344     }
3345 
3346     Value leaf_to_swap = add_child->getOperand(input_to_swap);
3347     if (leaf_to_swap.getDefiningOp() &&
3348         dialect_->IsConstant(leaf_to_swap.getDefiningOp())) {
3349       return failure();
3350     }
3351 
3352     rewriter.startRootUpdate(op);
3353     op->setOperand(const_index, leaf_to_swap);
3354     rewriter.finalizeRootUpdate(op);
3355     rewriter.startRootUpdate(add_child);
3356     add_child->setOperand(input_to_swap, const_child->getResult(0));
3357     rewriter.finalizeRootUpdate(add_child);
3358 
3359     return success();
3360   }
3361 };
3362 
3363 // This implementation is mapped with ConstantFolding::SimplifyCase in
3364 // grappler/optimizers/constant_folding.cc
3365 class SimplifyCaseOp : public FolderPatternBase<SimplifyCaseOp> {
3366  public:
SimplifyCaseOp(OpPropertyHelper & helper)3367   explicit SimplifyCaseOp(OpPropertyHelper &helper)
3368       : FolderPatternBase<SimplifyCaseOp>("tfg.Case", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3369   LogicalResult matchAndRewrite(Operation *op,
3370                                 PatternRewriter &rewriter) const override {
3371     Operation *branch_index_op = op->getOperand(0).getDefiningOp();
3372     if (!branch_index_op) return failure();
3373 
3374     ElementsAttr value_attr =
3375         branch_index_op->getAttrOfType<ElementsAttr>("value");
3376     if (!value_attr) return failure();
3377 
3378     int output_idx = value_attr.getSplatValue<int>();
3379     ArrayAttr branch_attr = op->getAttrOfType<ArrayAttr>("branches");
3380     if (output_idx < 0 || output_idx >= branch_attr.size()) return failure();
3381 
3382     OperationState state(op->getLoc(), "tfg.PartitionedCall");
3383     state.addOperands(ValueRange(op->getOperands()).drop_front());
3384 
3385     state.attributes = op->getAttrDictionary();
3386     state.attributes.erase("branches");
3387     // In TFG conanical form, `output_shapes` has been consolidated into op's
3388     // shape. Unlike grappler, we don't need to update the `output_shapes` attr
3389     // here.
3390     state.attributes.set("f", branch_attr[output_idx]);
3391 
3392     state.addTypes(op->getResultTypes());
3393 
3394     Operation *partitioned_call_op = rewriter.create(state);
3395     rewriter.replaceOp(op, partitioned_call_op->getResults());
3396 
3397     return success();
3398   }
3399 };
3400 
3401 // This implementation is mapped with ConstantFolding::SimplifySelect in
3402 // grappler/optimizers/constant_folding.cc
3403 template <typename ConcreteType>
3404 class SimplifySelectOpBase : public FolderPatternBase<ConcreteType> {
3405  protected:
SimplifySelectOpBase(StringRef op_name,OpPropertyHelper & helper)3406   SimplifySelectOpBase(StringRef op_name, OpPropertyHelper &helper)
3407       : FolderPatternBase<ConcreteType>(op_name, helper) {}
3408 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3409   LogicalResult matchAndRewrite(Operation *op,
3410                                 PatternRewriter &rewriter) const override {
3411     Operation *condition_op = op->getOperand(0).getDefiningOp();
3412     if (!condition_op) return failure();
3413 
3414     bool is_all_true = this->helper_.IsOnes(condition_op);
3415     bool is_all_false = this->helper_.IsZeros(condition_op);
3416     if (!is_all_true && !is_all_false) return failure();
3417 
3418     auto condition_type = op->getOperand(0).getType().cast<ShapedType>();
3419     auto t_type = op->getOperand(1).getType().cast<ShapedType>();
3420     auto e_type = op->getOperand(2).getType().cast<ShapedType>();
3421     if (!condition_type.hasStaticShape() || !t_type.hasStaticShape() ||
3422         !e_type.hasStaticShape()) {
3423       return failure();
3424     }
3425 
3426     const int live_input_idx = is_all_true ? 1 : 2;
3427     bool predicate_is_scalar = condition_type.getRank() == 0;
3428 
3429     if (t_type.getShape() == e_type.getShape() &&
3430         (condition_type.getShape() == t_type.getShape() ||
3431          predicate_is_scalar)) {
3432       Value live_operand = op->getOperand(live_input_idx);
3433       OperationState state(op->getLoc(), "tfg.Identity");
3434       state.addTypes(op->getResultTypes());
3435 
3436       state.addOperands(live_operand);
3437       auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
3438       for (Value operand : non_control_operands) {
3439         if (operand == live_operand) continue;
3440         // Add the remaining operands as control operands.
3441         state.addOperands(GetControlDependency(rewriter, operand));
3442       }
3443       // Append control operands
3444       state.addOperands(control_operands);
3445 
3446       state.attributes = op->getAttrDictionary();
3447       Operation *identity = rewriter.create(state);
3448       rewriter.replaceOp(op, identity->getResults());
3449     } else {
3450       FailureOr<TFOp> broadcast_to_op =
3451           ReplaceOperationWithBroadcastTo(rewriter, op, live_input_idx);
3452       if (failed(broadcast_to_op)) return failure();
3453       rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
3454     }
3455 
3456     return success();
3457   }
3458 };
3459 
3460 class SimplifySelectOp : public SimplifySelectOpBase<SimplifySelectOp> {
3461  public:
SimplifySelectOp(OpPropertyHelper & helper)3462   explicit SimplifySelectOp(OpPropertyHelper &helper)
3463       : SimplifySelectOpBase("tfg.Select", helper) {}
3464 };
3465 
3466 class SimplifySelectV2Op : public SimplifySelectOpBase<SimplifySelectV2Op> {
3467  public:
SimplifySelectV2Op(OpPropertyHelper & helper)3468   explicit SimplifySelectV2Op(OpPropertyHelper &helper)
3469       : SimplifySelectOpBase("tfg.SelectV2", helper) {}
3470 };
3471 
3472 namespace {
3473 
3474 // Utilities for filtering desired patterns.
3475 template <bool>
3476 struct FilterPattern {
3477   template <class Pattern>
3478   using type = std::tuple<Pattern>;
3479 };
3480 template <>
3481 struct FilterPattern<false> {
3482   template <class Pattern>
3483   using type = std::tuple<>;
3484 };
3485 template <template <class> class Pred, class... Patterns>
3486 struct FilterPatterns {
3487   using type = decltype(std::tuple_cat(
3488       std::declval<typename FilterPattern<Pred<Patterns>::value>::template type<
3489           Patterns>>()...));
3490 };
3491 
3492 // Predicates of selecting pattern kind.
3493 template <typename Pattern>
3494 using FolderPatterns = std::is_base_of<FolderTrait<Pattern>, Pattern>;
3495 template <typename Pattern>
3496 using PropagationPatterns = std::is_base_of<PropagationTrait<Pattern>, Pattern>;
3497 template <typename Pattern>
3498 using AllPatterns = std::true_type;
3499 
3500 // Registers a set of patterns.
3501 template <typename... Patterns>
3502 struct TargetPatterns;
3503 template <typename... Patterns>
3504 struct TargetPatterns<std::tuple<Patterns...>> {
Registermlir::tfg::__anon36974c671c11::TargetPatterns3505   static void Register(::mlir::RewritePatternSet &patterns,
3506                        OpPropertyHelper &helper) {
3507     patterns.insert<Patterns...>(helper);
3508   }
3509 };
3510 template <template <class> class PatternsFilter>
RegisterPatterns(::mlir::RewritePatternSet & patterns,OpPropertyHelper & helper)3511 void RegisterPatterns(::mlir::RewritePatternSet &patterns,
3512                       OpPropertyHelper &helper) {
3513   TargetPatterns<typename FilterPatterns<
3514       PatternsFilter, MaterializeBroadcastGradientArgsOp, MaterializeShapeNOp,
3515       SimplifySwitchOp, MergeNodeFolding, RefMergeNodeFolding,
3516       XlaMergeNodeFolding, MoveConstantsPastEnterOp,
3517       MoveConstantsPastRefEnterOp, MaterializeReductionIndices,
3518       PartialConstPropThroughIdentityN, ConstantPushDown, MulConvPushDown,
3519       ConstantPushDownBiasAdd, ConstantPushDownAdd, EvaluateConstant,
3520       PartialConcatConstFolding, PartialAssocOpConstFolding,
3521       SimplifyArithmeticOp, ReduceDivToReciprocalMul, SimplifyReshapeOp,
3522       RemoveReverse, SimplifyStridedSlice, SimplifyTileOp, SimplifySqueezeOp,
3523       SimplifySliceOp, RemoveTransposeOp, RemoveRandomShuffleOp,
3524       RemoveShuffleOp, SimplifyPackOp, SimplifyReductionOp, SimplifyPadOp,
3525       SimplifyPadV2Op, RemoveSplitOp, RemoveSplitVOp, MaterializeFillNode,
3526       MaterializeConstantValuedNode, MaterializeShapeOp, MaterializeRankOp,
3527       MaterializeSizeOp, MaterializeTensorArraySizeV3Op, MergeConcatOp,
3528       SimplifyCaseOp, SimplifySelectOp,
3529       SimplifySelectV2Op>::type>::Register(patterns, helper);
3530 }
3531 }  // namespace
3532 
3533 class ConstantFolding : public ConstantFoldingPassBase<ConstantFolding> {
3534  public:
initialize(MLIRContext * context)3535   LogicalResult initialize(MLIRContext *context) override {
3536     helper_ = std::make_shared<OpPropertyHelper>(
3537         context->getOrLoadDialect<TFGraphDialect>(),
3538         disable_compressed_tensor_optimization_);
3539     RewritePatternSet patterns(context);
3540     populatePatterns(patterns);
3541     final_patterns_ = std::move(patterns);
3542     return success();
3543   }
3544 
3545   void runOnOperation() override;
3546 
3547  private:
populatePatterns(::mlir::RewritePatternSet & patterns)3548   void populatePatterns(::mlir::RewritePatternSet &patterns) {
3549     switch (pattern_category_) {
3550       default:
3551         LOG(ERROR) << "unknown pattern category, will run all patterns";
3552         [[fallthrough]];
3553       case 0: {
3554         RegisterPatterns<AllPatterns>(patterns, *helper_);
3555         break;
3556       }
3557       case 1: {
3558         RegisterPatterns<FolderPatterns>(patterns, *helper_);
3559         break;
3560       }
3561       case 2: {
3562         RegisterPatterns<PropagationPatterns>(patterns, *helper_);
3563         break;
3564       }
3565     }
3566   }
3567 
3568   FrozenRewritePatternSet final_patterns_;
3569   std::shared_ptr<OpPropertyHelper> helper_;
3570 };
3571 
runOnOperation()3572 void ConstantFolding::runOnOperation() {
3573   // TODO(chiahungduan): Set up the attributes before operation creation.
3574   // Because of the conveniency, in some cases we set up the device/name later
3575   // operation creation.
3576 
3577   GraphFuncOp func = getOperation();
3578   Operation *return_op = func.getBody()->getTerminator();
3579   DenseSet<Operation *> unfoldable_ops;
3580   for (Value v : return_op->getOperands())
3581     unfoldable_ops.insert(v.getDefiningOp());
3582 
3583   // The max iteration is the same as the max default iteration in
3584   // applyPatternsAndFoldGreedily.
3585   constexpr int max_iterations = 10;
3586   int iteration = 0;
3587 
3588   SmallVector<Operation *> foldable_ops;
3589   do {
3590     // We need to collect the valid operations before each run because the ops
3591     // may be updated or removed.
3592     foldable_ops.clear();
3593     for (Operation &op : func.getBody()->without_terminator()) {
3594       if (unfoldable_ops.contains(&op)) continue;
3595       foldable_ops.push_back(&op);
3596     }
3597 
3598     // Unfoldable ops can't be folded. You may update its operands but the op
3599     // kind needs to be the same. For example, you may update an operand of an
3600     // AddOp with a constant but you can't fold the AddOp into a ConstOp even if
3601     // all its operands are constants. Therefore, we can't use
3602     // applyPatternsAndFoldGreedily which may optimize the ops as much as
3603     // possible.
3604     if (!applyOpPatternsAndFold(foldable_ops, final_patterns_, /*strict=*/true))
3605       break;
3606   } while (iteration++ < max_iterations);
3607 
3608   // TODO(chiahungduan): This is used to avoid evaluating a node multiple times.
3609   // See more details in EvaluateConstant pattern. Maybe we can remove this by
3610   // checking if the user of an op is empty.
3611   auto has_folded = StringAttr::get(&getContext(), "has_folded");
3612   getOperation()->walk([&](Operation *op) { op->removeAttr(has_folded); });
3613 }
3614 
CreateConstantFoldingPass()3615 std::unique_ptr<Pass> CreateConstantFoldingPass() {
3616   return std::make_unique<ConstantFolding>();
3617 }
3618 
3619 }  // namespace tfg
3620 }  // namespace mlir
3621