1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
18
19 #include <unordered_map>
20
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
23 #include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/kernel_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/status.h"
31
32 namespace tensorflow {
33
34 // ValidateConfig returns OK iff config is valid.
35 Status ValidateConfig(const tf2xla::Config& config);
36
37 // Modifies <graph_def> to include placeholders for each fed tensor, and
38 // update references to the fed tensors to refer to the placeholders.
39 // The existing nodes referenced by the feeds are not removed or modified
40 // (except where their input edges are modified by the replacement of other
41 // feeds).
42 Status AddPlaceholdersForFeeds(
43 const tf2xla::Config& config, const OpRegistryInterface* op_registry,
44 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
45
46 // Returns in <out> a copy of <in>, pruned to only include fetches from
47 // <config>.
48 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
49 GraphDef* out);
50
51 // Returns node:port for the given <id>.
52 string TensorIdToString(const tf2xla::TensorId& id);
53
54 // Updates the sharding of <n> based on the sharding of its neighbors.
55 // If <out_edges> is true, outgoing edges from <n> are considered; else incoming
56 // edges are considered.
57 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
58
59 // Add an allowed data type to the AttrConstraint with the given name.
60 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
61 KernelDef* kdef);
62
63 // Returns the next random seed to use for seeding xla rng.
64 uint32 GetXLARandomSeed();
65
66 // Indicates how a FunctionDef is associated with a graph node (e.g. the node is
67 // a function call, or the node has function attrs).
68 class AssociatedFunctionInfo {
69 public:
70 enum AssociatedFunctionType {
71 kFunctionAttr = 0,
72 kFunctionCallNode = 1,
73 kSymbolicGradient = 2,
74 };
75
76 // The function is an attr of the node.
FunctionAttr(const string & func_name,const AttrValueMap & attrs,const string & attr_name)77 static AssociatedFunctionInfo FunctionAttr(const string& func_name,
78 const AttrValueMap& attrs,
79 const string& attr_name) {
80 return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
81 }
82
83 // The node is a function call.
FunctionCall(const string & func_name,const AttrValueMap & attrs)84 static AssociatedFunctionInfo FunctionCall(const string& func_name,
85 const AttrValueMap& attrs) {
86 // attr_name will not be used in this case.
87 return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
88 /*attr_name=*/"");
89 }
90
91 // The node is a SymbolicGradient op.
SymbolicGradient(const string & func_name,const AttrValueMap & attrs)92 static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
93 const AttrValueMap& attrs) {
94 // attr_name will not be used in this case.
95 return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
96 /*attr_name=*/"");
97 }
98
type()99 AssociatedFunctionType type() const { return type_; }
100
func_name()101 const string& func_name() const { return func_name_; }
102
attr_name()103 const string& attr_name() const { return attr_name_; }
104
attrs()105 const AttrValueMap& attrs() const { return attrs_; }
106
107 private:
AssociatedFunctionInfo(AssociatedFunctionType type,const string & func_name,const AttrValueMap & attrs,const string & attr_name)108 AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
109 const AttrValueMap& attrs, const string& attr_name)
110 : type_(type),
111 func_name_(func_name),
112 attrs_(attrs),
113 attr_name_(attr_name) {}
114
115 // Available for all instances.
116 AssociatedFunctionType type_;
117 string func_name_;
118 AttrValueMap attrs_;
119
120 // Only available if the function is defined in an attr.
121 string attr_name_;
122 };
123
124 // Returns if the NodeDef has associated function.
125 bool HasAssociatedFunction(const NodeDef& node_def,
126 const FunctionLibraryDefinition* fld);
127
128 // Gets functions associated with the node. Current cases:
129 // 1. For function call node, its function name;
130 // 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
131 // and returned attrs will be this node's attributes;
132 // 3. For nodes like XlaWhile/XlaIf, all their function attributes.
133 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
134 const Node& node, const FunctionLibraryDefinition* fld);
135
136 // Changes associated functions for the node. Current cases:
137 // 1. For function call node, creates a new node with the new function name and
138 // remove the old node;
139 // 2. For SymbolicGradient op, add or replace GradientDef in
140 // FunctionLibraryDefinition;
141 // 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
142 Status RewriteAssociatedFunction(
143 Graph* graph, Node* node, FunctionLibraryDefinition* fld,
144 const AssociatedFunctionInfo& associated_function,
145 const string& rewritten_function_name);
146
147 // Class to act as cache for FunctionLibraryRuntime::Handle objects.
148 class CachedFunctionHandles {
149 public:
CachedFunctionHandles(FunctionLibraryRuntime * flr)150 CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {}
151
152 // Populates `handle` for requested function and attributes. If we have
153 // instantiated the function with the same attributes before, `handle` will be
154 // cached handle; otherwise instantiate the function and populate `handle`.
155 Status GetOrInstantiate(const string& func_name, AttrSlice attrs,
156 FunctionLibraryRuntime::Handle* handle);
157
158 // Releases all handles in the cache. Returns first non-OK status if any;
159 // returns OK otherwise.
160 Status ReleaseAllHandles();
161
~CachedFunctionHandles()162 ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); }
163
164 private:
165 FunctionLibraryRuntime* flr_;
166 std::map<string, FunctionLibraryRuntime::Handle> handles_;
167
168 TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles);
169 };
170
171 // Struct for node's output edge info.
172 struct OutEdgeInfo {
173 Node* dst;
174 int src_output, dst_input;
175 };
176
177 // Replaces node `n` with a new node whose NodeDef is `node_def`.
178 StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def);
179
180 // Helper function that builds an Identity node.
181 StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
182 DataType dtype, const Node* input,
183 std::optional<string> requested_device);
184
185 // For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite
186 // body functions to use the Const nodes instead of original _Arg nodes.
187 //
188 // For example, say we have the following computation:
189 // shape = constant_op.constant([1])
190 // return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape))
191 // If we do not rewrite then/else function, they will use _Arg node as shape
192 // input for tf.ones/tf.zeros. But XLA requires that shape input to be compile
193 // time constant, so XLA compilation will fail. This rewriting process will
194 // change the shape input to Const node.
195 Status PropagateConstIntoFunctionalNodes(
196 Graph* g, const FunctionLibraryDefinition* lookup_fld,
197 FunctionLibraryDefinition* fld);
198
199 // Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
200 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
201 FunctionLibraryDefinition* fld);
202
203 // Finds the following pattern in the graph:
204 // 1) EmptyTensorList -> forward While op -> backward While op,
205 // 2) in forward While op, a Const node is pushed,
206 // 3) in backward While op, data is popped from the tensor list.
207 // And rewrites backward While op to use Const node instead of TensorListPopBack
208 // result.
209 // TODO(b/128633174) remove the TensorList and related TensorList ops.
210 Status RewriteTensorListWithConstElement(Graph* g,
211 FunctionLibraryDefinition* fld);
212
IsConstTraversableOpType(const Node * node)213 inline bool IsConstTraversableOpType(const Node* node) {
214 return node->type_string() == "Identity" ||
215 node->type_string() == "IdentityN" || node->IsWhileNode();
216 }
217
218 // Determines whether a loop body is invariant for the given argument index.
219 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
220 const FunctionLibraryDefinition* lookup_fld);
221
222 } // namespace tensorflow
223
224 #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
225