xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h"
17 
18 #include <string>
19 
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "mlir/Pass/PassManager.h"  // from @llvm-project
27 #include "mlir/Transforms/Passes.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
29 #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h"
30 #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h"
31 #include "tensorflow/compiler/mlir/lite/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h"
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
38 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
39 
40 namespace mlir {
41 /// Create a pass to convert from the TFExecutor to the TF control dialect.
42 std::unique_ptr<OperationPass<func::FuncOp>>
43 CreateTFExecutorToControlDialectConversion();
44 }  // namespace mlir
45 
46 namespace tensorflow {
47 namespace {
48 // Data layout supported by TFLite.
49 const char kTFLiteDataLayout[] = "NHWC";
50 }  // namespace
51 
AddQuantizationPasses(const mlir::quant::QuantizationSpecs & quant_specs,mlir::OpPassManager & pass_manager)52 void AddQuantizationPasses(const mlir::quant::QuantizationSpecs& quant_specs,
53                            mlir::OpPassManager& pass_manager) {
54   pass_manager.addNestedPass<mlir::func::FuncOp>(
55       mlir::TFL::CreatePrepareQuantizePass(quant_specs));
56   if (quant_specs.default_ranges.first.has_value() ||
57       quant_specs.default_ranges.second.has_value()) {
58     pass_manager.addNestedPass<mlir::func::FuncOp>(
59         mlir::TFL::CreateDefaultQuantParamsPass(
60             quant_specs.default_ranges.first.getValueOr(0.0),
61             quant_specs.default_ranges.second.getValueOr(0.0),
62             quant_specs.IsSignedInferenceType()));
63   }
64   pass_manager.addNestedPass<mlir::func::FuncOp>(
65       mlir::TFL::CreateQuantizePass(quant_specs));
66   bool emit_quant_adaptor_ops =
67       quant_specs.inference_type != quant_specs.inference_input_type;
68   pass_manager.addNestedPass<mlir::func::FuncOp>(
69       mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops));
70   pass_manager.addNestedPass<mlir::func::FuncOp>(
71       mlir::TFL::CreateOptimizeOpOrderPass());
72   // Add optimization pass after quantization for additional fusing
73   // opportunities.
74   pass_manager.addNestedPass<mlir::func::FuncOp>(
75       mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
76 }
77 
AddDynamicRangeQuantizationPasses(const mlir::quant::QuantizationSpecs & quant_specs,mlir::OpPassManager & pass_manager)78 void AddDynamicRangeQuantizationPasses(
79     const mlir::quant::QuantizationSpecs& quant_specs,
80     mlir::OpPassManager& pass_manager) {
81   pass_manager.addNestedPass<mlir::func::FuncOp>(
82       mlir::TFL::CreatePrepareDynamicRangeQuantizePass(quant_specs));
83   pass_manager.addNestedPass<mlir::func::FuncOp>(
84       mlir::TFL::CreateQuantizePass(quant_specs));
85   bool emit_quant_adaptor_ops =
86       quant_specs.inference_type != quant_specs.inference_input_type;
87   pass_manager.addNestedPass<mlir::func::FuncOp>(
88       mlir::TFL::CreatePostQuantizePass(emit_quant_adaptor_ops,
89                                         quant_specs.custom_map));
90   pass_manager.addNestedPass<mlir::func::FuncOp>(
91       mlir::TFL::CreateOptimizeOpOrderPass());
92   // Add optimization pass after quantization for additional fusing
93   // opportunities.
94   pass_manager.addNestedPass<mlir::func::FuncOp>(
95       mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
96 }
97 
AddConvertHloToTfPass(std::string entry_function_name,mlir::OpPassManager * pass_manager)98 void AddConvertHloToTfPass(std::string entry_function_name,
99                            mlir::OpPassManager* pass_manager) {
100   // Legalize jax random to tflite custom op.
101   // The CreateLegalizeJaxRandom Pass has to stay at because we need to replace
102   // the random function body before being inlined.
103   pass_manager->addNestedPass<mlir::func::FuncOp>(
104       mlir::TFL::CreateLegalizeJaxRandomPass());
105 
106   // Canonicalize, CSE etc.
107   pass_manager->addNestedPass<mlir::func::FuncOp>(
108       mlir::createCanonicalizerPass());
109   pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
110   // DCE for private symbols.
111   pass_manager->addPass(mlir::createSymbolDCEPass());
112 
113   pass_manager->addPass(mlir::TF::CreateStripNoinlineAttributePass());
114   // Add inline pass.
115   pass_manager->addPass(mlir::createInlinerPass());
116 
117   // Expands mhlo.tuple ops.
118   pass_manager->addPass(
119       mlir::mhlo::createExpandHloTuplesPass(entry_function_name));
120   // Flatten tuples for control flows.
121   pass_manager->addNestedPass<mlir::func::FuncOp>(
122       mlir::mhlo::createFlattenTuplePass());
123 
124   // TF dialect passes
125   pass_manager->addNestedPass<mlir::func::FuncOp>(
126       mlir::TF::CreateLegalizeHloToTfPass());
127 
128   // Canonicalization after TF legalization.
129   pass_manager->addNestedPass<mlir::func::FuncOp>(
130       mlir::createCanonicalizerPass());
131 }
132 
133 // This is the early part of the conversion in isolation. This enables a caller
134 // to inject more information in the middle of the conversion before resuming
135 // it.
AddPreVariableFreezingTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)136 void AddPreVariableFreezingTFToTFLConversionPasses(
137     const mlir::TFL::PassConfig& pass_config,
138     mlir::OpPassManager* pass_manager) {
139   if (pass_config.enable_hlo_to_tf_conversion) {
140     // TODO(b/194747383): We need to valid that indeed the "main" func is
141     // presented.
142     AddConvertHloToTfPass("main", pass_manager);
143   }
144   // This pass wraps all the tf.FakeQuant ops in a custom op so they are not
145   // folded before being converted to tfl.quantize and tfl.dequantize ops.
146   auto wrapped_ops = mlir::TFL::AllTfFakeQuantOps();
147   pass_manager->addNestedPass<mlir::func::FuncOp>(
148       mlir::TFL::CreateRaiseCustomOpsPass(wrapped_ops));
149 
150   mlir::TF::StandardPipelineOptions standard_pipeline_options;
151   standard_pipeline_options.enable_inliner = false;
152   standard_pipeline_options.form_clusters = pass_config.form_clusters;
153   mlir::TF::CreateTFStandardPipeline(*pass_manager, standard_pipeline_options);
154   pass_manager->addNestedPass<mlir::func::FuncOp>(
155       mlir::TF::CreateDeviceIndexSelectorPass());
156 
157   // Add canonicalize pass to remove no-op session initializer pass.
158   pass_manager->addPass(mlir::createCanonicalizerPass());
159 
160   if (pass_config.guarantee_all_funcs_one_use) {
161     pass_manager->addPass(mlir::TF::CreateGuaranteeAllFuncsOneUsePass());
162   }
163   if (pass_config.shape_inference) {
164     pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
165   }
166 
167   // Keep this pass after the shape inference pass, which couldn't do shape
168   // inference for non-tf ops.
169   if (!pass_config.quant_specs.serialized_quant_stats.empty()) {
170     pass_manager->addNestedPass<mlir::func::FuncOp>(
171         mlir::quant::CreateImportQuantStatsPassForTFControlDialect(
172             pass_config.quant_specs.serialized_quant_stats));
173   }
174 
175   pass_manager->addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
176 
177   // The conversion pipeline has to follow the following orders:
178   // 1) Saved model related optimization like decompose resource ops
179   // 2) Convert composite functions like lstm/rnns, along with proper function
180   // inlining & dce.
181   // 3) Lower static tensor list pass.
182 
183   // This decomposes resource ops like ResourceGather into read-variable op
184   // followed by gather. This is used when the saved model import path is used
185   // during which resources dont get frozen in the python layer.
186   pass_manager->addNestedPass<mlir::func::FuncOp>(
187       mlir::TFDevice::CreateDecomposeResourceOpsPass());
188 
189   pass_manager->addPass(mlir::TF::CreateTFRegionControlFlowToFunctional());
190 }
191 
192 // This is the later part of the conversion in isolation. This enables a caller
193 // to resume the conversion after injecting more information in the middle of
194 // it.
AddPostVariableFreezingTFToTFLConversionPasses(llvm::StringRef saved_model_dir,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)195 void AddPostVariableFreezingTFToTFLConversionPasses(
196     llvm::StringRef saved_model_dir, const toco::TocoFlags& toco_flags,
197     const mlir::TFL::PassConfig& pass_config,
198     mlir::OpPassManager* pass_manager) {
199   // Note:
200   // We need to fuse composite ops before LowerStaticTensorList pass.
201   // The tensorflow list is not supported right now by that pass.
202   // Enable fusing composite ops that can be lowered to built-in TFLite ops.
203   if (pass_config.emit_builtin_tflite_ops &&
204       toco_flags.tf_quantization_mode().empty()) {
205     pass_manager->addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
206   }
207 
208   pass_manager->addPass(mlir::createInlinerPass());
209   pass_manager->addPass(mlir::createSymbolDCEPass());
210 
211   if (pass_config.lower_tensor_list_ops &&
212       toco_flags.tf_quantization_mode().empty()) {
213     // TODO(haoliang): Add this pass by default.
214     pass_manager->addPass(mlir::TFL::CreateLowerStaticTensorListPass(
215         /*allow_tensorlist_pass_through=*/toco_flags.force_select_tf_ops() ||
216             toco_flags.enable_select_tf_ops(),
217         /*default_to_single_batch=*/
218         toco_flags.default_to_single_batch_in_tensor_list_ops(),
219         /*enable_dynamic_update_slice=*/
220         toco_flags.enable_dynamic_update_slice()));
221   }
222 
223   // This pass does resource analysis of saved model global tensors and marks
224   // those deemed read-only as immutable.
225   pass_manager->addPass(
226       mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
227 
228   if (pass_config.shape_inference) {
229     // Add a shape inference pass to optimize away the unnecessary casts.
230     pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
231   }
232 
233   // Legalize while early to allow further constant folding.
234   // TODO(jpienaar): This may not actually matter as we do canonicalization
235   // after the legalize below, for now it needs to be below the above passes
236   // that work on TF dialect and before inliner so that the function calls in
237   // body and cond are inlined for optimization.
238   pass_manager->addPass(mlir::TFL::CreateLegalizeTFWhilePass());
239 
240   // Add function inlining pass. Both TF and TFLite dialects are opted into
241   // function inliner interface.
242   pass_manager->addPass(mlir::createInlinerPass());
243   // Reduce operands of TFL::While without changing the outcome.
244   // It needs to stay here because:
245   // 1. WhileOps are in TFL dialect.
246   // 2. The body and cond are inlined.
247   // 3. We need to do this before while canonicalization, otherwise it would be
248   //   difficult to find dependencies.
249   pass_manager->addNestedPass<mlir::func::FuncOp>(
250       mlir::TFL::CreateReduceWhileOperandsPass());
251   // Canonicalization includes const folding, which is utilized here to optimize
252   // away ops that can't get constant folded after PrepareTF pass. For example,
253   // tf.Conv2D is split into tf.Transpose and tfl.Conv2D.
254   pass_manager->addNestedPass<mlir::func::FuncOp>(
255       mlir::createCanonicalizerPass());
256   pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
257   // This pass does dead code elimination based on symbol visibility.
258   pass_manager->addPass(mlir::createSymbolDCEPass());
259 
260   if (!pass_config.disable_variable_freezing) {
261     // This pass 'freezes' immutable global tensors and inlines them as tf
262     // constant ops.
263     pass_manager->addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass(
264         /*allow_mutable_tensors=*/pass_config.enable_tflite_variables));
265   }
266 
267   if (!saved_model_dir.empty()) {
268     // This pass 'freezes' tf saved model asset ops and inlines as string values
269     // in a format of the tf constant op.
270     pass_manager->addPass(
271         mlir::tf_saved_model::CreateFreezeAssetsPass(saved_model_dir.str()));
272   }
273   // For TF Quantization, convert unsupported ops to Flex ops before other
274   // conversion passes.
275   if (!toco_flags.tf_quantization_mode().empty()) {
276     pass_manager->addNestedPass<mlir::func::FuncOp>(
277         mlir::TF::CreateFallbackToFlexOpsPass(
278             toco_flags.tf_quantization_mode()));
279   }
280   // The below passes only make sense if Builtin TFLite ops are enabled
281   // for emission.
282   if (pass_config.emit_builtin_tflite_ops) {
283     // Run shape inference after variables are converted to constants.
284     if (pass_config.shape_inference) {
285       pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
286     }
287     // Force layout supported by TFLite, this will transpose the data
288     // to match 'kTFLiteDataLayout'
289     mlir::TF::LayoutOptimizationPipelineOptions layout_optimization_options;
290     layout_optimization_options.force_data_format = kTFLiteDataLayout;
291     layout_optimization_options.skip_fold_transpose_in_ops = true;
292     mlir::TF::CreateLayoutOptimizationPipeline(
293         pass_manager->nest<mlir::func::FuncOp>(), layout_optimization_options);
294     // Prepare for TFLite dialect, rerun canonicalization, and then legalize to
295     // the TFLite dialect.
296     pass_manager->addNestedPass<mlir::func::FuncOp>(
297         mlir::TFL::CreatePrepareTFPass(pass_config.unfold_batch_matmul,
298                                        /*allow_bf16_and_f16_type_legalization=*/
299                                        !pass_config.runtime_verification,
300                                        toco_flags.use_fake_quant_num_bits()));
301     pass_manager->addNestedPass<mlir::func::FuncOp>(
302         mlir::createCanonicalizerPass());
303     if (pass_config.shape_inference) {
304       // Add a shape inference pass to optimize away the unnecessary casts.
305       // This also fixes the unranked shapes due to TF ops constant folding.
306       // TODO(fengliuai): remove this pass if TableGen patterns have a better
307       // to control the shapes for the intermediate results.
308       pass_manager->addPass(mlir::TF::CreateTFShapeInferencePass());
309     }
310 
311     // Inline function calls that left in the graph after folding functional
312     // control flow ops (IfOp, CaseOp).
313     pass_manager->addPass(mlir::createInlinerPass());
314 
315     // This pass removes the asset file dependencies in hash table use cases.
316     pass_manager->addNestedPass<mlir::func::FuncOp>(
317         mlir::TF::CreateInitTextFileToImportPass(saved_model_dir.str()));
318 
319     pass_manager->addNestedPass<mlir::func::FuncOp>(
320         mlir::TFL::CreateLegalizeTFPass(pass_config.runtime_verification,
321                                         pass_config.preserve_assert_op));
322     pass_manager->addPass(mlir::TFL::CreateAnalyzeVariablesPass());
323     pass_manager->addPass(mlir::TFL::CreateLegalizeVariablesPass());
324     pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass());
325     pass_manager->addNestedPass<mlir::func::FuncOp>(
326         mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
327     // This pass operates on TensorFlow ops but is triggered after legalization
328     // so that it can target constants introduced once TensorFlow Identity ops
329     // are removed during legalization.
330     pass_manager->addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
331     std::vector<std::string> empty_wrapped_ops({});
332     pass_manager->addNestedPass<mlir::func::FuncOp>(
333         mlir::TFL::CreateRaiseCustomOpsPass(empty_wrapped_ops));
334     pass_manager->addPass(mlir::createSymbolDCEPass());
335     pass_manager->addNestedPass<mlir::func::FuncOp>(
336         mlir::createCanonicalizerPass());
337     pass_manager->addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
338 
339     // Run quantization after all the floating point model conversion is
340     // completed. Add either full integer quantization or dynamic range
341     // quantization passes based on quant_specs.
342     if (pass_config.quant_specs.RunPropagationAndRewriteQuantizationPasses()) {
343       AddQuantizationPasses(pass_config.quant_specs, *pass_manager);
344       // Remove unnecessary QDQs while handling QAT models.
345       pass_manager->addNestedPass<mlir::func::FuncOp>(
346           mlir::TFL::CreatePostQuantizeRemoveQDQPass());
347     } else if (pass_config.quant_specs
348                    .RunAndRewriteDynamicRangeQuantizationPasses()) {
349       AddDynamicRangeQuantizationPasses(pass_config.quant_specs, *pass_manager);
350     }
351     pass_manager->addPass(mlir::createCanonicalizerPass());
352 
353     // This pass should be always at the end of the model
354     // conversion (even after quantization). Some TFL ops like unidirectional
355     // sequence lstm will have stateful operands and some optimization passes
356     // will merge those operands if they have identical values & types. However,
357     // it's not desired by TFL. This pass serves as a "fix" pass to split the
358     // merged inputs until we have 1st class variable support or reuse
359     // tf.variable to model this.
360     pass_manager->addNestedPass<mlir::func::FuncOp>(
361         mlir::TFL::CreateSplitMergedOperandsPass());
362 
363     // Add CallOnceOp when there is a session initializer function in tf saved
364     // model dialect.
365     pass_manager->addPass(
366         mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
367   }
368   if (pass_config.unfold_large_splat_constant) {
369     pass_manager->addPass(mlir::TFL::CreateUnfoldLargeSplatConstantPass());
370   }
371   if (pass_config.outline_tf_while) {
372     pass_manager->addPass(mlir::TFL::CreateWhileOutlinePass());
373   }
374   if (pass_config.runtime_verification) {
375     pass_manager->addNestedPass<mlir::func::FuncOp>(
376         mlir::TFL::CreateRuntimeVerifyPass());
377   }
378 }
379 
AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir,const toco::TocoFlags & toco_flags,const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)380 void AddTFToTFLConversionPasses(llvm::StringRef saved_model_dir,
381                                 const toco::TocoFlags& toco_flags,
382                                 const mlir::TFL::PassConfig& pass_config,
383                                 mlir::OpPassManager* pass_manager) {
384   AddPreVariableFreezingTFToTFLConversionPasses(pass_config, pass_manager);
385   AddPostVariableFreezingTFToTFLConversionPasses(saved_model_dir, toco_flags,
386                                                  pass_config, pass_manager);
387 }
AddTFToTFLConversionPasses(const mlir::TFL::PassConfig & pass_config,mlir::OpPassManager * pass_manager)388 void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
389                                 mlir::OpPassManager* pass_manager) {
390   const toco::TocoFlags toco_flags;
391   AddTFToTFLConversionPasses(/*saved_model_dir=*/"", toco_flags, pass_config,
392                              pass_manager);
393 }
394 
395 }  // namespace tensorflow
396 
397 namespace mlir {
398 namespace TFL {
399 
400 struct StandardPipelineOptions
401     : public PassPipelineOptions<StandardPipelineOptions> {
402   // TODO(b/150915052): All the tf_tfl_translate_cl flags should
403   // move inside this.
404 };
405 
406 // NOLINTNEXTLINE
407 // This creates the standard pass pipeline for TF->TFLite. This
408 // represents a std configuration for TFLite, for use with APIs like
409 // tensorflow/python/pywrap_mlir.py::experimental_run_pass_pipeline
410 // This does not yet include quantization passes.
CreateTFLStandardPipeline(OpPassManager & pm,const StandardPipelineOptions & options)411 void CreateTFLStandardPipeline(OpPassManager& pm,
412                                const StandardPipelineOptions& options) {
413   OpPassManager& func_pm = pm.nest<func::FuncOp>();
414 
415   // tf_executor dialect passes - Cleaning up the IR.
416   mlir::TF::StandardPipelineOptions standard_pipeline_options;
417   mlir::TF::CreateTFStandardPipeline(func_pm, standard_pipeline_options);
418 
419   // This is needed for control flow support with TF TensorList.
420   pm.addPass(mlir::TFL::CreateLowerStaticTensorListPass(
421       /*allow_tensorlist_pass_through=*/false,
422       /*default_to_single_batch=*/false,
423       /*enable_dynamic_update_slice=*/false));
424 
425   // Saved model pass to mark global tensors immutable.
426   pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass());
427   // Op fusion pass.
428   pm.addPass(mlir::TFL::CreatePrepareCompositeFunctionsPass());
429 
430   pm.addNestedPass<mlir::func::FuncOp>(mlir::TFL::CreateLegalizeTFWhilePass());
431 
432   pm.addPass(mlir::createInlinerPass());
433 
434   // Canonicalize, CSE etc.
435   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
436   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
437   // DCE for private symbols.
438   pm.addPass(mlir::createSymbolDCEPass());
439 
440   // freeze global tensors.
441   pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass());
442 
443   // TFLite dialect passes.
444   pm.addPass(mlir::TFL::CreatePrepareTFPass(
445       /*unfold_batch_matmul=*/true,
446       /*allow_bf16_and_f16_type_legalization=*/false));
447   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
448   pm.addPass(
449       mlir::TFL::CreateLegalizeTFPass(/*run_tfl_runtime_verification=*/true,
450                                       /*preserve_assert_op=*/false));
451   pm.addPass(mlir::TFL::CreateLegalizeHashTablesPass());
452   pm.addPass(mlir::TFL::CreateOptimizePass(/*enable_canonicalization=*/true));
453   pm.addPass(mlir::TFL::CreateOptimizeFunctionalOpsPass());
454   pm.addPass(mlir::createSymbolDCEPass());
455 
456   // Canonicalize, CSE etc.
457   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCanonicalizerPass());
458   pm.addNestedPass<mlir::tf_saved_model::SessionInitializerOp>(
459       mlir::createCanonicalizerPass());
460   pm.addNestedPass<mlir::func::FuncOp>(mlir::createCSEPass());
461 
462   // Pass for stateful operands like LSTM.
463   pm.addPass(mlir::TFL::CreateSplitMergedOperandsPass());
464 
465   pm.addPass(mlir::TFL::CreateWhileOutlinePass());
466 
467   pm.addNestedPass<mlir::func::FuncOp>(mlir::TFL::CreateRuntimeVerifyPass());
468 }
469 
470 // Registers a pass pipeline for the standard TFL passes.
471 static mlir::PassPipelineRegistration<StandardPipelineOptions> pipeline(
472     "tfl-standard-pipeline",
473     "Run the standard passes involved in transforming/optimizing the TF "
474     "program to TFLite after "
475     "importing into MLIR.",
476     CreateTFLStandardPipeline);
477 
478 }  // namespace TFL
479 }  // namespace mlir
480