xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/bridge_logger.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
18 
19 #include "mlir/IR/Operation.h"  // from @llvm-project
20 #include "mlir/Pass/Pass.h"  // from @llvm-project
21 #include "mlir/Pass/PassManager.h"  // from @llvm-project
22 #include "mlir/Support/Timing.h"  // from @llvm-project
23 
24 namespace tensorflow {
25 
26 // Logger for logging MLIR modules before and after passes in MLIR TPU bridge.
27 //
28 // The IR logging can be restricted to a particular set of pass invocations via
29 // filters that are specified with the `MLIR_BRIDGE_LOG_PASS_FILTER` and
30 // `MLIR_BRIDGE_LOG_STRING_FILTER` environment variables.
31 // `MLIR_BRIDGE_LOG_PASS_FILTER` takes a semicolon-separated list of pass class
32 // names, `MLIR_BRIDGE_LOG_STRING_FILTER` takes a semicolon-separated list of
33 // strings, and IR is only dumped for a pass invocation if the pass name exactly
34 // matches any of the provided pass names and if the serialized operation on
35 // which the pass is invoked contains any of the specified strings as a
36 // substring. An empty list is interpreted as no restriction. The string filter
37 // can be handy e.g. if one is only interested in a certain function or when
38 // checking where a certain attribute gets lost. Note that we use a semicolon
39 // instead of comma as the separator to allow strings that contain commas (which
40 // frequently appear in MLIR). The strings can contain any characters (including
41 // spaces) except semicolons.
42 //
43 // Example: Setting the environment variables
44 // `MLIR_BRIDGE_LOG_PASS_FILTER="LegalizeTF;Canonicalizer"` and
45 // `MLIR_BRIDGE_LOG_STRING_FILTER="my_string"` will dump IR only for invocations
46 // of `LegalizeTF` and `Canonicalizer` where the string `my_string` is contained
47 // in the serialized operation on which the pass is invoked. For verbose log
48 // level >= 1, `bridge_logger.cc` prints details about pass invocations for
49 // which the IR dumping was skipped because of a filter.
50 class BridgeLoggerConfig : public mlir::PassManager::IRPrinterConfig {
51  public:
52   explicit BridgeLoggerConfig(bool print_module_scope = false,
53                               bool print_after_only_on_change = true);
54 
55   // A hook that may be overridden by a derived config that checks if the IR
56   // of 'operation' should be dumped *before* the pass 'pass' has been
57   // executed. If the IR should be dumped, 'print_callback' should be invoked
58   // with the stream to dump into.
59   void printBeforeIfEnabled(mlir::Pass* pass, mlir::Operation* op,
60                             PrintCallbackFn print_callback) override;
61 
62   // A hook that may be overridden by a derived config that checks if the IR
63   // of 'operation' should be dumped *after* the pass 'pass' has been
64   // executed. If the IR should be dumped, 'print_callback' should be invoked
65   // with the stream to dump into.
66   void printAfterIfEnabled(mlir::Pass* pass, mlir::Operation* op,
67                            PrintCallbackFn print_callback) override;
68 
69   // Returns `true` iff we should log IR for given `pass` and `op`.
70   // Note: Visibility of this function is public for use in unit testing.
71   bool ShouldPrint(mlir::Pass* pass, mlir::Operation* op);
72 
73  private:
74   // Get `filter` encoded by environment variable `env_var`.
75   static std::vector<std::string> GetFilter(const std::string& env_var);
76   // Returns `true` iff any of the strings in `filter` matches `str`, either
77   // exactly or as a substring, depending on `exact_match`.
78   static bool MatchesFilter(const std::string& str,
79                             const std::vector<std::string>& filter,
80                             bool exact_match);
81 
82   // Only log pass invocations whose pass name exactly matches any string in
83   // `pass_filter_` (or when `pass_filter_` is empty).
84   const std::vector<std::string> pass_filter_;
85   // Only log pass invocations where the serialized operation on which the pass
86   // is invoked contains any of the specified strings as a substring (or when
87   // `string_filter_` is empty).
88   const std::vector<std::string> string_filter_;
89 };
90 
91 }  // namespace tensorflow
92 
93 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_BRIDGE_LOGGER_H_
94