1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/transforms/constant_folding/pass.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <numeric>
21 #include <string>
22 #include <tuple>
23 #include <type_traits>
24 #include <utility>
25
26 #include "llvm/ADT/APInt.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/Sequence.h"
30 #include "llvm/ADT/Twine.h"
31 #include "mlir/Dialect/Traits.h" // from @llvm-project
32 #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project
33 #include "mlir/IR/PatternMatch.h" // from @llvm-project
34 #include "mlir/Support/LLVM.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
37 #include "tensorflow/core/framework/resource_mgr.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/ir/dialect.h"
40 #include "tensorflow/core/ir/importexport/convert_types.h"
41 #include "tensorflow/core/ir/utility.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/transforms/pass_detail.h"
44 #include "tensorflow/core/transforms/utils/eval_utils.h"
45 #include "tensorflow/core/transforms/utils/op_cat_helper.h"
46 #include "tensorflow/core/transforms/utils/utils.h"
47 #include "tensorflow/core/util/bcast.h"
48 #include "tensorflow/core/util/device_name_utils.h"
49
50 namespace mlir {
51 namespace tfg {
52
53 template <typename T>
54 static std::enable_if_t<std::is_integral<T>::value, ElementsAttr>
CreateElementsAttrOfTypeValues(Type element_type,ArrayRef<int64_t> shape,ArrayRef<T> values)55 CreateElementsAttrOfTypeValues(Type element_type, ArrayRef<int64_t> shape,
56 ArrayRef<T> values) {
57 auto tensor_shape = RankedTensorType::get(shape, element_type);
58 SmallVector<APInt> elements;
59 for (T v : values)
60 elements.push_back(APInt(element_type.getIntOrFloatBitWidth(), v));
61 auto const_attr = DenseElementsAttr::get(tensor_shape, elements);
62 return const_attr;
63 }
64
65 template <typename T>
66 static std::enable_if_t<std::is_floating_point<T>::value, ElementsAttr>
CreateElementsAttrOfTypeValues(Type element_type,ArrayRef<int64_t> shape,ArrayRef<T> values)67 CreateElementsAttrOfTypeValues(Type element_type, ArrayRef<int64_t> shape,
68 ArrayRef<T> values) {
69 auto tensor_shape = RankedTensorType::get(shape, element_type);
70 SmallVector<APFloat> elements;
71 if (element_type.getIntOrFloatBitWidth() == 32)
72 llvm::for_each(values, [&](float v) { elements.push_back(APFloat(v)); });
73 else
74 llvm::for_each(values, [&](double v) { elements.push_back(APFloat(v)); });
75 auto const_attr = DenseElementsAttr::get(tensor_shape, elements);
76 return const_attr;
77 }
78
CreateElementsAttrOfTypeValues(Type element_type,ArrayRef<int64_t> shape,ElementsAttr value_attr)79 static ElementsAttr CreateElementsAttrOfTypeValues(Type element_type,
80 ArrayRef<int64_t> shape,
81 ElementsAttr value_attr) {
82 auto tensor_shape = RankedTensorType::get(shape, element_type);
83 DenseElementsAttr const_attr;
84 if (element_type.isIntOrIndex()) {
85 const_attr = DenseElementsAttr::get(
86 tensor_shape, llvm::to_vector(value_attr.getValues<APInt>()));
87 } else {
88 const_attr = DenseElementsAttr::get(
89 tensor_shape, llvm::to_vector(value_attr.getValues<APFloat>()));
90 }
91 return const_attr;
92 }
93
ConvertShapeToAttr(ShapedType shape)94 static ElementsAttr ConvertShapeToAttr(ShapedType shape) {
95 return CreateElementsAttrOfTypeValues(
96 IntegerType::get(shape.getContext(), 32), {shape.getRank()},
97 shape.getShape());
98 }
99
GetDataTypeFromOp(OpBuilder & builder,Operation * op)100 static Type GetDataTypeFromOp(OpBuilder &builder, Operation *op) {
101 if (auto t_attr = op->getAttrOfType<TypeAttr>("T")) {
102 return t_attr.getValue();
103 } else if (auto dtype_attr = op->getAttrOfType<TypeAttr>("dtype")) {
104 return dtype_attr.getValue();
105 } else if (op->getName().stripDialect() == "LogicalOr" ||
106 op->getName().stripDialect() == "LogicalAnd") {
107 return builder.getI1Type();
108 }
109 return *(op->result_type_begin());
110 }
111
CreateConstantTensorOp(OpBuilder & builder,Location loc,StringRef name_prefix,Type type,ValueRange control_operands,TypedAttr tensor_value,ArrayRef<NamedAttribute> other_attrs=llvm::None)112 static FailureOr<TFOp> CreateConstantTensorOp(
113 OpBuilder &builder, Location loc, StringRef name_prefix, Type type,
114 ValueRange control_operands, TypedAttr tensor_value,
115 ArrayRef<NamedAttribute> other_attrs = llvm::None) {
116 if (type.isa<VariantType>()) return failure();
117 // TODO(chiahungduan): Reuse ConstOp Like
118 // OperationFolder::tryGetOrCreateConstant.
119 OperationState state(loc, "tfg.Const");
120 state.addTypes({type, ControlType::get(builder.getContext())});
121
122 state.attributes = other_attrs;
123 util::EraseRegularNodeAttributes(state.attributes);
124 state.attributes.set(
125 "dtype", TypeAttr::get(
126 tensor_value.getType().cast<ShapedType>().getElementType()));
127 state.attributes.set("value", tensor_value);
128 if (!name_prefix.empty()) {
129 state.attributes.set(
130 TFGraphDialect::getNameAttrKey(),
131 builder.getStringAttr(Twine(name_prefix, "/const_folded")));
132 }
133
134 state.addOperands(control_operands);
135 return TFOp(builder.create(state));
136 }
137
IsControlAnchor(TFOp op,TFGraphDialect const * const dialect)138 static bool IsControlAnchor(TFOp op, TFGraphDialect const *const dialect) {
139 return (dialect->IsIdentity(op) || dialect->IsIdentityNSingleInput(op)) &&
140 op->getResults().drop_back().use_empty();
141 }
142
143 // We can't anchor control dependencies directly on the switch node: unlike
144 // other nodes only one of the outputs of the switch node will be generated
145 // when the switch node is executed, and we need to make sure the control
146 // dependency is only triggered when the corresponding output is triggered.
147 // We start by looking for an identity node connected to the output of the
148 // switch node, and use it to anchor the control dependency.
149 // @param builder Builder, used for creating the anchor if necessary
150 // @param value Output of a switch operation to be replaced
151 // @param dialect TFG dialect (passed in to avoid cost of looking it up)
GetControlAnchorForSwitchResult(OpBuilder & builder,OpResult value,TFGraphDialect const * const dialect)152 static TFOp GetControlAnchorForSwitchResult(
153 OpBuilder &builder, OpResult value, TFGraphDialect const *const dialect) {
154 assert(builder.getContext()->getLoadedDialect<TFGraphDialect>() == dialect);
155 TFOp switch_op = value.getDefiningOp();
156 assert(dialect->IsSwitch(switch_op));
157 // We cannot get the control edge from the parent op. We instead create a
158 // control anchor i.e. an Identity op without non-control uses and get the
159 // edge from there.
160
161 // Try to find an existing control anchor
162 if (auto it = llvm::find_if(
163 value.getUsers(),
164 [&](Operation *op) { return IsControlAnchor(op, dialect); });
165 it != value.getUsers().end())
166 return TFOp(*it);
167
168 // If it doesn't exist, create a new control anchor.
169 OperationState identity_op_state(value.getLoc(), "tfg.Identity");
170 identity_op_state.addOperands(value);
171 identity_op_state.addTypes(
172 {value.getType(), ControlType::get(builder.getContext())});
173 assert(switch_op->hasAttr("T"));
174 identity_op_state.addAttribute("T", switch_op->getAttr("T"));
175 TFOp identity_op = builder.create(identity_op_state);
176 if (StringAttr device_attr = switch_op.deviceAttr())
177 identity_op.setRequestedDevice(device_attr);
178 identity_op.setName(Twine(switch_op.name(), "/ControlDependencyCtrl_") +
179 Twine(value.cast<OpResult>().getResultNumber()));
180 return identity_op;
181 }
182
183 // Same as LookupControlDependency, except when value originates from a switch
184 // op. In such cases, we cannot add a control dependency to the parent op since
185 // the output does not necessarily activate when the switch op activates. We
186 // add a "control anchor" in the form of an identity op instead.
GetControlDependency(OpBuilder & builder,Value value)187 static Value GetControlDependency(OpBuilder &builder, Value value) {
188 if (value.getType().isa<ControlType>()) return value;
189
190 TFGraphDialect *dialect =
191 builder.getContext()->getLoadedDialect<TFGraphDialect>();
192 assert(dialect);
193 if (OpResult result = value.dyn_cast<OpResult>();
194 result && dialect->IsSwitch(result.getOwner())) {
195 return GetControlAnchorForSwitchResult(builder, result, dialect)
196 .controlRet();
197 } else {
198 return LookupControlDependency(value);
199 }
200 }
201
202 // Add control operand to `op` if it doesn't exist.
AddControlOperand(Operation * op,Value control,PatternRewriter & rewriter)203 static void AddControlOperand(Operation *op, Value control,
204 PatternRewriter &rewriter) {
205 assert(control.getType().isa<ControlType>());
206 if (llvm::is_contained(op->getOperands(), control)) return;
207 rewriter.startRootUpdate(op);
208 op->insertOperands(op->getNumOperands(), control);
209 rewriter.finalizeRootUpdate(op);
210 }
211
ReplaceOpWithConstantTensor(OpBuilder & builder,TFOp op,ElementsAttr value,ArrayRef<StringRef> exclude_attrs=llvm::None)212 static FailureOr<TFOp> ReplaceOpWithConstantTensor(
213 OpBuilder &builder, TFOp op, ElementsAttr value,
214 ArrayRef<StringRef> exclude_attrs = llvm::None) {
215 // New const op has the control dependency with op's non-control operands.
216 SmallVector<Value> operands_controls;
217 llvm::append_range(operands_controls,
218 OperandControlRetRange(op.getNonControlOperands()));
219
220 NamedAttrList attr_list;
221 for (NamedAttribute attr : op->getAttrs()) {
222 if (llvm::find_if(exclude_attrs,
223 [&](StringRef name) { return name == attr.getName(); }))
224 continue;
225 attr_list.append(attr);
226 }
227 FailureOr<TFOp> const_op = CreateConstantTensorOp(
228 builder, op->getLoc(), /*name_prefix=*/"", value.getType(),
229 operands_controls, value, attr_list);
230 (*const_op).setName(op.nameAttr());
231 if (!op.device().empty()) (*const_op).setRequestedDevice(op.deviceAttr());
232 return *const_op;
233 }
234
ReplaceOpWithIdentity(OpBuilder & builder,TFOp owner,unsigned idx)235 static FailureOr<TFOp> ReplaceOpWithIdentity(OpBuilder &builder, TFOp owner,
236 unsigned idx) {
237 OperationState state(owner->getLoc(), "tfg.Identity");
238 state.addTypes({owner->getOperand(idx).getType(),
239 ControlType::get(builder.getContext())});
240 state.addAttribute(
241 "T", TypeAttr::get(GetDataTypeFromOp(builder, owner.getOperation())));
242
243 Value kept_value = owner->getOperand(idx);
244 state.addOperands(kept_value);
245 auto [non_control_operands, control_operands] = owner.splitOperands();
246 for (Value value : non_control_operands) {
247 if (value != kept_value)
248 state.addOperands(GetControlDependency(builder, value));
249 }
250 state.addOperands(control_operands);
251
252 Operation *identity_op = builder.create(state);
253 TFOp(identity_op).setName(owner.nameAttr());
254 if (!owner.device().empty())
255 TFOp(identity_op).setRequestedDevice(owner.deviceAttr());
256 return TFOp(identity_op);
257 }
258
ReplaceOperationWithConstant(OpBuilder & builder,Operation * op,double constant_value)259 static FailureOr<TFOp> ReplaceOperationWithConstant(OpBuilder &builder,
260 Operation *op,
261 double constant_value) {
262 auto res = (*op->result_type_begin()).cast<ShapedType>();
263 Type dtype = GetDataTypeFromOp(builder, op);
264 Attribute value_attr;
265 if (dtype.isIntOrIndex())
266 value_attr = builder.getIntegerAttr(dtype, constant_value);
267 else
268 value_attr = builder.getFloatAttr(dtype, constant_value);
269
270 auto const_attr = SplatElementsAttr::get(
271 RankedTensorType::get(res.getShape(), dtype), value_attr);
272 return ReplaceOpWithConstantTensor(builder, op, const_attr);
273 }
274
ReplaceOperationWithSnapshot(OpBuilder & builder,TFOp op,int idx)275 static FailureOr<TFOp> ReplaceOperationWithSnapshot(OpBuilder &builder, TFOp op,
276 int idx) {
277 // TODO(chiahungduan): If the graph contains no ops that mutate their
278 // inputs, we can use Identity instead of Snapshot.
279 // if (!graph_contains_assign_or_inplace_op_)
280 auto [non_control_operands, control_operands] = op.splitOperands();
281
282 Value replace_value = op->getOperand(idx);
283 OperationState state(op->getLoc(), "tfg.Snapshot");
284 state.attributes = op->getAttrDictionary();
285 util::EraseRegularNodeAttributes(state.attributes);
286 state.addAttribute(
287 "T", TypeAttr::get(GetDataTypeFromOp(builder, op.getOperation())));
288 // Propagate the designated input through the Snapshot.
289 state.addOperands(replace_value);
290 // Add all other inputs as control dependencies.
291 llvm::append_range(state.operands,
292 OperandControlRetRange(non_control_operands));
293 // Append the control operands
294 state.addOperands(control_operands);
295 state.addTypes(op->getResultTypes());
296
297 Operation *snapshot_op = builder.create(state);
298 TFOp(snapshot_op).setName(op.nameAttr());
299 if (!op.device().empty())
300 TFOp(snapshot_op).setRequestedDevice(op.deviceAttr());
301 return TFOp(snapshot_op);
302 }
303
ReplaceOperationWithBroadcastTo(OpBuilder & builder,TFOp op,int idx_to_replace)304 static FailureOr<TFOp> ReplaceOperationWithBroadcastTo(OpBuilder &builder,
305 TFOp op,
306 int idx_to_replace) {
307 ShapedType tensor_type = (*op->result_type_begin()).cast<ShapedType>();
308 if (!tensor_type.hasStaticShape()) return failure();
309 ElementsAttr const_attr = ConvertShapeToAttr(tensor_type);
310
311 // Create a vector of control operands. We should not fail beyond this point
312 // since GetControlDependency may create a control anchor (a new op).
313 SmallVector<Value> control_operands;
314 for (auto &it : llvm::enumerate(op.getNonControlOperands())) {
315 int idx = it.index();
316 Value v = it.value();
317 if (idx == idx_to_replace) continue;
318 if (llvm::is_contained(control_operands, v)) continue;
319 control_operands.push_back(GetControlDependency(builder, v));
320 }
321 // CreateConstantTensorOp cannot fail; it only fails for variant types and
322 // const_attr is a tensor of i32.
323 TFOp const_op = *CreateConstantTensorOp(
324 builder, op->getLoc(),
325 (Twine(op.name(), "/broadcastto_shape_") + std::to_string(idx_to_replace))
326 .str(),
327 const_attr.getType(), control_operands, const_attr);
328 if (!op.device().empty()) const_op.setRequestedDevice(op.device());
329
330 OperationState state(op->getLoc(), "tfg.BroadcastTo");
331
332 state.attributes = op->getAttrDictionary();
333 util::EraseRegularNodeAttributes(state.attributes);
334 state.addAttribute(
335 "T", TypeAttr::get(GetDataTypeFromOp(builder, op.getOperation())));
336 state.addAttribute("Tidx", TypeAttr::get(builder.getI32Type()));
337
338 state.addOperands({op->getOperand(idx_to_replace), const_op->getResult(0)});
339 state.addOperands(control_operands);
340 state.addTypes(op->getResultTypes());
341
342 Operation *broadcast_to_op = builder.create(state);
343 TFOp(broadcast_to_op).setName(op.nameAttr());
344 if (!op.device().empty())
345 TFOp(broadcast_to_op).setRequestedDevice(op.deviceAttr());
346 return TFOp(broadcast_to_op);
347 }
348
349 namespace {
350 // A helper class to see if an operation falls into certain category or has
351 // certain non-trivial properties.
352 class OpPropertyHelper : public OpCatHelper {
353 public:
OpPropertyHelper(TFGraphDialect * dialect,bool disable_compressed_tensor_optimization)354 OpPropertyHelper(TFGraphDialect *dialect,
355 bool disable_compressed_tensor_optimization)
356 : OpCatHelper(dialect),
357 dialect_(dialect),
358 disable_compressed_tensor_optimization_(
359 disable_compressed_tensor_optimization) {}
360
361 // Return true if the operation modifies the input in-place.
362 bool ModifiesInputsInPlace(TFOp op);
363
364 // Return true if this operation doesn't have any side effect.
365 bool IsFreeOfSideEffect(TFOp op);
366
367 // Return true if an operation may modify the frame info.
ModifiesFrameInfo(TFOp op)368 bool ModifiesFrameInfo(TFOp op) {
369 return dialect_->IsEnter(op) || dialect_->IsExit(op) ||
370 dialect_->IsNextIteration(op);
371 }
372
373 // This combines the results of both MaybeFoldable() and IsFoldableUncached()
374 bool IsFoldable(TFOp op);
375
376 // Return if this is a preserved op. It checks the `name` attr.
377 bool ShouldPreserveOp(TFOp op);
378
379 // Disable compressed tensor optimization.
380 bool DisableCompressedTensorOptimization();
381
382 // Get the TFG dialect instance.
getDialect()383 TFGraphDialect *getDialect() { return dialect_; }
384
385 private:
386 // Return true if this operation is safe to be folded. This filter the ops by
387 // name.
388 bool MaybeFoldable(TFOp op);
389
390 // Return true if this operation is safe to be folded. This filter the ops by
391 // the operation property like, it'll check the operands, attributes, .etc.
392 bool IsFoldableUncached(TFOp op);
393
394 // A reference to the TFG dialect.
395 TFGraphDialect *dialect_;
396
397 // Indicate that if we've disabled compressed tensor optimization.
398 bool disable_compressed_tensor_optimization_;
399
400 // We only fold/materialize constants smaller than 100kB.
401 static constexpr int64_t kMaxConstantSize = 100 * 1024;
402 };
403 } // namespace
404
ModifiesInputsInPlace(TFOp op)405 bool OpPropertyHelper::ModifiesInputsInPlace(TFOp op) {
406 StringRef op_name = op->getName().stripDialect();
407
408 // Ops that modify resource variables effectively modify one of their inputs.
409 if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
410 op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
411 op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
412 op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
413 op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
414 return false;
415 }
416
417 std::string lower_op_name = op_name.str();
418 std::transform(lower_op_name.begin(), lower_op_name.end(),
419 lower_op_name.begin(), ::tolower);
420 if (absl::StrContains(lower_op_name, "inplace")) return true;
421
422 return op->hasAttr("in_place") || op->hasAttr("inplace");
423 }
424
IsFreeOfSideEffect(TFOp op)425 bool OpPropertyHelper::IsFreeOfSideEffect(TFOp op) {
426 tensorflow::OpRegistry *op_registry = tensorflow::OpRegistry::Global();
427 const tensorflow::OpDef *op_def;
428 tensorflow::Status status =
429 op_registry->LookUpOpDef(op->getName().stripDialect().str(), &op_def);
430 if (!status.ok()) return false;
431
432 if (op_def->is_stateful()) return false;
433
434 for (const auto &input : op_def->input_arg())
435 if (input.is_ref()) return false;
436
437 if (dialect_->IsQueue(op)) return false;
438
439 if (dialect_->IsSend(op)) return false;
440
441 return !ModifiesInputsInPlace(op);
442 }
443
444 // To determine if we want to evalue the value of the operation. There several
445 // kinds operation we don't want to evalute with the eager runtime. Those
446 // operations may not safe for evaluation or not worth for evaluating because of
447 // the evaluation cost. For example, Const op already has the constant value
448 // attached as attribute.
MaybeFoldable(TFOp op)449 bool OpPropertyHelper::MaybeFoldable(TFOp op) {
450 StringRef op_name = op->getName().stripDialect();
451
452 if (dialect_->IsConstant(op)) return false;
453
454 // Don't fold stateful ops such as TruncatedNormal.
455 if (!IsFreeOfSideEffect(op)) return false;
456
457 // TODO(chiahungduan): Handle preserve nodes
458
459 // Skips ops that don't benefit from folding.
460 if (dialect_->IsPlaceholder(op)) return false;
461
462 if (dialect_->IsFakeParam(op)) return false;
463
464 // Skip certain control flow nodes, they can't be folded.
465 if (ModifiesFrameInfo(op)) return false;
466
467 if (op_name == "AccumulateNV2") return false;
468
469 // Removing LoopCond nodes can screw up the partitioner.
470 if (op_name == "LoopCond") return false;
471
472 // TODO(chiahungduan): add fold_quantization_emulation arg.
473 // if (!fold_quantization_emulation && IsQuantizationEmulation(op)) return
474 // false;
475
476 if (dialect_->IsRestore(op) || op_name.contains("Save") ||
477 op_name.contains("Reader"))
478 return false;
479
480 if (op_name.contains("Quantized") ||
481 absl::StartsWith(op_name.data(), "Sparse"))
482 return false;
483
484 // Don't fold nodes that contain TPU attributes.
485 // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
486 // properly forward custom attributes, b/119051778.
487 for (NamedAttribute attr : op->getAttrs())
488 if (attr.getName().strref().find("_tpu_") != StringRef::npos) return false;
489
490 // Don't fold ops without outputs. Note that almost all tfg op has additional
491 // control output value.
492 if (op->getNumResults() <= 1) return false;
493
494 const tensorflow::OpDef *op_def = nullptr;
495 tensorflow::Status status = tensorflow::OpRegistry::Global()->LookUpOpDef(
496 op->getName().stripDialect().str(), &op_def);
497 if (!status.ok()) {
498 return false;
499 }
500 // Don't fold ops without outputs.
501 if (op_def->output_arg_size() == 0) {
502 return false;
503 }
504
505 // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
506 // TODO(rmlarsen): Only do this for XLA_* devices.
507 for (const tensorflow::OpDef::ArgDef &output_arg : op_def->output_arg()) {
508 if (output_arg.type() == tensorflow::DT_VARIANT) {
509 return false;
510 }
511 }
512
513 // Don't fold nodes that have no outgoing edges except allowlisted nodes.
514 // Such nodes could be introduced by an earlier constant folding pass and are
515 // preserved in case users want to fetch their values; re-processing them
516 // would lead to an error of adding a duplicated node to graph.
517 // TODO(chiahungduan): Op has no users and doesn't in nodes_allowlist_ can't
518 // be folded.
519 return true;
520 }
521
IsFoldableUncached(TFOp op)522 bool OpPropertyHelper::IsFoldableUncached(TFOp op) {
523 ValueRange operands = op.getNonControlOperands();
524 if (operands.empty()) return false;
525
526 // We can only fold nodes if all their inputs are known statically, except in
527 // the case of a merge node that propagate the first inputs that becomes
528 // available, and therefore only requires a single constant input to be
529 // foldable.
530 bool merge_has_constant_input = false;
531 bool is_merge = dialect_->IsMerge(op);
532 for (Value operand : operands) {
533 TFOp operand_op = operand.getDefiningOp();
534 if (operand_op && dialect_->IsConstant(operand_op)) {
535 auto dtype = operand_op->getAttrOfType<TypeAttr>("dtype");
536 if (!dtype || dtype.getValue().isa<tf_type::StringType>()) return false;
537
538 // Special case: If a Merge node has at least one constant input that
539 // does not depend on a control input, we can fold it.
540 merge_has_constant_input |= operand_op.getControlOperands().empty();
541 } else if (!is_merge) {
542 return false;
543 }
544 }
545
546 if (is_merge && !merge_has_constant_input) return false;
547 if (DisableCompressedTensorOptimization() &&
548 (dialect_->IsFill(op) || dialect_->IsZerosLike(op) ||
549 dialect_->IsOnesLike(op))) {
550 return false;
551 }
552
553 // If we know the output shapes, make sure that the outputs are small enough
554 // to materialize.
555 int64_t input_size_bytes = 0;
556 for (Value operand : operands) {
557 auto shape = operand.getType().dyn_cast<ShapedType>();
558 if (!shape || !shape.hasStaticShape()) continue;
559 auto element_type = shape.getElementType();
560
561 tensorflow::DataType dtype;
562 if (!ConvertScalarTypeToDataType(element_type, &dtype).ok()) return false;
563 input_size_bytes += shape.getNumElements() * DataTypeSize(dtype);
564 }
565 for (Value res : op->getResults().drop_back()) {
566 auto shape = res.getType().dyn_cast<ShapedType>();
567 if (!shape || !shape.hasStaticShape()) continue;
568 auto element_type = shape.getElementType();
569
570 tensorflow::DataType dtype;
571 if (!ConvertScalarTypeToDataType(element_type, &dtype).ok()) return false;
572 int64_t num_bytes = shape.getNumElements() * DataTypeSize(dtype);
573 if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize)
574 return false;
575 }
576
577 return true;
578 }
579
IsFoldable(TFOp op)580 bool OpPropertyHelper::IsFoldable(TFOp op) {
581 // TODO(chiahungduan): Cache foldable ops
582 if (!MaybeFoldable(op)) return false;
583 return IsFoldableUncached(op);
584 }
585
ShouldPreserveOp(TFOp op)586 bool OpPropertyHelper::ShouldPreserveOp(TFOp op) {
587 // TODO(tlongeri): Find a better way to identify preserved ops. A node has its
588 // control output returned if it is a node-to-be-preserved (in
589 // LiftGraphToFunc) - *not* iff, so the following check is overly broad:
590 return llvm::any_of(op.controlRet().getUsers(), [&](TFOp child_op) {
591 return dialect_->IsReturn(child_op);
592 });
593 }
594
DisableCompressedTensorOptimization()595 bool OpPropertyHelper::DisableCompressedTensorOptimization() {
596 return disable_compressed_tensor_optimization_;
597 }
598
IsValidConstShapeForMulConvPushDown(StringAttr data_format,ShapedType filter_shape,ShapedType const_shape)599 static bool IsValidConstShapeForMulConvPushDown(StringAttr data_format,
600 ShapedType filter_shape,
601 ShapedType const_shape) {
602 if (!filter_shape.hasStaticShape() || !const_shape.hasStaticShape())
603 return false;
604 if (const_shape.getRank() <= data_format.size() &&
605 const_shape.getNumElements() == 1) {
606 return true;
607 }
608 if (data_format == "NHWC" || data_format == "NDHWC") {
609 SmallVector<int64_t> broadcast_shape;
610 if (!OpTrait::util::getBroadcastedShape(
611 filter_shape.getShape(), const_shape.getShape(), broadcast_shape)) {
612 return false;
613 }
614
615 // TODO(chiahungduan): Symbolic shape equivalence is acceptable.
616 if (filter_shape.getShape() != llvm::makeArrayRef(broadcast_shape))
617 return false;
618
619 // Only the last dimension could be larger than one, since broadcasting over
620 // the last dimension (the output channel) will result in invalid filter.
621 for (int dim_size : const_shape.getShape())
622 if (dim_size > 1) return false;
623 return true;
624 } else if (data_format == "NCHW" || data_format == "NCDHW") {
625 // TODO(laigd): support NCHW and NCDHW (b/111214513).
626 return false;
627 }
628 return false;
629 }
630
631 namespace {
632 template <typename ConcreteType, template <typename> class... Traits>
633 class ConstantPatternBase : public RewritePattern,
634 public Traits<ConcreteType>... {
635 public:
636 using RewritePattern::RewritePattern;
637
ConstantPatternBase(StringRef opName,OpPropertyHelper & helper)638 ConstantPatternBase(StringRef opName, OpPropertyHelper &helper)
639 : RewritePattern(opName, PatternBenefit(1),
640 helper.getDialect()->getContext()),
641 helper_(helper),
642 dialect_(helper.getDialect()) {}
ConstantPatternBase(MatchAnyOpTypeTag tag,OpPropertyHelper & helper)643 ConstantPatternBase(MatchAnyOpTypeTag tag, OpPropertyHelper &helper)
644 : RewritePattern(tag, PatternBenefit(1),
645 helper.getDialect()->getContext()),
646 helper_(helper),
647 dialect_(helper.getDialect()) {}
648
649 protected:
650 OpPropertyHelper &helper_;
651 TFGraphDialect *dialect_;
652 };
653
654 // A base trait which can help with classifying patterns and filter patterns
655 // according to the classification.
656 template <typename ConcreteType>
657 struct TraitBase {
getPatternmlir::tfg::__anon36974c670711::TraitBase658 ConcreteType *getPattern() { return static_cast<ConcreteType *>(this); }
659 };
660
661 // A trait indicates that the pattern will fold the root operation into a
662 // another operation like a constant op.
663 template <typename ConcreteType>
664 struct FolderTrait : public TraitBase<ConcreteType> {};
665
666 // A trait indicates that the pattern may propagate the constant operands to its
667 // users.
668 template <typename ConcreteType>
669 struct PropagationTrait : public TraitBase<ConcreteType> {};
670
671 template <typename ConcreteType>
672 using FolderPatternBase = ConstantPatternBase<ConcreteType, FolderTrait>;
673
674 template <typename ConcreteType>
675 using PropagationPatternBase =
676 ConstantPatternBase<ConcreteType, PropagationTrait>;
677 } // namespace
678
679 // EvaluateConstant maps the implementation of FoldGraph in
680 // ConstantFolding::FoldGraph in grappler/optimizers/constant_folding.cc
681 class EvaluateConstant : public FolderPatternBase<EvaluateConstant> {
682 public:
EvaluateConstant(OpPropertyHelper & helper)683 explicit EvaluateConstant(OpPropertyHelper &helper)
684 : FolderPatternBase<EvaluateConstant>(MatchAnyOpTypeTag(), helper),
685 has_folded_(BoolAttr::get(helper.getDialect()->getContext(), true)),
686 folded_attr_name_(
687 StringAttr::get(helper.getDialect()->getContext(), "has_folded")),
688 cpu_device_(std::make_unique<util::SimpleDevice>()),
689 resource_mgr_(std::make_unique<tensorflow::ResourceMgr>()) {}
690
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const691 LogicalResult matchAndRewrite(Operation *op,
692 PatternRewriter &rewriter) const override {
693 if (!helper_.IsFoldable(op)) return failure();
694
695 // TODO(chiahungduan): Switch folding needs to delete dead values.
696 if (dialect_->IsSwitch(op)) return failure();
697
698 // The op has been folded but it has multiple results which we can just
699 // replace it with a constant op and it also has control edges which prevent
700 // it from removing. Use the attr to avoid evaluating them again.
701 if (op->hasAttr(folded_attr_name_)) return failure();
702
703 // If the op has no users, don't invoke the eager runtime.
704 if (op->getNumResults() > 2 &&
705 llvm::all_of(op->getResults().drop_back(),
706 [](Value v) { return v.use_empty(); })) {
707 return failure();
708 }
709
710 SmallVector<ElementsAttr> const_operands;
711 for (Value operand : TFOp(op).getNonControlOperands()) {
712 Operation *defining_op = operand.getDefiningOp();
713 if (defining_op && dialect_->IsConstant(defining_op)) {
714 const_operands.push_back(
715 defining_op->getAttrOfType<ElementsAttr>("value"));
716 } else {
717 return failure();
718 }
719 }
720
721 SmallVector<TypedAttr> result;
722 if (failed(util::EvaluateOperation(cpu_device_.get(), resource_mgr_.get(),
723 op, const_operands, result))) {
724 return failure();
725 }
726
727 StringAttr name_attr = static_cast<TFGraphDialect *>(op->getDialect())
728 ->getNameAttrIdentifier();
729 SmallVector<Value> control_operands(
730 OperandControlRetRange(op->getOperands()));
731
732 StringAttr device_attr = TFOp(op).deviceAttr();
733 SmallVector<TFOp> const_ops;
734 for (auto &it : llvm::enumerate(result)) {
735 TypedAttr attr = it.value();
736 FailureOr<TFOp> const_op = CreateConstantTensorOp(
737 rewriter, op->getLoc(),
738 (Twine(TFOp(op).name(), "/eval_") + Twine(it.index())).str(),
739 attr.getType().cast<ShapedType>(), control_operands, attr,
740 NamedAttribute(name_attr, TFOp(op).nameAttr()));
741 if (failed(const_op)) return failure();
742 if (device_attr) (*const_op).setRequestedDevice(device_attr);
743 const_ops.emplace_back(*const_op);
744 }
745
746 // If this is single output, just replace the op.
747 if (const_ops.size() == 1) {
748 // Use the same node name for the replacement. Note that even this is not
749 // in nodes_to_preserve, certain cases may still expect the op has the
750 // same name after folding.
751 const_ops[0].setName(TFOp(op).nameAttr());
752 rewriter.replaceOp(op, const_ops[0]->getResults());
753 } else {
754 for (auto &it : llvm::enumerate(const_ops)) {
755 for (OpOperand &user :
756 llvm::make_early_inc_range(op->getResult(it.index()).getUses())) {
757 rewriter.startRootUpdate(user.getOwner());
758 user.set(it.value()->getResult(0));
759 rewriter.finalizeRootUpdate(user.getOwner());
760 }
761 }
762
763 // Now all the non-control operands are replaced with constant ops, remove
764 // the op if it doesn't have control operand either.
765 if (TFOp(op).controlRet().use_empty()) {
766 rewriter.eraseOp(op);
767 } else {
768 // We can't remove it directly. To avoid folding it again, add an attr
769 // to identity these ops. This will be removed in the end of constant
770 // folding pass.
771 op->setAttr(folded_attr_name_, has_folded_);
772 }
773 }
774
775 return success();
776 }
777
778 private:
779 BoolAttr has_folded_;
780 StringAttr folded_attr_name_;
781 std::unique_ptr<util::SimpleDevice> cpu_device_;
782 std::unique_ptr<tensorflow::ResourceMgr> resource_mgr_;
783 };
784
785 // This implementation is mapped to the ShapeOp materialization in
786 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
787 class MaterializeShapeOp : public FolderPatternBase<MaterializeShapeOp> {
788 public:
MaterializeShapeOp(OpPropertyHelper & helper)789 explicit MaterializeShapeOp(OpPropertyHelper &helper)
790 : FolderPatternBase<MaterializeShapeOp>("tfg.Shape", helper) {}
791
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const792 LogicalResult matchAndRewrite(Operation *op,
793 PatternRewriter &rewriter) const override {
794 Value input = op->getOperand(0);
795
796 auto input_shape = input.getType().cast<ShapedType>();
797 if (!input_shape.hasStaticShape()) return failure();
798
799 // TODO(rmlarsen): Remove this workaround for b/150861569
800 // The bug involves an expression of the form Shape(ExpandDims(x)
801 // with an incorrectly inferred zero-size first dimension.
802 if (!input_shape.getShape().empty() && input_shape.getShape()[0] == 0)
803 return failure();
804
805 Type output_dtype =
806 op->getResult(0).getType().cast<ShapedType>().getElementType();
807 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
808 output_dtype, {input_shape.getRank()}, input_shape.getShape());
809
810 // Add the control edge to `input` to ensure that the constant value will
811 // only be run in the cases where Shape would have been run in the original
812 // graph.
813 TFOp const_op = *CreateConstantTensorOp(
814 rewriter, op->getLoc(), /*name_prefix=*/"", const_attr.getType(),
815 GetControlDependency(rewriter, input), const_attr, op->getAttrs());
816 const_op.setName(TFOp(op).nameAttr());
817
818 rewriter.replaceOp(op, const_op->getResults());
819
820 return success();
821 }
822 };
823
824 // This implementation is mapped to the SizeOp materialization in
825 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
826 class MaterializeSizeOp : public FolderPatternBase<MaterializeSizeOp> {
827 public:
MaterializeSizeOp(OpPropertyHelper & helper)828 explicit MaterializeSizeOp(OpPropertyHelper &helper)
829 : FolderPatternBase<MaterializeSizeOp>("tfg.Size", helper) {}
830
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const831 LogicalResult matchAndRewrite(Operation *op,
832 PatternRewriter &rewriter) const override {
833 Value input = op->getOperand(0);
834
835 auto input_shape = input.getType().cast<ShapedType>();
836 if (!input_shape.hasStaticShape()) return failure();
837
838 ShapedType result_type = (*op->result_type_begin()).cast<ShapedType>();
839 if (!result_type.getElementType().isIntOrIndexOrFloat()) return failure();
840
841 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
842 result_type.getElementType(), {},
843 ArrayRef<int64_t>(input_shape.getNumElements()));
844
845 // Add the control edge to `input` to ensure that the constant value will
846 // only be run in the cases where Size would have been run in the original
847 // graph.
848 TFOp const_op = *CreateConstantTensorOp(
849 rewriter, op->getLoc(), /*name_prefix=*/"", const_attr.getType(),
850 GetControlDependency(rewriter, input), const_attr, op->getAttrs());
851 const_op.setName(TFOp(op).nameAttr());
852
853 rewriter.replaceOp(op, const_op->getResults());
854
855 return success();
856 }
857 };
858
859 // This implementation is mapped to the RankOp materialization in
860 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
861 class MaterializeRankOp : public FolderPatternBase<MaterializeRankOp> {
862 public:
MaterializeRankOp(OpPropertyHelper & helper)863 explicit MaterializeRankOp(OpPropertyHelper &helper)
864 : FolderPatternBase<MaterializeRankOp>("tfg.Rank", helper) {}
865
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const866 LogicalResult matchAndRewrite(Operation *op,
867 PatternRewriter &rewriter) const override {
868 Value input = op->getOperand(0);
869
870 auto input_shape = input.getType().cast<ShapedType>();
871 if (!input_shape.hasRank()) return failure();
872
873 ShapedType result_type = (*op->result_type_begin()).cast<ShapedType>();
874 if (!result_type.getElementType().isIntOrIndexOrFloat()) return failure();
875
876 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
877 result_type.getElementType(), {}, ArrayRef<int>(input_shape.getRank()));
878
879 // Add the control edge to `input` to ensure that the constant value will
880 // only be run in the cases where Rank would have been run in the original
881 // graph.
882 TFOp const_op = *CreateConstantTensorOp(
883 rewriter, op->getLoc(), /*name_prefix=*/"", const_attr.getType(),
884 GetControlDependency(rewriter, input), const_attr, op->getAttrs());
885 const_op.setName(TFOp(op).nameAttr());
886
887 rewriter.replaceOp(op, const_op->getResults());
888
889 return success();
890 }
891 };
892
893 // This implementation is mapped to the TensorArraySizeV3 materialization in
894 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
895 class MaterializeTensorArraySizeV3Op
896 : public FolderPatternBase<MaterializeTensorArraySizeV3Op> {
897 public:
MaterializeTensorArraySizeV3Op(OpPropertyHelper & helper)898 explicit MaterializeTensorArraySizeV3Op(OpPropertyHelper &helper)
899 : FolderPatternBase<MaterializeTensorArraySizeV3Op>(
900 "tfg.TensorArraySizeV3", helper) {}
901
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const902 LogicalResult matchAndRewrite(Operation *op,
903 PatternRewriter &rewriter) const override {
904 Operation *handle_op = op->getOperand(0).getDefiningOp();
905 if (!handle_op || handle_op->getNumOperands() == 0) return failure();
906
907 auto dynamic_size = handle_op->getAttrOfType<BoolAttr>("dynamic_size");
908 if (dynamic_size && dynamic_size.getValue()) return failure();
909
910 Operation *array_size = handle_op->getOperand(0).getDefiningOp();
911 if (!array_size || !dialect_->IsConstant(array_size)) return failure();
912
913 // Don't materialize 0 sizes to avoid triggering incorrect static checks.
914 // A 0 sized array that can't grow isn't useful anyway.
915 auto size_attr = array_size->getAttrOfType<SplatElementsAttr>("value");
916 if (!size_attr || !size_attr.getElementType().isInteger(32))
917 return failure();
918 if (size_attr.getSplatValue<IntegerAttr>().getInt() == 0) return failure();
919
920 SmallVector<Value> control_operands;
921 control_operands.push_back(TFOp(handle_op).controlRet());
922 control_operands.push_back(
923 GetControlDependency(rewriter, op->getOperand(1)));
924 // CreateConstantTensorOp cannot fail; its type is tensor of i32
925 TFOp const_op = *CreateConstantTensorOp(
926 rewriter, op->getLoc(), /*name_prefix=*/"", size_attr.getType(),
927 control_operands, size_attr, op->getAttrs());
928 const_op.setName(TFOp(op).nameAttr());
929
930 rewriter.replaceOp(op, const_op->getResults());
931
932 return success();
933 }
934 };
935
936 // This implementation is mapped to the ShapeN materialization in
937 // ConstantFolding::MaterializeShapes in grappler/optimizers/constant_folding.cc
938 class MaterializeShapeNOp : public FolderPatternBase<MaterializeShapeNOp> {
939 public:
MaterializeShapeNOp(OpPropertyHelper & helper)940 explicit MaterializeShapeNOp(OpPropertyHelper &helper)
941 : FolderPatternBase<MaterializeShapeNOp>("tfg.ShapeN", helper) {}
942
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const943 LogicalResult matchAndRewrite(Operation *op,
944 PatternRewriter &rewriter) const override {
945 for (const auto &it : llvm::enumerate(TFOp(op).getNonControlOperands())) {
946 Value operand = op->getOperand(it.index());
947
948 auto operand_shape = operand.getType().cast<ShapedType>();
949 if (!operand_shape.hasStaticShape()) continue;
950
951 if (op->getResults()[it.index()].use_empty()) continue;
952
953 ElementsAttr const_attr = ConvertShapeToAttr(operand_shape);
954
955 FailureOr<TFOp> const_op = CreateConstantTensorOp(
956 rewriter, op->getLoc(), TFOp(op).name(), *(op->result_type_begin()),
957 TFOp(op).controlRet(), const_attr);
958 if (failed(const_op)) return failure();
959
960 (*const_op).setName(Twine(TFOp(op).name(), "/matshapes_") +
961 std::to_string(it.index()));
962 if (!TFOp(op).device().empty())
963 (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
964
965 // TODO(chiahungduan): Do we need to handle `direct_edges_exist` in
966 // ConstantFolding::MaterializeShapes for ShapeN?
967
968 for (OpOperand &user :
969 llvm::make_early_inc_range(op->getResult(it.index()).getUses())) {
970 rewriter.startRootUpdate(user.getOwner());
971 user.set((*const_op)->getResult(0));
972 rewriter.finalizeRootUpdate(user.getOwner());
973 }
974 }
975
976 return success();
977 }
978 };
979
980 // This implementation is mapped to the BroadcastGradientArgsOp materialization
981 // in ConstantFolding::MaterializeBroadcastGradientArgs in
982 // grappler/optimizers/constant_folding.cc
983 class MaterializeBroadcastGradientArgsOp
984 : public PropagationPatternBase<MaterializeBroadcastGradientArgsOp> {
985 public:
MaterializeBroadcastGradientArgsOp(OpPropertyHelper & helper)986 explicit MaterializeBroadcastGradientArgsOp(OpPropertyHelper &helper)
987 : PropagationPatternBase<MaterializeBroadcastGradientArgsOp>(
988 "tfg.BroadcastGradientArgs", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const989 LogicalResult matchAndRewrite(Operation *op,
990 PatternRewriter &rewriter) const override {
991 Operation *s0 = op->getOperand(0).getDefiningOp();
992 Operation *s1 = op->getOperand(1).getDefiningOp();
993 if (!s0 || !s1) return failure();
994
995 if (!dialect_->IsShape(s0) && !dialect_->IsConstant(s0)) return failure();
996 if (!dialect_->IsShape(s1) && !dialect_->IsConstant(s1)) return failure();
997
998 // This operation has been optimized.
999 if (op->getResult(0).use_empty() || op->getResult(1).use_empty())
1000 return failure();
1001
1002 auto get_shape = [this](Operation *op,
1003 SmallVector<int64_t> &shape) -> bool {
1004 if (dialect_->IsShape(op)) {
1005 auto type = op->getOperand(0).getType().cast<ShapedType>();
1006 if (!type.hasRank()) return false;
1007
1008 llvm::append_range(shape, type.getShape());
1009 } else {
1010 auto attr = op->getAttrOfType<ElementsAttr>("value");
1011 if (!attr) return false;
1012
1013 Type element_type = attr.getElementType();
1014 if (element_type.isInteger(32)) {
1015 llvm::append_range(shape, attr.getValues<int32_t>());
1016 } else if (element_type.isInteger(64)) {
1017 llvm::append_range(shape, attr.getValues<int64_t>());
1018 } else {
1019 return false;
1020 }
1021 }
1022 return true;
1023 };
1024
1025 SmallVector<int64_t> s0_shape;
1026 SmallVector<int64_t> s1_shape;
1027 if (!get_shape(s0, s0_shape) || !get_shape(s1, s1_shape)) return failure();
1028
1029 const int common_dims = std::min(s0_shape.size(), s1_shape.size());
1030 for (int i = 0; i < common_dims; ++i) {
1031 if (s0_shape[i] >= 0 && s1_shape[i] >= 0) continue;
1032
1033 // TODO(chiahungduan): Check if two dims are symbolically equal. Grappler
1034 // stores the symbolic shape information with dim < -1 which is not a
1035 // convention in TFG. Use symbolic shape information instead.
1036
1037 // Return failure if two dims are symbolically unequal.
1038 return failure();
1039 }
1040
1041 for (int i = common_dims; i < s0_shape.size(); ++i)
1042 if (s0_shape[i] < 0) return failure();
1043 for (int i = common_dims; i < s1_shape.size(); ++i)
1044 if (s1_shape[i] < 0) return failure();
1045
1046 tensorflow::BCast::Vec s0_vec(s0_shape.begin(), s0_shape.end());
1047 tensorflow::BCast::Vec s1_vec(s1_shape.begin(), s1_shape.end());
1048 tensorflow::BCast bcast(s0_vec, s1_vec);
1049 if (!bcast.IsValid()) return failure();
1050
1051 tensorflow::BCast::Vec reduce_dims[2];
1052 reduce_dims[0] = bcast.grad_x_reduce_idx();
1053 reduce_dims[1] = bcast.grad_y_reduce_idx();
1054
1055 auto type_attr = op->getAttrOfType<TypeAttr>("T");
1056 if (!type_attr) return failure();
1057 if (!type_attr.getValue().isIntOrIndexOrFloat()) return failure();
1058
1059 SmallVector<Value, 2> const_values;
1060 for (int j = 0; j < 2; ++j) {
1061 int reduction_indices = reduce_dims[j].size();
1062 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1063 type_attr.getValue(), {reduction_indices},
1064 llvm::makeArrayRef<int64_t>(reduce_dims[j].data(),
1065 reduction_indices));
1066 FailureOr<TFOp> const_op = CreateConstantTensorOp(
1067 rewriter, op->getLoc(), TFOp(op).name(), op->getResultTypes()[j],
1068 TFOp(op).controlRet(), const_attr);
1069 if (failed(const_op)) return failure();
1070
1071 (*const_op).setName(Twine(TFOp(op).name(), "/bcastargs_") +
1072 std::to_string(j));
1073 if (!TFOp(op).device().empty())
1074 (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
1075 const_values.push_back((*const_op)->getResult(0));
1076 }
1077
1078 for (OpOperand &user :
1079 llvm::make_early_inc_range(op->getResult(0).getUses())) {
1080 rewriter.startRootUpdate(user.getOwner());
1081 user.set(const_values[0]);
1082 rewriter.finalizeRootUpdate(user.getOwner());
1083 }
1084 for (OpOperand &user :
1085 llvm::make_early_inc_range(op->getResult(1).getUses())) {
1086 rewriter.startRootUpdate(user.getOwner());
1087 user.set(const_values[1]);
1088 rewriter.finalizeRootUpdate(user.getOwner());
1089 }
1090
1091 return success();
1092 }
1093 };
1094
1095 // This implementation is mapped to the indices of reduction ops materialization
1096 // in ConstantFolding::MaterializeReductionIndices in
1097 // grappler/optimizers/constant_folding.cc
1098 class MaterializeReductionIndices
1099 : public PropagationPatternBase<MaterializeReductionIndices> {
1100 public:
MaterializeReductionIndices(OpPropertyHelper & helper)1101 explicit MaterializeReductionIndices(OpPropertyHelper &helper)
1102 : PropagationPatternBase<MaterializeReductionIndices>(MatchAnyOpTypeTag(),
1103 helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1104 LogicalResult matchAndRewrite(Operation *op,
1105 PatternRewriter &rewriter) const override {
1106 if (!dialect_->IsReduction(op)) return failure();
1107
1108 Operation *indices = op->getOperand(1).getDefiningOp();
1109 // The reduction indices are already constant, there's nothing to do.
1110 if (!indices || dialect_->IsConstant(indices)) return failure();
1111
1112 auto indices_shape = indices->getResult(0).getType().cast<ShapedType>();
1113 if (!indices_shape.hasRank()) return failure();
1114 if (!indices_shape.getElementType().isInteger(32) &&
1115 !indices_shape.getElementType().isInteger(64)) {
1116 return failure();
1117 }
1118
1119 auto input_shape = op->getOperand(0).getType().cast<ShapedType>();
1120 // Unexpected graph, don't try to change it.
1121 if (!input_shape.hasRank() || input_shape.getRank() < 1) return failure();
1122
1123 auto output_shape = op->getResult(0).getType().cast<ShapedType>();
1124 const int output_rank =
1125 output_shape.hasRank() ? output_shape.getRank() : -1;
1126
1127 bool full_reduction = output_rank == 0 || (indices_shape.hasStaticShape() &&
1128 indices_shape.getNumElements() ==
1129 input_shape.getRank());
1130
1131 if (!full_reduction) {
1132 // A full reduction will generate a tensor of one of the shapes
1133 // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
1134 // elements in the output of the reduction, we may deduce it from reshape
1135 // nodes following it.
1136 for (Operation *user : op->getResult(0).getUsers()) {
1137 full_reduction = false;
1138 if (!dialect_->IsReshape(user)) return failure();
1139
1140 auto shape = user->getResult(0).getType().cast<ShapedType>();
1141 if (!shape.hasStaticShape() || shape.getNumElements() != 1)
1142 return failure();
1143 else
1144 full_reduction = true;
1145 }
1146 if (!full_reduction) return failure();
1147 }
1148
1149 // We know it's a full reduction. We can generate the full set of indices
1150 // to reduce as a constant node.
1151 SmallVector<int> elements(input_shape.getRank());
1152 std::iota(elements.begin(), elements.end(), 0);
1153
1154 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1155 indices_shape.getElementType(), {input_shape.getRank()},
1156 llvm::makeArrayRef(elements));
1157
1158 FailureOr<TFOp> const_op = CreateConstantTensorOp(
1159 rewriter, indices->getLoc(), Twine(TFOp(op).name(), "/indices").str(),
1160 const_attr.getType(), TFOp(indices).controlRet(), const_attr);
1161 if (failed(const_op)) return failure();
1162
1163 if (TFOp(op).deviceAttr())
1164 (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
1165
1166 rewriter.startRootUpdate(op);
1167 op->setOperand(1, (*const_op)->getResults()[0]);
1168 rewriter.finalizeRootUpdate(op);
1169
1170 return success();
1171 }
1172 };
1173
1174 // This implementation is mapped to the constant value materialization in
1175 // ConstantFolding::MaterializeConstantValuedNode in
1176 // grappler/optimizers/constant_folding.cc
1177 class MaterializeFillNode : public FolderPatternBase<MaterializeFillNode> {
1178 public:
MaterializeFillNode(OpPropertyHelper & helper)1179 explicit MaterializeFillNode(OpPropertyHelper &helper)
1180 : FolderPatternBase<MaterializeFillNode>("tfg.Fill", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1181 LogicalResult matchAndRewrite(Operation *op,
1182 PatternRewriter &rewriter) const override {
1183 if (helper_.DisableCompressedTensorOptimization()) return failure();
1184 // Only handles single result op. Note that another result is control ret.
1185 if (op->getNumResults() != 2) return failure();
1186
1187 auto output_type = op->getResult(0).getType().cast<ShapedType>();
1188 if (!output_type.hasStaticShape()) return failure();
1189 if (!output_type.isIntOrIndexOrFloat()) return failure();
1190
1191 Operation *dim = op->getOperand(0).getDefiningOp();
1192 Operation *value = op->getOperand(1).getDefiningOp();
1193 if (!dim || !value) return failure();
1194 // In grappler's constant folding, they also check if `dim` is constant.
1195 // Which is redundant because it's constant property is never used.
1196 if (!dialect_->IsConstant(value)) return failure();
1197
1198 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1199 output_type.getElementType(), output_type.getShape(),
1200 {value->getAttrOfType<ElementsAttr>("value")});
1201
1202 FailureOr<TFOp> const_op = ReplaceOpWithConstantTensor(
1203 rewriter, op, const_attr,
1204 /*exclude_attrs=*/ArrayRef<StringRef>({"T", "index_type"}));
1205 if (failed(const_op)) return failure();
1206
1207 rewriter.replaceOp(op, (*const_op)->getResults());
1208
1209 return success();
1210 }
1211 };
1212
1213 // This implementation is mapped to the constant value materialization in
1214 // ConstantFolding::MaterializeConstantValuedNode in
1215 // grappler/optimizers/constant_folding.cc
1216 class MaterializeConstantValuedNode
1217 : public FolderPatternBase<MaterializeConstantValuedNode> {
1218 public:
MaterializeConstantValuedNode(OpPropertyHelper & helper)1219 explicit MaterializeConstantValuedNode(OpPropertyHelper &helper)
1220 : FolderPatternBase<MaterializeConstantValuedNode>(MatchAnyOpTypeTag(),
1221 helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1222 LogicalResult matchAndRewrite(Operation *op,
1223 PatternRewriter &rewriter) const override {
1224 if (helper_.DisableCompressedTensorOptimization()) return failure();
1225 // Only handles single result op. Note that another result is control ret.
1226 if (op->getNumResults() != 2) return failure();
1227
1228 // FillOp is handled in MaterializeFillNode pattern.
1229 if (dialect_->IsFill(op)) return failure();
1230 if (!dialect_->IsZerosLike(op) && !dialect_->IsOnesLike(op))
1231 return failure();
1232
1233 // TODO(chiahungduan): If op->getOperand(0) has static shape, can we use
1234 // that to materialize?
1235 auto output_type = op->getResult(0).getType().cast<ShapedType>();
1236 if (!output_type.hasStaticShape()) return failure();
1237
1238 int value =
1239 dialect_->IsZerosLike(op) ? 0 : (dialect_->IsOnesLike(op) ? 1 : -1);
1240 if (value < 0) return failure();
1241
1242 if (!output_type.getElementType().isIntOrIndexOrFloat()) return failure();
1243
1244 ElementsAttr const_attr;
1245 if (output_type.getElementType().isIntOrIndex()) {
1246 const_attr = CreateElementsAttrOfTypeValues(output_type.getElementType(),
1247 output_type.getShape(),
1248 ArrayRef<int>(value));
1249 } else {
1250 const_attr = CreateElementsAttrOfTypeValues(output_type.getElementType(),
1251 output_type.getShape(),
1252 ArrayRef<double>(value));
1253 }
1254
1255 FailureOr<TFOp> const_op =
1256 ReplaceOpWithConstantTensor(rewriter, op, const_attr);
1257 if (failed(const_op)) return failure();
1258
1259 rewriter.replaceOp(op, (*const_op)->getResults());
1260 return success();
1261 }
1262 };
1263
1264 // This implementation is mapped to the output value materialization in
1265 // ConstantFolding::MaterializeOutputValues in
1266 // grappler/optimizers/constant_folding.cc
1267 class MaterializeOutputValue
1268 : public PropagationPatternBase<MaterializeOutputValue> {
1269 public:
MaterializeOutputValue(OpPropertyHelper & helper)1270 explicit MaterializeOutputValue(OpPropertyHelper &helper)
1271 : PropagationPatternBase<MaterializeOutputValue>(MatchAnyOpTypeTag(),
1272 helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1273 LogicalResult matchAndRewrite(Operation *op,
1274 PatternRewriter &rewriter) const override {
1275 // In grappler, the shape information is stored in a separate structure and
1276 // this pass is used to materialize the shape inference information to the
1277 // node. But in MLIR, the shape inference information is stored in the
1278 // operation.
1279 return failure();
1280 }
1281 };
1282
1283 // This implementation is mapped to the merge node folding in
1284 // ConstantFolding::FoldMergeNode in
1285 // grappler/optimizers/constant_folding.cc
1286 template <typename ConcreteType>
1287 class MergeNodeFoldingBase : public PropagationPatternBase<ConcreteType> {
1288 protected:
MergeNodeFoldingBase(StringRef op_name,OpPropertyHelper & helper)1289 MergeNodeFoldingBase(StringRef op_name, OpPropertyHelper &helper)
1290 : PropagationPatternBase<ConcreteType>(op_name, helper),
1291 zero_dim_i32_tensor_type_(RankedTensorType::get(
1292 llvm::None,
1293 IntegerType::get(helper.getDialect()->getContext(), 32))) {}
1294
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1295 LogicalResult matchAndRewrite(Operation *op,
1296 PatternRewriter &rewriter) const override {
1297 // Merge nodes are special, in the sense that they execute as soon as one of
1298 // their input is ready. We can therefore fold a merge node iff it has at
1299 // least one constant input without control dependency.
1300 // We still need to ensure that the nodes in the fanin of the merge node are
1301 // scheduled. We'll therefore add a control dependency from the merge node
1302 // to the folded constant. We end up with:
1303 // * the merge node and its inputs are preserved as is
1304 // * a new constant node C1, driven by the merge node through a control
1305 // dependency, initialized to the value of the folded input
1306 // * a new constant node C2, driven by the merge node through a control
1307 // dependency, initialized to the index of the folded input
1308 // * the fanout of the merge nodes is rewired to be driven by either C1 or
1309 // C2.
1310
1311 // The node may have been optimized.
1312 if (llvm::all_of(op->getResults().drop_back(),
1313 [](Value v) { return v.use_empty(); })) {
1314 return failure();
1315 }
1316
1317 int idx = 0;
1318 for (Value operand : TFOp(op).getNonControlOperands()) {
1319 Operation *operand_op = operand.getDefiningOp();
1320 if (!operand_op) continue;
1321 if (!this->dialect_->IsConstant(operand_op)) continue;
1322 if (!TFOp(operand_op).getControlOperands().empty()) continue;
1323
1324 FailureOr<TFOp> const_out = CreateConstantTensorOp(
1325 rewriter, op->getLoc(), TFOp(op).name(),
1326 *(operand_op->result_type_begin()), TFOp(op).controlRet(),
1327 operand_op->getAttrOfType<ElementsAttr>("value"), op->getAttrs());
1328 if (failed(const_out)) return failure();
1329 (*const_out).setName(Twine(TFOp(op).name(), "/const"));
1330 if (!TFOp(op).device().empty())
1331 (*const_out).setRequestedDevice(TFOp(op).device());
1332
1333 FailureOr<TFOp> const_index = CreateConstantTensorOp(
1334 rewriter, op->getLoc(), TFOp(op).name(), rewriter.getIntegerType(32),
1335 TFOp(op).controlRet(),
1336 DenseElementsAttr::get(zero_dim_i32_tensor_type_, idx++));
1337 if (failed(const_index)) return failure();
1338
1339 (*const_index).setName(Twine(TFOp(op).name(), "/index"));
1340 if (!TFOp(op).device().empty())
1341 (*const_index).setRequestedDevice(TFOp(op).device());
1342
1343 for (OpOperand &user :
1344 llvm::make_early_inc_range(op->getResults()[0].getUses())) {
1345 rewriter.startRootUpdate(user.getOwner());
1346 user.set((*const_out)->getResult(0));
1347 rewriter.finalizeRootUpdate(user.getOwner());
1348 }
1349 for (OpOperand &user :
1350 llvm::make_early_inc_range(op->getResults()[1].getUses())) {
1351 rewriter.startRootUpdate(user.getOwner());
1352 user.set((*const_index)->getResult(0));
1353 rewriter.finalizeRootUpdate(user.getOwner());
1354 }
1355
1356 // Already found an avaiable input.
1357 return success();
1358 }
1359 return failure();
1360 }
1361
1362 RankedTensorType zero_dim_i32_tensor_type_;
1363 };
1364
1365 class MergeNodeFolding : public MergeNodeFoldingBase<MergeNodeFolding> {
1366 public:
MergeNodeFolding(OpPropertyHelper & helper)1367 explicit MergeNodeFolding(OpPropertyHelper &helper)
1368 : MergeNodeFoldingBase("tfg.Merge", helper) {}
1369 };
1370
1371 class RefMergeNodeFolding : public MergeNodeFoldingBase<RefMergeNodeFolding> {
1372 public:
RefMergeNodeFolding(OpPropertyHelper & helper)1373 explicit RefMergeNodeFolding(OpPropertyHelper &helper)
1374 : MergeNodeFoldingBase("tfg.RefMerge", helper) {}
1375 };
1376
1377 class XlaMergeNodeFolding : public MergeNodeFoldingBase<XlaMergeNodeFolding> {
1378 public:
XlaMergeNodeFolding(OpPropertyHelper & helper)1379 explicit XlaMergeNodeFolding(OpPropertyHelper &helper)
1380 : MergeNodeFoldingBase("tfg.XlaMerge", helper) {}
1381 };
1382
1383 // This implementation is mapped with ConstantFolding::RemoveSplitOrSplitVin in
1384 // grappler/optimizers/constant_folding.cc
1385 class RemoveSplitOp : public FolderPatternBase<RemoveSplitOp> {
1386 public:
RemoveSplitOp(OpPropertyHelper & helper)1387 explicit RemoveSplitOp(OpPropertyHelper &helper)
1388 : FolderPatternBase<RemoveSplitOp>("tfg.Split", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1389 LogicalResult matchAndRewrite(Operation *op,
1390 PatternRewriter &rewriter) const override {
1391 auto num_split_attr = op->getAttrOfType<IntegerAttr>("num_split");
1392 if (!num_split_attr || num_split_attr.getInt() != 1) return failure();
1393 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 1);
1394 if (failed(identity)) return failure();
1395 rewriter.replaceOp(op, (*identity)->getResults());
1396 return success();
1397 }
1398 };
1399
1400 // This implementation is mapped with ConstantFolding::RemoveSplitOrSplitVin in
1401 // grappler/optimizers/constant_folding.cc
1402 class RemoveSplitVOp : public FolderPatternBase<RemoveSplitVOp> {
1403 public:
RemoveSplitVOp(OpPropertyHelper & helper)1404 explicit RemoveSplitVOp(OpPropertyHelper &helper)
1405 : FolderPatternBase<RemoveSplitVOp>("tfg.SplitV", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1406 LogicalResult matchAndRewrite(Operation *op,
1407 PatternRewriter &rewriter) const override {
1408 auto num_split_attr = op->getAttrOfType<IntegerAttr>("num_split");
1409 if (!num_split_attr || num_split_attr.getInt() != 1) return failure();
1410 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1411 if (failed(identity)) return failure();
1412 rewriter.replaceOp(op, (*identity)->getResults());
1413 return success();
1414 }
1415 };
1416
1417 // TODO(chiahungduan): Do we still have "Shuffle" op?
1418 // This implementation is mapped with ConstantFolding::RemoveShuffleOrTranspose
1419 // in grappler/optimizers/constant_folding.cc
1420 class RemoveShuffleOp : public FolderPatternBase<RemoveShuffleOp> {
1421 public:
RemoveShuffleOp(OpPropertyHelper & helper)1422 explicit RemoveShuffleOp(OpPropertyHelper &helper)
1423 : FolderPatternBase<RemoveShuffleOp>("tfg.Shuffle", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1424 LogicalResult matchAndRewrite(Operation *op,
1425 PatternRewriter &rewriter) const override {
1426 Operation *perm_op = op->getOperand(1).getDefiningOp();
1427 if (!perm_op || !dialect_->IsConstant(perm_op)) return failure();
1428 ElementsAttr perm_tensor = perm_op->getAttrOfType<ElementsAttr>("value");
1429 if (!perm_tensor) return failure();
1430
1431 ShapedType x_shape = op->getOperand(0).getType().cast<ShapedType>();
1432 if (!x_shape.hasRank()) return failure();
1433 if (perm_tensor.getNumElements() != x_shape.getRank()) return failure();
1434
1435 for (unsigned i = 0; i < x_shape.getRank(); ++i) {
1436 int64_t value = perm_tensor.getElementType().isInteger(32)
1437 ? perm_tensor.getValues<int32_t>()[i]
1438 : perm_tensor.getValues<int64_t>()[i];
1439 if (value != i && x_shape.getShape()[i] != 1) return failure();
1440 }
1441
1442 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1443 if (failed(identity)) return failure();
1444 rewriter.replaceOp(op, (*identity)->getResults());
1445
1446 return success();
1447 }
1448 };
1449
1450 // This implementation is mapped with ConstantFolding::RemoveShuffleOrTranspose
1451 // in grappler/optimizers/constant_folding.cc
1452 class RemoveTransposeOp : public FolderPatternBase<RemoveTransposeOp> {
1453 public:
RemoveTransposeOp(OpPropertyHelper & helper)1454 explicit RemoveTransposeOp(OpPropertyHelper &helper)
1455 : FolderPatternBase<RemoveTransposeOp>("tfg.Transpose", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1456 LogicalResult matchAndRewrite(Operation *op,
1457 PatternRewriter &rewriter) const override {
1458 Operation *perm_op = op->getOperand(1).getDefiningOp();
1459 if (!perm_op || !dialect_->IsConstant(perm_op)) return failure();
1460 ElementsAttr perm_tensor = perm_op->getAttrOfType<ElementsAttr>("value");
1461 if (!perm_tensor) return failure();
1462
1463 ShapedType x_shape = op->getOperand(0).getType().cast<ShapedType>();
1464 if (!x_shape.hasRank()) return failure();
1465 if (perm_tensor.getNumElements() != x_shape.getRank()) return failure();
1466
1467 for (unsigned i = 0; i < x_shape.getRank(); ++i) {
1468 int64_t value = perm_tensor.getElementType().isInteger(32)
1469 ? perm_tensor.getValues<int32_t>()[i]
1470 : perm_tensor.getValues<int64_t>()[i];
1471 if (value != i && x_shape.getShape()[i] != 1) return failure();
1472 }
1473
1474 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1475 if (failed(identity)) return failure();
1476 rewriter.replaceOp(op, (*identity)->getResults());
1477
1478 return success();
1479 }
1480 };
1481
1482 // This implementation is mapped with ConstantFolding::RemoveRandomShuffle
1483 // in grappler/optimizers/constant_folding.cc
1484 class RemoveRandomShuffleOp : public FolderPatternBase<RemoveRandomShuffleOp> {
1485 public:
RemoveRandomShuffleOp(OpPropertyHelper & helper)1486 explicit RemoveRandomShuffleOp(OpPropertyHelper &helper)
1487 : FolderPatternBase<RemoveRandomShuffleOp>("tfg.RandomShuffle", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1488 LogicalResult matchAndRewrite(Operation *op,
1489 PatternRewriter &rewriter) const override {
1490 auto shape = op->getOperand(0).getType().cast<ShapedType>();
1491 if (!shape.hasRank()) return failure();
1492 if (shape.getRank() != 0 && shape.getShape()[0] != 1) return failure();
1493
1494 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1495 if (failed(identity)) return failure();
1496 rewriter.replaceOp(op, (*identity)->getResults());
1497
1498 return success();
1499 }
1500 };
1501
1502 // This implementation is mapped with ConstantFolding::RemoveReverse
1503 // in grappler/optimizers/constant_folding.cc
1504 class RemoveReverse : public FolderPatternBase<RemoveReverse> {
1505 public:
RemoveReverse(OpPropertyHelper & helper)1506 explicit RemoveReverse(OpPropertyHelper &helper)
1507 : FolderPatternBase<RemoveReverse>("tfg.ReverseV2", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1508 LogicalResult matchAndRewrite(Operation *op,
1509 PatternRewriter &rewriter) const override {
1510 ShapedType tensor_type = op->getOperand(0).getType().cast<ShapedType>();
1511 if (!tensor_type.hasRank()) return failure();
1512
1513 Operation *dim_op = op->getOperand(1).getDefiningOp();
1514 if (!dim_op || !dialect_->IsConstant(dim_op)) return failure();
1515
1516 auto dim_attr = dim_op->getAttrOfType<ElementsAttr>("value");
1517 DenseSet<int> target_axis;
1518 for (unsigned i = 0; i < dim_attr.getNumElements(); ++i) {
1519 // Value of axis can be negative.
1520 if (dim_attr.getElementType().isInteger(32)) {
1521 target_axis.insert(
1522 (dim_attr.getValues<int32_t>()[i] + tensor_type.getRank()) %
1523 tensor_type.getRank());
1524 } else {
1525 target_axis.insert(
1526 (dim_attr.getValues<int64_t>()[i] + tensor_type.getRank()) %
1527 tensor_type.getRank());
1528 }
1529 }
1530
1531 for (unsigned i = 0; i < tensor_type.getRank(); ++i) {
1532 if (tensor_type.getShape()[i] != 1 && target_axis.contains(i))
1533 return failure();
1534 }
1535
1536 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1537 if (failed(identity)) return failure();
1538 rewriter.replaceOp(op, (*identity)->getResults());
1539
1540 return success();
1541 }
1542 };
1543
1544 // This implementation is mapped with ConstantFolding::SimplifySlice
1545 // in grappler/optimizers/constant_folding.cc
1546 class SimplifySliceOp : public FolderPatternBase<SimplifySliceOp> {
1547 public:
SimplifySliceOp(OpPropertyHelper & helper)1548 explicit SimplifySliceOp(OpPropertyHelper &helper)
1549 : FolderPatternBase<SimplifySliceOp>("tfg.Slice", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1550 LogicalResult matchAndRewrite(Operation *op,
1551 PatternRewriter &rewriter) const override {
1552 Operation *begin_op = op->getOperand(1).getDefiningOp();
1553 Operation *size_op = op->getOperand(2).getDefiningOp();
1554 if (!begin_op || !size_op) return failure();
1555
1556 if (!dialect_->IsConstant(begin_op) || !dialect_->IsConstant(size_op))
1557 return failure();
1558
1559 auto begin_attr = begin_op->getAttrOfType<ElementsAttr>("value");
1560 auto size_attr = size_op->getAttrOfType<ElementsAttr>("value");
1561
1562 ShapedType input_type = op->getOperand(0).getType().cast<ShapedType>();
1563 if (!input_type.hasRank()) return failure();
1564
1565 for (unsigned i = 0; i < input_type.getRank(); ++i) {
1566 if (begin_attr.getElementType().isInteger(32)) {
1567 if (begin_attr.getValues<int32_t>()[i] != 0) return failure();
1568 } else {
1569 if (begin_attr.getValues<int64_t>()[i] != 0) return failure();
1570 }
1571
1572 if (size_attr.getElementType().isInteger(32)) {
1573 if (size_attr.getValues<int32_t>()[i] != -1 &&
1574 size_attr.getValues<int32_t>()[i] != input_type.getShape()[i])
1575 return failure();
1576 } else {
1577 if (size_attr.getValues<int64_t>()[i] != -1 &&
1578 size_attr.getValues<int64_t>()[i] != input_type.getShape()[i])
1579 return failure();
1580 }
1581 }
1582
1583 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1584 if (failed(identity)) return failure();
1585 rewriter.replaceOp(op, (*identity)->getResults());
1586
1587 return success();
1588 }
1589 };
1590
1591 // This implementation is mapped with ConstantFolding::SimplifyStridedSlice
1592 // in grappler/optimizers/constant_folding.cc
1593 class SimplifyStridedSlice : public FolderPatternBase<SimplifyStridedSlice> {
1594 public:
SimplifyStridedSlice(OpPropertyHelper & helper)1595 explicit SimplifyStridedSlice(OpPropertyHelper &helper)
1596 : FolderPatternBase<SimplifyStridedSlice>("tfg.StridedSlice", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1597 LogicalResult matchAndRewrite(Operation *op,
1598 PatternRewriter &rewriter) const override {
1599 // Skip ops with new/shrink axis mask, since they involve dimension changes.
1600 if (auto attr = op->getAttrOfType<IntegerAttr>("new_axis_mask")) {
1601 if (attr.getInt() != 0) return failure();
1602 } else {
1603 return failure();
1604 }
1605 if (auto attr = op->getAttrOfType<IntegerAttr>("shrink_axis_mask")) {
1606 if (attr.getInt() != 0) return failure();
1607 } else {
1608 return failure();
1609 }
1610
1611 auto begin_mask_attr = op->getAttrOfType<IntegerAttr>("begin_mask");
1612 auto end_mask_attr = op->getAttrOfType<IntegerAttr>("end_mask");
1613 auto ellipsis_mask_attr = op->getAttrOfType<IntegerAttr>("ellipsis_mask");
1614 if (!begin_mask_attr || !end_mask_attr || !ellipsis_mask_attr)
1615 return failure();
1616
1617 ShapedType input_type = op->getOperand(0).getType().cast<ShapedType>();
1618 if (!input_type.hasStaticShape()) return failure();
1619
1620 Operation *begin_op = op->getOperand(1).getDefiningOp();
1621 Operation *end_op = op->getOperand(2).getDefiningOp();
1622 Operation *strides_op = op->getOperand(3).getDefiningOp();
1623 if (!begin_op || !end_op || !strides_op) return failure();
1624
1625 if (!dialect_->IsConstant(begin_op) || !dialect_->IsConstant(end_op) ||
1626 !dialect_->IsConstant(strides_op))
1627 return failure();
1628
1629 ElementsAttr begin_attr = begin_op->getAttrOfType<ElementsAttr>("value");
1630 ElementsAttr end_attr = end_op->getAttrOfType<ElementsAttr>("value");
1631 ElementsAttr strides_attr =
1632 strides_op->getAttrOfType<ElementsAttr>("value");
1633
1634 const int64_t begin_mask = begin_mask_attr.getInt();
1635 const int64_t end_mask = end_mask_attr.getInt();
1636 const int64_t ellipsis_mask = ellipsis_mask_attr.getInt();
1637 const int64_t num_strides_elements = strides_attr.getNumElements();
1638
1639 DenseSet<int> expanded_ellipsis_indices;
1640 int ellipsis_index = -1;
1641
1642 for (unsigned i = 0; i < input_type.getRank(); ++i) {
1643 if (ellipsis_mask & 1 << i ||
1644 (ellipsis_index == -1 && i >= num_strides_elements)) {
1645 ellipsis_index = i;
1646 }
1647 if (ellipsis_index != -1 &&
1648 input_type.getRank() > num_strides_elements + i - ellipsis_index) {
1649 expanded_ellipsis_indices.insert(i);
1650 }
1651 }
1652
1653 for (unsigned i = 0; i < input_type.getRank(); ++i) {
1654 if (expanded_ellipsis_indices.contains(i)) {
1655 // ellipsis_mask is effective on current dimension.
1656 continue;
1657 }
1658
1659 int j = i;
1660 int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
1661 if (ellipsis_index != -1 &&
1662 i >= ellipsis_index + expanded_ellipsis_indices_size) {
1663 j = i - expanded_ellipsis_indices_size;
1664 }
1665 int b = begin_attr.getElementType().isInteger(32)
1666 ? begin_attr.getValues<int32_t>()[j]
1667 : begin_attr.getValues<int64_t>()[j];
1668 int e = end_attr.getElementType().isInteger(32)
1669 ? end_attr.getValues<int32_t>()[j]
1670 : end_attr.getValues<int64_t>()[j];
1671 int s = strides_attr.getElementType().isInteger(32)
1672 ? strides_attr.getValues<int32_t>()[j]
1673 : strides_attr.getValues<int64_t>()[j];
1674
1675 if (!(begin_mask & 1 << j || b == 0) ||
1676 !(end_mask & 1 << j || e == input_type.getShape()[i]) || s != 1) {
1677 return failure();
1678 }
1679 }
1680
1681 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1682 if (failed(identity)) return failure();
1683 rewriter.replaceOp(op, (*identity)->getResults());
1684
1685 return success();
1686 }
1687 };
1688
1689 // This implementation is mapped with ConstantFolding::SimplifyTile
1690 // in grappler/optimizers/constant_folding.cc
1691 class SimplifyTileOp : public FolderPatternBase<SimplifyTileOp> {
1692 public:
SimplifyTileOp(OpPropertyHelper & helper)1693 explicit SimplifyTileOp(OpPropertyHelper &helper)
1694 : FolderPatternBase<SimplifyTileOp>("tfg.Tile", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1695 LogicalResult matchAndRewrite(Operation *op,
1696 PatternRewriter &rewriter) const override {
1697 Operation *multiples_op = op->getOperand(1).getDefiningOp();
1698 if (!multiples_op || !dialect_->IsConstant(multiples_op)) return failure();
1699
1700 ElementsAttr multiples_attr =
1701 multiples_op->getAttrOfType<ElementsAttr>("value");
1702 if (multiples_attr.getElementType().isInteger(32)) {
1703 if (llvm::any_of(multiples_attr.getValues<int32_t>(),
1704 [](int v) { return v != 1; })) {
1705 return failure();
1706 }
1707 } else {
1708 if (llvm::any_of(multiples_attr.getValues<int64_t>(),
1709 [](int64_t v) { return v != 1; })) {
1710 return failure();
1711 }
1712 }
1713
1714 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1715 if (failed(identity)) return failure();
1716 rewriter.replaceOp(op, (*identity)->getResults());
1717
1718 return success();
1719 }
1720 };
1721
1722 // This implementation is mapped with ConstantFolding::SimplifyPad
1723 // in grappler/optimizers/constant_folding.cc
1724 template <typename ConcreteType>
1725 class SimplifyPadOpBase : public FolderPatternBase<ConcreteType> {
1726 protected:
SimplifyPadOpBase(StringRef op_name,OpPropertyHelper & helper)1727 SimplifyPadOpBase(StringRef op_name, OpPropertyHelper &helper)
1728 : FolderPatternBase<ConcreteType>(op_name, helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1729 LogicalResult matchAndRewrite(Operation *op,
1730 PatternRewriter &rewriter) const override {
1731 Operation *paddings = op->getOperand(1).getDefiningOp();
1732 if (!paddings || !this->dialect_->IsConstant(paddings)) return failure();
1733
1734 ElementsAttr paddings_attr = paddings->getAttrOfType<ElementsAttr>("value");
1735 if (paddings_attr.getElementType().isInteger(32)) {
1736 if (llvm::any_of(paddings_attr.getValues<int32_t>(),
1737 [](int v) { return v != 0; })) {
1738 return failure();
1739 }
1740 } else {
1741 if (llvm::any_of(paddings_attr.getValues<int64_t>(),
1742 [](int64_t v) { return v != 0; })) {
1743 return failure();
1744 }
1745 }
1746
1747 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1748 if (failed(identity)) return failure();
1749 rewriter.replaceOp(op, (*identity)->getResults());
1750
1751 return success();
1752 }
1753 };
1754
1755 // This implementation is mapped with ConstantFolding::SimplifyPad
1756 // in grappler/optimizers/constant_folding.cc
1757 class SimplifyPadOp : public SimplifyPadOpBase<SimplifyPadOp> {
1758 public:
SimplifyPadOp(OpPropertyHelper & helper)1759 explicit SimplifyPadOp(OpPropertyHelper &helper)
1760 : SimplifyPadOpBase("tfg.Pad", helper) {}
1761 };
1762
1763 // This implementation is mapped with ConstantFolding::SimplifyPad
1764 // in grappler/optimizers/constant_folding.cc
1765 class SimplifyPadV2Op : public SimplifyPadOpBase<SimplifyPadV2Op> {
1766 public:
SimplifyPadV2Op(OpPropertyHelper & helper)1767 explicit SimplifyPadV2Op(OpPropertyHelper &helper)
1768 : SimplifyPadOpBase("tfg.PadV2", helper) {}
1769 };
1770
1771 // This implementation is mapped with ConstantFolding::SimplifySqueeze
1772 // in grappler/optimizers/constant_folding.cc
1773 class SimplifySqueezeOp : public FolderPatternBase<SimplifySqueezeOp> {
1774 public:
SimplifySqueezeOp(OpPropertyHelper & helper)1775 explicit SimplifySqueezeOp(OpPropertyHelper &helper)
1776 : FolderPatternBase<SimplifySqueezeOp>("tfg.Squeeze", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1777 LogicalResult matchAndRewrite(Operation *op,
1778 PatternRewriter &rewriter) const override {
1779 auto shape_type = op->getOperand(0).getType().cast<ShapedType>();
1780 if (!shape_type.hasRank()) return failure();
1781 if (llvm::any_of(shape_type.getShape(), [](int64_t s) { return s <= 1; }))
1782 return failure();
1783
1784 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
1785 if (failed(identity)) return failure();
1786 rewriter.replaceOp(op, (*identity)->getResults());
1787
1788 return success();
1789 }
1790 };
1791
1792 // This implementation is mapped with ConstantFolding::SimplifyPack
1793 // in grappler/optimizers/constant_folding.cc
1794 class SimplifyPackOp : public FolderPatternBase<SimplifyPackOp> {
1795 public:
SimplifyPackOp(OpPropertyHelper & helper)1796 explicit SimplifyPackOp(OpPropertyHelper &helper)
1797 : FolderPatternBase<SimplifyPackOp>("tfg.Pack", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1798 LogicalResult matchAndRewrite(Operation *op,
1799 PatternRewriter &rewriter) const override {
1800 auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
1801 if (non_control_operands.size() != 1) return failure();
1802
1803 // It's unsafe to add a control dependency on the feed node, because it
1804 // might have been never executed otherwiwise.
1805 if (non_control_operands[0].isa<BlockArgument>()) return failure();
1806
1807 IntegerAttr axis = op->getAttrOfType<IntegerAttr>("axis");
1808 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
1809 rewriter.getIntegerType(32), /*shape=*/{},
1810 ArrayRef<int>(axis ? axis.getInt() : 0));
1811 // CreateConstantTensorOp cannot fail
1812 TFOp const_op = *CreateConstantTensorOp(
1813 rewriter, op->getLoc(), TFOp(op).name(), const_attr.getType(),
1814 GetControlDependency(rewriter, op->getOperand(0)), const_attr);
1815
1816 const_op.setName(Twine(TFOp(op).name(), "/_const_axis"));
1817 if (!TFOp(op).device().empty())
1818 const_op.setRequestedDevice(TFOp(op).deviceAttr());
1819
1820 OperationState state(op->getLoc(), "tfg.ExpandDims");
1821 state.addTypes(op->getResultTypes());
1822
1823 state.attributes = op->getAttrDictionary();
1824 state.attributes.erase("axis");
1825 state.attributes.erase("N");
1826 state.addAttribute("Tdim", TypeAttr::get(rewriter.getI32Type()));
1827
1828 state.addOperands({op->getOperand(0), const_op->getResult(0)});
1829 state.addOperands(control_operands);
1830 Operation *expand_dims_op = rewriter.create(state);
1831 rewriter.replaceOp(op, expand_dims_op->getResults());
1832 return success();
1833 }
1834 };
1835
1836 // This implementation is mapped with ConstantFolding::MoveConstantsPastEnter
1837 // in grappler/optimizers/constant_folding.cc
1838 template <typename ConcreteType>
1839 class MoveConstantsPastEnterOpBase
1840 : public PropagationPatternBase<ConcreteType> {
1841 protected:
MoveConstantsPastEnterOpBase(StringRef op_name,OpPropertyHelper & helper)1842 MoveConstantsPastEnterOpBase(StringRef op_name, OpPropertyHelper &helper)
1843 : PropagationPatternBase<ConcreteType>(op_name, helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1844 LogicalResult matchAndRewrite(Operation *op,
1845 PatternRewriter &rewriter) const override {
1846 auto is_constant_attr = op->getAttrOfType<BoolAttr>("is_constant");
1847 if (!is_constant_attr || !is_constant_attr.getValue()) return failure();
1848
1849 Operation *input = op->getOperand(0).getDefiningOp();
1850 if (!input || !this->dialect_->IsConstant(input)) return failure();
1851
1852 // Find non-constant nodes that consume the outputs of Enter.
1853 if (op->getResults()[0].use_empty()) return failure();
1854
1855 FailureOr<TFOp> cloned_const_op = CreateConstantTensorOp(
1856 rewriter, op->getLoc(), TFOp(op).name(), *(input->result_type_begin()),
1857 TFOp(op).controlRet(), input->getAttr("value"), input->getAttrs());
1858 if (failed(cloned_const_op)) return failure();
1859
1860 (*cloned_const_op).setName(Twine(TFOp(op).name(), "/_enter"));
1861 if (!TFOp(op).device().empty())
1862 (*cloned_const_op).setRequestedDevice(TFOp(op).deviceAttr());
1863
1864 rewriter.startRootUpdate(op);
1865 op->getResults()[0].replaceAllUsesWith((*cloned_const_op)->getResults()[0]);
1866 rewriter.finalizeRootUpdate(op);
1867 return success();
1868 }
1869 };
1870
1871 // This implementation is mapped with ConstantFolding::MoveConstantsPastEnter
1872 // in grappler/optimizers/constant_folding.cc
1873 class MoveConstantsPastEnterOp
1874 : public MoveConstantsPastEnterOpBase<MoveConstantsPastEnterOp> {
1875 public:
MoveConstantsPastEnterOp(OpPropertyHelper & helper)1876 explicit MoveConstantsPastEnterOp(OpPropertyHelper &helper)
1877 : MoveConstantsPastEnterOpBase("tfg.Enter", helper) {}
1878 };
1879
1880 // This implementation is mapped with ConstantFolding::MoveConstantsPastEnter
1881 // in grappler/optimizers/constant_folding.cc
1882 class MoveConstantsPastRefEnterOp
1883 : public MoveConstantsPastEnterOpBase<MoveConstantsPastRefEnterOp> {
1884 public:
MoveConstantsPastRefEnterOp(OpPropertyHelper & helper)1885 explicit MoveConstantsPastRefEnterOp(OpPropertyHelper &helper)
1886 : MoveConstantsPastEnterOpBase("tfg.RefEnter", helper) {}
1887 };
1888
1889 // This implementation is mapped with ConstantFolding::SimplifySwitch
1890 // in grappler/optimizers/constant_folding.cc
1891 class SimplifySwitchOp : public PropagationPatternBase<SimplifySwitchOp> {
1892 public:
SimplifySwitchOp(OpPropertyHelper & helper)1893 explicit SimplifySwitchOp(OpPropertyHelper &helper)
1894 : PropagationPatternBase<SimplifySwitchOp>("tfg.Switch", helper),
1895 zero_dim_i1_tensor_type_(RankedTensorType::get(
1896 {}, IntegerType::get(helper.getDialect()->getContext(), 1))) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1897 LogicalResult matchAndRewrite(Operation *op,
1898 PatternRewriter &rewriter) const override {
1899 if (op->getOperand(0) != op->getOperand(1)) return failure();
1900
1901 // If the optimization was already applied, the switch would have exactly
1902 // one Identity node consuming each of its outputs, each without any
1903 // non-control outputs.
1904 // TODO(tlongeri): This does not hold anymore as other patterns may need to
1905 // introduce an anchor. Fix this check, and handle both sides independently.
1906 if (llvm::any_of(op->getResults().drop_back(), [&](Value res) {
1907 return res.hasOneUse() &&
1908 IsControlAnchor(*res.getUsers().begin(), dialect_);
1909 })) {
1910 return failure();
1911 }
1912
1913 TFOp true_control_identity =
1914 GetControlAnchorForSwitchResult(rewriter, op->getResult(1), dialect_);
1915 TFOp false_control_identity =
1916 GetControlAnchorForSwitchResult(rewriter, op->getResult(0), dialect_);
1917
1918 FailureOr<TFOp> true_op = CreateConstantTensorOp(
1919 rewriter, op->getLoc(), TFOp(op).name(), op->getResultTypes()[1],
1920 true_control_identity.controlRet(),
1921 DenseElementsAttr::get(zero_dim_i1_tensor_type_, true));
1922 if (failed(true_op)) return failure();
1923
1924 (*true_op).setName(Twine(TFOp(op).name(), "/_const_true"));
1925 if (!TFOp(op).device().empty())
1926 (*true_op).setRequestedDevice(TFOp(op).device());
1927
1928 FailureOr<TFOp> false_op = CreateConstantTensorOp(
1929 rewriter, op->getLoc(), TFOp(op).name(), op->getResultTypes()[0],
1930 false_control_identity.controlRet(),
1931 DenseElementsAttr::get(zero_dim_i1_tensor_type_, false));
1932 if (failed(false_op)) return failure();
1933
1934 (*false_op).setName(Twine(TFOp(op).name(), "/_const_false"));
1935 if (!TFOp(op).device().empty())
1936 (*false_op).setRequestedDevice(TFOp(op).device().data());
1937
1938 // Note that we can't use replaceAllUsesWith here because we don't want to
1939 // replace the user of control identity.
1940 for (OpOperand &user :
1941 llvm::make_early_inc_range(op->getResult(1).getUses())) {
1942 if (user.getOwner() == &(*true_control_identity)) continue;
1943
1944 rewriter.startRootUpdate(user.getOwner());
1945 user.set((*true_op)->getResult(0));
1946 rewriter.finalizeRootUpdate(user.getOwner());
1947 }
1948 for (OpOperand &user :
1949 llvm::make_early_inc_range(op->getResult(0).getUses())) {
1950 if (user.getOwner() == &(*false_control_identity)) continue;
1951
1952 rewriter.startRootUpdate(user.getOwner());
1953 user.set((*false_op)->getResult(0));
1954 rewriter.finalizeRootUpdate(user.getOwner());
1955 }
1956
1957 return success();
1958 }
1959
1960 RankedTensorType zero_dim_i1_tensor_type_;
1961 };
1962
1963 // This implementation is mapped with ConstantFolding::SimplifyReduction
1964 // in grappler/optimizers/constant_folding.cc
1965 class SimplifyReductionOp : public FolderPatternBase<SimplifyReductionOp> {
1966 public:
SimplifyReductionOp(OpPropertyHelper & helper)1967 explicit SimplifyReductionOp(OpPropertyHelper &helper)
1968 : FolderPatternBase<SimplifyReductionOp>(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const1969 LogicalResult matchAndRewrite(Operation *op,
1970 PatternRewriter &rewriter) const override {
1971 if (!dialect_->IsReduction(op)) return failure();
1972
1973 Operation *reduction_indices = op->getOperand(1).getDefiningOp();
1974 if (!reduction_indices) return failure();
1975
1976 ShapedType indices_type = *(reduction_indices->result_type_begin());
1977 if (indices_type.hasStaticShape() && indices_type.getNumElements() == 0) {
1978 Operation *identity_op = ReplaceReductionWithIdentity(rewriter, op);
1979 if (!identity_op) return failure();
1980
1981 rewriter.replaceOp(op, identity_op->getResults());
1982 return success();
1983 }
1984
1985 // Check `IsReductionCandidateForSimplification`
1986 auto input_type = op->getOperand(0).getType().cast<ShapedType>();
1987 auto op_type = (*op->result_type_begin()).cast<ShapedType>();
1988 if (!input_type.hasStaticShape() || !op_type.hasStaticShape())
1989 return failure();
1990
1991 bool is_single_element_op =
1992 (input_type.getNumElements() == 1) &&
1993 (op_type.hasStaticShape() && op_type.getNumElements() == 1);
1994
1995 bool keep_dims = false;
1996 if (auto attr = op->getAttrOfType<BoolAttr>("keep_dims")) {
1997 keep_dims = attr.getValue();
1998 }
1999 bool simplifiable_to_reshape =
2000 is_single_element_op && !keep_dims && op->hasAttr("T");
2001
2002 bool simplifiable_to_identity = keep_dims;
2003 // In grappler, they call EvaluateNode() to try to get the constant value of
2004 // reduction indices. But if it is a constant, then the EvaluationConstant
2005 // will have folded it. So we don't need to evalute the node here.
2006 if (dialect_->IsConstant(reduction_indices)) {
2007 ElementsAttr reduction_indices_attr =
2008 reduction_indices->getAttrOfType<ElementsAttr>("value");
2009
2010 if (reduction_indices_attr.getElementType().isInteger(32)) {
2011 for (int v : reduction_indices_attr.getValues<int32_t>()) {
2012 if (v < 0) v += input_type.getRank();
2013 if (v < 0 || v >= input_type.getRank() ||
2014 input_type.getShape()[v] != 1)
2015 simplifiable_to_identity = false;
2016 }
2017 } else {
2018 for (int64_t v : reduction_indices_attr.getValues<int64_t>()) {
2019 if (v < 0) v += input_type.getRank();
2020 if (v < 0 || v >= input_type.getRank() ||
2021 input_type.getShape()[v] != 1)
2022 simplifiable_to_identity = false;
2023 }
2024 }
2025 }
2026
2027 if (simplifiable_to_reshape) {
2028 Operation *reshape_op =
2029 ReplaceReductionWithReshape(rewriter, op, reduction_indices);
2030 if (!reshape_op) return failure();
2031
2032 rewriter.replaceOp(op, reshape_op->getResults());
2033 } else if (simplifiable_to_identity) {
2034 Operation *identity_op = ReplaceReductionWithIdentity(rewriter, op);
2035 if (!identity_op) return failure();
2036
2037 rewriter.replaceOp(op, identity_op->getResults());
2038 } else {
2039 return failure();
2040 }
2041
2042 return success();
2043 }
2044
2045 private:
ReplaceReductionWithReshape(OpBuilder & builder,Operation * op,Operation * reduction_indices) const2046 Operation *ReplaceReductionWithReshape(OpBuilder &builder, Operation *op,
2047 Operation *reduction_indices) const {
2048 const int new_num_dimensions =
2049 (*op->result_type_begin()).cast<ShapedType>().getRank();
2050 SmallVector<int64_t> elements(new_num_dimensions);
2051 std::iota(elements.begin(), elements.end(), 1);
2052 ElementsAttr const_attr = CreateElementsAttrOfTypeValues(
2053 builder.getIntegerType(32), {new_num_dimensions},
2054 llvm::makeArrayRef(elements));
2055 FailureOr<TFOp> const_op = CreateConstantTensorOp(
2056 builder, op->getLoc(), TFOp(op).name(),
2057 *(reduction_indices->result_type_begin()),
2058 TFOp(reduction_indices).controlRet(), const_attr);
2059 if (failed(const_op)) return nullptr;
2060
2061 (*const_op).setName(Twine(TFOp(op).name(), "/_shape_const"));
2062 if (!TFOp(op).device().empty())
2063 (*const_op).setRequestedDevice(TFOp(op).deviceAttr());
2064
2065 OperationState state(op->getLoc(), "tfg.Reshape");
2066 state.attributes = op->getAttrDictionary();
2067 state.attributes.erase("keep_dims");
2068 state.attributes.erase("Tidx");
2069 state.addAttribute("Tshape", TypeAttr::get(builder.getI32Type()));
2070
2071 state.addOperands(op->getOperands());
2072 state.operands[1] = (*const_op)->getResult(0);
2073 state.addTypes(op->getResultTypes());
2074
2075 Operation *reshape_op = builder.create(state);
2076 TFOp(reshape_op).setName(TFOp(op).nameAttr());
2077 if (!TFOp(op).device().empty())
2078 TFOp(reshape_op).setRequestedDevice(TFOp(op).deviceAttr());
2079 return reshape_op;
2080 }
2081
ReplaceReductionWithIdentity(OpBuilder & builder,Operation * op) const2082 Operation *ReplaceReductionWithIdentity(OpBuilder &builder,
2083 Operation *op) const {
2084 OperationState state(op->getLoc(), "tfg.Identity");
2085 Type t_attr_type;
2086 if (auto T_attr = op->getAttrOfType<TypeAttr>("T"))
2087 t_attr_type = T_attr.getValue();
2088 else if (dialect_->IsAny(op) || dialect_->IsAll(op))
2089 t_attr_type = builder.getI1Type();
2090 else
2091 return nullptr;
2092 state.attributes = op->getAttrDictionary();
2093 util::EraseRegularNodeAttributes(state.attributes);
2094 state.addAttribute("T", TypeAttr::get(t_attr_type));
2095 state.addTypes(op->getResultTypes());
2096 state.addOperands(
2097 {op->getOperand(0), GetControlDependency(builder, op->getOperand(1))});
2098
2099 Operation *identity_op = builder.create(state);
2100 TFOp(identity_op).setName(TFOp(op).nameAttr());
2101 if (!TFOp(op).device().empty())
2102 TFOp(identity_op).setRequestedDevice(TFOp(op).deviceAttr());
2103 return identity_op;
2104 }
2105 };
2106
2107 // This implementation is mapped with ConstantFolding::SimplifyReshapeOp
2108 // in grappler/optimizers/constant_folding.cc
2109 class SimplifyReshapeOp : public FolderPatternBase<SimplifyReshapeOp> {
2110 public:
SimplifyReshapeOp(OpPropertyHelper & helper)2111 explicit SimplifyReshapeOp(OpPropertyHelper &helper)
2112 : FolderPatternBase<SimplifyReshapeOp>(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2113 LogicalResult matchAndRewrite(Operation *op,
2114 PatternRewriter &rewriter) const override {
2115 if (!dialect_->IsReshape(op) || !op->hasAttr("T")) return failure();
2116
2117 auto input_shape = op->getOperand(0).getType().cast<ShapedType>();
2118 if (!input_shape.hasStaticShape()) return failure();
2119
2120 Operation *shape_op = op->getOperand(1).getDefiningOp();
2121 if (!shape_op || !dialect_->IsConstant(shape_op)) return failure();
2122
2123 auto shape_attr = shape_op->getAttrOfType<ElementsAttr>("value");
2124 // TODO(tlongeri): only reason for SmallVector instead of range directly is
2125 // that llvm::zip implementation requires copy assignment (it shouldn't)
2126 SmallVector<APInt> new_shape(shape_attr.getValues<APInt>());
2127
2128 if (input_shape.getRank() != new_shape.size()) return failure();
2129 for (const auto &it : llvm::zip(input_shape.getShape(), new_shape)) {
2130 int64_t dim_0 = std::get<0>(it);
2131 int64_t dim_1 = std::get<1>(it).getSExtValue();
2132 if (dim_0 >= 0 && dim_1 >= 0 && dim_0 != dim_1) return failure();
2133 }
2134
2135 OperationState state(op->getLoc(), "tfg.Identity");
2136 state.addTypes(op->getResultTypes());
2137 state.addOperands(
2138 {op->getOperand(0), GetControlDependency(rewriter, op->getOperand(1))});
2139 state.addOperands(TFOp(op).getControlOperands());
2140
2141 state.attributes = op->getAttrDictionary();
2142 util::EraseRegularNodeAttributes(state.attributes);
2143 state.addAttribute("T", op->getAttrOfType<TypeAttr>("T"));
2144
2145 Operation *identity_op = rewriter.create(state);
2146 TFOp(identity_op).setName(TFOp(op).nameAttr());
2147 if (!TFOp(op).device().empty())
2148 TFOp(identity_op).setRequestedDevice(TFOp(op).deviceAttr());
2149 rewriter.replaceOp(op, identity_op->getResults());
2150
2151 return success();
2152 }
2153 };
2154
2155 // This implementation is mapped with
2156 // ConstantFolding::SimplifyArithmeticOperations in
2157 // grappler/optimizers/constant_folding.cc
2158 class SimplifyArithmeticOp
2159 : public ConstantPatternBase<SimplifyArithmeticOp, FolderTrait,
2160 PropagationTrait> {
2161 public:
SimplifyArithmeticOp(OpPropertyHelper & helper)2162 explicit SimplifyArithmeticOp(OpPropertyHelper &helper)
2163 : ConstantPatternBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2164 LogicalResult matchAndRewrite(Operation *op,
2165 PatternRewriter &rewriter) const override {
2166 const bool is_mul = dialect_->IsAnyMul(op) || dialect_->IsLogicalAnd(op);
2167 const bool is_matmul = dialect_->IsAnyMatMul(op);
2168 const bool is_add = dialect_->IsAdd(op) || dialect_->IsBiasAdd(op) ||
2169 dialect_->IsLogicalOr(op);
2170 const bool is_sub = dialect_->IsSub(op);
2171 const bool is_any_div = dialect_->IsAnyDiv(op) && !dialect_->IsFloorDiv(op);
2172
2173 if (!is_mul && !is_matmul && !is_add && !is_sub && !is_any_div)
2174 return failure();
2175
2176 Operation *x = op->getOperand(0).getDefiningOp();
2177 Operation *y = op->getOperand(1).getDefiningOp();
2178 if (!x || !y) return failure();
2179
2180 ShapedType op_type = (*op->result_type_begin()).cast<ShapedType>();
2181 ShapedType x_type = (*x->result_type_begin()).cast<ShapedType>();
2182 ShapedType y_type = (*y->result_type_begin()).cast<ShapedType>();
2183
2184 const bool y_matches_output_shape = op_type == y_type;
2185 const bool x_matches_output_shape = op_type == x_type;
2186
2187 const bool x_is_zero = helper_.IsZeros(x);
2188 const bool x_is_one = x_is_zero ? false : helper_.IsOnes(x);
2189
2190 // TODO(chiahungduan): Check if the optimizations has been applied.
2191
2192 if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
2193 // 1 * y = y or 0 + y = y.
2194 if (y_matches_output_shape) {
2195 FailureOr<TFOp> snapshot_op =
2196 ReplaceOperationWithSnapshot(rewriter, op, 1);
2197 if (failed(snapshot_op)) return failure();
2198 rewriter.replaceOp(op, (*snapshot_op)->getResults());
2199 return success();
2200 } else if (x_matches_output_shape) {
2201 FailureOr<TFOp> broadcast_to_op =
2202 ReplaceOperationWithBroadcastTo(rewriter, op, 1);
2203 rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2204 return success();
2205 }
2206 return failure();
2207 }
2208
2209 if (y_matches_output_shape && (is_sub && x_is_zero)) {
2210 // Replace 0 - y with Neg(y).
2211 OperationState state(op->getLoc(), "tfg.Neg");
2212 state.addOperands({op->getOperand(1),
2213 GetControlDependency(rewriter, op->getOperand(0))});
2214 state.addOperands(TFOp(op).getControlOperands());
2215 state.attributes = op->getAttrDictionary();
2216 state.addTypes(op->getResultTypes());
2217 Operation *neg = rewriter.create(state);
2218 rewriter.replaceOp(op, neg->getResults());
2219 return success();
2220 }
2221
2222 // Replace 1 / y with Reciprocal op.
2223 if (y_matches_output_shape && is_any_div && x_is_one) {
2224 TypeAttr type_attr = op->getAttrOfType<TypeAttr>("T");
2225 if (!type_attr) return failure();
2226
2227 if (type_attr.getValue().isa<FloatType>() ||
2228 type_attr.getValue().isa<ComplexType>()) {
2229 OperationState state(op->getLoc(), "tfg.Reciprocal");
2230 state.addOperands({op->getOperand(1),
2231 GetControlDependency(rewriter, op->getOperand(0))});
2232 state.addOperands(TFOp(op).getControlOperands());
2233 state.attributes = op->getAttrDictionary();
2234 state.addTypes(op->getResultTypes());
2235 Operation *reciprocal_op = rewriter.create(state);
2236 rewriter.replaceOp(op, reciprocal_op->getResults());
2237 return success();
2238 }
2239 }
2240
2241 const bool y_is_zero = helper_.IsZeros(y);
2242 const bool y_is_one = helper_.IsOnes(y);
2243
2244 if (((is_mul || is_any_div) && y_is_one) ||
2245 ((is_add || is_sub) && y_is_zero)) {
2246 // x * 1 = x or x / 1 = x or x +/- 0 = x
2247 if (x_matches_output_shape) {
2248 FailureOr<TFOp> snapshot_op =
2249 ReplaceOperationWithSnapshot(rewriter, op, 0);
2250 if (failed(snapshot_op)) return failure();
2251 rewriter.replaceOp(op, (*snapshot_op)->getResults());
2252 return success();
2253 } else if (y_matches_output_shape) {
2254 FailureOr<TFOp> broadcast_to_op =
2255 ReplaceOperationWithBroadcastTo(rewriter, op, 0);
2256 if (failed(broadcast_to_op)) return failure();
2257 rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2258 return success();
2259 }
2260 return failure();
2261 }
2262
2263 // x OR true = true OR y = true.
2264 if (op_type.hasStaticShape() && dialect_->IsLogicalOr(op) &&
2265 (y_is_one || x_is_one)) {
2266 FailureOr<TFOp> const_op = ReplaceOperationWithConstant(rewriter, op, 1);
2267 if (failed(const_op)) return failure();
2268 rewriter.replaceOp(op, (*const_op)->getResults());
2269 return success();
2270 }
2271
2272 // TFG optimizer doesn't support aggrasive mode.
2273 const bool is_aggressive = false;
2274 // Note that this is always false because of `is_aggressive`. Keep it in
2275 // this form to alleviate the effort of comparing the logic with the same
2276 // logic in grappler.
2277 bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
2278 if ((x_is_zero || y_is_zero) &&
2279 (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
2280 if (op_type.hasStaticShape()) {
2281 bool is_quantized = dialect_->IsQuantizedMatMul(op);
2282 if (is_quantized) {
2283 // TODO(chiahungduan): AddQuantizedMatMulMinMaxOutConstNodes
2284 return failure();
2285 }
2286
2287 FailureOr<TFOp> const_op =
2288 ReplaceOperationWithConstant(rewriter, op, 0);
2289 if (failed(const_op)) return failure();
2290
2291 rewriter.replaceOp(op, (*const_op)->getResults());
2292 return success();
2293 }
2294
2295 if ((is_mul || is_any_div) && x_is_zero) {
2296 if (x_matches_output_shape) {
2297 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
2298 if (failed(identity)) return failure();
2299 rewriter.replaceOp(op, (*identity)->getResults());
2300 return success();
2301 } else if (y_matches_output_shape) {
2302 FailureOr<TFOp> broadcast_to_op =
2303 ReplaceOperationWithBroadcastTo(rewriter, op, 0);
2304 if (failed(broadcast_to_op)) return failure();
2305 rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2306 return success();
2307 }
2308 } else if (is_mul && y_is_zero) {
2309 if (y_matches_output_shape) {
2310 FailureOr<TFOp> identity = ReplaceOpWithIdentity(rewriter, op, 0);
2311 if (failed(identity)) return failure();
2312 rewriter.replaceOp(op, (*identity)->getResults());
2313 return success();
2314 } else if (x_matches_output_shape) {
2315 FailureOr<TFOp> broadcast_to_op =
2316 ReplaceOperationWithBroadcastTo(rewriter, op, 1);
2317 if (failed(broadcast_to_op)) return failure();
2318 rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
2319 return success();
2320 }
2321 }
2322 }
2323
2324 return failure();
2325 }
2326 };
2327
2328 // This implementation is mapped with ConstantFolding::ReduceDivToReciprocalMul
2329 // in grappler/optimizers/constant_folding.cc
2330 class ReduceDivToReciprocalMul
2331 : public FolderPatternBase<ReduceDivToReciprocalMul> {
2332 public:
ReduceDivToReciprocalMul(OpPropertyHelper & helper)2333 explicit ReduceDivToReciprocalMul(OpPropertyHelper &helper)
2334 : FolderPatternBase<ReduceDivToReciprocalMul>(MatchAnyOpTypeTag(),
2335 helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2336 LogicalResult matchAndRewrite(Operation *op,
2337 PatternRewriter &rewriter) const override {
2338 // Strength reduce floating point division by a constant Div(x, const) to
2339 // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
2340 // will be constant folded to Mul(x, 1.0/const).
2341 if (!dialect_->IsDiv(op) && !dialect_->IsRealDiv(op) &&
2342 !dialect_->IsXdivy(op)) {
2343 return failure();
2344 }
2345
2346 Operation *y = op->getOperand(1).getDefiningOp();
2347 if (!y || !dialect_->IsConstant(y)) return failure();
2348
2349 TypeAttr type_attr = op->getAttrOfType<TypeAttr>("T");
2350 if (!type_attr) return failure();
2351
2352 // Skip integer division.
2353 if (dialect_->IsDiv(op) && !(type_attr.getValue().isa<FloatType>() ||
2354 type_attr.getValue().isa<ComplexType>())) {
2355 return failure();
2356 }
2357
2358 OperationState state(op->getLoc(), "tfg.Reciprocal");
2359 state.addOperands(y->getResult(0));
2360 state.addTypes({*(y->result_type_begin()), ControlType::get(getContext())});
2361 state.addAttribute("T", type_attr);
2362 TFOp reciprocal_op = rewriter.create(state);
2363 reciprocal_op.setName(Twine(TFOp(op).name(), "/") +
2364 Twine(TFOp(y).name(), "/_recip"));
2365 if (!TFOp(op).device().empty())
2366 reciprocal_op.setRequestedDevice(TFOp(op).deviceAttr());
2367
2368 StringRef new_op_name = dialect_->IsXdivy(op) ? "tfg.MulNoNan" : "tfg.Mul";
2369 OperationState new_op_state(op->getLoc(), new_op_name);
2370
2371 if (dialect_->IsXdivy(op)) {
2372 new_op_state.addOperands(
2373 {reciprocal_op->getResult(0), op->getOperand(0)});
2374 } else {
2375 new_op_state.addOperands(
2376 {op->getOperand(0), reciprocal_op->getResult(0)});
2377 }
2378 new_op_state.addOperands(TFOp(op).getControlOperands());
2379
2380 new_op_state.attributes = op->getAttrDictionary();
2381 new_op_state.addTypes(op->getResultTypes());
2382
2383 Operation *new_op = rewriter.create(new_op_state);
2384 rewriter.replaceOp(op, new_op->getResults());
2385
2386 return success();
2387 }
2388 };
2389
2390 namespace {
2391 template <typename ConcreteType>
2392 using Base = ConstantPatternBase<ConcreteType, FolderTrait, PropagationTrait>;
2393
2394 template <typename ConcreteType>
2395 class ConstantPushDownBase : public Base<ConcreteType> {
2396 protected:
2397 using Base<ConcreteType>::Base;
2398
IsOperandsSafeToMove(Operation * op_child,Operation * const_child) const2399 bool IsOperandsSafeToMove(Operation *op_child, Operation *const_child) const {
2400 // Don't rewrite the tree if it might create cycles.
2401 // TODO(chiahungduan): Remove the control dependency which may create
2402 // cycles.
2403 if (llvm::any_of(
2404 TFOp(const_child).getControlOperands(),
2405 [op_child](Value v) { return v.getDefiningOp() == op_child; })) {
2406 return false;
2407 }
2408
2409 // Move operands may change the result shapes, only do it when there's one
2410 // user for each of non control return values.
2411 if (llvm::any_of(op_child->getResults().drop_back(),
2412 [](Value v) { return !v.hasOneUse(); })) {
2413 return false;
2414 }
2415 return true;
2416 }
2417 };
2418 } // namespace
2419
2420 // Consider the transformation
2421 //
2422 // + + = parent
2423 // / \ / \
2424 // C + -- > X + = children
2425 // / \ / \
2426 // X Y C Y = leaves
2427 //
2428 // where C is constant, X is non-constant, Y may be constant or non-constant,
2429 // and '+' denotes an associative and commutative operator like addition or
2430 // multiplication. This optimization pushes constants down in the tree to
2431 // canonicalize it. Moreover, in cases where the child node has a second
2432 // constant input Y we will create a leaf node that can be folded, e.g.
2433 //
2434 // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
2435 //
2436 // We also handle the non-commutative cases of subtraction and division
2437 // by rotating the tree locally, e.g.
2438 // Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
2439 // Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
2440 class ConstantPushDown : public ConstantPushDownBase<ConstantPushDown> {
2441 public:
ConstantPushDown(OpPropertyHelper & helper)2442 explicit ConstantPushDown(OpPropertyHelper &helper)
2443 : ConstantPushDownBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2444 LogicalResult matchAndRewrite(Operation *op,
2445 PatternRewriter &rewriter) const override {
2446 // Get parent op type.
2447 const bool is_add = dialect_->IsAdd(op);
2448 const bool is_mul = dialect_->IsMul(op);
2449 const bool is_sub = dialect_->IsSub(op);
2450 const bool is_div = dialect_->IsDiv(op);
2451 if (!(is_add || is_sub || is_mul || is_div)) return failure();
2452 const bool is_symmetric = is_add || is_mul;
2453
2454 Operation *child_op = op->getOperand(0).getDefiningOp();
2455 Operation *const_op = op->getOperand(1).getDefiningOp();
2456 if (!child_op || !const_op) return failure();
2457
2458 // Don't move nodes across devices.
2459 if (TFOp(op).deviceAttr() != TFOp(child_op).deviceAttr() ||
2460 TFOp(op).deviceAttr() != TFOp(const_op).deviceAttr()) {
2461 return failure();
2462 }
2463
2464 const bool left_child_is_const = dialect_->IsConstant(child_op);
2465
2466 // One of the child op has to be constant.
2467 if (!dialect_->IsConstant(const_op)) std::swap(child_op, const_op);
2468 if (!dialect_->IsConstant(const_op)) return failure();
2469 if (helper_.ShouldPreserveOp(child_op)) return failure();
2470
2471 if (!IsOperandsSafeToMove(child_op, const_op)) return failure();
2472
2473 // Get child op type.
2474 const bool is_child_add = dialect_->IsAdd(child_op);
2475 const bool is_child_mul = dialect_->IsMul(child_op);
2476 const bool is_child_sub = dialect_->IsSub(child_op);
2477 const bool is_child_div = dialect_->IsDiv(child_op);
2478 const bool is_add_sub =
2479 (is_add || is_sub) && (is_child_add || is_child_sub);
2480 const bool is_mul_div =
2481 (is_mul || is_div) && (is_child_mul || is_child_div);
2482 if (!is_add_sub && !is_mul_div) return failure();
2483
2484 const bool is_child_symmetric = is_child_add || is_child_mul;
2485
2486 TypeAttr t_attr = op->getAttrOfType<TypeAttr>("T");
2487 if (!t_attr) return failure();
2488
2489 if (!(is_symmetric && is_child_symmetric) &&
2490 t_attr.getValue().isIntOrIndex()) {
2491 return failure();
2492 }
2493
2494 Operation *left_leaf_op = child_op->getOperand(0).getDefiningOp();
2495 Operation *right_leaf_op = child_op->getOperand(1).getDefiningOp();
2496 if (!left_leaf_op || !right_leaf_op) return failure();
2497
2498 // Don't move nodes across devices.
2499 if (TFOp(op).deviceAttr() != TFOp(left_leaf_op).deviceAttr() ||
2500 TFOp(op).deviceAttr() != TFOp(right_leaf_op).deviceAttr()) {
2501 return failure();
2502 }
2503
2504 const bool left_leaf_is_const = dialect_->IsConstant(left_leaf_op);
2505 Operation *y_node = left_leaf_is_const ? left_leaf_op : right_leaf_op;
2506
2507 if (!dialect_->IsConstant(y_node)) {
2508 // If we know the shapes of the nodes being swapped, make sure we don't
2509 // push down a larger node and create more work by broadcasting earlier
2510 // in the expressions tree.
2511 auto c_shape = op->getOperand((left_child_is_const ? 0 : 1))
2512 .getType()
2513 .cast<ShapedType>();
2514 auto x_shape = child_op->getOperand((left_leaf_is_const ? 0 : 1))
2515 .getType()
2516 .cast<ShapedType>();
2517
2518 if (c_shape.hasStaticShape() && x_shape.hasStaticShape() &&
2519 c_shape.getNumElements() > x_shape.getNumElements()) {
2520 return failure();
2521 }
2522 if (c_shape.hasRank() && x_shape.hasRank() && c_shape.getRank() > 0) {
2523 for (auto it : llvm::zip(c_shape.getShape(), x_shape.getShape())) {
2524 int c_dim = std::get<0>(it);
2525 int x_dim = std::get<1>(it);
2526 if (x_dim >= 0 && c_dim > x_dim) return failure();
2527 }
2528 }
2529 }
2530
2531 // Child input
2532 Operation *input_x = left_leaf_is_const
2533 ? child_op->getOperand(1).getDefiningOp()
2534 : child_op->getOperand(0).getDefiningOp();
2535 Operation *input_y = left_leaf_is_const
2536 ? child_op->getOperand(0).getDefiningOp()
2537 : child_op->getOperand(1).getDefiningOp();
2538 if (!input_x || !input_y) return failure();
2539
2540 Operation *input_c = const_op;
2541 Operation *input_op = child_op;
2542
2543 if (op->getOperand(0).getDefiningOp() == input_c)
2544 op->setOperand(0, input_x->getResult(0));
2545 else
2546 op->setOperand(1, input_x->getResult(0));
2547
2548 if (is_symmetric && is_child_symmetric) {
2549 // Easy case (only commutative ops). We always write this as one of
2550 // +
2551 // / \
2552 // X +
2553 // / \
2554 // C Y
2555 rewriter.startRootUpdate(op);
2556 op->setOperand(0, input_x->getResult(0));
2557 op->setOperand(1, input_op->getResult(0));
2558 rewriter.finalizeRootUpdate(op);
2559 rewriter.startRootUpdate(child_op);
2560 child_op->setOperand(0, input_c->getResult(0));
2561 child_op->setOperand(1, input_y->getResult(0));
2562 rewriter.finalizeRootUpdate(child_op);
2563 } else {
2564 // More complicated case: When there are non-commutative operations like
2565 // subtractions or divisions involved, we may have to rotate the tree
2566 // and/or change op types. There are 6 non-trivial cases depending on
2567 // the effective generalized "sign" of each of the three terms C, Y, and
2568 // X. Here are the final trees we want to generate for those 6 cases:
2569 //
2570 // (CYX signs): ++- +-- -+- --+ +-+ -++
2571 //
2572 // - - - - + +
2573 // / \ / \ / \ / \ / \ / \
2574 // + X - X - X X + X - X -
2575 // / \ / \ / \ / \ / \ / \
2576 // C Y C Y Y C Y C C Y Y C
2577 //
2578
2579 // First, let's determine the effective sign of each term in the original
2580 // expression
2581 auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
2582 bool leaf_negated = !is_child_symmetric && is_right_leaf;
2583 bool child_negated = !is_symmetric && left_child_is_const;
2584 return leaf_negated != child_negated;
2585 };
2586
2587 StringRef symmetric_op = (is_add || is_sub) ? "tfg.Add" : "tfg.Mul";
2588 StringRef nonsymmetric_op = (is_add || is_sub) ? "tfg.Sub" : "tfg.Div";
2589 bool neg_c = !is_symmetric && !left_child_is_const;
2590 bool neg_x = is_leaf_negated(left_leaf_is_const);
2591 bool neg_y = is_leaf_negated(!left_leaf_is_const);
2592
2593 StringRef op_name =
2594 (neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op;
2595 OperationState state(op->getLoc(), op_name);
2596 state.addOperands({input_op->getResult(0), input_x->getResult(0)});
2597 if (!neg_x) std::swap(state.operands[0], state.operands[1]);
2598 state.addOperands(TFOp(op).getControlOperands());
2599 state.attributes = op->getAttrDictionary();
2600 state.addTypes(op->getResultTypes());
2601 Operation *new_op = rewriter.create(state);
2602 rewriter.replaceOp(op, new_op->getResults());
2603
2604 StringRef child_name = neg_c != neg_y ? nonsymmetric_op : symmetric_op;
2605 OperationState new_child_state(child_op->getLoc(), child_name);
2606 new_child_state.addOperands(
2607 {input_y->getResult(0), input_c->getResult(0)});
2608 if (!neg_c)
2609 std::swap(new_child_state.operands[0], new_child_state.operands[1]);
2610 new_child_state.addOperands(TFOp(child_op).getControlOperands());
2611 new_child_state.attributes = child_op->getAttrDictionary();
2612 new_child_state.addTypes(child_op->getResultTypes());
2613 rewriter.setInsertionPoint(child_op);
2614 Operation *new_child_op = rewriter.create(new_child_state);
2615 rewriter.replaceOp(child_op, new_child_op->getResults());
2616 }
2617 return success();
2618 }
2619 };
2620
2621 // This implementation is mapped with
2622 // ConstantFolding::PartialConstPropThroughIdentityN in
2623 // grappler/optimizers/constant_folding.cc
2624 class PartialConstPropThroughIdentityN
2625 : public PropagationPatternBase<PartialConstPropThroughIdentityN> {
2626 public:
PartialConstPropThroughIdentityN(OpPropertyHelper & helper)2627 explicit PartialConstPropThroughIdentityN(OpPropertyHelper &helper)
2628 : PropagationPatternBase<PartialConstPropThroughIdentityN>(
2629 MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2630 LogicalResult matchAndRewrite(Operation *op,
2631 PatternRewriter &rewriter) const override {
2632 // In grappler's constant folding, it propagates the values from IdentityN.
2633 // At here, we check the operand which is defined by Identity/IdentityN.
2634
2635 SmallVector<Value> control_operands;
2636 for (OpOperand &operand : op->getOpOperands()) {
2637 Value v = operand.get();
2638 if (v.getType().isa<ControlType>()) break;
2639
2640 Operation *v_op = v.getDefiningOp();
2641 if (!v_op || !dialect_->IsIdentityN(v_op) ||
2642 dialect_->IsIdentityNSingleInput(v_op)) {
2643 continue;
2644 }
2645
2646 int res_index = v.cast<OpResult>().getResultNumber();
2647 Value value_to_forward = v_op->getOperand(res_index);
2648 if (!value_to_forward.getDefiningOp() ||
2649 !dialect_->IsConstant(value_to_forward.getDefiningOp())) {
2650 continue;
2651 }
2652
2653 rewriter.startRootUpdate(op);
2654 operand.set(value_to_forward);
2655 rewriter.finalizeRootUpdate(op);
2656
2657 // Add the control dependency to the Identity/IdentityN. Note that it's
2658 // possible to have multiple operands defined by the same
2659 // Identity/IdentityN. Given the number is small and this propagation is
2660 // usually done on an operation one time, do a linear scan before
2661 // insertion.
2662 Value control = TFOp(v_op).controlRet();
2663 if (!llvm::is_contained(control_operands, control))
2664 control_operands.push_back(control);
2665 }
2666
2667 // No new control operands implies that we didn't find constants that can be
2668 // propagated through Identity/IdentityN.
2669 if (control_operands.empty()) return failure();
2670
2671 OperationState state(op->getLoc(), op->getName());
2672 state.attributes = op->getAttrDictionary();
2673 state.addOperands(op->getOperands());
2674 // Append the newly added control operands from Identity/IdentityN.
2675 state.addOperands(control_operands);
2676 state.addTypes(op->getResultTypes());
2677
2678 Operation *new_op = rewriter.create(state);
2679 rewriter.replaceOp(op, new_op->getResults());
2680
2681 return success();
2682 }
2683 };
2684
2685 // This implementation is mapped with
2686 // ConstantFolding::PartialAssocOpConstFolding in
2687 // grappler/optimizers/constant_folding.cc
2688 class PartialAssocOpConstFolding
2689 : public FolderPatternBase<PartialAssocOpConstFolding> {
2690 public:
PartialAssocOpConstFolding(OpPropertyHelper & helper)2691 explicit PartialAssocOpConstFolding(OpPropertyHelper &helper)
2692 : FolderPatternBase<PartialAssocOpConstFolding>(MatchAnyOpTypeTag(),
2693 helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2694 LogicalResult matchAndRewrite(Operation *op,
2695 PatternRewriter &rewriter) const override {
2696 // Partial constant folding for associative operators:
2697 // Split AddN/AccumulateNV2 to enable partial
2698 // folding of ops when more than one but not all inputs are constant.
2699 // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
2700 // addition is commutative.
2701 if (!helper_.IsAggregate(op) || !helper_.IsCommutative(op))
2702 return failure();
2703
2704 SmallVector<Value> const_inputs;
2705 SmallVector<Value> non_const_inputs;
2706
2707 auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
2708 int non_control_inputs_size = non_control_operands.size();
2709 if (non_control_inputs_size <= 2) return failure();
2710
2711 if (llvm::any_of(non_control_operands, [](Value v) {
2712 Operation *v_op = v.getDefiningOp();
2713 return v_op &&
2714 TFOp(v_op).name().rfind("_partial_split_") != StringRef::npos;
2715 })) {
2716 return failure();
2717 }
2718
2719 for (Value operand : non_control_operands) {
2720 Operation *may_const_op = operand.getDefiningOp();
2721 if (may_const_op && dialect_->IsConstant(may_const_op))
2722 const_inputs.push_back(operand);
2723 else
2724 non_const_inputs.push_back(operand);
2725 }
2726
2727 if (const_inputs.size() == non_control_inputs_size &&
2728 op->getName().stripDialect() == "AccumulateNV2") {
2729 OperationState state(op->getLoc(), "tfg.AddN");
2730 state.addTypes(op->getResultTypes());
2731 state.addOperands(op->getOperands());
2732 state.attributes = op->getAttrDictionary();
2733 state.attributes.erase("shape");
2734 Operation *add_n = rewriter.create(state);
2735 rewriter.replaceOp(op, add_n->getResults());
2736 return success();
2737 }
2738
2739 if (const_inputs.size() <= 1) return failure();
2740
2741 OperationState state(op->getLoc(), "tfg.AddN");
2742 state.addOperands(const_inputs);
2743 state.addTypes(op->getResultTypes());
2744 state.attributes = op->getAttrDictionary();
2745 state.attributes.erase("shape");
2746 state.attributes.set("N", IntegerAttr::get(rewriter.getIntegerType(32),
2747 const_inputs.size()));
2748 Operation *add_n = rewriter.create(state);
2749 TFOp(add_n).setName(Twine(TFOp(op).name(), "/_partial_split_") +
2750 std::to_string(const_inputs.size()));
2751 // Op inherits all the attrs of op, don't need to update the device attr.
2752
2753 OperationState new_op_state(op->getLoc(), op->getName());
2754 // Note that in grappler, it puts the AddOp at the position of the first
2755 // const operand. Here we always put the AddOp at begin.
2756 new_op_state.addOperands(add_n->getResult(0));
2757 new_op_state.addOperands(non_const_inputs);
2758 new_op_state.addOperands(control_operands);
2759 new_op_state.addTypes(op->getResultTypes());
2760 new_op_state.attributes = op->getAttrDictionary();
2761 new_op_state.attributes.set("N",
2762 IntegerAttr::get(rewriter.getIntegerType(32),
2763 non_const_inputs.size() + 1));
2764
2765 Operation *new_op = rewriter.create(new_op_state);
2766 rewriter.replaceOp(op, new_op->getResults());
2767
2768 return success();
2769 }
2770 };
2771
2772 // This implementation is mapped with ConstantFolding::MergeConcat in
2773 // grappler/optimizers/constant_folding.cc
2774 class MergeConcatOp : public FolderPatternBase<MergeConcatOp> {
2775 public:
MergeConcatOp(OpPropertyHelper & helper)2776 explicit MergeConcatOp(OpPropertyHelper &helper)
2777 : FolderPatternBase<MergeConcatOp>("tfg.ConcatV2", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2778 LogicalResult matchAndRewrite(Operation *op,
2779 PatternRewriter &rewriter) const override {
2780 if (helper_.ShouldPreserveOp(op)) return failure();
2781
2782 auto getAxis = [&](Operation *axis_op) {
2783 ElementsAttr axis_attr = axis_op->getAttrOfType<ElementsAttr>("value");
2784 return axis_attr.getElementType().isInteger(64)
2785 ? static_cast<int>(axis_attr.getSplatValue<int64_t>())
2786 : axis_attr.getSplatValue<int>();
2787 };
2788
2789 auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
2790 Operation *axis_op = non_control_operands.back().getDefiningOp();
2791 if (!axis_op || !dialect_->IsConstant(axis_op)) return failure();
2792 int axis = getAxis(axis_op);
2793
2794 // In grappler, it checks the first user of the ConcatV2 to see if it's also
2795 // a ConcatV2. At here, we check the user's operand. Another difference is
2796 // that grappler only checks the first user and we check all the operands.
2797 Operation *concat_operand = nullptr;
2798 for (Value operand : non_control_operands) {
2799 Operation *defining_op = operand.getDefiningOp();
2800 if (defining_op && dialect_->IsConcatV2(defining_op)) {
2801 concat_operand = defining_op;
2802 break;
2803 }
2804 }
2805 if (!concat_operand) return failure();
2806
2807 auto [concat_non_control_operands, concat_control_operands] =
2808 TFOp(concat_operand).splitOperands();
2809 Operation *concat_operand_axis_op =
2810 concat_non_control_operands.back().getDefiningOp();
2811 if (!concat_operand_axis_op ||
2812 !dialect_->IsConstant(concat_operand_axis_op)) {
2813 return failure();
2814 }
2815 if (axis != getAxis(concat_operand_axis_op)) return failure();
2816
2817 // If all inputs are constant, don't merge and let EvaluateConstant take
2818 // case of it.
2819 if (llvm::all_of(concat_non_control_operands.drop_back(), [&](Value v) {
2820 return v.getDefiningOp() && dialect_->IsConstant(v.getDefiningOp());
2821 })) {
2822 return failure();
2823 }
2824
2825 // Make a pass over the parent inputs to see if any of them have explicit
2826 // device() fields set, and if different inputs are on different tasks. If
2827 // so, this concat of concats may have been carefully constructed to be a
2828 // two-stage concat, and we don't want to undo that here.
2829 std::string task, device;
2830 StringRef unique_input_tasks;
2831 for (Value v : non_control_operands) {
2832 Operation *v_op = v.getDefiningOp();
2833 if (!v_op || v_op == axis_op) continue;
2834 StringRef op_device = TFOp(v_op).device();
2835 if (!op_device.empty() && tensorflow::DeviceNameUtils::SplitDeviceName(
2836 op_device.str(), &task, &device)) {
2837 if (unique_input_tasks.empty())
2838 unique_input_tasks = task;
2839 else if (unique_input_tasks != task)
2840 return failure();
2841 }
2842 }
2843
2844 OperationState state(op->getLoc(), "tfg.ConcatV2");
2845 for (Value operand : non_control_operands) {
2846 if (operand == concat_operand->getResult(0)) {
2847 // Inline the non-control operands of concat_operand.
2848 state.addOperands(ValueRange(concat_non_control_operands.drop_back()));
2849 } else {
2850 state.addOperands(operand);
2851 }
2852 }
2853 // Copy the control operands.
2854 state.addOperands(control_operands);
2855 state.addOperands(concat_control_operands);
2856 state.addTypes(op->getResultTypes());
2857 state.attributes = op->getAttrDictionary();
2858 state.attributes.set("N", IntegerAttr::get(rewriter.getIntegerType(32),
2859 state.operands.size() - 1));
2860 Operation *concat_op = rewriter.create(state);
2861 rewriter.replaceOp(op, concat_op->getResults());
2862
2863 return success();
2864 }
2865 };
2866
2867 // This implementation is mapped with ConstantFolding::MulConvPushDown
2868 // in grappler/optimizers/constant_folding.cc
2869 class MulConvPushDown : public ConstantPatternBase<MulConvPushDown, FolderTrait,
2870 PropagationTrait> {
2871 public:
MulConvPushDown(OpPropertyHelper & helper)2872 explicit MulConvPushDown(OpPropertyHelper &helper)
2873 : ConstantPatternBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const2874 LogicalResult matchAndRewrite(Operation *op,
2875 PatternRewriter &rewriter) const override {
2876 // Push down multiplication on ConvND.
2877 // * ConvND
2878 // / \ / \
2879 // ConvND C2 -- > X *
2880 // / \ / \
2881 // X C1 C1 C2
2882 //
2883 // where C1 and C2 are constants and X is non-constant.
2884 if (!dialect_->IsAnyMul(op)) return failure();
2885
2886 Operation *mul_left_child = op->getOperand(0).getDefiningOp();
2887 Operation *mul_right_child = op->getOperand(1).getDefiningOp();
2888 if (!mul_left_child || !mul_right_child) return failure();
2889
2890 const bool left_child_is_constant = dialect_->IsConstant(mul_left_child);
2891 const bool right_child_is_constant = dialect_->IsConstant(mul_right_child);
2892 // One child must be constant, and the second must be Conv op.
2893 if (!left_child_is_constant && !right_child_is_constant) return failure();
2894
2895 Operation *conv_node =
2896 left_child_is_constant ? mul_right_child : mul_left_child;
2897 if (!dialect_->IsConv2D(conv_node) && !dialect_->IsConv3D(conv_node))
2898 return failure();
2899
2900 // Make sure that it is safe to change the value of the convolution
2901 // output.
2902 if (helper_.ShouldPreserveOp(conv_node)) return failure();
2903
2904 if (TFOp(op).deviceAttr() != TFOp(mul_left_child).deviceAttr() ||
2905 TFOp(op).deviceAttr() != TFOp(mul_right_child).deviceAttr()) {
2906 return failure();
2907 }
2908
2909 // Identify the nodes to swap.
2910 Operation *conv_left_child = conv_node->getOperand(0).getDefiningOp();
2911 Operation *conv_right_child = conv_node->getOperand(1).getDefiningOp();
2912 const bool conv_left_is_constant =
2913 conv_left_child && dialect_->IsConstant(conv_left_child);
2914 const bool conv_right_is_constant =
2915 conv_right_child && dialect_->IsConstant(conv_right_child);
2916 if (!conv_left_is_constant && !conv_right_is_constant) {
2917 // At least one of the convolution inputs should be constant.
2918 return failure();
2919 }
2920
2921 if (conv_left_is_constant && conv_right_is_constant) {
2922 // Operation evaluation will handle this.
2923 return failure();
2924 }
2925
2926 ShapedType mul_shape = (*op->result_type_begin()).cast<ShapedType>();
2927 ShapedType conv_shape =
2928 (*conv_node->result_type_begin()).cast<ShapedType>();
2929 // TODO(chiahungduan): Symbolic shape equivalence is acceptable.
2930 if (!mul_shape.hasStaticShape() || !conv_shape.hasStaticShape() ||
2931 mul_shape != conv_shape) {
2932 return failure();
2933 }
2934
2935 auto filter_shape = conv_node->getOperand(1).getType().cast<ShapedType>();
2936
2937 Operation *const_node =
2938 left_child_is_constant ? mul_left_child : mul_right_child;
2939 auto const_node_shape =
2940 (*const_node->result_type_begin()).cast<ShapedType>();
2941 if (!IsValidConstShapeForMulConvPushDown(
2942 conv_node->getAttrOfType<StringAttr>("data_format"), filter_shape,
2943 const_node_shape)) {
2944 return failure();
2945 }
2946
2947 Operation *conv_const_node =
2948 conv_left_is_constant ? conv_left_child : conv_right_child;
2949 // Make sure we don't introduce loops in the graph by removing control
2950 // dependencies from the conv2d node to c2.
2951 if (Operation *new_const_op =
2952 RemoveControlOperandIfExist(rewriter, const_node, conv_node)) {
2953 rewriter.replaceOp(const_node, new_const_op->getResults());
2954 const_node = new_const_op;
2955
2956 // Add a control dep from c1 to c2 to ensure c2 is in the right frame
2957 AddControlOperand(const_node, TFOp(conv_const_node).controlRet(),
2958 rewriter);
2959 }
2960
2961 StringRef conv_node_name = TFOp(conv_node).name();
2962
2963 rewriter.startRootUpdate(conv_node);
2964 TFOp(conv_node).setName(TFOp(op).nameAttr());
2965 if (conv_left_is_constant)
2966 conv_node->setOperand(0, op->getResult(0));
2967 else
2968 conv_node->setOperand(1, op->getResult(0));
2969 rewriter.finalizeRootUpdate(conv_node);
2970
2971 rewriter.startRootUpdate(op);
2972 TFOp(op).setName(Twine(conv_node_name, "/merged_input"));
2973 if (left_child_is_constant)
2974 op->setOperand(1, conv_const_node->getResult(0));
2975 else
2976 op->setOperand(0, conv_const_node->getResult(0));
2977 rewriter.finalizeRootUpdate(op);
2978
2979 return success();
2980 }
2981
2982 private:
2983 // Remove the control dependency from `op` to `to_remove` if any.
RemoveControlOperandIfExist(OpBuilder & builder,Operation * op,Operation * to_remove) const2984 Operation *RemoveControlOperandIfExist(OpBuilder &builder, Operation *op,
2985 Operation *to_remove) const {
2986 auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
2987 Value control_to_remove = TFOp(to_remove).controlRet();
2988 SmallVector<Value> new_control_operands(control_operands);
2989 auto it = llvm::remove_if(
2990 new_control_operands,
2991 [control_to_remove](Value v) { return v == control_to_remove; });
2992 if (it == new_control_operands.end()) return nullptr;
2993 new_control_operands.erase(it, new_control_operands.end());
2994
2995 OperationState state(op->getLoc(), op->getName());
2996 state.addOperands(non_control_operands);
2997 state.addOperands(new_control_operands);
2998 state.addAttributes(op->getAttrs());
2999 state.addTypes(op->getResultTypes());
3000
3001 return builder.create(state);
3002 }
3003 };
3004
3005 // This implementation is mapped with ConstantFolding::PartialConcatConstFolding
3006 // in grappler/optimizers/constant_folding.cc
3007 class PartialConcatConstFolding
3008 : public FolderPatternBase<PartialConcatConstFolding> {
3009 public:
PartialConcatConstFolding(OpPropertyHelper & helper)3010 explicit PartialConcatConstFolding(OpPropertyHelper &helper)
3011 : FolderPatternBase<PartialConcatConstFolding>(MatchAnyOpTypeTag(),
3012 helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3013 LogicalResult matchAndRewrite(Operation *op,
3014 PatternRewriter &rewriter) const override {
3015 // Partial constant folding for Concat which is not commutative, so
3016 // we have to preserve order and can only push consecutive runs of constant
3017 // inputs into sub-nodes.
3018 if (!dialect_->IsConcat(op)) return failure();
3019 if (TFOp(op).name().rfind("_partial_split_") != StringRef::npos) {
3020 return failure();
3021 }
3022
3023 auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
3024 const int num_non_control_inputs = non_control_operands.size();
3025 if (num_non_control_inputs <= 3) return failure();
3026
3027 int axis_arg = -1;
3028 int begin = 0;
3029 int end = num_non_control_inputs;
3030 // Note that IsConcat includes both Concat and ConcatV2 so that we need to
3031 // check ConcatV2 first.
3032 if (dialect_->IsConcatV2(op)) {
3033 end = num_non_control_inputs - 1;
3034 axis_arg = num_non_control_inputs - 1;
3035 } else if (dialect_->IsConcat(op)) {
3036 begin = 1;
3037 axis_arg = 0;
3038 } else {
3039 return failure();
3040 }
3041
3042 // We search for consecutive runs of constant inputs in the range
3043 // [begin:end] and push then down into child nodes.
3044 SmallVector<std::pair<int, int>> constant_input_runs;
3045 int first = begin;
3046 int last = begin;
3047 while (last < end) {
3048 while (first < end) {
3049 Operation *v_op = op->getOperand(first).getDefiningOp();
3050 if (v_op && dialect_->IsConstant(v_op)) break;
3051 ++first;
3052 }
3053
3054 // Invariant: node[first] is constant || first >= end.
3055 last = first + 1;
3056 while (last < end) {
3057 Operation *v_op = op->getOperand(last).getDefiningOp();
3058 if (!v_op || !dialect_->IsConstant(v_op)) break;
3059 ++last;
3060 }
3061
3062 // Invariant: node[last] is not constant || last >= end
3063 // Discard intervals shorter than 2 elements.
3064 if (first < end && (last - first) > 1)
3065 constant_input_runs.emplace_back(first, last);
3066 first = last;
3067 }
3068
3069 // Skip if all inputs are constant, and let constant folding take over.
3070 if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3071 constant_input_runs[0].first == begin &&
3072 constant_input_runs[0].second == end)) {
3073 return failure();
3074 }
3075
3076 // TODO(chiahungduan): The optimization is able to be applied multiple
3077 // times. Find a better way to name the new ops without having duplicate
3078 // name. Now we just optimize it once.
3079 if (llvm::any_of(non_control_operands, [](Value v) {
3080 Operation *v_op = v.getDefiningOp();
3081 return v_op &&
3082 TFOp(v_op).name().rfind("_partial_split_") != StringRef::npos;
3083 })) {
3084 return failure();
3085 }
3086
3087 DenseSet<int> inputs_to_delete;
3088 for (auto interval : constant_input_runs) {
3089 // Push the constant inputs in the interval to a child node than can be
3090 // constant folded.
3091 OperationState state(op->getLoc(), "tfg.ConcatV2");
3092 state.addOperands(op->getOperand(interval.first));
3093 for (auto i : llvm::seq<int>(interval.first + 1, interval.second)) {
3094 state.addOperands(op->getOperand(i));
3095 inputs_to_delete.insert(i);
3096 }
3097 state.addOperands(op->getOperand(axis_arg));
3098 state.attributes = op->getAttrDictionary();
3099 state.attributes.set("N",
3100 IntegerAttr::get(rewriter.getI32Type(),
3101 interval.second - interval.first));
3102 state.addTypes(op->getResultTypes());
3103
3104 Operation *new_op = rewriter.create(state);
3105 TFOp(new_op).setName(Twine(TFOp(op).name(), "/_partial_split_") +
3106 std::to_string(interval.first));
3107 // Op inherits all the attrs of op, don't need to update the device attr.
3108
3109 // Overwrite the first constant input with the result of the added
3110 // child node.
3111 rewriter.startRootUpdate(op);
3112 op->setOperand(interval.first, new_op->getResult(0));
3113 rewriter.finalizeRootUpdate(op);
3114 }
3115
3116 if (!inputs_to_delete.empty()) {
3117 OperationState state(op->getLoc(), op->getName());
3118 for (auto &it : llvm::enumerate(non_control_operands)) {
3119 if (inputs_to_delete.contains(it.index())) continue;
3120 state.addOperands(it.value());
3121 }
3122 assert(state.operands.size() != non_control_operands.size());
3123 state.addOperands(control_operands);
3124
3125 state.attributes = op->getAttrDictionary();
3126 state.attributes.set(
3127 "N", IntegerAttr::get(
3128 rewriter.getI32Type(),
3129 state.operands.size() - control_operands.size() - 1));
3130 state.addTypes(op->getResultTypes());
3131 Operation *new_op = rewriter.create(state);
3132 rewriter.replaceOp(op, new_op->getResults());
3133 }
3134
3135 return success();
3136 }
3137 };
3138
3139 // This implements constant push-down for BiasAdd. In the following "CV" is a
3140 // constant vector (tensor of rank 1), "V" is a (possibly) non-constant vector,
3141 // "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3142 // non-constant matrix, and "BA" is BiasAdd.
3143 // For a valid input graph, the following 4 rewrites are legal:
3144 //
3145 // 1) + +
3146 // / \ / \
3147 // BA CV -- > BA V
3148 // / \ / \
3149 // M V M CV
3150 //
3151 // 2) + +
3152 // / \ / \
3153 // BA CM -- > BA M
3154 // / \ / \
3155 // M V CM V
3156 //
3157 // 3) BA BA
3158 // / \ / \
3159 // + CV -- > + V
3160 // / \ / \
3161 // M V M CV
3162 //
3163 // 4) BA BA = parent
3164 // / \ / \
3165 // BA CV -- > BA V = children
3166 // / \ / \
3167 // M V M CV = leaves
3168 //
3169 // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3170 class ConstantPushDownBiasAdd
3171 : public ConstantPushDownBase<ConstantPushDownBiasAdd> {
3172 public:
ConstantPushDownBiasAdd(OpPropertyHelper & helper)3173 explicit ConstantPushDownBiasAdd(OpPropertyHelper &helper)
3174 : ConstantPushDownBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3175 LogicalResult matchAndRewrite(Operation *op,
3176 PatternRewriter &rewriter) const override {
3177 if (!dialect_->IsBiasAdd(op)) return failure();
3178
3179 Operation *add_child = op->getOperand(0).getDefiningOp();
3180 if (!add_child) return failure();
3181
3182 Operation *const_child = op->getOperand(1).getDefiningOp();
3183 if (!const_child || !dialect_->IsConstant(const_child)) return failure();
3184
3185 if (helper_.ShouldPreserveOp(add_child)) return failure();
3186
3187 // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3188 // >= 2 and the leaves must be vectors, we cannot swap them.
3189 if (dialect_->IsConstant(add_child)) return failure();
3190 if (!dialect_->IsBiasAdd(add_child) && !dialect_->IsAdd(add_child))
3191 return failure();
3192
3193 if (!IsOperandsSafeToMove(add_child, const_child)) return failure();
3194
3195 auto hasRank = [&](Value value) {
3196 return value.getType().cast<ShapedType>().hasRank();
3197 };
3198
3199 if (!hasRank(op->getOperand(0)) || !hasRank(op->getOperand(1)) ||
3200 !hasRank(add_child->getOperand(0)) ||
3201 !hasRank(add_child->getOperand(0))) {
3202 return failure();
3203 }
3204
3205 // Now get the ranks and types of the 3 leaf nodes.
3206 const int left_leaf_rank =
3207 add_child->getOperand(0).getType().cast<ShapedType>().getRank();
3208 const int right_leaf_rank =
3209 add_child->getOperand(1).getType().cast<ShapedType>().getRank();
3210
3211 // At least one leaf must be a vector.
3212 if (left_leaf_rank != 1 && right_leaf_rank != 1) return failure();
3213
3214 const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3215 auto vector_type =
3216 add_child->getOperand(vector_idx).getType().cast<ShapedType>();
3217 Type vector_d_type = vector_type.getElementType();
3218
3219 auto const_type = const_child->getResultTypes()[0].cast<ShapedType>();
3220 const int const_rank = const_type.getRank();
3221 Type const_d_type = const_type.getElementType();
3222
3223 if (const_rank != 1 || const_d_type != vector_d_type) return failure();
3224
3225 // This is case #1, #3, and #4:
3226 int input_to_swap = vector_idx;
3227
3228 Value leaf_to_swap = add_child->getOperand(input_to_swap);
3229 if (leaf_to_swap.getDefiningOp() &&
3230 dialect_->IsConstant(leaf_to_swap.getDefiningOp())) {
3231 return failure();
3232 }
3233
3234 rewriter.startRootUpdate(op);
3235 op->setOperand(1, leaf_to_swap);
3236 rewriter.finalizeRootUpdate(op);
3237 rewriter.startRootUpdate(add_child);
3238 add_child->setOperand(input_to_swap, const_child->getResult(0));
3239 rewriter.finalizeRootUpdate(add_child);
3240
3241 return success();
3242 }
3243 };
3244
3245 // This implements constant push-down for Add. In the following "CV" is a
3246 // constant vector (tensor of rank 1), "V" is a (possibly) non-constant vector,
3247 // "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3248 // non-constant matrix, and "BA" is BiasAdd.
3249 // For a valid input graph, the following 4 rewrites are legal:
3250 //
3251 // 1) + +
3252 // / \ / \
3253 // BA CV -- > BA V
3254 // / \ / \
3255 // M V M CV
3256 //
3257 // 2) + +
3258 // / \ / \
3259 // BA CM -- > BA M
3260 // / \ / \
3261 // M V CM V
3262 //
3263 // 3) BA BA
3264 // / \ / \
3265 // + CV -- > + V
3266 // / \ / \
3267 // M V M CV
3268 //
3269 // 4) BA BA = parent
3270 // / \ / \
3271 // BA CV -- > BA V = children
3272 // / \ / \
3273 // M V M CV = leaves
3274 //
3275 // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3276 class ConstantPushDownAdd : public ConstantPushDownBase<ConstantPushDownAdd> {
3277 public:
ConstantPushDownAdd(OpPropertyHelper & helper)3278 explicit ConstantPushDownAdd(OpPropertyHelper &helper)
3279 : ConstantPushDownBase(MatchAnyOpTypeTag(), helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3280 LogicalResult matchAndRewrite(Operation *op,
3281 PatternRewriter &rewriter) const override {
3282 if (!dialect_->IsAdd(op)) return failure();
3283
3284 Operation *add_child = op->getOperand(0).getDefiningOp();
3285 Operation *const_child = op->getOperand(1).getDefiningOp();
3286 if (!add_child || !const_child) return failure();
3287
3288 if (!dialect_->IsConstant(const_child)) std::swap(add_child, const_child);
3289 if (!dialect_->IsConstant(const_child)) return failure();
3290
3291 if (!IsOperandsSafeToMove(add_child, const_child)) return failure();
3292
3293 bool child_is_bias_add = dialect_->IsBiasAdd(add_child);
3294 if (!child_is_bias_add && !dialect_->IsAdd(add_child)) return failure();
3295
3296 auto hasRank = [&](Value value) {
3297 return value.getType().cast<ShapedType>().hasRank();
3298 };
3299
3300 if (!hasRank(op->getOperand(0)) || !hasRank(op->getOperand(1)) ||
3301 !hasRank(add_child->getOperand(0)) ||
3302 !hasRank(add_child->getOperand(1))) {
3303 return failure();
3304 }
3305
3306 // Now get the ranks and types of the 3 leaf nodes.
3307 const int left_leaf_rank =
3308 add_child->getOperand(0).getType().cast<ShapedType>().getRank();
3309 const int right_leaf_rank =
3310 add_child->getOperand(1).getType().cast<ShapedType>().getRank();
3311 // At least one leaf must be a vector.
3312 if (left_leaf_rank != 1 && right_leaf_rank != 1) return failure();
3313
3314 const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3315 const int matrix_idx = 1 - vector_idx;
3316
3317 ShapedType vector_type =
3318 add_child->getOperand(vector_idx).getType().cast<ShapedType>();
3319 Type vector_d_type = vector_type.getElementType();
3320
3321 ShapedType matrix_type =
3322 add_child->getOperand(matrix_idx).getType().cast<ShapedType>();
3323 const int matrix_rank = matrix_type.getRank();
3324 Type matrix_d_type = matrix_type.getElementType();
3325
3326 const int const_index =
3327 op->getOperand(0).getDefiningOp() == const_child ? 0 : 1;
3328 ShapedType const_type =
3329 const_child->getResult(0).getType().cast<ShapedType>();
3330 const int const_rank = const_type.getRank();
3331 Type const_d_type = const_type.getElementType();
3332
3333 int input_to_swap = -1;
3334
3335 if (child_is_bias_add && const_rank == matrix_rank &&
3336 const_d_type == matrix_d_type) {
3337 // Case 2:
3338 input_to_swap = matrix_idx;
3339 } else if (const_rank == 1 && const_d_type == vector_d_type) {
3340 // Case 1, 3, and, 4:
3341 input_to_swap = vector_idx;
3342 } else {
3343 return failure();
3344 }
3345
3346 Value leaf_to_swap = add_child->getOperand(input_to_swap);
3347 if (leaf_to_swap.getDefiningOp() &&
3348 dialect_->IsConstant(leaf_to_swap.getDefiningOp())) {
3349 return failure();
3350 }
3351
3352 rewriter.startRootUpdate(op);
3353 op->setOperand(const_index, leaf_to_swap);
3354 rewriter.finalizeRootUpdate(op);
3355 rewriter.startRootUpdate(add_child);
3356 add_child->setOperand(input_to_swap, const_child->getResult(0));
3357 rewriter.finalizeRootUpdate(add_child);
3358
3359 return success();
3360 }
3361 };
3362
3363 // This implementation is mapped with ConstantFolding::SimplifyCase in
3364 // grappler/optimizers/constant_folding.cc
3365 class SimplifyCaseOp : public FolderPatternBase<SimplifyCaseOp> {
3366 public:
SimplifyCaseOp(OpPropertyHelper & helper)3367 explicit SimplifyCaseOp(OpPropertyHelper &helper)
3368 : FolderPatternBase<SimplifyCaseOp>("tfg.Case", helper) {}
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3369 LogicalResult matchAndRewrite(Operation *op,
3370 PatternRewriter &rewriter) const override {
3371 Operation *branch_index_op = op->getOperand(0).getDefiningOp();
3372 if (!branch_index_op) return failure();
3373
3374 ElementsAttr value_attr =
3375 branch_index_op->getAttrOfType<ElementsAttr>("value");
3376 if (!value_attr) return failure();
3377
3378 int output_idx = value_attr.getSplatValue<int>();
3379 ArrayAttr branch_attr = op->getAttrOfType<ArrayAttr>("branches");
3380 if (output_idx < 0 || output_idx >= branch_attr.size()) return failure();
3381
3382 OperationState state(op->getLoc(), "tfg.PartitionedCall");
3383 state.addOperands(ValueRange(op->getOperands()).drop_front());
3384
3385 state.attributes = op->getAttrDictionary();
3386 state.attributes.erase("branches");
3387 // In TFG conanical form, `output_shapes` has been consolidated into op's
3388 // shape. Unlike grappler, we don't need to update the `output_shapes` attr
3389 // here.
3390 state.attributes.set("f", branch_attr[output_idx]);
3391
3392 state.addTypes(op->getResultTypes());
3393
3394 Operation *partitioned_call_op = rewriter.create(state);
3395 rewriter.replaceOp(op, partitioned_call_op->getResults());
3396
3397 return success();
3398 }
3399 };
3400
3401 // This implementation is mapped with ConstantFolding::SimplifySelect in
3402 // grappler/optimizers/constant_folding.cc
3403 template <typename ConcreteType>
3404 class SimplifySelectOpBase : public FolderPatternBase<ConcreteType> {
3405 protected:
SimplifySelectOpBase(StringRef op_name,OpPropertyHelper & helper)3406 SimplifySelectOpBase(StringRef op_name, OpPropertyHelper &helper)
3407 : FolderPatternBase<ConcreteType>(op_name, helper) {}
3408
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const3409 LogicalResult matchAndRewrite(Operation *op,
3410 PatternRewriter &rewriter) const override {
3411 Operation *condition_op = op->getOperand(0).getDefiningOp();
3412 if (!condition_op) return failure();
3413
3414 bool is_all_true = this->helper_.IsOnes(condition_op);
3415 bool is_all_false = this->helper_.IsZeros(condition_op);
3416 if (!is_all_true && !is_all_false) return failure();
3417
3418 auto condition_type = op->getOperand(0).getType().cast<ShapedType>();
3419 auto t_type = op->getOperand(1).getType().cast<ShapedType>();
3420 auto e_type = op->getOperand(2).getType().cast<ShapedType>();
3421 if (!condition_type.hasStaticShape() || !t_type.hasStaticShape() ||
3422 !e_type.hasStaticShape()) {
3423 return failure();
3424 }
3425
3426 const int live_input_idx = is_all_true ? 1 : 2;
3427 bool predicate_is_scalar = condition_type.getRank() == 0;
3428
3429 if (t_type.getShape() == e_type.getShape() &&
3430 (condition_type.getShape() == t_type.getShape() ||
3431 predicate_is_scalar)) {
3432 Value live_operand = op->getOperand(live_input_idx);
3433 OperationState state(op->getLoc(), "tfg.Identity");
3434 state.addTypes(op->getResultTypes());
3435
3436 state.addOperands(live_operand);
3437 auto [non_control_operands, control_operands] = TFOp(op).splitOperands();
3438 for (Value operand : non_control_operands) {
3439 if (operand == live_operand) continue;
3440 // Add the remaining operands as control operands.
3441 state.addOperands(GetControlDependency(rewriter, operand));
3442 }
3443 // Append control operands
3444 state.addOperands(control_operands);
3445
3446 state.attributes = op->getAttrDictionary();
3447 Operation *identity = rewriter.create(state);
3448 rewriter.replaceOp(op, identity->getResults());
3449 } else {
3450 FailureOr<TFOp> broadcast_to_op =
3451 ReplaceOperationWithBroadcastTo(rewriter, op, live_input_idx);
3452 if (failed(broadcast_to_op)) return failure();
3453 rewriter.replaceOp(op, (*broadcast_to_op)->getResults());
3454 }
3455
3456 return success();
3457 }
3458 };
3459
3460 class SimplifySelectOp : public SimplifySelectOpBase<SimplifySelectOp> {
3461 public:
SimplifySelectOp(OpPropertyHelper & helper)3462 explicit SimplifySelectOp(OpPropertyHelper &helper)
3463 : SimplifySelectOpBase("tfg.Select", helper) {}
3464 };
3465
3466 class SimplifySelectV2Op : public SimplifySelectOpBase<SimplifySelectV2Op> {
3467 public:
SimplifySelectV2Op(OpPropertyHelper & helper)3468 explicit SimplifySelectV2Op(OpPropertyHelper &helper)
3469 : SimplifySelectOpBase("tfg.SelectV2", helper) {}
3470 };
3471
3472 namespace {
3473
3474 // Utilities for filtering desired patterns.
3475 template <bool>
3476 struct FilterPattern {
3477 template <class Pattern>
3478 using type = std::tuple<Pattern>;
3479 };
3480 template <>
3481 struct FilterPattern<false> {
3482 template <class Pattern>
3483 using type = std::tuple<>;
3484 };
3485 template <template <class> class Pred, class... Patterns>
3486 struct FilterPatterns {
3487 using type = decltype(std::tuple_cat(
3488 std::declval<typename FilterPattern<Pred<Patterns>::value>::template type<
3489 Patterns>>()...));
3490 };
3491
3492 // Predicates of selecting pattern kind.
3493 template <typename Pattern>
3494 using FolderPatterns = std::is_base_of<FolderTrait<Pattern>, Pattern>;
3495 template <typename Pattern>
3496 using PropagationPatterns = std::is_base_of<PropagationTrait<Pattern>, Pattern>;
3497 template <typename Pattern>
3498 using AllPatterns = std::true_type;
3499
3500 // Registers a set of patterns.
3501 template <typename... Patterns>
3502 struct TargetPatterns;
3503 template <typename... Patterns>
3504 struct TargetPatterns<std::tuple<Patterns...>> {
Registermlir::tfg::__anon36974c671c11::TargetPatterns3505 static void Register(::mlir::RewritePatternSet &patterns,
3506 OpPropertyHelper &helper) {
3507 patterns.insert<Patterns...>(helper);
3508 }
3509 };
3510 template <template <class> class PatternsFilter>
RegisterPatterns(::mlir::RewritePatternSet & patterns,OpPropertyHelper & helper)3511 void RegisterPatterns(::mlir::RewritePatternSet &patterns,
3512 OpPropertyHelper &helper) {
3513 TargetPatterns<typename FilterPatterns<
3514 PatternsFilter, MaterializeBroadcastGradientArgsOp, MaterializeShapeNOp,
3515 SimplifySwitchOp, MergeNodeFolding, RefMergeNodeFolding,
3516 XlaMergeNodeFolding, MoveConstantsPastEnterOp,
3517 MoveConstantsPastRefEnterOp, MaterializeReductionIndices,
3518 PartialConstPropThroughIdentityN, ConstantPushDown, MulConvPushDown,
3519 ConstantPushDownBiasAdd, ConstantPushDownAdd, EvaluateConstant,
3520 PartialConcatConstFolding, PartialAssocOpConstFolding,
3521 SimplifyArithmeticOp, ReduceDivToReciprocalMul, SimplifyReshapeOp,
3522 RemoveReverse, SimplifyStridedSlice, SimplifyTileOp, SimplifySqueezeOp,
3523 SimplifySliceOp, RemoveTransposeOp, RemoveRandomShuffleOp,
3524 RemoveShuffleOp, SimplifyPackOp, SimplifyReductionOp, SimplifyPadOp,
3525 SimplifyPadV2Op, RemoveSplitOp, RemoveSplitVOp, MaterializeFillNode,
3526 MaterializeConstantValuedNode, MaterializeShapeOp, MaterializeRankOp,
3527 MaterializeSizeOp, MaterializeTensorArraySizeV3Op, MergeConcatOp,
3528 SimplifyCaseOp, SimplifySelectOp,
3529 SimplifySelectV2Op>::type>::Register(patterns, helper);
3530 }
3531 } // namespace
3532
3533 class ConstantFolding : public ConstantFoldingPassBase<ConstantFolding> {
3534 public:
initialize(MLIRContext * context)3535 LogicalResult initialize(MLIRContext *context) override {
3536 helper_ = std::make_shared<OpPropertyHelper>(
3537 context->getOrLoadDialect<TFGraphDialect>(),
3538 disable_compressed_tensor_optimization_);
3539 RewritePatternSet patterns(context);
3540 populatePatterns(patterns);
3541 final_patterns_ = std::move(patterns);
3542 return success();
3543 }
3544
3545 void runOnOperation() override;
3546
3547 private:
populatePatterns(::mlir::RewritePatternSet & patterns)3548 void populatePatterns(::mlir::RewritePatternSet &patterns) {
3549 switch (pattern_category_) {
3550 default:
3551 LOG(ERROR) << "unknown pattern category, will run all patterns";
3552 [[fallthrough]];
3553 case 0: {
3554 RegisterPatterns<AllPatterns>(patterns, *helper_);
3555 break;
3556 }
3557 case 1: {
3558 RegisterPatterns<FolderPatterns>(patterns, *helper_);
3559 break;
3560 }
3561 case 2: {
3562 RegisterPatterns<PropagationPatterns>(patterns, *helper_);
3563 break;
3564 }
3565 }
3566 }
3567
3568 FrozenRewritePatternSet final_patterns_;
3569 std::shared_ptr<OpPropertyHelper> helper_;
3570 };
3571
runOnOperation()3572 void ConstantFolding::runOnOperation() {
3573 // TODO(chiahungduan): Set up the attributes before operation creation.
3574 // Because of the conveniency, in some cases we set up the device/name later
3575 // operation creation.
3576
3577 GraphFuncOp func = getOperation();
3578 Operation *return_op = func.getBody()->getTerminator();
3579 DenseSet<Operation *> unfoldable_ops;
3580 for (Value v : return_op->getOperands())
3581 unfoldable_ops.insert(v.getDefiningOp());
3582
3583 // The max iteration is the same as the max default iteration in
3584 // applyPatternsAndFoldGreedily.
3585 constexpr int max_iterations = 10;
3586 int iteration = 0;
3587
3588 SmallVector<Operation *> foldable_ops;
3589 do {
3590 // We need to collect the valid operations before each run because the ops
3591 // may be updated or removed.
3592 foldable_ops.clear();
3593 for (Operation &op : func.getBody()->without_terminator()) {
3594 if (unfoldable_ops.contains(&op)) continue;
3595 foldable_ops.push_back(&op);
3596 }
3597
3598 // Unfoldable ops can't be folded. You may update its operands but the op
3599 // kind needs to be the same. For example, you may update an operand of an
3600 // AddOp with a constant but you can't fold the AddOp into a ConstOp even if
3601 // all its operands are constants. Therefore, we can't use
3602 // applyPatternsAndFoldGreedily which may optimize the ops as much as
3603 // possible.
3604 if (!applyOpPatternsAndFold(foldable_ops, final_patterns_, /*strict=*/true))
3605 break;
3606 } while (iteration++ < max_iterations);
3607
3608 // TODO(chiahungduan): This is used to avoid evaluating a node multiple times.
3609 // See more details in EvaluateConstant pattern. Maybe we can remove this by
3610 // checking if the user of an op is empty.
3611 auto has_folded = StringAttr::get(&getContext(), "has_folded");
3612 getOperation()->walk([&](Operation *op) { op->removeAttr(has_folded); });
3613 }
3614
CreateConstantFoldingPass()3615 std::unique_ptr<Pass> CreateConstantFoldingPass() {
3616 return std::make_unique<ConstantFolding>();
3617 }
3618
3619 } // namespace tfg
3620 } // namespace mlir
3621