xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/graph_rewrite/variable_merger_pass.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Optimization pass that merges VarHandleOps and ReadVariableOps into their
17 // fused forms.
18 //
19 // The goal of this pass is to fix a latency problem sometimes observed in
20 // inference benchmarks. Often a inference step starts by reading the value of
21 // many weights. Reading a resource variable requires a VarHandleOp and a
22 // ReadVariableOp per variable. Running hundreds of trivial ops can add hundreds
23 // of microseconds of latency to the critical path of an inference step. The
24 // inter-op latency of the executor can be easily hundreds of nanoseconds, which
25 // rapidly adds up over many inexpensive ops.
26 //
27 // This pass merges VarHandleOps that have only the graph source node as a
28 // predecessor into a single VarHandlesOp that reads all at once.
29 // It then merges ReadVariablesOp that have no control inputs and originate from
30 // the same handle op into a single large ReadVariablesOp.
31 
32 #ifndef TENSORFLOW_CORE_TPU_GRAPH_REWRITE_VARIABLE_MERGER_PASS_H_
33 #define TENSORFLOW_CORE_TPU_GRAPH_REWRITE_VARIABLE_MERGER_PASS_H_
34 
35 #include "tensorflow/core/common_runtime/optimization_registry.h"
36 #include "tensorflow/core/graph/graph.h"
37 
38 namespace tensorflow {
39 
40 class VariableMergerPass : public GraphOptimizationPass {
41  public:
42   Status Run(const GraphOptimizationPassOptions& options) override;
43 };
44 
45 }  // namespace tensorflow
46 
47 #endif  // TENSORFLOW_CORE_TPU_GRAPH_REWRITE_VARIABLE_MERGER_PASS_H_
48