1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
17
18 #include <string>
19
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/Pass/Pass.h" // from @llvm-project
26 #include "mlir/Pass/PassManager.h" // from @llvm-project
27 #include "mlir/Transforms/Passes.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
29 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
30 #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
38 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
39
40 namespace mlir {
41 /// Create a pass to convert from the TFExecutor to the TF control dialect.
42 std::unique_ptr<OperationPass<func::FuncOp>>
43 CreateTFExecutorToControlDialectConversion();
44 } // namespace mlir
45
46 namespace tensorflow {
47 namespace {
48 // Data layout supported by TFLite.
49 const char kTFLiteDataLayout[] = "NHWC";
50 } // namespace
51
AddQuantizationPasses(const mlir::quant::QuantizationSpecs & quant_specs,mlir::OpPassManager & pass_manager)52 void AddQuantizationPasses(const mlir::quant::QuantizationSpecs& quant_specs,
53 mlir::OpPassManager& pass_manager) {
54 pass_manager.addNestedPass<mlir::func::FuncOp>(
55 mlir::TFL::CreatePrepareQuantizePass(quant_specs));
56 if (quant_specs.default_ranges.first.has_value() ||
57 quant_specs.default_ranges.second.has_value()) {
58 pass_manager.addNestedPass<mlir::func::FuncOp>(
59 mlir::TFL::CreateDefaultQuantParamsPass(
60 quant_specs.default_ranges.first.getValueOr(0.0),
61 quant_specs.default_ranges.second.getValueOr(0.0),
62 quant_specs.IsSignedInferenceType()));
63 }
64 pass_manager.addNestedPass<mlir::func::FuncOp>(
65 mlir::TFL::CreateQuantizePass(quant_specs));
66 bool emit_quant_adaptor_ops =
67 quant_specs.inference_type != quant_specs.inference_input_type;
68 pass_manager.addNestedPass<mlir::func::FuncOp>(
69 mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
70 pass_manager.addNestedPass<mlir::func::FuncOp>(
71 mlir::TFL::CreateOptimizeOpOrderPass());
72 // Add optimization pass after quantization for additional fusing
73 // opportunities.
74 pass_manager.addNestedPass<mlir::func::FuncOp>(
75 mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
76 }
77
AddDynamicRangeQuantizationPasses(const mlir::quant::QuantizationSpecs & quant_specs,mlir::OpPassManager & pass_manager)78 void AddDynamicRangeQuantizationPasses(
79 const mlir::quant::QuantizationSpecs& quant_specs,
80 mlir::OpPassManager& pass_manager) {
81 pass_manager.addNestedPass<mlir::func::FuncOp>(
82 mlir::TFL::CreatePrepareDynamicRangeQuantizePass(quant_specs));
83 pass_manager.addNestedPass<mlir::func::FuncOp>(
84 mlir::TFL::CreateQuantizePass(quant_specs));
85 bool emit_quant_adaptor_ops =
86 quant_specs.inference_type != quant_specs.inference_input_type;
87 pass_manager.addNestedPass<mlir::func::FuncOp>(
88 mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops,
89 quant_specs.custom_map));
90 pass_manager.addNestedPass<mlir::func::FuncOp>(
91 mlir::TFL::CreateOptimizeOpOrderPass());
92 // Add optimization pass after quantization for additional fusing
93 // opportunities.
94 pass_manager.addNestedPass<mlir::func::FuncOp>(
95 mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
96 }
97
AddConvertHloToTfPass(std::string entry_function_name,mlir::OpPassManager * pass_manager)98 void AddConvertHloToTfPass(std::string entry_function_name,
99 mlir::OpPassManager* pass_manager) {
100 // Legalize jax random to tflite custom op.
101 // The CreateLegalizeJaxRandom Pass has to stay at because we need to replace
102 // the random function body before being inlined.
103 pass_manager->addNestedPass<mlir::func::FuncOp>(
104 mlir::TFL::CreateLegalizeJaxRandomPass());
105
106 // Canonicalize, CSE etc.
107 pass_manager->addNestedPass<mlir::func::FuncOp>(
108 mlir::createCanonicalizerPass());
109 pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
110 // DCE for private symbols.
111 pass_manager->addPass(mlir::createSymbolDCEPass());
112
113 pass_manager->addPass(mlir::TF::CreateStripNoinlineAttributePass());
114 // Add inline pass.
115 pass_manager->addPass(mlir::createInlinerPass());
116
117 // Expands mhlo.tuple ops.
118 pass_manager->addPass(
119 mlir::mhlo::createExpandHloTuplesPass(entry_function_name));
120 // Flatten tuples for control flows.
121 pass_manager->addNestedPass<mlir::func::FuncOp>(
122 mlir::mhlo::createFlattenTuplePass());
123
124 // TF dialect passes
125 pass_manager->addNestedPass<mlir::func::FuncOp>(
126 mlir::TF::CreateLegalizeHloToTfPass());
127
128 // Canonicalization after TF legalization.
129 pass_manager->addNestedPass<mlir::func::FuncOp>(
130 mlir::createCanonicalizerPass());
131 }
132
133 // This is the early part of the conversion in isolation. This enables a caller
134 // to inject more information in the middle of the conversion before resuming
135 // it.
AddPreVariableFreezingTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)136 void AddPreVariableFreezingTFToTFLConversionPasses(
137 const mlir::TFL::PassConfig& pass_config,
138 mlir::OpPassManager* pass_manager) {
139 if (pass_config.enable_hlo_to_tf_conversion) {
140 // TODO(b/194747383): We need to valid that indeed the "main" func is
141 // presented.
142 AddConvertHloToTfPass("main", pass_manager);
143 }
144 // This pass wraps all the tf.FakeQuant ops in a custom op so they are not
145 // folded before being converted to tfl.quantize and tfl.dequantize ops.
146 auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps();
147 pass_manager->addNestedPass<mlir::func::FuncOp>(
148 mlir::TFL::CreateRaiseCustomOpsPass(wrapped_ops));
149
150 mlir::TF::StandardPipelineOptions standard_pipeline_options;
151 standard_pipeline_options.enable_inliner = false;
152 standard_pipeline_options.form_clusters = pass_config.form_clusters;
153 mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
154 pass_manager->addNestedPass<mlir::func::FuncOp>(
155 mlir::TF::CreateDeviceIndexSelectorPass());
156
157 // Add canonicalize pass to remove no-op session initializer pass.
158 pass_manager->addPass(mlir::createCanonicalizerPass());
159
160 if (pass_config.guarantee_all_funcs_one_use) {
161 pass_manager->addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
162 }
163 if (pass_config.shape_inference) {
164 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
165 }
166
167 // Keep this pass after the shape inference pass, which couldn't do shape
168 // inference for non-tf ops.
169 if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
170 pass_manager->addNestedPass<mlir::func::FuncOp>(
171 mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
172 pass_config.quant_specs.serialized_quant_stats));
173 }
174
175 pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
176
177 // The conversion pipeline has to follow the following orders:
178 // 1) Saved model related optimization like decompose resource ops
179 // 2) Convert composite functions like lstm/rnns, along with proper function
180 // inlining & dce.
181 // 3) Lower static tensor list pass.
182
183 // This decomposes resource ops like ResourceGather into read-variable op
184 // followed by gather. This is used when the saved model import path is used
185 // during which resources dont get frozen in the python layer.
186 pass_manager->addNestedPass<mlir::func::FuncOp>(
187 mlir::TFDevice::CreateDecomposeResourceOpsPass());
188
189 pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
190 }
191
192 // This is the later part of the conversion in isolation. This enables a caller
193 // to resume the conversion after injecting more information in the middle of
194 // it.
AddPostVariableFreezingTFToTFLConversionPasses(llvm::StringRef saved_model_dir,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)195 void AddPostVariableFreezingTFToTFLConversionPasses(
196 llvm::StringRef saved_model_dir, const toco::TocoFlags& toco_flags,
197 const mlir::TFL::PassConfig& pass_config,
198 mlir::OpPassManager* pass_manager) {
199 // Note:
200 // We need to fuse composite ops before LowerStaticTensorList pass.
201 // The tensorflow list is not supported right now by that pass.
202 // Enable fusing composite ops that can be lowered to built-in TFLite ops.
203 if (pass_config.emit_builtin_tflite_ops &&
204 toco_flags.tf_quantization_mode().empty()) {
205 pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
206 }
207
208 pass_manager->addPass(mlir::createInlinerPass());
209 pass_manager->addPass(mlir::createSymbolDCEPass());
210
211 if (pass_config.lower_tensor_list_ops &&
212 toco_flags.tf_quantization_mode().empty()) {
213 // TODO(haoliang): Add this pass by default.
214 pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass(
215 /*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() ||
216 toco_flags.enable_select_tf_ops(),
217 /*default_to_single_batch=*/
218 toco_flags.default_to_single_batch_in_tensor_list_ops(),
219 /*enable_dynamic_update_slice=*/
220 toco_flags.enable_dynamic_update_slice()));
221 }
222
223 // This pass does resource analysis of saved model global tensors and marks
224 // those deemed read-only as immutable.
225 pass_manager->addPass(
226 mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
227
228 if (pass_config.shape_inference) {
229 // Add a shape inference pass to optimize away the unnecessary casts.
230 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
231 }
232
233 // Legalize while early to allow further constant folding.
234 // TODO(jpienaar): This may not actually matter as we do canonicalization
235 // after the legalize below, for now it needs to be below the above passes
236 // that work on TF dialect and before inliner so that the function calls in
237 // body and cond are inlined for optimization.
238 pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
239
240 // Add function inlining pass. Both TF and TFLite dialects are opted into
241 // function inliner interface.
242 pass_manager->addPass(mlir::createInlinerPass());
243 // Reduce operands of TFL::While without changing the outcome.
244 // It needs to stay here because:
245 // 1. WhileOps are in TFL dialect.
246 // 2. The body and cond are inlined.
247 // 3. We need to do this before while canonicalization, otherwise it would be
248 // difficult to find dependencies.
249 pass_manager->addNestedPass<mlir::func::FuncOp>(
250 mlir::TFL::CreateReduceWhileOperandsPass());
251 // Canonicalization includes const folding, which is utilized here to optimize
252 // away ops that can't get constant folded after PrepareTF pass. For example,
253 // tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
254 pass_manager->addNestedPass<mlir::func::FuncOp>(
255 mlir::createCanonicalizerPass());
256 pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
257 // This pass does dead code elimination based on symbol visibility.
258 pass_manager->addPass(mlir::createSymbolDCEPass());
259
260 if (!pass_config.disable_variable_freezing) {
261 // This pass 'freezes' immutable global tensors and inlines them as tf
262 // constant ops.
263 pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass(
264 /*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
265 }
266
267 if (!saved_model_dir.empty()) {
268 // This pass 'freezes' tf saved model asset ops and inlines as string values
269 // in a format of the tf constant op.
270 pass_manager->addPass(
271 mlir::tf_saved_model::CreateFreezeAssetsPass(saved_model_dir.str()));
272 }
273 // For TF Quantization, convert unsupported ops to Flex ops before other
274 // conversion passes.
275 if (!toco_flags.tf_quantization_mode().empty()) {
276 pass_manager->addNestedPass<mlir::func::FuncOp>(
277 mlir::TF::CreateFallbackToFlexOpsPass(
278 toco_flags.tf_quantization_mode()));
279 }
280 // The below passes only make sense if Builtin TFLite ops are enabled
281 // for emission.
282 if (pass_config.emit_builtin_tflite_ops) {
283 // Run shape inference after variables are converted to constants.
284 if (pass_config.shape_inference) {
285 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
286 }
287 // Force layout supported by TFLite, this will transpose the data
288 // to match 'kTFLiteDataLayout'
289 mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
290 layout_optimization_options.force_data_format = kTFLiteDataLayout;
291 layout_optimization_options.skip_fold_transpose_in_ops = true;
292 mlir::TF::CreateLayoutOptimizationPipeline(
293 pass_manager->nest<mlir::func::FuncOp>(), layout_optimization_options);
294 // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
295 // the TFLite dialect.
296 pass_manager->addNestedPass<mlir::func::FuncOp>(
297 mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul,
298 /*allow_bf16_and_f16_type_legalization=*/
299 !pass_config.runtime_verification,
300 toco_flags.use_fake_quant_num_bits()));
301 pass_manager->addNestedPass<mlir::func::FuncOp>(
302 mlir::createCanonicalizerPass());
303 if (pass_config.shape_inference) {
304 // Add a shape inference pass to optimize away the unnecessary casts.
305 // This also fixes the unranked shapes due to TF ops constant folding.
306 // TODO(fengliuai): remove this pass if TableGen patterns have a better
307 // to control the shapes for the intermediate results.
308 pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
309 }
310
311 // Inline function calls that left in the graph after folding functional
312 // control flow ops (IfOp, CaseOp).
313 pass_manager->addPass(mlir::createInlinerPass());
314
315 // This pass removes the asset file dependencies in hash table use cases.
316 pass_manager->addNestedPass<mlir::func::FuncOp>(
317 mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str()));
318
319 pass_manager->addNestedPass<mlir::func::FuncOp>(
320 mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification,
321 pass_config.preserve_assert_op));
322 pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass());
323 pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
324 pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass());
325 pass_manager->addNestedPass<mlir::func::FuncOp>(
326 mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
327 // This pass operates on TensorFlow ops but is triggered after legalization
328 // so that it can target constants introduced once TensorFlow Identity ops
329 // are removed during legalization.
330 pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
331 std::vector<std::string> empty_wrapped_ops({});
332 pass_manager->addNestedPass<mlir::func::FuncOp>(
333 mlir::TFL::CreateRaiseCustomOpsPass(empty_wrapped_ops));
334 pass_manager->addPass(mlir::createSymbolDCEPass());
335 pass_manager->addNestedPass<mlir::func::FuncOp>(
336 mlir::createCanonicalizerPass());
337 pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
338
339 // Run quantization after all the floating point model conversion is
340 // completed. Add either full integer quantization or dynamic range
341 // quantization passes based on quant_specs.
342 if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
343 AddQuantizationPasses(pass_config.quant_specs, *pass_manager);
344 // Remove unnecessary QDQs while handling QAT models.
345 pass_manager->addNestedPass<mlir::func::FuncOp>(
346 mlir::TFL::CreatePostQuantizeRemoveQDQPass());
347 } else if (pass_config.quant_specs
348 .RunAndRewriteDynamicRangeQuantizationPasses()) {
349 AddDynamicRangeQuantizationPasses(pass_config.quant_specs, *pass_manager);
350 }
351 pass_manager->addPass(mlir::createCanonicalizerPass());
352
353 // This pass should be always at the end of the model
354 // conversion (even after quantization). Some TFL ops like unidirectional
355 // sequence lstm will have stateful operands and some optimization passes
356 // will merge those operands if they have identical values & types. However,
357 // it's not desired by TFL. This pass serves as a "fix" pass to split the
358 // merged inputs until we have 1st class variable support or reuse
359 // tf.variable to model this.
360 pass_manager->addNestedPass<mlir::func::FuncOp>(
361 mlir::TFL::CreateSplitMergedOperandsPass());
362
363 // Add CallOnceOp when there is a session initializer function in tf saved
364 // model dialect.
365 pass_manager->addPass(
366 mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
367 }
368 if (pass_config.unfold_large_splat_constant) {
369 pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass());
370 }
371 if (pass_config.outline_tf_while) {
372 pass_manager->addPass(mlir::TFL::CreateWhileOutlinePass());
373 }
374 if (pass_config.runtime_verification) {
375 pass_manager->addNestedPass<mlir::func::FuncOp>(
376 mlir::TFL::CreateRuntimeVerifyPass());
377 }
378 }
379
AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)380 void AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir,
381 const toco::TocoFlags& toco_flags,
382 const mlir::TFL::PassConfig& pass_config,
383 mlir::OpPassManager* pass_manager) {
384 AddPreVariableFreezingTFToTFLConversionPasses(pass_config, pass_manager);
385 AddPostVariableFreezingTFToTFLConversionPasses(saved_model_dir, toco_flags,
386 pass_config, pass_manager);
387 }
AddTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)388 void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
389 mlir::OpPassManager* pass_manager) {
390 const toco::TocoFlags toco_flags;
391 AddTFToTFLConversionPasses(/*saved_model_dir=*/"", toco_flags, pass_config,
392 pass_manager);
393 }
394
395 } // namespace tensorflow
396
397 namespace mlir {
398 namespace TFL {
399
400 struct StandardPipelineOptions
401 : public PassPipelineOptions<StandardPipelineOptions> {
402 // TODO(b/150915052): All the tf_tfl_translate_cl flags should
403 // move inside this.
404 };
405
406 // NOLINTNEXTLINE
407 // This creates the standard pass pipeline for TF->TFLite. This
408 // represents a std configuration for TFLite, for use with APIs like
409 // tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline
410 // This does not yet include quantization passes.
CreateTFLStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)411 void CreateTFLStandardPipeline(OpPassManager& pm,
412 const StandardPipelineOptions& options) {
413 OpPassManager& func_pm = pm.nest<func::FuncOp>();
414
415 // tf_executor dialect passes - Cleaning up the IR.
416 mlir::TF::StandardPipelineOptions standard_pipeline_options;
417 mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
418
419 // This is needed for control flow support with TF TensorList.
420 pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass(
421 /*allow_tensorlist_pass_through=*/false,
422 /*default_to_single_batch=*/false,
423 /*enable_dynamic_update_slice=*/false));
424
425 // Saved model pass to mark global tensors immutable.
426 pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
427 // Op fusion pass.
428 pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
429
430 pm.addNestedPass<mlir::func::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass());
431
432 pm.addPass(mlir::createInlinerPass());
433
434 // Canonicalize, CSE etc.
435 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
436 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
437 // DCE for private symbols.
438 pm.addPass(mlir::createSymbolDCEPass());
439
440 // freeze global tensors.
441 pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
442
443 // TFLite dialect passes.
444 pm.addPass(mlir::TFL::CreatePrepareTFPass(
445 /*unfold_batch_matmul=*/true,
446 /*allow_bf16_and_f16_type_legalization=*/false));
447 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
448 pm.addPass(
449 mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true,
450 /*preserve_assert_op=*/false));
451 pm.addPass(mlir::TFL::CreateLegalizeHashTablesPass());
452 pm.addPass(mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
453 pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
454 pm.addPass(mlir::createSymbolDCEPass());
455
456 // Canonicalize, CSE etc.
457 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
458 pm.addNestedPass<mlir::tf_saved_model::SessionInitializerOp>(
459 mlir::createCanonicalizerPass());
460 pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
461
462 // Pass for stateful operands like LSTM.
463 pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass());
464
465 pm.addPass(mlir::TFL::CreateWhileOutlinePass());
466
467 pm.addNestedPass<mlir::func::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
468 }
469
470 // Registers a pass pipeline for the standard TFL passes.
471 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
472 "tfl-standard-pipeline",
473 "Run the standard passes involved in transforming/optimizing the TF "
474 "program to TFLite after "
475 "importing into MLIR.",
476 CreateTFLStandardPipeline);
477
478 } // namespace TFL
479 } // namespace mlir
480