1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_ 18 19 #include <string> 20 21 #include "absl/container/flat_hash_map.h" 22 #include "absl/container/flat_hash_set.h" 23 #include "absl/hash/hash.h" 24 #include "absl/strings/str_cat.h" 25 #include "tensorflow/core/common_runtime/function.h" 26 #include "tensorflow/core/graph/graph.h" 27 #include "tensorflow/core/lib/core/errors.h" 28 #include "tensorflow/stream_executor/lib/statusor.h" 29 30 namespace tensorflow { 31 class ResourceUsageAnalysis { 32 public: 33 // NodeInfo is a triple of function_name:node_name:op to uniquely identity a 34 // node in graph. ResourceUsageAnalysis uses it to represent resource sources 35 // and users. 36 class NodeInfo { 37 public: 38 std::optional<std::string> function_name_; 39 std::string node_name_; 40 std::string op_; 41 NodeInfo()42 NodeInfo() {} 43 NodeInfo(const std::optional<std::string> & function_name,std::string node_name,std::string op)44 NodeInfo(const std::optional<std::string>& function_name, 45 std::string node_name, std::string op) 46 : function_name_(function_name), 47 node_name_(std::move(node_name)), 48 op_(std::move(op)) {} 49 DebugString()50 std::string DebugString() const { 51 return absl::StrJoin({function_name_.value_or(""), node_name_, op_}, ":"); 52 } 53 54 bool operator==(const NodeInfo& o) const { 55 return function_name_ == o.function_name_ && node_name_ == o.node_name_ && 56 op_ == o.op_; 57 } 58 59 template <typename H> AbslHashValue(H h,const NodeInfo & o)60 friend H AbslHashValue(H h, const NodeInfo& o) { 61 return H::combine(std::move(h), o.function_name_, o.node_name_, o.op_); 62 } 63 }; 64 65 // This method analyzes a Tensorflow graph and finds all operations that 66 // create Stack/TensorArray resources and all the operations that consume 67 // resource created by them. 68 // 69 // Note that _Arg nodes that introduce resources are not considered sources. 70 // Note again that Control Flow v1 nodes 71 // (Enter/Exit/Switch/Merge/NextIteration) are not supported. Graphs contain 72 // these nodes cause analysis failures. However Control Flow v2 nodes 73 // (While/If) will be supported. 74 // 75 // TODO(b/135628319): Support analyzing functional while/if as pass-through 76 // ops. 77 // 78 // For example, consider following subgraph: 79 // 80 // TensorArrayOp -> Identity -> TensorArrayWriteOp 81 // 82 // It should be able to tell that TensorArrayWriteOp actually operates on the 83 // resource created by TensorArrayOp even though there might be 84 // non-resource-specific operations like Identity (or other pass-through 85 // operations). 86 // 87 // source_to_path maps the nodes that creates resources to all nodes that 88 // operate on the corresponding resource, not including sources themselves. It 89 // is cleared upon calling this method. 90 static Status Analyze( 91 const Graph* graph, FunctionLibraryRuntime* lib_runtime, 92 absl::flat_hash_map<NodeInfo, absl::flat_hash_set<NodeInfo>>* 93 source_to_path); 94 }; 95 96 } // namespace tensorflow 97 #endif // TENSORFLOW_COMPILER_TF2XLA_RESOURCE_UTIL_H_ 98