xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/constant_folding.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/constant_folding.h"
17 
18 #include <algorithm>
19 #include <atomic>
20 #include <set>
21 #include <unordered_map>
22 #include <vector>
23 
24 #include "tensorflow/core/common_runtime/device_factory.h"
25 #include "tensorflow/core/common_runtime/executor.h"
26 #include "tensorflow/core/common_runtime/function_utils.h"
27 #include "tensorflow/core/common_runtime/graph_runner.h"
28 #include "tensorflow/core/common_runtime/memory_types.h"
29 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
30 #include "tensorflow/core/framework/log_memory.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/graph/algorithm.h"
35 #include "tensorflow/core/graph/node_builder.h"
36 #include "tensorflow/core/graph/subgraph.h"
37 #include "tensorflow/core/lib/core/threadpool.h"
38 #include "tensorflow/core/lib/gtl/cleanup.h"
39 #include "tensorflow/core/lib/gtl/flatset.h"
40 #include "tensorflow/core/lib/strings/strcat.h"
41 #include "tensorflow/core/platform/denormal.h"
42 #include "tensorflow/core/platform/setround.h"
43 #include "tensorflow/core/public/session_options.h"
44 
45 namespace tensorflow {
46 
47 namespace {
48 
49 const char kScopedAllocatorAttrName[] = "_scoped_allocator";
50 
51 // Test to see if the Op is one that turns into a constant when its
52 // inputs' shapes are known.
IsShapeOp(const Node * n)53 bool IsShapeOp(const Node* n) {
54   const auto& ts = n->type_string();
55   return ts == "Shape" || ts == "ShapeN" || ts == "Rank" || ts == "Size";
56 }
57 
58 // Reads the partially-known shape of each of n's inputs from shape_map, and
59 // stores it to input_shapes. Returns false if any input does not have a shape
60 // in shape_map.
ReadPartialShapesFromShapeMap(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,std::vector<PartialTensorShape> * input_shapes)61 bool ReadPartialShapesFromShapeMap(
62     const Node* n,
63     const std::unordered_map<string, std::vector<PartialTensorShape>>*
64         shape_map,
65     std::vector<PartialTensorShape>* input_shapes) {
66   CHECK(shape_map != nullptr);
67   input_shapes->resize(n->num_inputs());
68   for (const Edge* in : n->in_edges()) {
69     // Don't need to check if incoming control edges have known shapes.
70     if (in->IsControlEdge()) continue;
71     const auto known_shape_iter = shape_map->find(in->src()->name());
72     if (known_shape_iter == shape_map->end()) {
73       // One of n's inputs doesn't have known shapes, so don't replace n.
74       return false;
75     }
76     const auto& known_shape = known_shape_iter->second;
77     CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first;
78     DCHECK_GE(in->dst_input(), 0);
79     DCHECK_LT(in->dst_input(), input_shapes->size());
80     (*input_shapes)[in->dst_input()] = known_shape[in->src_output()];
81   }
82   return true;
83 }
84 
85 // If all of n's inputs have fully-defined shapes, inserts those shapes as a
86 // vector of Tensors in the shape_replacement_map.
MaybeReplaceShapeOrShapeNOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)87 bool MaybeReplaceShapeOrShapeNOp(
88     const Node* n, const std::vector<PartialTensorShape>& input_shapes,
89     std::unordered_map<const Node*, std::vector<Tensor>>*
90         shape_replacement_map) {
91   std::vector<Tensor> defined_shape;
92   for (const auto& shape : input_shapes) {
93     if (!shape.IsFullyDefined()) {
94       return false;
95     }
96     const int rank = shape.dims();
97     DataType op_type = n->output_type(0);
98     Tensor t(op_type, TensorShape({rank}));
99     if (op_type == DT_INT64) {
100       auto vec = t.vec<int64_t>();
101       for (int i = 0; i < rank; ++i) {
102         vec(i) = shape.dim_size(i);
103       }
104     } else {
105       CHECK(op_type == DT_INT32);
106       auto vec = t.vec<int32>();
107       for (int i = 0; i < rank; ++i) {
108         if (shape.dim_size(i) > INT_MAX) {
109           VLOG(1) << "Node " << n->name() << " has input shape dimension " << i
110                   << " of " << shape.dim_size(i) << " but type INT32 "
111                   << " so not replacing as constant: this will trigger a "
112                      "runtime error later.";
113           return false;
114         }
115         vec(i) = static_cast<int32>(shape.dim_size(i));
116       }
117     }
118     defined_shape.push_back(t);
119   }
120   // All the inputs had known shapes so we can replace the node by constants
121   // later in the rewrite.
122   shape_replacement_map->insert({n, defined_shape});
123   return true;
124 }
125 
126 // If n's input has defined rank, inserts that rank as a Tensor in the
127 //  shape_replacement_map.
MaybeReplaceRankOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)128 bool MaybeReplaceRankOp(const Node* n,
129                         const std::vector<PartialTensorShape>& input_shapes,
130                         std::unordered_map<const Node*, std::vector<Tensor>>*
131                             shape_replacement_map) {
132   CHECK_EQ(input_shapes.size(), 1);
133   if (input_shapes[0].unknown_rank()) {
134     return false;
135   }
136   Tensor t(DT_INT32, TensorShape({}));
137   t.scalar<int32>()() = input_shapes[0].dims();
138   shape_replacement_map->insert({n, {t}});
139   return true;
140 }
141 
142 // If n's input has defined size, inserts that size as a Tensor in the
143 //  shape_replacement_map.
MaybeReplaceSizeOp(const Node * n,const std::vector<PartialTensorShape> & input_shapes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)144 bool MaybeReplaceSizeOp(const Node* n,
145                         const std::vector<PartialTensorShape>& input_shapes,
146                         std::unordered_map<const Node*, std::vector<Tensor>>*
147                             shape_replacement_map) {
148   CHECK_EQ(input_shapes.size(), 1);
149   if (!input_shapes[0].IsFullyDefined()) {
150     return false;
151   }
152   DataType op_type = n->output_type(0);
153   Tensor t(op_type, TensorShape({}));
154   int64_t size = input_shapes[0].num_elements();
155   if (op_type == DT_INT64) {
156     t.scalar<int64_t>()() = size;
157   } else {
158     CHECK(op_type == DT_INT32);
159     if (size > INT_MAX) {
160       VLOG(1) << "Node " << n->name() << " has input shape size " << size
161               << " but type INT32 "
162               << " so not replacing as constant: this will trigger a runtime "
163                  "error later.";
164       return false;
165     }
166     t.scalar<int32>()() = static_cast<int32>(size);
167   }
168   shape_replacement_map->insert({n, {t}});
169   return true;
170 }
171 
172 // If n is a shape Op (Shape, ShapeN, Rank, or Size) and its inputs have their
173 // shapes specified in shape_map, then adds to shape_replacement_map a mapping
174 // from n to a vector of Tensors, where Tensor k is the (statically known) value
175 // on n's kth output edge. shape_replacement_map has an entry for n iff
176 // MaybeReplaceShapeOp returns true, so it's valid to use
177 // shape_replacement_map->count(n) as a test to see if n is a shape op that can
178 // be replaced.
MaybeReplaceShapeOp(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)179 bool MaybeReplaceShapeOp(
180     const Node* n,
181     const std::unordered_map<string, std::vector<PartialTensorShape>>*
182         shape_map,
183     std::unordered_map<const Node*, std::vector<Tensor>>*
184         shape_replacement_map) {
185   if (shape_map == nullptr || !IsShapeOp(n)) {
186     return false;
187   }
188   // input_shapes will contain the shapes of each of n's inputs.
189   std::vector<PartialTensorShape> input_shapes;
190   if (!ReadPartialShapesFromShapeMap(n, shape_map, &input_shapes)) {
191     return false;
192   }
193   const auto& ts = n->type_string();
194   if (ts == "Shape" || ts == "ShapeN") {
195     if (!MaybeReplaceShapeOrShapeNOp(n, input_shapes, shape_replacement_map)) {
196       return false;
197     }
198   } else if (ts == "Rank") {
199     if (!MaybeReplaceRankOp(n, input_shapes, shape_replacement_map)) {
200       return false;
201     }
202   } else {
203     CHECK_EQ(ts, "Size");
204     if (!MaybeReplaceSizeOp(n, input_shapes, shape_replacement_map)) {
205       return false;
206     }
207   }
208   return true;
209 }
210 
211 // Returns true if n can be evaluated as constant. shape_map maps from
212 // nodes to the partially-known shapes of their outputs. consider if
213 // non-null returns a bool indicating whether a given (non-Const,
214 // non-Shape) node is eligible to be
215 // constant-propagated. shape_replacement_map is filled in with a
216 // vector of constant output tensors for constant-foldable shape nodes
217 // (Shape, ShapeN, Size, or Rank).
IsConstantFoldable(const Node * n,const std::unordered_map<string,std::vector<PartialTensorShape>> * shape_map,const std::function<bool (const Node *)> & consider,int64_t max_constant_size_in_bytes,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)218 bool IsConstantFoldable(
219     const Node* n,
220     const std::unordered_map<string, std::vector<PartialTensorShape>>*
221         shape_map,
222     const std::function<bool(const Node*)>& consider,
223     int64_t max_constant_size_in_bytes,
224     std::unordered_map<const Node*, std::vector<Tensor>>*
225         shape_replacement_map) {
226   if (n->IsConstant()) {
227     // Skip constant folding resources as they cannot be deep copied.
228     return n->output_type(0) != DT_RESOURCE;
229   }
230   if (MaybeReplaceShapeOp(n, shape_map, shape_replacement_map)) {
231     return true;
232   }
233   if (n->op_def().is_stateful()) {
234     return false;
235   }
236   if (consider && !consider(n)) {
237     return false;
238   }
239   if (shape_map != nullptr) {
240     // We can skip the node if an output is known to be oversized.
241     auto shape_it = shape_map->find(n->name());
242     if (shape_it != shape_map->end()) {
243       for (int64_t i = 0; i < shape_it->second.size(); ++i) {
244         const auto& out_shape = shape_it->second[i];
245         if (out_shape.IsFullyDefined() &&
246             out_shape.num_elements() * DataTypeSize(n->output_type(i)) >
247                 max_constant_size_in_bytes) {
248           return false;
249         }
250       }
251     }
252   }
253   if (n->IsControlFlow() || n->IsSend() || n->IsRecv()) {
254     return false;
255   }
256   // TODO(yuanbyu): For now disable these session handle operations.
257   if (n->IsGetSessionHandle() || n->IsGetSessionTensor() ||
258       n->IsDeleteSessionTensor()) {
259     return false;
260   }
261   if (n->IsSource()) {
262     return false;
263   }
264   if (n->IsSink()) {
265     return false;
266   }
267   if (n->IsFakeParam()) {
268     return false;
269   }
270   // Since constant-folding runs on the CPU, do not attempt to constant-fold
271   // operators that have no CPU kernel. Also implies that we will not
272   // constant-fold functions.
273   // TODO(phawkins): allow constant-folding for functions; functions may
274   // be arbitrarily expensive to execute.
275   if (!KernelDefAvailable(DeviceType(DEVICE_CPU), n->def())) {
276     return false;
277   }
278   // Do not constant fold nodes which will be allocated by ScopedAllocator.
279   // This is because the constant-folding graph will not contain the
280   // `_ScopedAllocator` node, and that is necessary to be able to run a node
281   // that will use this allocator.
282   if (n->attrs().Find(kScopedAllocatorAttrName) != nullptr) {
283     VLOG(2) << "Skip node [" << n->DebugString()
284             << "] for constant folding due to scoped allocator";
285     return false;
286   }
287   return true;
288 }
289 
290 // If n is eligible for constant-folding, adds it to nodes, and places its
291 // control dependencies and those transitively of its constant-foldable inputs
292 // into constant_control_deps. If n is a constant-foldable shape node (Shape,
293 // ShapeN, Rank, or Size), also puts its outputs into shape_replacement_map.
ConsiderConstantFoldableNode(Node * n,const ConstantFoldingOptions & opts,std::vector<Node * > * nodes,std::unordered_map<const Node *,gtl::FlatSet<Node * >> * constant_control_deps,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map,bool * internal_node_inserted)294 void ConsiderConstantFoldableNode(
295     Node* n, const ConstantFoldingOptions& opts, std::vector<Node*>* nodes,
296     std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
297     std::unordered_map<const Node*, std::vector<Tensor>>* shape_replacement_map,
298     bool* internal_node_inserted) {
299   if (IsConstantFoldable(n, opts.shape_map, opts.consider,
300                          opts.max_constant_size_in_bytes,
301                          shape_replacement_map)) {
302     // A node is constant provided all of its non-control incoming Tensors come
303     // from constant nodes, or it's a shape Op with statically known inputs in
304     // which case it is placed in shape_replacement_map.
305     //
306     // We allow control dependencies from non-constant nodes to constant nodes,
307     // but to preserve the graph structure we must transfer the control
308     // dependency onto any constant replacement.
309     bool all_parents_constant = true;
310     for (const Edge* in : n->in_edges()) {
311       // Allows non-constant -> constant control edges.
312       if (!in->IsControlEdge() &&
313           constant_control_deps->count(in->src()) == 0) {
314         all_parents_constant = false;
315         break;
316       }
317     }
318     if (all_parents_constant || shape_replacement_map->count(n) != 0) {
319       gtl::FlatSet<Node*>& control_deps = (*constant_control_deps)[n];
320       for (const Edge* e : n->in_edges()) {
321         if (constant_control_deps->count(e->src()) == 0) {
322           // This branch is taken if the incoming edge is a control dependency,
323           // in which case we want to add it to the dependencies being
324           // accumulated for this node, or the incoming edge is not
325           // constant. The latter may happen when n is a shape node and the
326           // source has known shape. In that case add a control dependency from
327           // the source node, since there was previously a data dependency and
328           // we want to preserve sequencing constraints.
329           if (!e->src()->IsSource()) {
330             control_deps.insert(e->src());
331           }
332         } else {
333           // If the parent has been accumulating control dependencies, add all
334           // of its transitive control deps.
335           const gtl::FlatSet<Node*>& parent_deps =
336               (*constant_control_deps)[e->src()];
337           control_deps.insert(parent_deps.begin(), parent_deps.end());
338         }
339       }
340       nodes->push_back(n);
341       if (!n->IsConstant()) {
342         *internal_node_inserted = true;
343       }
344     }
345   }
346 }
347 
348 // Returns the constant foldable nodes in `nodes` in topological order.
349 // Populates `constant_control_deps` with the non-constant control dependencies
350 // of each constant node.
FindConstantFoldableNodes(const Graph * graph,const ConstantFoldingOptions & opts,std::vector<Node * > * nodes,std::unordered_map<const Node *,gtl::FlatSet<Node * >> * constant_control_deps,std::unordered_map<const Node *,std::vector<Tensor>> * shape_replacement_map)351 void FindConstantFoldableNodes(
352     const Graph* graph, const ConstantFoldingOptions& opts,
353     std::vector<Node*>* nodes,
354     std::unordered_map<const Node*, gtl::FlatSet<Node*>>* constant_control_deps,
355     std::unordered_map<const Node*, std::vector<Tensor>>*
356         shape_replacement_map) {
357   bool internal_node_inserted = false;
358   // Walk the nodes in data flow order.
359   ReverseDFS(
360       *graph, nullptr,
361       [nodes, constant_control_deps, shape_replacement_map,
362        &internal_node_inserted, &opts](Node* n) {
363         ConsiderConstantFoldableNode(n, opts, nodes, constant_control_deps,
364                                      shape_replacement_map,
365                                      &internal_node_inserted);
366       },
367       NodeComparatorName());
368   // If we have inserted just leaf level nodes, then there is nothing to fold.
369   if (!internal_node_inserted) {
370     nodes->clear();
371     constant_control_deps->clear();
372   }
373 }
374 
375 typedef std::pair<Node*, int> NodeAndOutput;
376 
377 // Adds n to constant_graph which is being built up for subsequent evaluation of
378 // constant propagation. node_map is the mapping of nodes in the original graph
379 // to nodes in the constant graph. The value of an entry in node_map is a vector
380 // of nodes because a ShapeN node in the original graph is replaced by a vector
381 // of Constant nodes in the constant graph.
AddNodeToConstantGraph(Node * n,std::unordered_map<Node *,std::vector<Node * >> * node_map,Graph * constant_graph)382 void AddNodeToConstantGraph(
383     Node* n, std::unordered_map<Node*, std::vector<Node*>>* node_map,
384     Graph* constant_graph) {
385   std::vector<Node*>& added = (*node_map)[n];
386   added.push_back(constant_graph->CopyNode(n));
387   for (const Edge* in_edge : n->in_edges()) {
388     // Don't copy control edges to the constant graph.
389     if (!in_edge->IsControlEdge()) {
390       Node* in = in_edge->src();
391       auto it = node_map->find(in);
392       CHECK(it != node_map->end())
393           << n->DebugString() << " <-" << in->DebugString();
394       if (it->second.size() == 1) {
395         constant_graph->AddEdge(it->second[0], in_edge->src_output(), added[0],
396                                 in_edge->dst_input());
397       } else {
398         // The original source node had multiple outputs and was replaced by a
399         // vector of constants, so the edge comes from the 0th output of the kth
400         // added constant, rather than the kth output of the added node as in
401         // the standard case above.
402         constant_graph->AddEdge(it->second[in_edge->src_output()], 0, added[0],
403                                 in_edge->dst_input());
404       }
405     }
406   }
407 }
408 
409 // Replaces constant-foldable shape node n by a vector of constants in
410 // constant_graph, which is being built up for subsequent evaluation of constant
411 // propagation. node_map is the mapping of nodes in the original graph to nodes
412 // in the constant graph. The value of an entry in node_map is a vector of nodes
413 // because a ShapeN node in the original graph is replaced by a vector of
414 // Constant nodes in the constant graph.
AddShapeNodeToConstantGraph(Node * n,const std::unordered_map<const Node *,std::vector<Tensor>> & shape_replacement_map,std::unordered_map<Node *,std::vector<Node * >> * node_map,const ConstantFoldNameGenerator & generate_new_name,Graph * constant_graph)415 void AddShapeNodeToConstantGraph(
416     Node* n,
417     const std::unordered_map<const Node*, std::vector<Tensor>>&
418         shape_replacement_map,
419     std::unordered_map<Node*, std::vector<Node*>>* node_map,
420     const ConstantFoldNameGenerator& generate_new_name, Graph* constant_graph) {
421   std::vector<Node*>& added = (*node_map)[n];
422   const string& node_name = n->name();
423   for (const Tensor& t : shape_replacement_map.at(n)) {
424     auto builder =
425         NodeDefBuilder(generate_new_name(constant_graph, node_name), "Const")
426             .Attr("dtype", t.dtype())
427             .Attr("value", t);
428     NodeDef def;
429     CHECK(builder.Finalize(&def).ok());
430     Node* constant_node;
431     CHECK(NodeBuilder(builder).Finalize(constant_graph, &constant_node).ok());
432     added.push_back(constant_node);
433   }
434   // Don't copy incoming edges to shape nodes that are being replaced.
435 }
436 
437 // Given the constant foldable nodes in 'nodes', returns a new graph 'g'. 'g'
438 // will contain copies of the nodes in 'nodes'. In addition, if there is an edge
439 // going from a node 'n' in 'nodes' to another node in 'orig_graph' but not in
440 // 'nodes', then 'tensors_to_fetch' will contain the mapping from the
441 // corresponding copy of 'n' and the edge number in 'g' to 'n'.
GetConstantGraph(const Graph * orig_graph,const std::vector<Node * > & nodes,const std::unordered_map<const Node *,std::vector<Tensor>> & shape_replacement_map,std::map<NodeAndOutput,NodeAndOutput> * tensors_to_fetch,const ConstantFoldNameGenerator & generate_new_name)442 Graph* GetConstantGraph(
443     const Graph* orig_graph, const std::vector<Node*>& nodes,
444     const std::unordered_map<const Node*, std::vector<Tensor>>&
445         shape_replacement_map,
446     std::map<NodeAndOutput, NodeAndOutput>* tensors_to_fetch,
447     const ConstantFoldNameGenerator& generate_new_name) {
448   Graph* constant_graph = new Graph(orig_graph->op_registry());
449   std::unordered_map<Node*, std::vector<Node*>> node_map;
450   node_map[orig_graph->source_node()] = {constant_graph->source_node()};
451   node_map[orig_graph->sink_node()] = {constant_graph->sink_node()};
452   for (Node* n : nodes) {
453     if (shape_replacement_map.count(n) == 0) {
454       AddNodeToConstantGraph(n, &node_map, constant_graph);
455     } else {
456       AddShapeNodeToConstantGraph(n, shape_replacement_map, &node_map,
457                                   generate_new_name, constant_graph);
458     }
459   }
460 
461   for (auto const& added_nodes : node_map) {
462     for (const Edge* out_edge : added_nodes.first->out_edges()) {
463       if (node_map.count(out_edge->dst()) == 0) {
464         if (out_edge->IsControlEdge()) continue;
465         if (added_nodes.second.size() == 1) {
466           tensors_to_fetch->insert(
467               {{added_nodes.second[0], out_edge->src_output()},
468                {added_nodes.first, out_edge->src_output()}});
469         } else {
470           // The node had multiple outputs and was replaced by a
471           // vector of constants, so the NodeAndOutput is the 0th
472           // output of the kth added constant, rather than the kth
473           // output of the added node as in the standard case above.
474           tensors_to_fetch->insert(
475               {{added_nodes.second[out_edge->src_output()], 0},
476                {added_nodes.first, out_edge->src_output()}});
477         }
478       }
479     }
480   }
481 
482   return constant_graph;
483 }
484 
485 // Replaces the identified Tensor in 'graph' by a 'Const' node with
486 // the value supplied in 'constant'. 'partition_device', if non-null
487 // is the device where the graph executes. Returns true if the
488 // replacement was successful, false otherwise.
489 // 'control_deps' is the set of nodes that should be control predecessors of the
490 // new constant node.
ReplaceTensorWithConstant(Graph * graph,const Device * partition_device,NodeAndOutput tensor,const Tensor & constant,const gtl::FlatSet<Node * > & control_deps,int64_t max_constant_size_in_bytes,const ConstantFoldNameGenerator & generate_new_name)491 bool ReplaceTensorWithConstant(
492     Graph* graph, const Device* partition_device, NodeAndOutput tensor,
493     const Tensor& constant, const gtl::FlatSet<Node*>& control_deps,
494     int64_t max_constant_size_in_bytes,
495     const ConstantFoldNameGenerator& generate_new_name) {
496   // Be conservative when replacing a tensor with a constant, when not
497   // running on CPU.
498   // 1) Do not replace another constant.
499   // 2) If the destination tensor or any other tensor from the same node is not
500   // an int32 tensor, and has HOST_MEMORY constraint, do not replace it.
501   // 3) If the destination tensor or any other tensor from the same node is an
502   // int32 tensor, and has DEVICE_MEMORY constraint, do not replace it.
503   // 4) If the size of the constant in bytes is too large (>
504   // max_constant_in_bytes), do not replace it. This prevents the size of the
505   // Graph from growing too large.
506   // 5) If the constant op created does not have a kernel implementation
507   // for the device, do not use it.
508   // TODO(keveman): Consider adding a new constant op that has a kernel
509   // implementation for all types, but with HostMemory constraint on it's
510   // output.
511   if (tensor.first->IsConstant()) {
512     return false;
513   }
514   DeviceType device_type = partition_device
515                                ? DeviceType{partition_device->device_type()}
516                                : DEVICE_CPU;
517   if (partition_device && device_type != DEVICE_CPU) {
518     MemoryTypeVector input_mvec;
519     MemoryTypeVector output_mvec;
520     if (!MemoryTypesForNode(graph->op_registry(), device_type,
521                             tensor.first->def(), &input_mvec, &output_mvec)
522              .ok()) {
523       return false;
524     }
525     for (int i = 0; i < output_mvec.size(); i++) {
526       MemoryType memory_type = output_mvec[i];
527       bool is_int32 = tensor.first->output_type(i) == DT_INT32;
528       if ((memory_type == HOST_MEMORY && !is_int32) ||
529           (memory_type == DEVICE_MEMORY && is_int32)) {
530         return false;
531       }
532     }
533   }
534   if (constant.TotalBytes() > max_constant_size_in_bytes) {
535     return false;
536   }
537 
538   Node* n = tensor.first;
539   std::vector<const Edge*> edges_to_remove;
540   for (const Edge* out_edge : n->out_edges()) {
541     if (out_edge->src_output() == tensor.second) {
542       edges_to_remove.push_back(out_edge);
543     }
544   }
545   const string& node_name = n->name();
546   Node* constant_node;
547   auto builder = NodeDefBuilder(generate_new_name(graph, node_name), "Const")
548                      .Attr("dtype", constant.dtype())
549                      .Attr("value", constant);
550   if (partition_device) {
551     builder.Device(partition_device->name());
552   }
553   NodeDef def;
554   if (!builder.Finalize(&def).ok()) {
555     return false;
556   }
557   const KernelDef* kdef;
558   if (!FindKernelDef(device_type, def, &kdef, nullptr).ok()) {
559     return false;
560   }
561 
562   VLOG(1) << "Replacing " << tensor.first->name() << " :: " << tensor.second
563           << " with a constant";
564 
565   if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) {
566     return false;
567   }
568   for (auto edge : edges_to_remove) {
569     graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input());
570     graph->RemoveEdge(edge);
571   }
572   if (control_deps.empty()) {
573     graph->AddControlEdge(graph->source_node(), constant_node);
574   } else {
575     for (Node* node : control_deps) {
576       graph->AddControlEdge(node, constant_node);
577     }
578   }
579   if (partition_device) {
580     constant_node->set_assigned_device_name(partition_device->name());
581   }
582   return true;
583 }
584 
585 }  // namespace
586 
ConstantFold(const ConstantFoldingOptions & opts,FunctionLibraryRuntime * function_library,Env * env,const Device * partition_device,Graph * graph,bool * was_mutated)587 Status ConstantFold(const ConstantFoldingOptions& opts,
588                     FunctionLibraryRuntime* function_library, Env* env,
589                     const Device* partition_device, Graph* graph,
590                     bool* was_mutated) {
591   // TensorFlow flushes denormals to zero and rounds to nearest, so we do
592   // the same here.
593   port::ScopedFlushDenormal flush;
594   port::ScopedSetRound round(FE_TONEAREST);
595 
596   DumpGraph("Before", graph);
597 
598   ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
599   std::atomic_int_fast64_t constant_unique_id{0};
600   if (generate_new_name == nullptr) {
601     generate_new_name = [&constant_unique_id](Graph* graph, string old_name) {
602       return strings::StrCat(graph->NewName(old_name), "__cf__",
603                              constant_unique_id.fetch_add(1));
604     };
605   }
606 
607   std::vector<Node*> constant_foldable_nodes;
608   std::unordered_map<const Node*, gtl::FlatSet<Node*>> constant_control_deps;
609   std::unordered_map<const Node*, std::vector<Tensor>> shape_replacement_map;
610   FindConstantFoldableNodes(graph, opts, &constant_foldable_nodes,
611                             &constant_control_deps, &shape_replacement_map);
612   if (constant_foldable_nodes.empty()) {
613     VLOG(1) << "No constant foldable nodes found";
614     *was_mutated = false;
615     // This is not an error, so return the status as OK.
616     return OkStatus();
617   }
618 
619   std::map<NodeAndOutput, NodeAndOutput> tensors_to_fetch;
620   std::unique_ptr<Graph> constant_graph(
621       GetConstantGraph(graph, constant_foldable_nodes, shape_replacement_map,
622                        &tensors_to_fetch, generate_new_name));
623   DumpGraph("Constant graph", constant_graph.get());
624 
625   if (tensors_to_fetch.empty()) {
626     VLOG(1) << "No constant nodes found that feed into the original graph.";
627     *was_mutated = false;
628     // This is not an error, so return the status as OK.
629     return OkStatus();
630   }
631   VLOG(1) << "Constant foldable " << constant_graph->num_node_ids() << " : "
632           << graph->num_node_ids();
633 
634   std::vector<string> tensors_to_fetch_names;
635   std::vector<NodeAndOutput> tensors_to_replace;
636   // Sorting the nodes based on the name gives us a stable ordering between runs
637   // for the same graph.
638   std::vector<std::pair<NodeAndOutput, NodeAndOutput>> tensors_to_fetch_sorted(
639       tensors_to_fetch.begin(), tensors_to_fetch.end());
640   std::sort(tensors_to_fetch_sorted.begin(), tensors_to_fetch_sorted.end(),
641             [](const std::pair<NodeAndOutput, NodeAndOutput>& n1,
642                const std::pair<NodeAndOutput, NodeAndOutput>& n2) {
643               return std::tie(n1.first.first->name(), n1.first.second) <
644                      std::tie(n2.first.first->name(), n2.first.second);
645             });
646   for (auto n : tensors_to_fetch_sorted) {
647     tensors_to_fetch_names.push_back(
648         strings::StrCat(n.first.first->name(), ":", n.first.second));
649     tensors_to_replace.push_back(n.second);
650   }
651 
652   auto graph_runner = std::unique_ptr<GraphRunner>(new GraphRunner(env));
653   // Evaluate the constant foldable nodes.
654   std::vector<Tensor> outputs;
655   auto delete_tensors = gtl::MakeCleanup([&graph_runner, &outputs] {
656     // Output tensors need to be cleared before the GraphRunner is deleted.
657     outputs.clear();
658     graph_runner.reset(nullptr);
659   });
660 
661   Status s =
662       graph_runner->Run(constant_graph.get(), function_library, {} /* inputs*/,
663                         tensors_to_fetch_names, &outputs);
664   if (!s.ok()) {
665     VLOG(1) << "Could not fetch constants: " << s;
666     *was_mutated = false;
667     return s;
668   }
669 
670   // Fetch the constant tensors and replace the corresponding tensors in the
671   // original graph with those constants.
672   int32_t num_nodes_replaced = 0;
673   for (size_t c = 0; c < outputs.size(); ++c) {
674     const gtl::FlatSet<Node*>& control_deps =
675         constant_control_deps[tensors_to_replace[c].first];
676     if (ReplaceTensorWithConstant(
677             graph, partition_device, tensors_to_replace[c], outputs[c],
678             control_deps, opts.max_constant_size_in_bytes, generate_new_name)) {
679       ++num_nodes_replaced;
680     }
681   }
682 
683   DumpGraph("After", graph);
684 
685   *was_mutated = (num_nodes_replaced > 0);
686   return OkStatus();
687 }
688 
689 }  // namespace tensorflow
690