xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/tf2xla_util.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
18 
19 #include <unordered_map>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
23 #include "tensorflow/compiler/tf2xla/tf2xla_defs.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/kernel_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/status.h"
31 
32 namespace tensorflow {
33 
34 // ValidateConfig returns OK iff config is valid.
35 Status ValidateConfig(const tf2xla::Config& config);
36 
37 // Modifies <graph_def> to include placeholders for each fed tensor, and
38 // update references to the fed tensors to refer to the placeholders.
39 // The existing nodes referenced by the feeds are not removed or modified
40 // (except where their input edges are modified by the replacement of other
41 // feeds).
42 Status AddPlaceholdersForFeeds(
43     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
44     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def);
45 
46 // Returns in <out> a copy of <in>, pruned to only include fetches from
47 // <config>.
48 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
49                          GraphDef* out);
50 
51 // Returns node:port for the given <id>.
52 string TensorIdToString(const tf2xla::TensorId& id);
53 
54 // Updates the sharding of <n> based on the sharding of its neighbors.
55 // If <out_edges> is true, outgoing edges from <n> are considered; else incoming
56 // edges are considered.
57 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges);
58 
59 // Add an allowed data type to the AttrConstraint with the given name.
60 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
61                                    KernelDef* kdef);
62 
63 // Returns the next random seed to use for seeding xla rng.
64 uint32 GetXLARandomSeed();
65 
66 // Indicates how a FunctionDef is associated with a graph node (e.g. the node is
67 // a function call, or the node has function attrs).
68 class AssociatedFunctionInfo {
69  public:
70   enum AssociatedFunctionType {
71     kFunctionAttr = 0,
72     kFunctionCallNode = 1,
73     kSymbolicGradient = 2,
74   };
75 
76   // The function is an attr of the node.
FunctionAttr(const string & func_name,const AttrValueMap & attrs,const string & attr_name)77   static AssociatedFunctionInfo FunctionAttr(const string& func_name,
78                                              const AttrValueMap& attrs,
79                                              const string& attr_name) {
80     return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
81   }
82 
83   // The node is a function call.
FunctionCall(const string & func_name,const AttrValueMap & attrs)84   static AssociatedFunctionInfo FunctionCall(const string& func_name,
85                                              const AttrValueMap& attrs) {
86     // attr_name will not be used in this case.
87     return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
88                                   /*attr_name=*/"");
89   }
90 
91   // The node is a SymbolicGradient op.
SymbolicGradient(const string & func_name,const AttrValueMap & attrs)92   static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
93                                                  const AttrValueMap& attrs) {
94     // attr_name will not be used in this case.
95     return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
96                                   /*attr_name=*/"");
97   }
98 
type()99   AssociatedFunctionType type() const { return type_; }
100 
func_name()101   const string& func_name() const { return func_name_; }
102 
attr_name()103   const string& attr_name() const { return attr_name_; }
104 
attrs()105   const AttrValueMap& attrs() const { return attrs_; }
106 
107  private:
AssociatedFunctionInfo(AssociatedFunctionType type,const string & func_name,const AttrValueMap & attrs,const string & attr_name)108   AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
109                          const AttrValueMap& attrs, const string& attr_name)
110       : type_(type),
111         func_name_(func_name),
112         attrs_(attrs),
113         attr_name_(attr_name) {}
114 
115   // Available for all instances.
116   AssociatedFunctionType type_;
117   string func_name_;
118   AttrValueMap attrs_;
119 
120   // Only available if the function is defined in an attr.
121   string attr_name_;
122 };
123 
124 // Returns if the NodeDef has associated function.
125 bool HasAssociatedFunction(const NodeDef& node_def,
126                            const FunctionLibraryDefinition* fld);
127 
128 // Gets functions associated with the node. Current cases:
129 // 1. For function call node, its function name;
130 // 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
131 //    and returned attrs will be this node's attributes;
132 // 3. For nodes like XlaWhile/XlaIf, all their function attributes.
133 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
134     const Node& node, const FunctionLibraryDefinition* fld);
135 
136 // Changes associated functions for the node. Current cases:
137 // 1. For function call node, creates a new node with the new function name and
138 //    remove the old node;
139 // 2. For SymbolicGradient op, add or replace GradientDef in
140 //    FunctionLibraryDefinition;
141 // 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
142 Status RewriteAssociatedFunction(
143     Graph* graph, Node* node, FunctionLibraryDefinition* fld,
144     const AssociatedFunctionInfo& associated_function,
145     const string& rewritten_function_name);
146 
147 // Class to act as cache for FunctionLibraryRuntime::Handle objects.
148 class CachedFunctionHandles {
149  public:
CachedFunctionHandles(FunctionLibraryRuntime * flr)150   CachedFunctionHandles(FunctionLibraryRuntime* flr) : flr_(flr) {}
151 
152   // Populates `handle` for requested function and attributes. If we have
153   // instantiated the function with the same attributes before, `handle` will be
154   // cached handle; otherwise instantiate the function and populate `handle`.
155   Status GetOrInstantiate(const string& func_name, AttrSlice attrs,
156                           FunctionLibraryRuntime::Handle* handle);
157 
158   // Releases all handles in the cache. Returns first non-OK status if any;
159   // returns OK otherwise.
160   Status ReleaseAllHandles();
161 
~CachedFunctionHandles()162   ~CachedFunctionHandles() { ReleaseAllHandles().IgnoreError(); }
163 
164  private:
165   FunctionLibraryRuntime* flr_;
166   std::map<string, FunctionLibraryRuntime::Handle> handles_;
167 
168   TF_DISALLOW_COPY_AND_ASSIGN(CachedFunctionHandles);
169 };
170 
171 // Struct for node's output edge info.
172 struct OutEdgeInfo {
173   Node* dst;
174   int src_output, dst_input;
175 };
176 
177 // Replaces node `n` with a new node whose NodeDef is `node_def`.
178 StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def);
179 
180 // Helper function that builds an Identity node.
181 StatusOr<Node*> BuildIdentityNode(Graph* graph, const string& node_name,
182                                   DataType dtype, const Node* input,
183                                   std::optional<string> requested_device);
184 
185 // For "If"/"While" nodes, if some of their inputs are Const nodes, rewrite
186 // body functions to use the Const nodes instead of original _Arg nodes.
187 //
188 // For example, say we have the following computation:
189 //     shape = constant_op.constant([1])
190 //     return tf.cond(pred, lambda: tf.ones(shape), lambda: tf.zeros(shape))
191 // If we do not rewrite then/else function, they will use _Arg node as shape
192 // input for tf.ones/tf.zeros. But XLA requires that shape input to be compile
193 // time constant, so XLA compilation will fail. This rewriting process will
194 // change the shape input to Const node.
195 Status PropagateConstIntoFunctionalNodes(
196     Graph* g, const FunctionLibraryDefinition* lookup_fld,
197     FunctionLibraryDefinition* fld);
198 
199 // Prunes unreachable FunctionDefs from FunctionLibraryDefinition.
200 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
201                                           FunctionLibraryDefinition* fld);
202 
203 // Finds the following pattern in the graph:
204 // 1) EmptyTensorList -> forward While op -> backward While op,
205 // 2) in forward While op, a Const node is pushed,
206 // 3) in backward While op, data is popped from the tensor list.
207 // And rewrites backward While op to use Const node instead of TensorListPopBack
208 // result.
209 // TODO(b/128633174) remove the TensorList and related TensorList ops.
210 Status RewriteTensorListWithConstElement(Graph* g,
211                                          FunctionLibraryDefinition* fld);
212 
IsConstTraversableOpType(const Node * node)213 inline bool IsConstTraversableOpType(const Node* node) {
214   return node->type_string() == "Identity" ||
215          node->type_string() == "IdentityN" || node->IsWhileNode();
216 }
217 
218 // Determines whether a loop body is invariant for the given argument index.
219 StatusOr<bool> IsLoopInvariant(const FunctionBody* loop_body, int index,
220                                const FunctionLibraryDefinition* lookup_fld);
221 
222 }  // namespace tensorflow
223 
224 #endif  // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_
225