xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/resource_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_
18 
19 #include <string>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/hash/hash.h"
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/graph/graph.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/stream_executor/lib/statusor.h"
29 
30 namespace tensorflow {
31 class ResourceUsageAnalysis {
32  public:
33   // NodeInfo is a triple of function_name:node_name:op to uniquely identity a
34   // node in graph. ResourceUsageAnalysis uses it to represent resource sources
35   // and users.
36   class NodeInfo {
37    public:
38     std::optional<std::string> function_name_;
39     std::string node_name_;
40     std::string op_;
41 
NodeInfo()42     NodeInfo() {}
43 
NodeInfo(const std::optional<std::string> & function_name,std::string node_name,std::string op)44     NodeInfo(const std::optional<std::string>& function_name,
45              std::string node_name, std::string op)
46         : function_name_(function_name),
47           node_name_(std::move(node_name)),
48           op_(std::move(op)) {}
49 
DebugString()50     std::string DebugString() const {
51       return absl::StrJoin({function_name_.value_or(""), node_name_, op_}, ":");
52     }
53 
54     bool operator==(const NodeInfo& o) const {
55       return function_name_ == o.function_name_ && node_name_ == o.node_name_ &&
56              op_ == o.op_;
57     }
58 
59     template <typename H>
AbslHashValue(H h,const NodeInfo & o)60     friend H AbslHashValue(H h, const NodeInfo& o) {
61       return H::combine(std::move(h), o.function_name_, o.node_name_, o.op_);
62     }
63   };
64 
65   // This method analyzes a Tensorflow graph and finds all operations that
66   // create Stack/TensorArray resources and all the operations that consume
67   // resource created by them.
68   //
69   // Note that _Arg nodes that introduce resources are not considered sources.
70   // Note again that Control Flow v1 nodes
71   // (Enter/Exit/Switch/Merge/NextIteration) are not supported. Graphs contain
72   // these nodes cause analysis failures. However Control Flow v2 nodes
73   // (While/If) will be supported.
74   //
75   // TODO(b/135628319): Support analyzing functional while/if as pass-through
76   // ops.
77   //
78   // For example, consider following subgraph:
79   //
80   // TensorArrayOp -> Identity -> TensorArrayWriteOp
81   //
82   // It should be able to tell that TensorArrayWriteOp actually operates on the
83   // resource created by TensorArrayOp even though there might be
84   // non-resource-specific operations like Identity (or other pass-through
85   // operations).
86   //
87   // source_to_path maps the nodes that creates resources to all nodes that
88   // operate on the corresponding resource, not including sources themselves. It
89   // is cleared upon calling this method.
90   static Status Analyze(
91       const Graph* graph, FunctionLibraryRuntime* lib_runtime,
92       absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>*
93           source_to_path);
94 };
95 
96 }  // namespace tensorflow
97 #endif  // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_
98