1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
19
20 #include <cmath>
21
22 #include "absl/strings/string_view.h"
23 #include "absl/strings/substitute.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/attr_value.pb.h"
26 #include "tensorflow/core/framework/function.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_def.pb.h"
30 #include "tensorflow/core/framework/tensor.pb.h"
31 #include "tensorflow/core/framework/tensor_shape.pb.h"
32 #include "tensorflow/core/framework/tensor_util.h"
33 #include "tensorflow/core/framework/types.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/versions.pb.h"
36 #include "tensorflow/core/grappler/clusters/cluster.h"
37 #include "tensorflow/core/grappler/costs/graph_properties.h"
38 #include "tensorflow/core/grappler/grappler_item.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
43 #include "tensorflow/core/lib/core/errors.h"
44 #include "tensorflow/core/lib/core/stringpiece.h"
45 #include "tensorflow/core/lib/gtl/cleanup.h"
46 #include "tensorflow/core/lib/gtl/inlined_vector.h"
47 #include "tensorflow/core/lib/strings/numbers.h"
48 #include "tensorflow/core/lib/strings/strcat.h"
49 #include "tensorflow/core/platform/cpu_info.h"
50 #include "tensorflow/core/platform/denormal.h"
51 #include "tensorflow/core/platform/env.h"
52 #include "tensorflow/core/platform/setround.h"
53 #include "tensorflow/core/platform/tensor_coding.h"
54 #include "tensorflow/core/public/version.h"
55 #include "tensorflow/core/util/bcast.h"
56 #include "tensorflow/core/util/saved_tensor_slice_util.h"
57
58 namespace tensorflow {
59 namespace grappler {
60 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
61
62 // We only fold/materialize constants smaller than 100kB.
63 const int64_t kMaxConstantSize = 100 * 1024;
64
65 namespace {
66 template <typename T>
AllValuesAre(const TensorProto & proto,const T & value)67 bool AllValuesAre(const TensorProto& proto, const T& value) {
68 Tensor tensor;
69 if (!tensor.FromProto(proto)) {
70 return false;
71 }
72 auto values = tensor.flat<T>();
73 for (int i = 0; i < tensor.NumElements(); ++i) {
74 if (values(i) != value) {
75 return false;
76 }
77 }
78 return true;
79 }
80
81 // Add new_input as a control input to node if it does not already depend on it.
82 // TODO(rmlarsen): Move the following two utility functions to utils.{h,cc} and
83 // clean up code that should be using them.
MaybeAddControlInput(const string & ctrl_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)84 bool MaybeAddControlInput(const string& ctrl_input, NodeDef* node,
85 GraphDef* graph, NodeMap* node_map) {
86 bool already_exists = false;
87 for (const string& input : node->input()) {
88 if (input == ctrl_input || AsControlDependency(input) == ctrl_input) {
89 already_exists = true;
90 break;
91 }
92 }
93 if (!already_exists) {
94 const string ctrl_dep =
95 ConstantFolding::AddControlDependency(ctrl_input, graph, node_map);
96 node->add_input(ctrl_dep);
97 node_map->AddOutput(NodeName(ctrl_input), node->name());
98 }
99 return !already_exists;
100 }
101
102 // Remove old_input as a control input to node.
MaybeRemoveControlInput(const string & old_input,NodeDef * node,GraphDef * graph,NodeMap * node_map)103 bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
104 GraphDef* graph, NodeMap* node_map) {
105 bool removed_input = false;
106 bool update_node_map = true;
107 const string old_input_ctrl_dep = AsControlDependency(NodeName(old_input));
108 for (int i = 0; i < node->input_size(); ++i) {
109 const string& input = node->input(i);
110 if (old_input_ctrl_dep == input) {
111 if (IsControlInput(input)) {
112 node->mutable_input()->SwapElements(i, node->input_size() - 1);
113 node->mutable_input()->RemoveLast();
114 removed_input = true;
115 } else {
116 // There is a non-control input from the same node.
117 // Don't remove the output from the NodeMap.
118 update_node_map = false;
119 }
120 }
121 }
122 if (update_node_map) {
123 node_map->RemoveOutput(NodeName(old_input), node->name());
124 }
125 return removed_input;
126 }
127
HasTPUAttributes(const NodeDef & node)128 bool HasTPUAttributes(const NodeDef& node) {
129 AttrSlice attrs(node);
130 for (const auto& attr : attrs) {
131 if (attr.first.find("_tpu_") != attr.first.npos) {
132 return true;
133 }
134 }
135 return false;
136 }
137
138 template <typename T>
PackedValuesNotEqual(T a,T b)139 bool PackedValuesNotEqual(T a, T b) {
140 return a != b;
141 }
142
143 template <>
PackedValuesNotEqual(float a,float b)144 bool PackedValuesNotEqual(float a, float b) {
145 return reinterpret_cast<int32_t&>(a) != reinterpret_cast<int32_t&>(b);
146 }
147
148 template <>
PackedValuesNotEqual(double a,double b)149 bool PackedValuesNotEqual(double a, double b) {
150 return reinterpret_cast<int64_t&>(a) != reinterpret_cast<int64_t&>(b);
151 }
152
QuantizedTypeMinAsFloat(DataType data_type)153 float QuantizedTypeMinAsFloat(DataType data_type) {
154 switch (data_type) {
155 case DT_QINT8:
156 return Eigen::NumTraits<qint8>::lowest();
157 case DT_QUINT8:
158 return Eigen::NumTraits<quint8>::lowest();
159 case DT_QINT16:
160 return Eigen::NumTraits<qint16>::lowest();
161 case DT_QUINT16:
162 return Eigen::NumTraits<quint16>::lowest();
163 case DT_QINT32:
164 return Eigen::NumTraits<qint32>::lowest();
165 default:
166 return 0.0f;
167 }
168 }
169
QuantizedTypeMaxAsFloat(DataType data_type)170 float QuantizedTypeMaxAsFloat(DataType data_type) {
171 switch (data_type) {
172 case DT_QINT8:
173 return Eigen::NumTraits<qint8>::highest();
174 case DT_QUINT8:
175 return Eigen::NumTraits<quint8>::highest();
176 case DT_QINT16:
177 return Eigen::NumTraits<qint16>::highest();
178 case DT_QUINT16:
179 return Eigen::NumTraits<quint16>::highest();
180 case DT_QINT32:
181 return Eigen::NumTraits<qint32>::highest();
182 default:
183 return 0.0f;
184 }
185 }
186
187 } // namespace
188
ConstantFolding(RewriterConfig::Toggle opt_level,DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_emulation)189 ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
190 DeviceBase* cpu_device,
191 bool disable_compressed_tensor_optimization,
192 bool fold_quantization_emulation)
193 : opt_level_(opt_level),
194 cpu_device_(cpu_device),
195 disable_compressed_tensor_optimization_(
196 disable_compressed_tensor_optimization),
197 fold_quantization_emulation_(fold_quantization_emulation) {
198 resource_mgr_.reset(new ResourceMgr());
199 }
200
ConstantFolding(DeviceBase * cpu_device,bool disable_compressed_tensor_optimization,bool fold_quantization_ops)201 ConstantFolding::ConstantFolding(DeviceBase* cpu_device,
202 bool disable_compressed_tensor_optimization,
203 bool fold_quantization_ops)
204 : ConstantFolding(RewriterConfig::ON, cpu_device,
205 disable_compressed_tensor_optimization,
206 fold_quantization_ops) {}
207
208 // static
AddControlDependency(const string & input_name,GraphDef * graph,NodeMap * node_map)209 string ConstantFolding::AddControlDependency(const string& input_name,
210 GraphDef* graph,
211 NodeMap* node_map) {
212 if (IsControlInput(input_name)) {
213 return input_name;
214 }
215 const NodeDef* node = node_map->GetNode(input_name);
216 // Sanity check for missing node.
217 if (!node) {
218 return input_name;
219 }
220 if (!IsSwitch(*node)) {
221 return AsControlDependency(*node);
222 } else {
223 // We can't anchor control dependencies directly on the switch node: unlike
224 // other nodes only one of the outputs of the switch node will be generated
225 // when the switch node is executed, and we need to make sure the control
226 // dependency is only triggered when the corresponding output is triggered.
227 // We start by looking for an identity node connected to the output of the
228 // switch node, and use it to anchor the control dependency.
229 for (const NodeDef* output : node_map->GetOutputs(node->name())) {
230 if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
231 if (IsSameInput(output->name(), input_name)) {
232 return AsControlDependency(*output);
233 }
234 }
235 }
236 // We haven't found an existing node where we can anchor the control
237 // dependency: add a new identity node.
238 int port = 0;
239 string ctrl_dep_name = ParseNodeName(input_name, &port);
240 strings::StrAppend(&ctrl_dep_name, "_", port);
241 ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
242 const DataType output_type = node->attr().at("T").type();
243
244 NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
245 if (added_node == nullptr) {
246 added_node = graph->add_node();
247 added_node->set_name(ctrl_dep_name);
248 added_node->set_op("Identity");
249 added_node->set_device(node->device());
250
251 (*added_node->mutable_attr())["T"].set_type(output_type);
252 *added_node->add_input() = input_name;
253 node_map->AddNode(added_node->name(), added_node);
254 node_map->AddOutput(node->name(), added_node->name());
255 }
256 return AsControlDependency(*added_node);
257 }
258 }
259
260 // Forward inputs at the given indices to outputs and add a control dependency
261 // on node.
ForwardInputs(NodeDef * node,absl::Span<const int> inputs_to_forward)262 bool ConstantFolding::ForwardInputs(NodeDef* node,
263 absl::Span<const int> inputs_to_forward) {
264 for (int input_idx : inputs_to_forward) {
265 if (input_idx < 0 || input_idx >= node->input_size()) {
266 return false;
267 }
268 }
269
270 const auto& tmp = node_map_->GetOutputs(node->name());
271 const std::vector<NodeDef*> consumers(tmp.begin(), tmp.end());
272 bool updated_graph = false;
273 for (int input_idx : inputs_to_forward) {
274 const string& input = node->input(input_idx);
275 if (IsControlInput(input) && consumers.size() > 1) {
276 continue;
277 }
278 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
279 if (input_node == nullptr) {
280 LOG(ERROR) << "Bad input: " << input;
281 break;
282 }
283 // Update each consumer.
284 for (NodeDef* consumer : consumers) {
285 bool add_dep = false;
286 for (int consumer_input_idx = 0;
287 consumer_input_idx < consumer->input_size(); ++consumer_input_idx) {
288 const string& consumer_input = consumer->input(consumer_input_idx);
289 if (IsControlInput(consumer_input)) {
290 break;
291 }
292 // It is illegal to add control dependencies to _Retval nodes, so we
293 // can't bypass value producing `node` and forward inputs to `consumer`.
294 if (IsRetval(*consumer)) {
295 break;
296 }
297 int output_idx;
298 const string input_node_name =
299 ParseNodeName(consumer_input, &output_idx);
300 if (input_node_name == node->name() && output_idx == input_idx) {
301 consumer->set_input(consumer_input_idx, input);
302 // We will keep the input from the node through a control
303 // dependency, so we only need to add the consumer as an output
304 // for the input node.
305 node_map_->AddOutput(NodeName(input), consumer->name());
306 add_dep = true;
307 }
308 }
309 if (add_dep) {
310 consumer->add_input(AsControlDependency(node->name()));
311 updated_graph = true;
312 }
313 }
314 }
315
316 if (updated_graph) {
317 for (NodeDef* consumer : consumers) {
318 DedupControlInputs(consumer);
319 }
320 }
321 return updated_graph;
322 }
323
324 // Puts the given value into the tensor at the given "flat" index.
PutValueIntoTensor(const int64_t value,const DataType & type,const int index,Tensor * tensor)325 static Status PutValueIntoTensor(const int64_t value, const DataType& type,
326 const int index, Tensor* tensor) {
327 if (type == DT_INT32) {
328 if (value >= INT_MAX) {
329 return Status(error::INVALID_ARGUMENT, "int32 overflow");
330 }
331 tensor->flat<int32>()(index) = static_cast<int32>(value);
332 } else {
333 tensor->flat<int64_t>()(index) = value;
334 }
335 return OkStatus();
336 }
337
338 // Writes the given tensor shape into the given tensor.
339 // Op is assumed to be Shape, ShapeN, Size or Rank.
ConvertShapeToConstant(const string & op,const DataType & type,const PartialTensorShape & shp,Tensor * tensor)340 static Status ConvertShapeToConstant(const string& op, const DataType& type,
341 const PartialTensorShape& shp,
342 Tensor* tensor) {
343 if (op == "Shape" || op == "ShapeN") {
344 *tensor = Tensor(type, TensorShape({shp.dims()}));
345 for (int i = 0; i < shp.dims(); ++i) {
346 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dim_size(i), type, i, tensor));
347 }
348 } else if (op == "Size") {
349 int64_t size = 1;
350 for (int i = 0; i < shp.dims(); ++i) {
351 size *= shp.dim_size(i);
352 }
353 *tensor = Tensor(type, TensorShape({}));
354 TF_RETURN_IF_ERROR(PutValueIntoTensor(size, type, 0, tensor));
355 } else {
356 CHECK_EQ(op, "Rank");
357 *tensor = Tensor(type, TensorShape({}));
358 TF_RETURN_IF_ERROR(PutValueIntoTensor(shp.dims(), type, 0, tensor));
359 }
360 return OkStatus();
361 }
362
363 // TODO(rmlarsen): Perhaps we should move this to the GraphOptimizer base class.
OptimizedNodeExists(const NodeDef & node,StringPiece suffix) const364 bool ConstantFolding::OptimizedNodeExists(const NodeDef& node,
365 StringPiece suffix) const {
366 return node_map_->NodeExists(OptimizedNodeName(node, suffix));
367 }
368
OptimizedNodeName(const NodeDef & node,StringPiece suffix) const369 string ConstantFolding::OptimizedNodeName(const NodeDef& node,
370 StringPiece suffix) const {
371 return AddPrefixToNodeName(strings::StrCat(node.name(), suffix),
372 kConstantFoldingConst);
373 }
374
IsReallyConstant(const NodeDef & node) const375 bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
376 if (!IsConstant(node)) {
377 return false;
378 }
379 // If the node is fed it's not constant anymore.
380 return feed_nodes_.find(node.name()) == feed_nodes_.end();
381 }
382
383 // TODO(rmlarsen): Refactor to shared util.
GetTensorFromConstNode(const string & node_name_or_input,Tensor * tensor)384 bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
385 Tensor* tensor) {
386 const NodeDef* node = node_map_->GetNode(node_name_or_input);
387 return node != nullptr && IsReallyConstant(*node) &&
388 CheckAttrExists(*node, "value").ok() &&
389 tensor->FromProto(node->attr().at("value").tensor());
390 }
391
392 // Materialize the shapes using constants whenever possible.
MaterializeShapes(const GraphProperties & properties)393 Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
394 // We may add some nodes to the graph to encode control dependencies and hold
395 // the materialized shapes: there is no need to process these added nodes, so
396 // only iterate over the nodes of the input graph.
397 const int node_count = graph_->node_size();
398 for (int node_idx = 0; node_idx < node_count; ++node_idx) {
399 NodeDef* node = graph_->mutable_node(node_idx);
400 const string op = node->op();
401 if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN" &&
402 op != "TensorArraySizeV3") {
403 continue;
404 }
405 const std::vector<OpInfo::TensorProperties>& output =
406 properties.GetOutputProperties(node->name());
407 const std::vector<OpInfo::TensorProperties>& input =
408 properties.GetInputProperties(node->name());
409 if (input.empty() || output.empty()) {
410 continue;
411 }
412
413 if (op == "Shape" || op == "Size" || op == "Rank") {
414 CHECK_EQ(1, output.size());
415 CHECK_EQ(1, input.size());
416
417 const DataType type = output[0].dtype();
418 CHECK(type == DT_INT32 || type == DT_INT64);
419 const PartialTensorShape shape(input[0].shape());
420
421 if ((op != "Rank" && !shape.IsFullyDefined()) ||
422 (op == "Rank" && shape.unknown_rank())) {
423 continue;
424 }
425
426 Tensor constant_value(type);
427 if (!ConvertShapeToConstant(op, type, shape, &constant_value).ok()) {
428 continue;
429 }
430
431 // TODO(rmlarsen): Remove this workaround for b/150861569
432 // The bug involves an expression of the form Shape(ExpandDims(x)
433 // with an incorrectly inferred zero-size first dimension.
434 if (op == "Shape") {
435 if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
436 }
437
438 // Repurpose the existing node to be the constant.
439 // Device placement is preserved.
440 graph_modified_ = true;
441 node->set_op("Const");
442 EraseRegularNodeAttributes(node);
443 (*node->mutable_attr())["dtype"].set_type(type);
444 constant_value.AsProtoTensorContent(
445 (*node->mutable_attr())["value"].mutable_tensor());
446
447 // Turn the data input into a control dependency: this is needed to
448 // ensure that the constant value will only be run in the
449 // cases where the shape/rank/size would have been run in
450 // the original graph.
451 string ctrl_dep =
452 AddControlDependency(node->input(0), graph_, node_map_.get());
453 node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
454 node->set_input(0, ctrl_dep);
455 // Done with the Shape/Size/Rank node, move to the next node.
456 continue;
457 }
458
459 if (op == "TensorArraySizeV3") {
460 const NodeDef* array = CHECK_NOTNULL(node_map_->GetNode(node->input(0)));
461 if (array->input_size() == 0 ||
462 (array->attr().count("dynamic_size") != 0 &&
463 array->attr().at("dynamic_size").b())) {
464 continue;
465 }
466 const NodeDef* array_size =
467 CHECK_NOTNULL(node_map_->GetNode(array->input(0)));
468 if (IsReallyConstant(*array_size)) {
469 // Don't materialize 0 sizes to avoid triggering incorrect static
470 // checks. A 0 sized array that can't grow isn't useful anyway.
471 if (array_size->attr().count("value") == 0) {
472 continue;
473 }
474 const TensorProto& raw_val = array_size->attr().at("value").tensor();
475 if (raw_val.dtype() != DT_INT32) {
476 continue;
477 }
478 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
479 if (!value.FromProto(raw_val)) {
480 continue;
481 }
482 if (value.flat<int32>()(0) == 0) {
483 continue;
484 }
485
486 graph_modified_ = true;
487 node->set_op("Const");
488 *node->mutable_attr() = array_size->attr();
489 node->set_input(0, AsControlDependency(NodeName(node->input(0))));
490 node->set_input(1, AddControlDependency(NodeName(node->input(1)),
491 graph_, node_map_.get()));
492 }
493 continue;
494 }
495
496 // Handle ShapeN materialization case.
497 // It's possible that not all input tensors have known shapes.
498 CHECK_EQ(op, "ShapeN");
499 CHECK_EQ(input.size(), output.size());
500 const NodeDef* const shape_n_node = node;
501 for (int port_idx = 0, idx_limit = output.size(); port_idx < idx_limit;
502 ++port_idx) {
503 const DataType type = output[port_idx].dtype();
504 CHECK(type == DT_INT32 || type == DT_INT64);
505 const PartialTensorShape shape(input[port_idx].shape());
506 if (!shape.IsFullyDefined()) {
507 continue;
508 }
509 Tensor constant_value(type);
510 auto status = ConvertShapeToConstant(op, type, shape, &constant_value);
511 if (!status.ok()) {
512 continue;
513 }
514
515 // We make a copy because we mutate the nodes.
516 auto fanouts = node_map_->GetOutputs(shape_n_node->name());
517 // Find all nodes consuming this shape and connect them through the new
518 // constant node instead.
519 for (NodeDef* output : fanouts) {
520 // Track whether there are any direct edges left between shape_n_node
521 // and this output node after the transformation.
522 bool direct_edges_exist = false;
523 for (int k = 0; k < output->input_size(); ++k) {
524 int port;
525 const string node_name = ParseNodeName(output->input(k), &port);
526 if (node_name == shape_n_node->name() && port == port_idx) {
527 // Create a const node as ShapeN's output if not already.
528 const string const_name = OptimizedNodeName(
529 *shape_n_node, strings::StrCat("-matshapes-", port_idx));
530 if (node_map_->GetNode(const_name) == nullptr) {
531 NodeDef* added_node = graph_->add_node();
532 added_node->set_name(const_name);
533 added_node->set_op("Const");
534 added_node->set_device(shape_n_node->device());
535 node_map_->AddNode(added_node->name(), added_node);
536 (*added_node->mutable_attr())["dtype"].set_type(type);
537 constant_value.AsProtoTensorContent(
538 (*added_node->mutable_attr())["value"].mutable_tensor());
539 // We add a control dependency to the original ShapeN node,
540 // so that the node will only be run if all inputs of the
541 // original ShapeN node are run.
542 string ctrl_dep = AddControlDependency(shape_n_node->name(),
543 graph_, node_map_.get());
544 *added_node->add_input() = ctrl_dep;
545 node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
546 }
547 *output->mutable_input(k) = const_name;
548 node_map_->AddOutput(const_name, output->name());
549 graph_modified_ = true;
550 }
551 if (node_name == shape_n_node->name() && port != port_idx) {
552 direct_edges_exist = true;
553 }
554 }
555 if (!direct_edges_exist) {
556 node_map_->RemoveOutput(node->name(), output->name());
557 }
558 }
559 }
560 }
561
562 return OkStatus();
563 }
564
565 namespace {
ExtractShape(const NodeDef & shape_node,const GraphProperties & properties,BCast::Vec * shape,int64_t * min_id)566 bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
567 BCast::Vec* shape, int64_t* min_id) {
568 if (shape_node.op() == "Shape") {
569 const std::vector<OpInfo::TensorProperties>& prop1 =
570 properties.GetInputProperties(shape_node.name());
571 if (prop1.size() != 1) {
572 return false;
573 }
574 const TensorShapeProto& shp = prop1[0].shape();
575 if (shp.unknown_rank()) {
576 return false;
577 }
578 for (const auto& dim : shp.dim()) {
579 shape->push_back(dim.size());
580 *min_id = std::min<int64_t>(*min_id, dim.size());
581 }
582 } else {
583 if (shape_node.attr().count("value") == 0) {
584 return false;
585 }
586 const TensorProto& raw_val = shape_node.attr().at("value").tensor();
587 if (raw_val.dtype() != DT_INT64 && raw_val.dtype() != DT_INT32) {
588 return false;
589 }
590 Tensor value(raw_val.dtype(), raw_val.tensor_shape());
591 if (!value.FromProto(raw_val)) {
592 return false;
593 }
594 for (int j = 0; j < value.NumElements(); ++j) {
595 if (raw_val.dtype() == DT_INT64) {
596 shape->push_back(value.vec<int64_t>()(j));
597 } else {
598 shape->push_back(value.vec<int>()(j));
599 }
600 }
601 }
602 return true;
603 }
604 } // namespace
605
MaterializeBroadcastGradientArgs(const NodeDef & node,const GraphProperties & properties)606 Status ConstantFolding::MaterializeBroadcastGradientArgs(
607 const NodeDef& node, const GraphProperties& properties) {
608 const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
609 const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
610 if (shape_node1 == nullptr ||
611 (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
612 shape_node2 == nullptr ||
613 (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
614 return OkStatus();
615 }
616
617 // Don't optimize this again if it was already optimized and folded.
618 if (OptimizedNodeExists(node, "-folded-1") ||
619 OptimizedNodeExists(node, "-folded-2")) {
620 return OkStatus();
621 }
622 int64_t min_id = 0;
623 BCast::Vec shape1;
624 if (!ExtractShape(*shape_node1, properties, &shape1, &min_id)) {
625 return OkStatus();
626 }
627 BCast::Vec shape2;
628 if (!ExtractShape(*shape_node2, properties, &shape2, &min_id)) {
629 return OkStatus();
630 }
631 // A value of -1 means we don't known anything about the dimension. Replace
632 // the -1 values with unique dimension ids since we don't want two '-1'
633 // dimensions to be considered equal.
634 for (auto& id : shape1) {
635 if (id == -1) {
636 id = --min_id;
637 }
638 }
639 for (auto& id : shape2) {
640 if (id == -1) {
641 id = --min_id;
642 }
643 }
644
645 // Beware: the reduction dimensions computed by the BCast class are valid iff
646 // we assume that two distinct symbolic dimensions can't be equal and a
647 // symbolic dimension can't be equal to 1. This is often but not always true,
648 // so to make this optimization safe we filter out these cases.
649 const int common_dims = std::min(shape1.size(), shape2.size());
650 for (int i = 0; i < common_dims; ++i) {
651 if (shape1[i] >= 0 && shape2[i] >= 0) {
652 continue;
653 }
654 if (shape1[i] != shape2[i]) {
655 // We're either dealing with 2 different symbolic dimensions or a symbolic
656 // and a know dimensions. We can't be sure whether both are equal or not,
657 // so we can't be sure whether we'll be broadcasting or not.
658 return OkStatus();
659 }
660 }
661 // These extra dims could be equal to 1, in which case there is no
662 // broadcasting. It could also be greater than 1, in which case there would
663 // be broadcasting. Since we don't know, we'll just punt.
664 for (int i = common_dims, end = shape1.size(); i < end; ++i) {
665 if (shape1[i] < 0) {
666 return OkStatus();
667 }
668 }
669 for (int i = common_dims, end = shape2.size(); i < end; ++i) {
670 if (shape2[i] < 0) {
671 return OkStatus();
672 }
673 }
674
675 BCast bcast(shape1, shape2);
676 if (!bcast.IsValid()) {
677 return OkStatus();
678 }
679
680 BCast::Vec reduce_dims[2];
681 reduce_dims[0] = bcast.grad_x_reduce_idx();
682 reduce_dims[1] = bcast.grad_y_reduce_idx();
683
684 TF_RETURN_IF_ERROR(CheckAttrExists(node, "T"));
685 const DataType type = node.attr().at("T").type();
686 NodeDef* out[2];
687 for (int j = 0; j < 2; ++j) {
688 int reduction_indices = reduce_dims[j].size();
689 Tensor value(type, TensorShape({reduction_indices}));
690 for (int i = 0; i < reduction_indices; ++i) {
691 if (type == DT_INT32) {
692 value.vec<int32>()(i) = reduce_dims[j][i];
693 } else {
694 value.vec<int64_t>()(i) = reduce_dims[j][i];
695 }
696 }
697 string const_name =
698 OptimizedNodeName(node, strings::StrCat("-bcastargs-", j));
699 out[j] = node_map_->GetNode(const_name);
700 if (out[j] == nullptr) {
701 out[j] = graph_->add_node();
702 TF_RETURN_IF_ERROR(
703 CreateNodeDef(const_name, TensorValue(&value), out[j]));
704 out[j]->set_device(node.device());
705 node_map_->AddNode(const_name, out[j]);
706 string ctrl_dep =
707 AddControlDependency(node.name(), graph_, node_map_.get());
708 *out[j]->add_input() = ctrl_dep;
709 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
710 }
711 }
712
713 // We make a copy here since we might mutate the set.
714 const auto outputs = node_map_->GetOutputs(node.name());
715 for (NodeDef* output : outputs) {
716 for (int k = 0; k < output->input_size(); ++k) {
717 int port;
718 string node_name = ParseNodeName(output->input(k), &port);
719 if (node_name == node.name() && port >= 0 && port < 2 && out[port]) {
720 *output->mutable_input(k) = out[port]->name();
721 node_map_->UpdateInput(output->name(), node_name, out[port]->name());
722 }
723 }
724 }
725
726 return OkStatus();
727 }
728
MaterializeReductionIndices(NodeDef * node,const GraphProperties & properties)729 Status ConstantFolding::MaterializeReductionIndices(
730 NodeDef* node, const GraphProperties& properties) {
731 if (node->input_size() < 2) {
732 return OkStatus();
733 }
734 const NodeDef* indices = node_map_->GetNode(node->input(1));
735 if (!indices || IsReallyConstant(*indices)) {
736 // The reduction indices are already constant, there's nothing to do.
737 return OkStatus();
738 }
739
740 const std::vector<OpInfo::TensorProperties>& input_props =
741 properties.GetInputProperties(node->name());
742 if (input_props.size() != 2) {
743 return OkStatus();
744 }
745 const OpInfo::TensorProperties& input_prop = input_props[0];
746 if (input_prop.shape().unknown_rank()) {
747 // We can't do anything if we don't know the rank of the input.
748 return OkStatus();
749 }
750 const int input_rank = input_prop.shape().dim_size();
751 if (input_rank < 1) {
752 // Unexpected graph, don't try to change it.
753 return OkStatus();
754 }
755 const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
756 DataType dtype = reduction_indices_prop.dtype();
757 if (dtype != DT_INT32 && dtype != DT_INT64) {
758 return OkStatus();
759 }
760 PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
761 const int num_reduction_indices = reduction_indices_shape.num_elements();
762
763 const std::vector<OpInfo::TensorProperties>& output_props =
764 properties.GetOutputProperties(node->name());
765 if (output_props.size() != 1) {
766 return OkStatus();
767 }
768 const OpInfo::TensorProperties& output_prop = output_props[0];
769 const int output_rank =
770 output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
771
772 bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
773 if (!full_reduction) {
774 // A full reduction will generate a tensor of one of the shapes
775 // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
776 // elements in the output of the reduction, we may deduce it from reshape
777 // nodes following it.
778 for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
779 full_reduction = false;
780 if (!IsReshape(*fanout)) {
781 return OkStatus();
782 }
783 const std::vector<OpInfo::TensorProperties>& reshape_props =
784 properties.GetOutputProperties(fanout->name());
785 if (reshape_props.size() != 1) {
786 return OkStatus();
787 }
788 const OpInfo::TensorProperties& reshape_prop = reshape_props[0];
789 PartialTensorShape shape(reshape_prop.shape());
790 if (shape.num_elements() != 1) {
791 return OkStatus();
792 } else {
793 full_reduction = true;
794 }
795 }
796 if (!full_reduction) {
797 return OkStatus();
798 }
799 }
800
801 // We know it's a full reduction. We can generate the full set of indices to
802 // reduce as a constant node.
803 string const_name = OptimizedNodeName(*node, "-reduction_indices");
804 if (node_map_->GetNode(const_name)) {
805 return OkStatus();
806 }
807 NodeDef* reduction_indices = graph_->add_node();
808 Tensor value(dtype, TensorShape({input_rank}));
809 for (int i = 0; i < input_rank; ++i) {
810 if (dtype == DT_INT32) {
811 value.vec<int32>()(i) = i;
812 } else {
813 value.vec<int64_t>()(i) = i;
814 }
815 }
816 TF_RETURN_IF_ERROR(
817 CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
818
819 reduction_indices->set_device(node->device());
820 string ctrl_dep =
821 AddControlDependency(node->input(1), graph_, node_map_.get());
822 *reduction_indices->add_input() = ctrl_dep;
823 node_map_->AddNode(const_name, reduction_indices);
824 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
825
826 node->set_input(1, reduction_indices->name());
827 node_map_->UpdateInput(node->name(), indices->name(),
828 reduction_indices->name());
829
830 return OkStatus();
831 }
832
MaterializeConstantValuedNode(NodeDef * node,const GraphProperties & properties)833 Status ConstantFolding::MaterializeConstantValuedNode(
834 NodeDef* node, const GraphProperties& properties) {
835 if (disable_compressed_tensor_optimization_) {
836 return OkStatus();
837 }
838 // Nodes that generate constant-valued outputs can be represented compactly in
839 // compressed format, regardless of their shape.
840 const std::vector<OpInfo::TensorProperties>& output_props =
841 properties.GetOutputProperties(node->name());
842 if (output_props.size() != 1) return OkStatus();
843 const auto& output_shape = output_props[0].shape();
844 if (!PartialTensorShape(output_shape).IsFullyDefined()) {
845 return OkStatus();
846 }
847 if (IsFill(*node)) {
848 const auto output_dtype = output_props[0].dtype();
849 NodeDef* input_node = nullptr;
850 for (int i = 0; i < 2; ++i) {
851 input_node = node_map_->GetNode(NodeName(node->input(i)));
852 if (input_node == nullptr || !IsReallyConstant(*input_node)) {
853 return OkStatus();
854 }
855 }
856 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
857
858 // Copy the input tensor to the fill node, set the output shape and data
859 // type, and change the node type to Const.
860 TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
861 const TensorProto& input_tensor = input_node->attr().at("value").tensor();
862 if (!input_tensor.tensor_content().empty()) {
863 // Convert the value to repeated field format, so we can use the
864 // decompression mechanism to store only a single value in the constant
865 // node, even if the shape specified in the original Fill is large.
866 Tensor t;
867 if (!t.FromProto(input_tensor)) {
868 return errors::InvalidArgument(
869 "Could not construct Tensor form TensorProto in node: ",
870 input_node->name());
871 }
872 tensor->clear_tensor_content();
873 t.AsProtoField(tensor);
874 } else {
875 *tensor = input_tensor;
876 }
877 *(tensor->mutable_tensor_shape()) = output_shape;
878 (*node->mutable_attr())["dtype"].set_type(output_dtype);
879 node->mutable_attr()->erase("T");
880 node->mutable_attr()->erase("index_type");
881 node->set_op("Const");
882 for (int i = 0; i < 2; i++) {
883 // Change inputs to a control inputs.
884 const string ctrl_dep = AsControlDependency(node->input(i));
885 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
886 node->set_input(i, ctrl_dep);
887 }
888 graph_modified_ = true;
889 } else {
890 double value =
891 (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
892 if (value >= 0) {
893 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
894 value, properties, output_shape, node, graph_));
895 }
896 }
897 return OkStatus();
898 }
899
900 // Materialize output values inferred by the shape inference.
MaterializeOutputValues(NodeDef * node,const GraphProperties & properties)901 Status ConstantFolding::MaterializeOutputValues(
902 NodeDef* node, const GraphProperties& properties) {
903 const std::vector<OpInfo::TensorProperties>& output =
904 properties.GetOutputProperties(node->name());
905 if (output.size() != 1 || !output[0].has_value() ||
906 !IsFoldable(*node, &properties)) {
907 return OkStatus();
908 }
909
910 // If this is a trivial Identity node with a constant input, just route the
911 // input around it.
912 if (IsIdentity(*node)) {
913 NodeDef* input = node_map_->GetNode(node->input(0));
914 if (IsReallyConstant(*input)) {
915 std::vector<int> inputs_to_forward;
916 std::iota(inputs_to_forward.begin(), inputs_to_forward.end(), 0);
917 graph_modified_ = ForwardInputs(node, inputs_to_forward);
918 return OkStatus();
919 }
920 }
921 // Repurpose the existing node to be the constant.
922 // Device placement is preserved.
923 TensorProto value_copy = output[0].value();
924 return ReplaceOperationWithConstantTensor(output[0].dtype(), &value_copy,
925 node, graph_);
926 }
927
MaterializeConstants(const GraphProperties & properties)928 Status ConstantFolding::MaterializeConstants(
929 const GraphProperties& properties) {
930 const int node_count = graph_->node_size();
931 for (int i = 0; i < node_count; ++i) {
932 NodeDef& node = *graph_->mutable_node(i);
933 const string& op = node.op();
934 if (op == "BroadcastGradientArgs") {
935 TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
936 } else if (IsReduction(node)) {
937 TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
938 } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
939 TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
940 } else {
941 TF_RETURN_IF_ERROR(MaterializeOutputValues(&node, properties));
942 }
943 }
944 return OkStatus();
945 }
946
IsFoldable(const NodeDef & node,const GraphProperties * properties)947 bool ConstantFolding::IsFoldable(const NodeDef& node,
948 const GraphProperties* properties) {
949 string key = strings::StrCat(node.name(), "/", node.op());
950 auto it = maybe_foldable_nodes_.find(key);
951 if (it == maybe_foldable_nodes_.end()) {
952 it = maybe_foldable_nodes_
953 .emplace(std::move(key), MaybeFoldable(node, properties))
954 .first;
955 }
956 if (!it->second) {
957 return false;
958 } else {
959 return IsFoldableUncached(node, properties);
960 }
961 }
962
IsFoldableUncached(const NodeDef & node,const GraphProperties * properties) const963 bool ConstantFolding::IsFoldableUncached(
964 const NodeDef& node, const GraphProperties* properties) const {
965 // Folding not applicable to ops with no inputs.
966 if (node.input().empty()) {
967 return false;
968 }
969 // We can only fold nodes if all their inputs are known statically, except in
970 // the case of a merge node that propagate the first inputs that becomes
971 // available, and therefore only requires a single constant input to be
972 // foldable.
973 bool merge_has_constant_input = false;
974 const bool is_merge = IsMerge(node);
975 for (const auto& input : node.input()) {
976 if (IsControlInput(input)) {
977 continue;
978 }
979 const NodeDef* input_node = node_map_->GetNode(input);
980 if (!input_node) {
981 return false;
982 }
983 bool is_const = IsReallyConstant(*input_node);
984 if (is_const) {
985 // Don't fold strings constants for now since this causes problems with
986 // checkpointing.
987 if (input_node->attr().count("dtype") == 0 ||
988 input_node->attr().at("dtype").type() == DT_STRING) {
989 return false;
990 }
991 // Special case: If a Merge node has at least one constant input that
992 // does not depend on a control input, we can fold it.
993 merge_has_constant_input |= !HasControlInputs(*input_node);
994 } else if (!is_merge) {
995 return false;
996 }
997 }
998 if (is_merge && !merge_has_constant_input) return false;
999 if (disable_compressed_tensor_optimization_ &&
1000 (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)))
1001 return false;
1002
1003 // If we know the output shapes, make sure that the outputs are small enough
1004 // to materialize.
1005 if (properties != nullptr && properties->HasOutputProperties(node.name())) {
1006 const std::vector<OpInfo::TensorProperties>& input_props =
1007 properties->GetInputProperties(node.name());
1008 const std::vector<OpInfo::TensorProperties>& output_props =
1009 properties->GetOutputProperties(node.name());
1010 // Compute total size of inputs.
1011 int64_t input_size_bytes = 0;
1012 for (const auto& input_prop : input_props) {
1013 const PartialTensorShape input_shape(input_prop.shape());
1014 if (input_shape.IsFullyDefined()) {
1015 input_size_bytes +=
1016 input_shape.num_elements() * DataTypeSize(input_prop.dtype());
1017 }
1018 }
1019 for (const auto& output_prop : output_props) {
1020 PartialTensorShape output_shape;
1021 if (!PartialTensorShape::BuildPartialTensorShape(output_prop.shape(),
1022 &output_shape)
1023 .ok()) {
1024 return false;
1025 }
1026 if (output_shape.IsFullyDefined()) {
1027 const int64_t num_bytes =
1028 output_shape.num_elements() * DataTypeSize(output_prop.dtype());
1029 if (num_bytes > input_size_bytes && num_bytes > kMaxConstantSize) {
1030 // Do not fold nodes if the in-memory size of output is too large.
1031 // Notice that this is not exactly the same check used in
1032 // CreateNodeDef() where the actual encoded size is checked.
1033 return false;
1034 }
1035 }
1036 }
1037 }
1038
1039 return true;
1040 }
1041
MaybeFoldable(const NodeDef & node,const GraphProperties * properties) const1042 bool ConstantFolding::MaybeFoldable(const NodeDef& node,
1043 const GraphProperties* properties) const {
1044 // Skip constants, they're already folded
1045 if (IsConstant(node)) {
1046 return false;
1047 }
1048 // Don't fold stateful ops such as TruncatedNormal.
1049 if (!IsFreeOfSideEffect(node)) {
1050 return false;
1051 }
1052
1053 // Skips nodes that must be preserved except allowlisted nodes.
1054 if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
1055 nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1056 return false;
1057 }
1058
1059 // Skip control flow nodes, they can't be folded.
1060 if (ModifiesFrameInfo(node)) {
1061 return false;
1062 }
1063
1064 // Skips ops that don't benefit from folding.
1065 if (IsPlaceholder(node)) {
1066 return false;
1067 }
1068 // `FakeParam` op is used as a placeholder in If branch function. It doesn't
1069 // have a valid output when executed.
1070 if (IsFakeParam(node)) {
1071 return false;
1072 }
1073
1074 if (node.op() == "AccumulateNV2") {
1075 return false;
1076 }
1077 // Removing LoopCond nodes can screw up the partitioner.
1078 if (node.op() == "LoopCond") {
1079 return false;
1080 }
1081
1082 if (!fold_quantization_emulation_ && IsQuantizationEmulation(node)) {
1083 return false;
1084 }
1085
1086 const string& op = node.op();
1087 if (op.find("Save") != string::npos || op.find("Restore") != string::npos ||
1088 op.find("Reader") != string::npos) {
1089 return false;
1090 }
1091 if (op.find("Quantized") != string::npos || absl::StartsWith(op, "Sparse")) {
1092 return false;
1093 }
1094
1095 // Don't fold nodes that contain TPU attributes.
1096 // TODO(rmlarsen): We should be able to fold many of these nodes as long as we
1097 // properly forward custom attributes, b/119051778.
1098 if (HasTPUAttributes(node)) {
1099 return false;
1100 }
1101
1102 const OpDef* op_def = nullptr;
1103 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
1104 if (!status.ok()) {
1105 return false;
1106 }
1107 // Don't fold ops without outputs.
1108 if (op_def->output_arg_size() == 0) {
1109 return false;
1110 }
1111 // Don't fold DT_VARIANT outputs as this can cause problems with XLA compile.
1112 // TODO(rmlarsen): Only do this for XLA_* devices.
1113 for (const OpDef::ArgDef& output_arg : op_def->output_arg()) {
1114 if (output_arg.type() == DT_VARIANT) {
1115 return false;
1116 }
1117 }
1118
1119 // Don't fold nodes that have no outgoing edges except allowlisted nodes.
1120 // Such nodes could be introduced by an earlier constant folding pass and are
1121 // preserved in case users want to fetch their values; re-processing them
1122 // would lead to an error of adding a duplicated node to graph.
1123 const auto& outputs = node_map_->GetOutputs(node.name());
1124 if (outputs.empty() &&
1125 nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) {
1126 return false;
1127 }
1128 return true;
1129 }
1130
1131 namespace {
1132
1133 #define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
1134 case DTYPE: \
1135 t->add_##NAME##_val(static_cast<TYPE>(value)); \
1136 break;
1137
CreateConstantTensorAttrValue(DataType type,double value,const TensorShapeProto & shape,AttrValue * attr_tensor)1138 Status CreateConstantTensorAttrValue(DataType type, double value,
1139 const TensorShapeProto& shape,
1140 AttrValue* attr_tensor) {
1141 TensorProto* t = attr_tensor->mutable_tensor();
1142 t->set_dtype(type);
1143 *t->mutable_tensor_shape() = shape;
1144 switch (type) {
1145 case DT_HALF:
1146 t->add_half_val(
1147 Eigen::numext::bit_cast<uint16>(static_cast<Eigen::half>(value)));
1148 break;
1149 case DT_BFLOAT16:
1150 t->add_half_val(
1151 Eigen::numext::bit_cast<uint16>(static_cast<bfloat16>(value)));
1152 break;
1153 SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
1154 SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
1155 SET_TENSOR_VAL_CASE(DT_INT64, int64_t, int64);
1156 SET_TENSOR_VAL_CASE(DT_UINT64, int64_t, int64);
1157 SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
1158 SET_TENSOR_VAL_CASE(DT_UINT32, int32, int);
1159 SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
1160 SET_TENSOR_VAL_CASE(DT_UINT16, int32, int);
1161 SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
1162 SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
1163 SET_TENSOR_VAL_CASE(DT_QINT32, int32, int);
1164 SET_TENSOR_VAL_CASE(DT_QINT16, int32, int);
1165 SET_TENSOR_VAL_CASE(DT_QUINT16, int32, int);
1166 SET_TENSOR_VAL_CASE(DT_QINT8, int32, int);
1167 SET_TENSOR_VAL_CASE(DT_QUINT8, int32, int);
1168 SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
1169 default:
1170 return errors::InvalidArgument(
1171 "Unsupported type in CreateConstantTensorAttrValue: ",
1172 DataTypeString(type));
1173 }
1174 return OkStatus();
1175 }
1176
1177 #undef SET_TENSOR_CAL_CASE
1178
GetDataTypeFromNodeOrProps(const NodeDef & node,const GraphProperties & properties)1179 DataType GetDataTypeFromNodeOrProps(const NodeDef& node,
1180 const GraphProperties& properties) {
1181 DataType dtype = DT_INVALID;
1182 if (node.attr().count("T") == 1) {
1183 dtype = node.attr().at("T").type();
1184 } else if (node.attr().count("dtype") == 1) {
1185 dtype = node.attr().at("dtype").type();
1186 } else if (IsLogicalOr(node) || IsLogicalAnd(node)) {
1187 dtype = DT_BOOL;
1188 } else {
1189 auto output_props = properties.GetOutputProperties(node.name());
1190 if (!output_props.empty()) {
1191 dtype = output_props[0].dtype();
1192 }
1193 }
1194 return dtype;
1195 }
1196
1197 // Checks whether the shape of the const input of the Mul op is valid to perform
1198 // the MulConvPushDown optimization.
IsValidConstShapeForMulConvPushDown(const string & data_format,const TensorShapeProto & filter_shape,const TensorShapeProto & mul_const_input_shape)1199 bool IsValidConstShapeForMulConvPushDown(
1200 const string& data_format, const TensorShapeProto& filter_shape,
1201 const TensorShapeProto& mul_const_input_shape) {
1202 // If the const is a scalar, or it has fewer or same number of dimensions
1203 // than the filter and it only has single element, the optimization should
1204 // work.
1205 if (mul_const_input_shape.dim_size() <=
1206 static_cast<int>(data_format.size()) &&
1207 TensorShape(mul_const_input_shape).num_elements() == 1) {
1208 return true;
1209 }
1210
1211 // Otherwise, check the eligibility according to data format.
1212 if (data_format == "NHWC" || data_format == "NDHWC") {
1213 TensorShapeProto new_filter_shape;
1214 if (!ShapeAfterBroadcast(filter_shape, mul_const_input_shape,
1215 &new_filter_shape)) {
1216 return false;
1217 }
1218 if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) {
1219 return false;
1220 }
1221 // Only the last dimension could be larger than one, since broadcasting over
1222 // the last dimension (the output channel) will result in invalid filter.
1223 for (int i = 0; i < mul_const_input_shape.dim_size() - 1; ++i) {
1224 if (mul_const_input_shape.dim(i).size() > 1) return false;
1225 }
1226 return true;
1227 } else if (data_format == "NCHW" || data_format == "NCDHW") {
1228 // TODO(laigd): support NCHW and NCDHW (b/111214513).
1229 return false;
1230 }
1231 return false;
1232 }
1233
1234 } // namespace
1235
1236 // static
CreateNodeDef(const string & name,const TensorValue & tensor,NodeDef * node,size_t original_size)1237 Status ConstantFolding::CreateNodeDef(const string& name,
1238 const TensorValue& tensor, NodeDef* node,
1239 size_t original_size) {
1240 node->set_name(name);
1241 node->set_op("Const");
1242
1243 AttrValue attr_type;
1244 attr_type.set_type(tensor->dtype());
1245 node->mutable_attr()->insert({"dtype", attr_type});
1246
1247 AttrValue attr_tensor;
1248 TensorProto* t = attr_tensor.mutable_tensor();
1249 bool optimized = false;
1250 size_t encoded_size;
1251 // Use the packed representation whenever possible to avoid generating large
1252 // graphdefs. Moreover, avoid repeating the last values if they're equal.
1253 if (tensor->NumElements() > 4) {
1254 #define POPULATE_TENSOR_PROTO(tensor, t, TYPE, FIELDTYPE) \
1255 { \
1256 const auto* val_ptr = tensor->flat<TYPE>().data(); \
1257 auto last = *val_ptr; \
1258 int64_t last_index = 0; \
1259 for (int64_t i = 0; i < tensor->NumElements(); ++i) { \
1260 TYPE cur = *val_ptr++; \
1261 if (PackedValuesNotEqual(cur, last)) { \
1262 last = cur; \
1263 last_index = i; \
1264 } \
1265 } \
1266 encoded_size = (last_index + 1) * sizeof(FIELDTYPE); \
1267 if (encoded_size < kint32max) { \
1268 optimized = true; \
1269 t->mutable_##FIELDTYPE##_val()->Reserve(last_index + 1); \
1270 const auto* src_ptr = tensor->flat<TYPE>().data(); \
1271 auto* dst_ptr = \
1272 t->mutable_##FIELDTYPE##_val()->AddNAlreadyReserved(last_index + 1); \
1273 std::copy(src_ptr, src_ptr + last_index + 1, dst_ptr); \
1274 } \
1275 } \
1276 break
1277
1278 switch (tensor->dtype()) {
1279 case DT_FLOAT:
1280 POPULATE_TENSOR_PROTO(tensor, t, float, float);
1281 case DT_DOUBLE:
1282 POPULATE_TENSOR_PROTO(tensor, t, double, double);
1283 case DT_INT64:
1284 POPULATE_TENSOR_PROTO(tensor, t, int64_t, int64);
1285 case DT_UINT64:
1286 POPULATE_TENSOR_PROTO(tensor, t, uint64, uint64);
1287 case DT_INT32:
1288 POPULATE_TENSOR_PROTO(tensor, t, int32_t, int);
1289 case DT_UINT32:
1290 POPULATE_TENSOR_PROTO(tensor, t, uint32, uint32);
1291 case DT_INT16:
1292 POPULATE_TENSOR_PROTO(tensor, t, int16_t, int);
1293 case DT_UINT16:
1294 POPULATE_TENSOR_PROTO(tensor, t, uint16, int);
1295 case DT_INT8:
1296 POPULATE_TENSOR_PROTO(tensor, t, int8_t, int);
1297 case DT_UINT8:
1298 POPULATE_TENSOR_PROTO(tensor, t, uint8, int);
1299 case DT_BOOL:
1300 POPULATE_TENSOR_PROTO(tensor, t, bool, bool);
1301 default:
1302 /* Do nothing. */
1303 break;
1304 }
1305 }
1306 if (optimized) {
1307 // Also specify type and shape.
1308 t->set_dtype(tensor->dtype());
1309 tensor->shape().AsProto(t->mutable_tensor_shape());
1310 } else {
1311 // DT_HALF, DT_BFLOAT16, DT_QINT32, DT_QINT16, DT_QUINT16, DT_QINT8,
1312 // DT_QUINT8
1313 tensor->AsProtoTensorContent(t);
1314 encoded_size = t->tensor_content().size();
1315 }
1316 node->mutable_attr()->insert({"value", attr_tensor});
1317
1318 if (encoded_size > original_size && encoded_size >= kMaxConstantSize) {
1319 return errors::InvalidArgument(
1320 strings::StrCat("Can't fold ", name, ", its size would be too large (",
1321 encoded_size, " >= ", kMaxConstantSize, " bytes)"));
1322 }
1323 return OkStatus();
1324 }
1325
EvaluateNode(const NodeDef & node,const TensorVector & inputs,TensorVector * output) const1326 Status ConstantFolding::EvaluateNode(const NodeDef& node,
1327 const TensorVector& inputs,
1328 TensorVector* output) const {
1329 return ::tensorflow::grappler::EvaluateNode(node, inputs, cpu_device_,
1330 resource_mgr_.get(), output);
1331 }
1332
EvaluateOneFoldable(const NodeDef & node,std::vector<NodeDef> * outputs,bool * result_too_large)1333 Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
1334 std::vector<NodeDef>* outputs,
1335 bool* result_too_large) {
1336 TensorVector inputs;
1337 TensorVector output_tensors;
1338 auto inputs_cleanup = gtl::MakeCleanup([&inputs, &output_tensors] {
1339 for (const auto& input : inputs) {
1340 delete input.tensor;
1341 }
1342 for (const auto& output : output_tensors) {
1343 if (output.tensor) {
1344 delete output.tensor;
1345 }
1346 }
1347 });
1348
1349 size_t total_inputs_size = 0;
1350 for (const auto& input : node.input()) {
1351 const TensorId input_tensor = ParseTensorName(input);
1352 if (input_tensor.index() < 0) {
1353 // Control dependency
1354 break;
1355 }
1356 const NodeDef* input_node = node_map_->GetNode(input);
1357 if (!IsReallyConstant(*input_node)) {
1358 return Status(error::INVALID_ARGUMENT,
1359 strings::StrCat("Can't fold ", node.name(), ", its ", input,
1360 " isn't constant"));
1361 }
1362 TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
1363 const TensorProto& raw_val = input_node->attr().at("value").tensor();
1364 if (raw_val.dtype() == DT_INVALID) {
1365 return Status(
1366 error::INVALID_ARGUMENT,
1367 strings::StrCat("A tensor in the input node, with TensorId of ",
1368 input_tensor.ToString(),
1369 " has a dtype of DT_INVALID."));
1370 }
1371 if (IsRefType(raw_val.dtype())) {
1372 return errors::InvalidArgument(
1373 "Not allowed to construct a tensor with reference dtype, got ",
1374 DataTypeString(raw_val.dtype()));
1375 }
1376 Tensor* value = new Tensor(raw_val.dtype(), raw_val.tensor_shape());
1377 if (!value->FromProto(raw_val)) {
1378 delete (value);
1379 return errors::InvalidArgument("Unable to make Tensor from proto for ",
1380 node.name(), " with shape ",
1381 raw_val.tensor_shape().DebugString());
1382 }
1383 inputs.emplace_back(value);
1384 total_inputs_size += value->TotalBytes();
1385 }
1386
1387 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, &output_tensors));
1388 if (output_tensors.empty()) {
1389 return Status(error::INVALID_ARGUMENT, "Expected at least one output.");
1390 }
1391
1392 outputs->resize(output_tensors.size());
1393 for (size_t i = 0; i < output_tensors.size(); i++) {
1394 string node_name = OptimizedNodeName(node, "-folded");
1395 if (output_tensors.size() > 1) {
1396 node_name = strings::StrCat(node_name, "-", i);
1397 }
1398 if (output_tensors[i].tensor) {
1399 Status s = CreateNodeDef(node_name, output_tensors[i], &outputs->at(i),
1400 total_inputs_size);
1401 if (!s.ok()) {
1402 *result_too_large = true;
1403 return s;
1404 }
1405 } else {
1406 // Create an empty NodeDef to identify dead outputs (e.g. the output of a
1407 // switch that's not selected by the switch predicate).
1408 outputs->at(i) = NodeDef();
1409 }
1410 }
1411 return OkStatus();
1412 }
1413
FoldMergeNode(NodeDef * node,GraphDef * output_graph)1414 Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
1415 // Merge nodes are special, in the sense that they execute as soon as one of
1416 // their input is ready. We can therefore fold a merge node iff it has at
1417 // least one constant input without control dependency.
1418 // We still need to ensure that the nodes in the fanin of the merge node are
1419 // scheduled. We'll therefore add a control dependency from the merge node
1420 // to the folded constant. We end up with:
1421 // * the merge node and its inputs are preserved as is
1422 // * a new constant node C1, driven by the merge node through a control
1423 // dependency, initialized to the value of the folded input
1424 // * a new constant node C2, driven by the merge node through a control
1425 // dependency, initialized to the index of the folded input
1426 // * the fanout of the merge nodes is rewired to be driven by either C1 or
1427 // C2.
1428 for (int input_index = 0; input_index < node->input_size(); ++input_index) {
1429 const auto& input = node->input(input_index);
1430 if (IsControlInput(input)) {
1431 // Try the next input.
1432 continue;
1433 }
1434 NodeDef* input_node = node_map_->GetNode(input);
1435 if (!IsReallyConstant(*input_node)) {
1436 continue;
1437 }
1438 bool valid_input = true;
1439 for (const string& fanin_of_input : input_node->input()) {
1440 if (IsControlInput(fanin_of_input)) {
1441 valid_input = false;
1442 break;
1443 }
1444 }
1445 if (!valid_input) {
1446 // Try the next input
1447 continue;
1448 }
1449
1450 string const_out_name = OptimizedNodeName(*node, "_const");
1451 string const_index_name = OptimizedNodeName(*node, "_index");
1452 if (node_map_->GetNode(const_out_name) ||
1453 node_map_->GetNode(const_index_name)) {
1454 // Intended name already exists.
1455 return errors::AlreadyExists(
1456 strings::StrCat(const_out_name, " or ", const_index_name,
1457 " already present in the graph"));
1458 }
1459
1460 NodeDef* const_out = output_graph->add_node();
1461 *const_out = *input_node;
1462 const_out->set_name(const_out_name);
1463 const_out->set_device(node->device());
1464 *const_out->add_input() = AsControlDependency(*node);
1465 node_map_->AddNode(const_out->name(), const_out);
1466 node_map_->AddOutput(node->name(), const_out->name());
1467
1468 NodeDef* const_index = output_graph->add_node();
1469 const_index->set_op("Const");
1470 Tensor index(DT_INT32, TensorShape({}));
1471 index.flat<int32>()(0) = input_index;
1472 (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
1473 index.AsProtoTensorContent(
1474 (*const_index->mutable_attr())["value"].mutable_tensor());
1475 const_index->set_name(const_index_name);
1476 const_index->set_device(node->device());
1477 *const_index->add_input() = AsControlDependency(*node);
1478 node_map_->AddNode(const_index->name(), const_index);
1479 node_map_->AddOutput(node->name(), const_index->name());
1480
1481 // We make a copy because we mutate the nodes.
1482 auto outputs = node_map_->GetOutputs(node->name());
1483 for (NodeDef* output : outputs) {
1484 for (int i = 0; i < output->input_size(); i++) {
1485 int port;
1486 string node_name = ParseNodeName(output->input(i), &port);
1487 if (node_name == node->name()) {
1488 if (port == 0) {
1489 *output->mutable_input(i) = const_out->name();
1490 node_map_->AddOutput(const_out->name(), output->name());
1491 } else if (port == 1) {
1492 *output->mutable_input(i) = const_index->name();
1493 node_map_->AddOutput(const_index->name(), output->name());
1494 } else {
1495 // This is a control dependency (or an invalid edge since the
1496 // merge node has only 2 outputs): preserve them.
1497 }
1498 }
1499 }
1500 }
1501 return OkStatus();
1502 }
1503 return OkStatus();
1504 }
1505
FoldNode(NodeDef * node,GraphDef * output_graph,bool * result_too_large)1506 Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
1507 bool* result_too_large) {
1508 *result_too_large = false;
1509 if (IsMerge(*node)) {
1510 return FoldMergeNode(node, output_graph);
1511 }
1512
1513 std::vector<NodeDef> const_nodes;
1514 TF_RETURN_IF_ERROR(
1515 EvaluateOneFoldable(*node, &const_nodes, result_too_large));
1516 VLOG(2) << "Folded node: " << SummarizeNodeDef(*node);
1517
1518 NodeDef* constant_output = nullptr;
1519 for (int i = 0, end = const_nodes.size(); i < end; i++) {
1520 NodeDef* const_node = &const_nodes[i];
1521 VLOG(3) << "Generated constant node: " << SummarizeNodeDef(*const_node);
1522 if (const_node->name().empty()) {
1523 // Dead output: we can't create a constant to encode its value, so we'll
1524 // just skip it. We'll preserve the edges that originate from that
1525 // output below to preserve the overall behavior of the graph wrt dead
1526 // edges.
1527 continue;
1528 }
1529
1530 // Returns `true` iff `const_node` already has control input named `input`.
1531 const auto is_duplicate_control_input = [&](const string& input) -> bool {
1532 auto it = absl::c_find(const_node->input(), input);
1533 return it != const_node->input().end();
1534 };
1535
1536 // Forward control dependencies.
1537 for (const string& input : node->input()) {
1538 // Forward control dependencies from folded node.
1539 if (IsControlInput(input)) {
1540 if (!is_duplicate_control_input(input)) {
1541 *const_node->add_input() = input;
1542 }
1543 }
1544
1545 // Forward control dependencies from constant inputs to folded node.
1546 if (!IsControlInput(input)) {
1547 NodeDef* input_node = node_map_->GetNode(input);
1548 for (const string& fanin_of_input : input_node->input()) {
1549 if (!is_duplicate_control_input(fanin_of_input)) {
1550 *const_node->add_input() = fanin_of_input;
1551 }
1552 }
1553 }
1554 }
1555
1556 // We rewrite the existing node if it only has a single output, and
1557 // create new nodes otherwise.
1558 if (const_nodes.size() == 1) {
1559 node->set_op("Const");
1560 // Note we need to clear the inputs in NodeMap before we clear the inputs
1561 // in the node, otherwise NodeMap would see empty inputs and effectively
1562 // does nothing.
1563 node_map_->RemoveInputs(node->name());
1564 node->clear_input();
1565 *node->mutable_input() = const_node->input();
1566 for (const auto& input : node->input()) {
1567 node_map_->AddOutput(NodeName(input), node->name());
1568 }
1569 *node->mutable_attr() = const_node->attr();
1570 break;
1571 } else {
1572 if (node_map_->GetNode(const_node->name())) {
1573 // Intended name already exists.
1574 return errors::AlreadyExists(strings::StrCat(
1575 const_node->name(), " already present in the graph"));
1576 }
1577 NodeDef* added_node = output_graph->add_node();
1578 *added_node = *const_node;
1579 added_node->set_device(node->device());
1580 node_map_->AddNode(added_node->name(), added_node);
1581 for (const auto& input : added_node->input()) {
1582 node_map_->AddOutput(NodeName(input), added_node->name());
1583 }
1584 // All the constant nodes encoding output values have the same control
1585 // dependencies (since these are the control dependencies of the node
1586 // we're trying to fold). Record one such constant node.
1587 constant_output = added_node;
1588 }
1589 }
1590
1591 if (const_nodes.size() > 1) {
1592 // We make a copy because we mutate the nodes.
1593 auto outputs = node_map_->GetOutputs(node->name());
1594 for (NodeDef* output : outputs) {
1595 for (int i = 0; i < output->input_size(); i++) {
1596 int port;
1597 string node_name = ParseNodeName(output->input(i), &port);
1598 if (node_name == node->name()) {
1599 if (port < 0) {
1600 // Propagate control dependencies if possible. If not, we'll just
1601 // preserve the existing control dependencies.
1602 if (constant_output != nullptr) {
1603 node_map_->UpdateInput(node_name, NodeName(output->input(i)),
1604 constant_output->name());
1605 *output->mutable_input(i) = AsControlDependency(*constant_output);
1606 }
1607 } else if (port < static_cast<int>(const_nodes.size()) &&
1608 !const_nodes[port].name().empty()) {
1609 // Replace alive outputs with the corresponding constant.
1610 node_map_->UpdateInput(output->name(), NodeName(output->input(i)),
1611 const_nodes[port].name());
1612 *output->mutable_input(i) = const_nodes[port].name();
1613 } else {
1614 // Leave this edge alone.
1615 VLOG(3) << "Preserving edge from " << node->name() << ":" << port
1616 << "[" << node->op() << "] to " << output->name() << ":"
1617 << i << "[" << output->op() << "]";
1618 }
1619 }
1620 }
1621 }
1622 outputs = node_map_->GetOutputs(node->name());
1623 if (outputs.empty() && has_fetch_ &&
1624 nodes_to_preserve_.find(node->name()) == nodes_to_preserve_.end()) {
1625 node_map_->RemoveInputs(node->name());
1626 node->clear_input();
1627 }
1628 }
1629 return OkStatus();
1630 }
1631
FoldGraph(const GraphProperties & properties,GraphDef * optimized_graph,absl::flat_hash_set<string> * nodes_to_not_simplify)1632 Status ConstantFolding::FoldGraph(
1633 const GraphProperties& properties, GraphDef* optimized_graph,
1634 absl::flat_hash_set<string>* nodes_to_not_simplify) {
1635 // We build a new optimized_graph by inserting the folded nodes into it, then
1636 // copy other nodes that might be needed at the end of this function.
1637 absl::flat_hash_set<string> processed_nodes;
1638 std::deque<NodeDef*> queue;
1639 for (int i = 0; i < graph_->node_size(); i++) {
1640 const NodeDef& node = graph_->node(i);
1641 if (IsFoldable(node, &properties) &&
1642 !nodes_to_not_simplify->count(node.name())) {
1643 queue.push_back(graph_->mutable_node(i));
1644 }
1645 }
1646 while (!queue.empty()) {
1647 NodeDef* node = queue.front();
1648 queue.pop_front();
1649 if (processed_nodes.count(node->name())) {
1650 continue;
1651 }
1652 // We need to record a copy of output nodes before FoldNode() modifies it.
1653 // We also need to ensure that the fanout is sorted deterministically.
1654 std::vector<NodeDef*> fanout =
1655 node_map_->GetOutputsOrderedByNodeName(node->name());
1656 bool result_too_large = false;
1657 Status s = FoldNode(node, optimized_graph, &result_too_large);
1658 processed_nodes.insert(node->name());
1659 if (!s.ok()) {
1660 VLOG(1) << "Failed to fold node " << node->DebugString()
1661 << "\nError message: " << s;
1662 if (result_too_large) {
1663 nodes_to_not_simplify->emplace(node->name());
1664 }
1665 } else {
1666 for (auto& fanout_node : fanout) {
1667 if (IsFoldable(*fanout_node, &properties) &&
1668 !nodes_to_not_simplify->count(fanout_node->name())) {
1669 queue.push_back(fanout_node);
1670 }
1671 }
1672 }
1673 }
1674
1675 // Delete the newly created nodes that don't feed anything.
1676 std::vector<int> nodes_to_delete;
1677 for (int i = 0; i < optimized_graph->node_size(); i++) {
1678 const auto& fanout = node_map_->GetOutputs(optimized_graph->node(i).name());
1679 if (fanout.empty()) nodes_to_delete.push_back(i);
1680 }
1681 EraseNodesFromGraph(std::move(nodes_to_delete), optimized_graph);
1682
1683 for (int i = 0; i < graph_->node_size(); ++i) {
1684 NodeDef* node = graph_->mutable_node(i);
1685 // If no fetch nodes is provided, we conservatively
1686 // move all nodes in the original graph to the output, in case users need
1687 // to fetch their values.
1688 const auto& fanout = node_map_->GetOutputs(node->name());
1689 if (!fanout.empty() || !has_fetch_ ||
1690 nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end()) {
1691 *(optimized_graph->add_node()) = std::move(*node);
1692 }
1693 }
1694 return OkStatus();
1695 }
1696
IsSimplifiableReshape(const NodeDef & node,const GraphProperties & properties) const1697 Status ConstantFolding::IsSimplifiableReshape(
1698 const NodeDef& node, const GraphProperties& properties) const {
1699 if (!IsReshape(node)) {
1700 return errors::Internal("Node ", node.name(), " is not a Reshape node");
1701 }
1702 if (2 > node.input_size()) {
1703 return errors::Internal("Node ", node.name(),
1704 " must have at most 2 inputs but has ",
1705 node.input_size());
1706 }
1707 const NodeDef* new_shape = node_map_->GetNode(node.input(1));
1708 if (!IsReallyConstant(*new_shape)) {
1709 return errors::Internal("Node ", node.name(), " has shape ",
1710 new_shape->DebugString(),
1711 " which is not a constant");
1712 }
1713 TensorVector outputs;
1714 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1715 for (const auto& output : outputs) {
1716 delete output.tensor;
1717 }
1718 });
1719
1720 Status s = EvaluateNode(*new_shape, TensorVector(), &outputs);
1721 if (!s.ok()) {
1722 return errors::Internal("Could not evaluate node ", node.name());
1723 }
1724 if (outputs.size() != 1) {
1725 return errors::Internal("Node ", node.name(),
1726 " must have exactly 1 output but has ",
1727 outputs.size());
1728 }
1729
1730 const std::vector<OpInfo::TensorProperties>& props =
1731 properties.GetInputProperties(node.name());
1732 if (props.empty()) {
1733 return errors::Internal("Node ", node.name(), " has no properties");
1734 }
1735 const OpInfo::TensorProperties& prop = props[0];
1736 if (prop.dtype() == DT_INVALID) {
1737 return errors::Internal("Node ", node.name(), " has property ",
1738 prop.DebugString(), " with invalid dtype");
1739 }
1740 const PartialTensorShape shape(prop.shape());
1741 if (!shape.IsFullyDefined()) {
1742 return errors::Internal("Node ", node.name(), " has property ",
1743 prop.DebugString(), " with shape ",
1744 shape.DebugString(), " which is not fully defined");
1745 }
1746
1747 PartialTensorShape new_dims;
1748 if (outputs[0]->dtype() == DT_INT32) {
1749 std::vector<int32> shp;
1750 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1751 int32_t dim = outputs[0]->flat<int32>()(i);
1752 shp.push_back(dim);
1753 }
1754 s = TensorShapeUtils::MakeShape(shp, &new_dims);
1755 if (!s.ok()) return s;
1756 } else {
1757 std::vector<int64_t> shp;
1758 for (int i = 0; i < outputs[0]->NumElements(); ++i) {
1759 int64_t dim = outputs[0]->flat<int64_t>()(i);
1760 shp.push_back(dim);
1761 }
1762 s = TensorShapeUtils::MakeShape(shp, &new_dims);
1763 if (!s.ok()) return s;
1764 }
1765
1766 if (!shape.IsCompatibleWith(new_dims)) {
1767 return errors::Internal("Expected shape ", shape.DebugString(),
1768 "to be compatible with ", new_dims.DebugString());
1769 }
1770
1771 return OkStatus();
1772 }
1773
1774 #define IS_VALUE_CASE(DTYPE, VALUE) \
1775 case DTYPE: \
1776 return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
1777 node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
1778
1779 #define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
1780 #define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
1781
IsOnes(const NodeDef & node) const1782 bool ConstantFolding::IsOnes(const NodeDef& node) const {
1783 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1784 return false;
1785 }
1786 if (IsOnesLike(node)) return true;
1787 if (IsZerosLike(node)) return false;
1788 if (node.op() == "Fill") {
1789 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1790 return values != nullptr && IsOnes(*values);
1791 }
1792 if (node.op() != "Const") return false;
1793 if (node.attr().count("dtype") == 0) return false;
1794 const auto dtype = node.attr().at("dtype").type();
1795 switch (dtype) {
1796 IS_ONES_CASE(DT_BOOL);
1797 IS_ONES_CASE(DT_HALF);
1798 IS_ONES_CASE(DT_BFLOAT16);
1799 IS_ONES_CASE(DT_FLOAT);
1800 IS_ONES_CASE(DT_DOUBLE);
1801 IS_ONES_CASE(DT_COMPLEX64);
1802 IS_ONES_CASE(DT_COMPLEX128);
1803 IS_ONES_CASE(DT_UINT8);
1804 IS_ONES_CASE(DT_INT8);
1805 IS_ONES_CASE(DT_UINT16);
1806 IS_ONES_CASE(DT_INT16);
1807 IS_ONES_CASE(DT_INT32);
1808 IS_ONES_CASE(DT_INT64);
1809 IS_ONES_CASE(DT_QINT32);
1810 IS_ONES_CASE(DT_QINT16);
1811 IS_ONES_CASE(DT_QUINT16);
1812 IS_ONES_CASE(DT_QINT8);
1813 IS_ONES_CASE(DT_QUINT8);
1814 default:
1815 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1816 return false;
1817 }
1818 return false;
1819 }
1820
IsZeros(const NodeDef & node) const1821 bool ConstantFolding::IsZeros(const NodeDef& node) const {
1822 if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
1823 return false;
1824 }
1825 if (IsOnesLike(node)) return false;
1826 if (IsZerosLike(node)) return true;
1827 if (node.op() == "Fill") {
1828 NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
1829 return values != nullptr && IsZeros(*values);
1830 }
1831 if (!IsConstant(node)) return false;
1832 if (node.attr().count("dtype") == 0) return false;
1833 const auto dtype = node.attr().at("dtype").type();
1834 switch (dtype) {
1835 IS_ZEROS_CASE(DT_BOOL);
1836 IS_ZEROS_CASE(DT_HALF);
1837 IS_ZEROS_CASE(DT_BFLOAT16);
1838 IS_ZEROS_CASE(DT_FLOAT);
1839 IS_ZEROS_CASE(DT_DOUBLE);
1840 IS_ZEROS_CASE(DT_COMPLEX64);
1841 IS_ZEROS_CASE(DT_COMPLEX128);
1842 IS_ZEROS_CASE(DT_UINT8);
1843 IS_ZEROS_CASE(DT_INT8);
1844 IS_ZEROS_CASE(DT_UINT16);
1845 IS_ZEROS_CASE(DT_INT16);
1846 IS_ZEROS_CASE(DT_INT32);
1847 IS_ZEROS_CASE(DT_INT64);
1848 IS_ZEROS_CASE(DT_QINT32);
1849 IS_ZEROS_CASE(DT_QINT16);
1850 IS_ZEROS_CASE(DT_QUINT16);
1851 IS_ZEROS_CASE(DT_QINT8);
1852 IS_ZEROS_CASE(DT_QUINT8);
1853 default:
1854 VLOG(1) << "Unsupported type " << DataTypeString(dtype);
1855 return false;
1856 }
1857 return false;
1858 }
1859
ReplaceOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1860 bool ConstantFolding::ReplaceOperationWithBroadcastTo(
1861 int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
1862 GraphDef* graph) {
1863 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1864 if (dtype == DT_INVALID) {
1865 return false;
1866 }
1867 const PartialTensorShape shape(
1868 properties.GetOutputProperties(node->name())[0].shape());
1869 if (!shape.IsFullyDefined()) {
1870 return false;
1871 }
1872 // Create constant node with shape.
1873 const string const_name = OptimizedNodeName(
1874 *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
1875 if (node_map_->GetNode(const_name) != nullptr) {
1876 return false;
1877 }
1878
1879 Tensor shape_t;
1880 if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) {
1881 return false;
1882 }
1883 NodeDef tmp;
1884 if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) {
1885 return false;
1886 }
1887 NodeDef* const_node = graph->add_node();
1888 const_node->Swap(&tmp);
1889 const_node->set_device(node->device());
1890 node_map_->AddNode(const_name, const_node);
1891 for (int i = 0; i < node->input_size(); ++i) {
1892 if (i != input_to_broadcast) {
1893 // Add a control input on the unused input.
1894 string ctrl_dep = AddControlDependency(NodeName(node->input(i)), graph,
1895 node_map_.get());
1896 *const_node->add_input() = ctrl_dep;
1897 node_map_->AddOutput(NodeName(ctrl_dep), const_name);
1898 }
1899 }
1900
1901 // Rewrite `node` in-place to BroadcastTo.
1902 node->set_op("BroadcastTo");
1903 EraseRegularNodeAttributes(node);
1904 (*node->mutable_attr())["T"].set_type(dtype);
1905 (*node->mutable_attr())["Tidx"].set_type(DT_INT32);
1906 // Set the designated input to BroadcastTo.
1907 node->mutable_input()->SwapElements(0, input_to_broadcast);
1908 // Keep all other inputs as control dependencies.
1909 for (int i = 1; i < node->input_size(); ++i) {
1910 if (IsControlInput(node->input(i))) {
1911 break;
1912 }
1913 const string ctrl_dep =
1914 AddControlDependency(node->input(i), graph, node_map_.get());
1915 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1916 node->set_input(i, ctrl_dep);
1917 }
1918 // Add the shape argument.
1919 *node->add_input() = const_node->name();
1920 node_map_->AddOutput(const_name, node->name());
1921 node->mutable_input()->SwapElements(1, node->input_size() - 1);
1922 return true;
1923 }
1924
1925 // Replace an operation with Identity.
ReplaceOperationWithIdentity(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1926 void ConstantFolding::ReplaceOperationWithIdentity(
1927 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1928 GraphDef* graph) {
1929 if (input_to_forward < 0 || input_to_forward >= node->input_size()) return;
1930 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1931 if (dtype == DT_INVALID) return;
1932
1933 node->set_op("Identity");
1934 EraseRegularNodeAttributes(node);
1935 (*node->mutable_attr())["T"].set_type(dtype);
1936 // Propagate the designated input through the identity.
1937 node->mutable_input()->SwapElements(0, input_to_forward);
1938 // Add all other inputs as control dependencies.
1939 for (int i = 1; i < node->input_size(); ++i) {
1940 if (IsControlInput(node->input(i))) {
1941 break;
1942 }
1943 const string ctrl_dep =
1944 AddControlDependency(node->input(i), graph, node_map_.get());
1945 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1946 node->set_input(i, ctrl_dep);
1947 }
1948 graph_modified_ = true;
1949 }
1950
ReplaceOperationWithSnapshot(int input_to_forward,const GraphProperties & properties,NodeDef * node,GraphDef * graph)1951 void ConstantFolding::ReplaceOperationWithSnapshot(
1952 int input_to_forward, const GraphProperties& properties, NodeDef* node,
1953 GraphDef* graph) {
1954 // If the graph contains no ops that mutate their inputs, we can
1955 // use Identity instead of Snapshot.
1956 if (!graph_contains_assign_or_inplace_op_) {
1957 ReplaceOperationWithIdentity(input_to_forward, properties, node, graph);
1958 return;
1959 }
1960
1961 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
1962 if (dtype == DT_INVALID) return;
1963
1964 node->set_op("Snapshot");
1965 EraseRegularNodeAttributes(node);
1966 (*node->mutable_attr())["T"].set_type(dtype);
1967 // Propagate the designated input through the Snapshot.
1968 node->mutable_input()->SwapElements(0, input_to_forward);
1969 // Add all other inputs as control dependencies.
1970 for (int i = 1; i < node->input_size(); ++i) {
1971 if (IsControlInput(node->input(i))) {
1972 break;
1973 }
1974 const string ctrl_dep =
1975 AddControlDependency(node->input(i), graph, node_map_.get());
1976 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
1977 node->set_input(i, ctrl_dep);
1978 }
1979 graph_modified_ = true;
1980 }
1981
1982 // Replace a node with NoOp. Change all inputs to control dependencies.
1983 // If the node has non-control outputs, no change will be performed.
ReplaceOperationWithNoOp(NodeDef * node,GraphProperties * properties,GraphDef * graph)1984 void ConstantFolding::ReplaceOperationWithNoOp(NodeDef* node,
1985 GraphProperties* properties,
1986 GraphDef* graph) {
1987 if (HasRegularOutputs(*node, *node_map_)) return;
1988 node->set_op("NoOp");
1989 EraseRegularNodeAttributes(node);
1990 EraseNodeOutputAttributes(node);
1991 // Erase attributes that describe output properties.
1992 properties->ClearOutputProperties(node->name());
1993 // Change all inputs to control dependencies.
1994 for (int i = 0; i < node->input_size(); ++i) {
1995 if (IsControlInput(node->input(i))) {
1996 break;
1997 }
1998 const string ctrl_dep =
1999 AddControlDependency(node->input(i), graph, node_map_.get());
2000 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2001 node->set_input(i, ctrl_dep);
2002 }
2003 DedupControlInputs(node);
2004 graph_modified_ = true;
2005 }
2006
ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,const GraphProperties & properties,NodeDef * node,GraphDef * graph)2007 void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
2008 int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
2009 GraphDef* graph) {
2010 if (!ReplaceOperationWithBroadcastTo(input_to_broadcast, properties, node,
2011 graph)) {
2012 return;
2013 }
2014 graph_modified_ = true;
2015 }
2016
ReplaceDivisionOfOnesByReciprocal(NodeDef * node,GraphDef * graph)2017 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
2018 GraphDef* graph) {
2019 node->set_op("Reciprocal");
2020 node->mutable_input()->SwapElements(0, 1);
2021 const string ctrl_dep =
2022 AddControlDependency(node->input(1), graph, node_map_.get());
2023 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
2024 node->set_input(1, ctrl_dep);
2025 graph_modified_ = true;
2026 }
2027
ReplaceSubtractionFromZeroByNegation(NodeDef * node,GraphDef * graph)2028 void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
2029 GraphDef* graph) {
2030 node->set_op("Neg");
2031 node->mutable_input()->SwapElements(0, 1);
2032 const string ctrl_dep =
2033 AddControlDependency(node->input(1), graph, node_map_.get());
2034 node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep);
2035 node->set_input(1, ctrl_dep);
2036 graph_modified_ = true;
2037 }
2038
ReplaceOperationWithConstantTensor(DataType dtype,TensorProto * value,NodeDef * node,GraphDef * graph)2039 Status ConstantFolding::ReplaceOperationWithConstantTensor(DataType dtype,
2040 TensorProto* value,
2041 NodeDef* node,
2042 GraphDef* graph) {
2043 if (dtype == DT_VARIANT) return OkStatus();
2044 node->set_op("Const");
2045 EraseRegularNodeAttributes(node);
2046 (*node->mutable_attr())["dtype"].set_type(dtype);
2047 (*node->mutable_attr())["value"].mutable_tensor()->Swap(value);
2048 // Convert all inputs to control dependencies.
2049 for (int i = 0; i < node->input_size(); ++i) {
2050 if (IsControlInput(node->input(i))) {
2051 break;
2052 }
2053 const string ctrl_dep =
2054 AddControlDependency(node->input(i), graph, node_map_.get());
2055 node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
2056 node->set_input(i, ctrl_dep);
2057 }
2058 DedupControlInputs(node);
2059 graph_modified_ = true;
2060 return OkStatus();
2061 }
2062
ReplaceOperationWithConstant(double value,const GraphProperties & properties,const TensorShapeProto & shape,NodeDef * node,GraphDef * graph)2063 Status ConstantFolding::ReplaceOperationWithConstant(
2064 double value, const GraphProperties& properties,
2065 const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
2066 const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
2067 if (dtype == DT_VARIANT) return OkStatus();
2068 AttrValue tensor_attr;
2069 Status s = CreateConstantTensorAttrValue(dtype, value, shape, &tensor_attr);
2070 if (!s.ok()) {
2071 // Fail gracefully without mutating the graph.
2072 VLOG(1) << "Failed to replace node " << node->name() << " of type "
2073 << DataTypeString(dtype) << " with constant tensor of value "
2074 << value;
2075 return OkStatus();
2076 }
2077 return ReplaceOperationWithConstantTensor(dtype, tensor_attr.mutable_tensor(),
2078 node, graph);
2079 }
2080
SimplifyGraph(GraphDef * optimized_graph,GraphProperties * properties,absl::flat_hash_set<string> * nodes_to_not_simplify)2081 Status ConstantFolding::SimplifyGraph(
2082 GraphDef* optimized_graph, GraphProperties* properties,
2083 absl::flat_hash_set<string>* nodes_to_not_simplify) {
2084 for (int i = 0; i < optimized_graph->node_size(); ++i) {
2085 NodeDef* node = optimized_graph->mutable_node(i);
2086 // TODO(lyandy): Move nodes to not simplify check into SimplifyNode and
2087 // generalize to only restrict certain simplifications.
2088 if (nodes_to_not_simplify->find(node->name()) ==
2089 nodes_to_not_simplify->end()) {
2090 if (HasTPUAttributes(*node)) {
2091 nodes_to_not_simplify->insert(node->name());
2092 continue;
2093 }
2094
2095 TF_RETURN_IF_ERROR(SimplifyNode(node, optimized_graph, properties));
2096 }
2097 }
2098 return OkStatus();
2099 }
2100
2101 #define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
2102 TF_RETURN_IF_ERROR(EXPR); \
2103 if (graph_modified_) return OkStatus()
2104
2105 #define SET_AND_RETURN_IF_MODIFIED(EXPR) \
2106 graph_modified_ = EXPR; \
2107 if (graph_modified_) return OkStatus()
2108
2109 #define RETURN_IF_MODIFIED(EXPR) \
2110 EXPR; \
2111 if (graph_modified_) return OkStatus()
2112
SimplifyNode(NodeDef * node,GraphDef * optimized_graph,GraphProperties * properties)2113 Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph,
2114 GraphProperties* properties) {
2115 bool graph_modified_cached = graph_modified_;
2116 graph_modified_ = false;
2117
2118 bool use_shape_info = properties->has_properties();
2119 RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
2120 RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
2121 *properties, use_shape_info, optimized_graph, node));
2122 RETURN_IF_MODIFIED(
2123 RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
2124 RETURN_IF_ERROR_OR_MODIFIED(
2125 RemoveReverse(*properties, use_shape_info, optimized_graph, node));
2126 RETURN_IF_ERROR_OR_MODIFIED(
2127 SimplifySlice(*properties, use_shape_info, optimized_graph, node));
2128 RETURN_IF_ERROR_OR_MODIFIED(
2129 SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
2130 RETURN_IF_ERROR_OR_MODIFIED(
2131 SimplifyTile(*properties, use_shape_info, optimized_graph, node));
2132 RETURN_IF_ERROR_OR_MODIFIED(
2133 SimplifyPad(*properties, use_shape_info, optimized_graph, node));
2134 RETURN_IF_MODIFIED(
2135 SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
2136 SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
2137 SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
2138 SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
2139 SET_AND_RETURN_IF_MODIFIED(
2140 SimplifyReduction(optimized_graph, *properties, node));
2141 SET_AND_RETURN_IF_MODIFIED(
2142 SimplifyReshape(*properties, use_shape_info, node));
2143 RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
2144 *properties, use_shape_info, optimized_graph, node));
2145 SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
2146 SET_AND_RETURN_IF_MODIFIED(
2147 ConstantPushDown(properties, optimized_graph, node));
2148 SET_AND_RETURN_IF_MODIFIED(
2149 MulConvPushDown(optimized_graph, node, *properties));
2150 SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
2151 SET_AND_RETURN_IF_MODIFIED(
2152 PartialAssocOpConstFolding(optimized_graph, properties, node));
2153 SET_AND_RETURN_IF_MODIFIED(
2154 MergeConcat(use_shape_info, properties, optimized_graph, node));
2155 SET_AND_RETURN_IF_MODIFIED(
2156 PartialConcatConstFolding(optimized_graph, properties, node));
2157 SET_AND_RETURN_IF_MODIFIED(
2158 ConstantPushDownBiasAdd(properties, optimized_graph, node));
2159 SET_AND_RETURN_IF_MODIFIED(SimplifyCase(optimized_graph, node));
2160 SET_AND_RETURN_IF_MODIFIED(
2161 SimplifySelect(*properties, optimized_graph, node));
2162 RETURN_IF_MODIFIED(
2163 RemoveRedundantVariableUpdates(properties, optimized_graph, node));
2164
2165 graph_modified_ = graph_modified_cached;
2166 return OkStatus();
2167 }
2168
RemoveSplitOrSplitV(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2169 void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
2170 GraphDef* optimized_graph,
2171 NodeDef* node) {
2172 if (node->attr().count("num_split") == 0) return;
2173 if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
2174 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
2175 }
2176 if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
2177 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2178 }
2179 }
2180
RemoveShuffleOrTranspose(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2181 Status ConstantFolding::RemoveShuffleOrTranspose(
2182 const GraphProperties& properties, bool use_shape_info,
2183 GraphDef* optimized_graph, NodeDef* node) {
2184 if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
2185 return OkStatus();
2186 Tensor permutation_tensor;
2187 if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
2188 properties.HasInputProperties(node->name())) {
2189 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2190 std::vector<int> permutation;
2191 for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
2192 if (permutation_tensor.dtype() == DT_INT64) {
2193 permutation.push_back(permutation_tensor.vec<int64_t>()(j));
2194 } else {
2195 permutation.push_back(permutation_tensor.vec<int>()(j));
2196 }
2197 }
2198 int permutation_size = permutation.size();
2199 if (permutation_size != shape.dim_size()) {
2200 // Number of elements in perm should be same as dim_size. Skip if not.
2201 return OkStatus();
2202 }
2203 // The node is replaceable iff
2204 // dim_size == 0 || all dims have size 1 ||
2205 // all dims with > 1 size are not permuted.
2206 bool replaceable = true;
2207 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2208 replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
2209 }
2210 if (replaceable) {
2211 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2212 }
2213 }
2214 return OkStatus();
2215 }
2216
RemoveRandomShuffle(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2217 void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
2218 bool use_shape_info,
2219 GraphDef* optimized_graph,
2220 NodeDef* node) {
2221 if (use_shape_info && IsRandomShuffle(*node) &&
2222 !properties.GetInputProperties(node->name()).empty()) {
2223 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2224 // The node is replaceable iff
2225 // unknown_rank == false && (dim_size == 0 || first dim is of size 1)
2226 if (!shape.unknown_rank() &&
2227 (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
2228 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2229 }
2230 }
2231 }
2232
RemoveReverse(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2233 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
2234 bool use_shape_info,
2235 GraphDef* optimized_graph,
2236 NodeDef* node) {
2237 if (!use_shape_info || node->op() != "ReverseV2") return OkStatus();
2238 Tensor axis;
2239 if (properties.HasInputProperties(node->name()) &&
2240 GetTensorFromConstNode(node->input(1), &axis)) {
2241 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2242 if (shape.unknown_rank()) return OkStatus();
2243 std::set<int> target_axes;
2244 for (int j = 0; j < axis.NumElements(); ++j) {
2245 // value of axis can be negative.
2246 if (axis.dtype() == DT_INT64) {
2247 target_axes.insert((axis.vec<int64_t>()(j) + shape.dim_size()) %
2248 shape.dim_size());
2249 } else {
2250 target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
2251 shape.dim_size());
2252 }
2253 }
2254
2255 // The node is replaceable iff
2256 // unknown_rank == false &&
2257 // (dim_size == 0 || all dims have size 1 ||
2258 // all dims with > 1 size are not in target_axes)
2259 bool replaceable = true;
2260 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2261 replaceable &=
2262 shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
2263 }
2264 if (replaceable) {
2265 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2266 }
2267 }
2268 return OkStatus();
2269 }
2270
SimplifySlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2271 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
2272 bool use_shape_info,
2273 GraphDef* optimized_graph,
2274 NodeDef* node) {
2275 if (!use_shape_info || !IsSlice(*node)) return OkStatus();
2276 Tensor begin;
2277 Tensor size;
2278 if (properties.HasInputProperties(node->name()) &&
2279 GetTensorFromConstNode(node->input(1), &begin) &&
2280 GetTensorFromConstNode(node->input(2), &size)) {
2281 const auto& input = properties.GetInputProperties(node->name())[0];
2282 // The node is replaceable iff unknown_rank == false &&
2283 // begin == 0 && (size == -1 || size == input_shape) for all dimensions
2284 bool replaceable = !input.shape().unknown_rank();
2285 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2286 if (begin.dtype() == DT_INT32) {
2287 replaceable &= begin.vec<int>()(j) == 0;
2288 } else {
2289 replaceable &= begin.vec<int64_t>()(j) == 0;
2290 }
2291 if (size.dtype() == DT_INT32) {
2292 replaceable &= (size.vec<int>()(j) == -1 ||
2293 size.vec<int>()(j) == input.shape().dim(j).size());
2294 } else {
2295 replaceable &= (size.vec<int64_t>()(j) == -1 ||
2296 size.vec<int64_t>()(j) == input.shape().dim(j).size());
2297 }
2298 }
2299 if (replaceable) {
2300 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2301 }
2302 }
2303 return OkStatus();
2304 }
2305
SimplifyStridedSlice(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2306 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
2307 bool use_shape_info,
2308 GraphDef* optimized_graph,
2309 NodeDef* node) {
2310 if (use_shape_info && IsStridedSlice(*node) &&
2311 properties.GetInputProperties(node->name()).size() == 4) {
2312 TF_RETURN_IF_ERROR(
2313 CheckAttrsExist(*node, {"new_axis_mask", "shrink_axis_mask"}));
2314 if (node->attr().at("new_axis_mask").i() != 0 ||
2315 node->attr().at("shrink_axis_mask").i() != 0) {
2316 // Skip nodes with new/shrink axis mask, since they involve dimension
2317 // changes.
2318 return OkStatus();
2319 }
2320 const auto& input = properties.GetInputProperties(node->name())[0];
2321 for (int j = 0; j < input.shape().dim_size(); ++j) {
2322 // Skip if input shape is not fully determined.
2323 if (input.shape().dim(j).size() < 0) {
2324 return OkStatus();
2325 }
2326 }
2327
2328 std::vector<Tensor> input_tensors(3);
2329 for (int i = 1; i < 4; ++i) {
2330 if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
2331 return OkStatus();
2332 }
2333 }
2334
2335 const Tensor& begin = input_tensors[0];
2336 const Tensor& end = input_tensors[1];
2337 const Tensor& strides = input_tensors[2];
2338
2339 TF_RETURN_IF_ERROR(
2340 CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
2341 int begin_mask = node->attr().at("begin_mask").i();
2342 int end_mask = node->attr().at("end_mask").i();
2343 std::set<int> expanded_ellipsis_indices;
2344 int ellipsis_index = -1;
2345 for (int j = 0; j < input.shape().dim_size(); ++j) {
2346 // find the ellipsis_mask. If not found, insert one in the end if
2347 // necessary.
2348 if (node->attr().at("ellipsis_mask").i() & 1 << j ||
2349 (ellipsis_index == -1 && j >= strides.NumElements())) {
2350 ellipsis_index = j;
2351 }
2352 // insert the indices that are immediately after ellipsis_index if
2353 // necessary.
2354 if (ellipsis_index != -1 &&
2355 input.shape().dim_size() >
2356 strides.NumElements() + j - ellipsis_index) {
2357 expanded_ellipsis_indices.insert(j);
2358 }
2359 }
2360
2361 // The node is replaceable iff unknown_rank == false &&
2362 // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
2363 // && strides == 1) for all dimensions.
2364 bool replaceable = !input.shape().unknown_rank();
2365 for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
2366 if (expanded_ellipsis_indices.find(j) !=
2367 expanded_ellipsis_indices.end()) {
2368 // ellipsis_mask is effective on current dimension.
2369 continue;
2370 }
2371 // when we have ellipsis_mask in between, input.shape().dim_size() will
2372 // be greater than strides.NumElements(), since we will insert
2373 // as many as expanded_ellipsis_indices.size() axes during computation.
2374 // We need to subtract this number from j.
2375 int i = j;
2376 int expanded_ellipsis_indices_size = expanded_ellipsis_indices.size();
2377 if (ellipsis_index != -1 &&
2378 j >= ellipsis_index + expanded_ellipsis_indices_size) {
2379 i = j - expanded_ellipsis_indices_size;
2380 }
2381 int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
2382 : begin.vec<int64_t>()(i);
2383 int e =
2384 end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64_t>()(i);
2385 int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
2386 : strides.vec<int64_t>()(i);
2387 replaceable &= (begin_mask & 1 << i || b == 0) &&
2388 (end_mask & 1 << i || e == input.shape().dim(j).size()) &&
2389 s == 1;
2390 }
2391 if (replaceable) {
2392 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2393 }
2394 }
2395 return OkStatus();
2396 }
2397
SimplifyTile(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2398 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
2399 bool use_shape_info,
2400 GraphDef* optimized_graph, NodeDef* node) {
2401 Tensor multiplies;
2402 if (use_shape_info && IsTile(*node) &&
2403 GetTensorFromConstNode(node->input(1), &multiplies)) {
2404 // The node is replaceable iff all values in multiplies are 1.
2405 bool replaceable = true;
2406 if (multiplies.dtype() == DT_INT32) {
2407 for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
2408 replaceable &= multiplies.vec<int>()(j) == 1;
2409 }
2410 } else {
2411 for (int j = 0; replaceable && j < multiplies.vec<int64_t>().size();
2412 ++j) {
2413 replaceable &= multiplies.vec<int64_t>()(j) == 1;
2414 }
2415 }
2416 if (replaceable) {
2417 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2418 }
2419 }
2420 return OkStatus();
2421 }
2422
SimplifyPad(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2423 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
2424 bool use_shape_info,
2425 GraphDef* optimized_graph, NodeDef* node) {
2426 if (!use_shape_info || !IsPad(*node)) return OkStatus();
2427
2428 Tensor paddings;
2429 if (GetTensorFromConstNode(node->input(1), &paddings)) {
2430 // The node is replaceable iff all values in paddings are 0.
2431 bool replaceable = true;
2432 if (paddings.dtype() == DT_INT32) {
2433 const auto flatten = paddings.flat<int32>();
2434 for (int j = 0; replaceable && j < flatten.size(); ++j) {
2435 replaceable &= flatten(j) == 0;
2436 }
2437 } else {
2438 const auto flatten = paddings.flat<int64_t>();
2439 for (int j = 0; replaceable && j < flatten.size(); ++j) {
2440 replaceable &= flatten(j) == 0;
2441 }
2442 }
2443 if (replaceable) {
2444 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2445 }
2446 }
2447 return OkStatus();
2448 }
2449
SimplifySqueeze(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2450 void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
2451 bool use_shape_info,
2452 GraphDef* optimized_graph,
2453 NodeDef* node) {
2454 if (use_shape_info && IsSqueeze(*node) &&
2455 !properties.GetInputProperties(node->name()).empty()) {
2456 // https://www.tensorflow.org/api_docs/python/tf/squeeze mentions it's
2457 // error to squeeze a dimension that is not 1, so we only need to check
2458 // whether the input has > 1 size for each dimension.
2459 const auto& shape = properties.GetInputProperties(node->name())[0].shape();
2460 // The node is replaceable iff
2461 // unknown_rank == false && (dim_size == 0 || all dims have size > 1)
2462 bool replaceable = !shape.unknown_rank();
2463 for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
2464 replaceable &= shape.dim(j).size() > 1;
2465 }
2466 if (replaceable) {
2467 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
2468 }
2469 }
2470 }
2471
SimplifyPack(GraphDef * optimized_graph,NodeDef * node)2472 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
2473 const string axis_node_name = OptimizedNodeName(*node, "_const_axis");
2474 if (!IsPack(*node) || NumNonControlInputs(*node) != 1 ||
2475 node_map_->NodeExists(axis_node_name)) {
2476 return false;
2477 }
2478
2479 // It's unsafe to add a control dependency on the feed node, because it might
2480 // have been never executed otherwiwise.
2481 if (feed_nodes_.find(NodeName(node->input(0))) != feed_nodes_.end()) {
2482 return false;
2483 }
2484
2485 // Create constant axis node.
2486 Tensor axis_t(DT_INT32, TensorShape({}));
2487 const int axis =
2488 node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
2489 NodeDef new_node;
2490 if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
2491 !CreateNodeDef(axis_node_name, TensorValue(&axis_t), &new_node).ok()) {
2492 return false;
2493 }
2494 NodeDef* axis_node = optimized_graph->add_node();
2495 *axis_node = std::move(new_node);
2496 axis_node->set_name(axis_node_name);
2497 node_map_->AddNode(axis_node->name(), axis_node);
2498 // Add a control dependency to make sure axis_node is in the right frame.
2499 const string ctrl_dep = ConstantFolding::AddControlDependency(
2500 node->input(0), optimized_graph, node_map_.get());
2501 axis_node->add_input(ctrl_dep);
2502 axis_node->set_device(node->device());
2503 node_map_->AddOutput(NodeName(node->input(0)), axis_node->name());
2504 node->set_op("ExpandDims");
2505 if (node->attr().count("axis") != 0) {
2506 node->mutable_attr()->erase("axis");
2507 }
2508 if (node->attr().count("N") != 0) {
2509 node->mutable_attr()->erase("N");
2510 }
2511 (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
2512 node->add_input(axis_node->name());
2513 node_map_->AddOutput(axis_node->name(), node->name());
2514 if (node->input_size() > 2) {
2515 node->mutable_input()->SwapElements(1, node->input_size() - 1);
2516 }
2517 return true;
2518 }
2519
SimplifyCase(GraphDef * optimized_graph,NodeDef * node)2520 bool ConstantFolding::SimplifyCase(GraphDef* optimized_graph, NodeDef* node) {
2521 if (node->op() != "Case") return false;
2522 const NodeDef* output_idx_node = node_map_->GetNode(node->input(0));
2523 if (output_idx_node == nullptr ||
2524 !CheckAttrExists(*output_idx_node, "value").ok()) {
2525 return false;
2526 }
2527 Tensor output_idx_t;
2528 if (!output_idx_t.FromProto(output_idx_node->attr().at("value").tensor()))
2529 return false;
2530 int output_idx = output_idx_t.scalar<int>()();
2531 const auto& func_list = node->attr().at("branches").list();
2532 if (output_idx < 0 || output_idx >= func_list.func_size()) return false;
2533 NodeDef call_node = *node;
2534 call_node.set_op("PartitionedCall");
2535 call_node.clear_input();
2536 for (int i = 1; i < node->input_size(); ++i) {
2537 call_node.add_input(node->input(i));
2538 }
2539 auto* new_func = (*call_node.mutable_attr())["f"].mutable_func();
2540 *new_func = func_list.func(output_idx);
2541
2542 // Move the output shape of the branch to _output_shapes if it is known.
2543 const auto& output_shape_list =
2544 (*node->mutable_attr())["output_shapes"].list();
2545 if (output_shape_list.shape_size() > output_idx) {
2546 TensorShapeProto* new_output_shape =
2547 (*call_node.mutable_attr())["_output_shapes"]
2548 .mutable_list()
2549 ->add_shape();
2550 *new_output_shape =
2551 std::move(node->attr().at("output_shapes").list().shape(output_idx));
2552 }
2553
2554 call_node.mutable_attr()->erase("output_shapes");
2555 call_node.mutable_attr()->erase("branches");
2556
2557 *node = std::move(call_node);
2558 return true;
2559 }
2560
SimplifySelect(const GraphProperties & properties,GraphDef * optimized_graph,NodeDef * node)2561 bool ConstantFolding::SimplifySelect(const GraphProperties& properties,
2562 GraphDef* optimized_graph, NodeDef* node) {
2563 if (!IsSelect(*node)) return false;
2564 const std::vector<OpInfo::TensorProperties>& input_props =
2565 properties.GetInputProperties(node->name());
2566 if (input_props.size() < 3) return false;
2567 const NodeDef* predicate_node = node_map_->GetNode(node->input(0));
2568 const bool is_all_true = IsOnes(*predicate_node);
2569 const bool is_all_false = IsZeros(*predicate_node);
2570 if (!is_all_true && !is_all_false) {
2571 return false;
2572 }
2573 const int live_input_idx = is_all_true ? 1 : 2;
2574 const int ignored_input_idx = is_all_true ? 2 : 1;
2575 const TensorShapeProto& predicate_shape = input_props[0].shape();
2576 const bool predicate_is_scalar =
2577 !predicate_shape.unknown_rank() && predicate_shape.dim_size() == 0;
2578 if (ShapesSymbolicallyEqual(input_props[1], input_props[2]) &&
2579 (ShapesSymbolicallyEqual(input_props[0], input_props[1]) ||
2580 predicate_is_scalar)) {
2581 // Replace node with Identity if no broadcasting is involved.
2582 node->set_op("Identity");
2583 *node->mutable_input(0) =
2584 AddControlDependency(node->input(0), optimized_graph, node_map_.get());
2585 *node->mutable_input(ignored_input_idx) = AddControlDependency(
2586 node->input(ignored_input_idx), optimized_graph, node_map_.get());
2587 node->mutable_input()->SwapElements(0, live_input_idx);
2588 } else if (!ReplaceOperationWithBroadcastTo(live_input_idx, properties, node,
2589 optimized_graph)) {
2590 return false;
2591 }
2592 DedupControlInputs(node);
2593 return true;
2594 }
2595
RemoveRedundantVariableUpdates(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)2596 void ConstantFolding::RemoveRedundantVariableUpdates(
2597 GraphProperties* properties, GraphDef* optimized_graph, NodeDef* node) {
2598 static const absl::flat_hash_set<string>* kVariableReadOps =
2599 new absl::flat_hash_set<string>{"AssignAddVariableOp",
2600 "AssignSubVariableOp",
2601 "AssignAdd",
2602 "AssignSub",
2603 "ScatterAdd",
2604 "ScatterSub",
2605 "ScatterMul",
2606 "ScatterDiv",
2607 "ScatterNdAdd",
2608 "ScatterNdSub",
2609 "ScatterNdMul",
2610 "ScatterNdDiv",
2611 "ResourceScatterAdd",
2612 "ResourceScatterSub",
2613 "ResourceScatterMul",
2614 "ResourceScatterDiv",
2615 "ResourceScatterNdAdd",
2616 "ResourceScatterNdSub",
2617 "ResourceScatterNdMul",
2618 "ResourceScatterNdDiv"};
2619 if (kVariableReadOps == nullptr ||
2620 kVariableReadOps->find(node->op()) == kVariableReadOps->end())
2621 return;
2622 const int value_index = absl::StrContains(node->op(), "Scatter") ? 2 : 1;
2623 const NodeDef* delta_node = node_map_->GetNode(node->input(value_index));
2624 if (delta_node == nullptr) return;
2625 const bool is_add_or_sub = absl::StrContains(node->op(), "Add") ||
2626 absl::StrContains(node->op(), "Sub");
2627 if ((is_add_or_sub && IsZeros(*delta_node)) ||
2628 (!is_add_or_sub && IsOnes(*delta_node))) {
2629 VLOG(1) << "Removing redundant variable update: " << node->DebugString();
2630 if (absl::StrContains(node->op(), "Variable") ||
2631 absl::StrContains(node->op(), "Resource")) {
2632 ReplaceOperationWithNoOp(node, properties, optimized_graph);
2633 } else {
2634 ReplaceOperationWithIdentity(0 /* input_to_forward */, *properties, node,
2635 optimized_graph);
2636 }
2637 }
2638 }
2639
MoveConstantsPastEnter(GraphDef * optimized_graph,NodeDef * node)2640 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
2641 NodeDef* node) {
2642 if (!IsEnter(*node) || node->input_size() == 0 ||
2643 node->attr().count("is_constant") == 0 ||
2644 !node->attr().at("is_constant").b()) {
2645 return false;
2646 }
2647 const string& node_name = node->name();
2648 const NodeDef* input = node_map_->GetNode(node->input(0));
2649 if (input == nullptr || !IsReallyConstant(*input) ||
2650 OptimizedNodeExists(*input, "_enter")) {
2651 return false;
2652 }
2653 // Find non-constant nodes that consume the output of *node.
2654 std::vector<NodeDef*> consumers;
2655 for (const NodeDef* fanout : node_map_->GetOutputs(node_name)) {
2656 if (!IsConstant(*fanout)) {
2657 for (int i = 0; i < fanout->input_size(); ++i) {
2658 if (fanout->input(i) == node_name) {
2659 consumers.push_back(const_cast<NodeDef*>(fanout));
2660 break;
2661 }
2662 }
2663 }
2664 }
2665 if (consumers.empty()) {
2666 return false;
2667 }
2668 graph_modified_ = true;
2669 NodeDef* new_node = optimized_graph->add_node();
2670 *new_node = *input;
2671 new_node->set_name(OptimizedNodeName(*input, "_enter"));
2672 new_node->set_device(node->device());
2673 new_node->clear_input();
2674 new_node->add_input(AsControlDependency(node_name));
2675 node_map_->AddNode(new_node->name(), new_node);
2676 node_map_->AddOutput(node_name, new_node->name());
2677 for (NodeDef* consumer : consumers) {
2678 for (int i = 0; i < consumer->input_size(); ++i) {
2679 if (NodeName(consumer->input(i)) == node_name) {
2680 node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
2681 consumer->set_input(i, new_node->name());
2682 }
2683 }
2684 }
2685 return true;
2686 }
2687
SimplifySwitch(GraphDef * optimized_graph,NodeDef * node)2688 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
2689 if (node->op() == "Switch" && node->input(0) == node->input(1) &&
2690 !OptimizedNodeExists(*node, "_const_false") &&
2691 !OptimizedNodeExists(*node, "_const_true")) {
2692 bool already_optimized = true;
2693 // If the optimization was already applied, the switch would have exactly
2694 // one Identity node consuming each of its outputs, each without any
2695 // non-control outputs.
2696 const auto& fanouts = node_map_->GetOutputs(node->name());
2697 if (fanouts.size() == 2) {
2698 for (const NodeDef* fanout : fanouts) {
2699 if ((!IsIdentity(*fanout) && !IsIdentityNSingleInput(*fanout)) ||
2700 HasRegularOutputs(*fanout, *node_map_)) {
2701 already_optimized = false;
2702 break;
2703 }
2704 }
2705 }
2706 Tensor false_t(DT_BOOL, TensorShape({}));
2707 Tensor true_t(DT_BOOL, TensorShape({}));
2708 // Make sure we don't proceed if this switch node was already optimized.
2709 if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() &&
2710 SetTensorValue(DT_BOOL, false, &false_t).ok()) {
2711 // Copy the set of consumers of the switch as they will be manipulated
2712 // below.
2713 std::vector<NodeDef*> consumers =
2714 node_map_->GetOutputsOrderedByNodeName(node->name());
2715 // Create constant false & true nodes.
2716 NodeDef tmp_false_node;
2717 tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
2718 if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
2719 &tmp_false_node)
2720 .ok()) {
2721 return false;
2722 }
2723 tmp_false_node.set_device(node->device());
2724 NodeDef tmp_true_node;
2725 tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
2726 if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
2727 &tmp_true_node)
2728 .ok()) {
2729 return false;
2730 }
2731 tmp_true_node.set_device(node->device());
2732
2733 // Add const nodes to graph.
2734 NodeDef* false_node = optimized_graph->add_node();
2735 false_node->Swap(&tmp_false_node);
2736 NodeDef* true_node = optimized_graph->add_node();
2737 true_node->Swap(&tmp_true_node);
2738
2739 // Add controls from the switch ports to the constants, and connect the
2740 // constants to the original switch outputs.
2741 const string false_port = node->name();
2742 const string true_port = strings::StrCat(node->name(), ":1");
2743 const string false_ctrl_dep =
2744 AddControlDependency(false_port, optimized_graph, node_map_.get());
2745 false_node->add_input(false_ctrl_dep);
2746 const string true_ctrl_dep =
2747 AddControlDependency(true_port, optimized_graph, node_map_.get());
2748 true_node->add_input(true_ctrl_dep);
2749
2750 node_map_->AddNode(false_node->name(), false_node);
2751 node_map_->AddNode(true_node->name(), true_node);
2752 node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name());
2753 node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name());
2754
2755 for (NodeDef* consumer : consumers) {
2756 for (int i = 0; i < consumer->input_size(); ++i) {
2757 const string& input = consumer->input(i);
2758 if (input == false_port) {
2759 consumer->set_input(i, false_node->name());
2760 node_map_->UpdateInput(consumer->name(), false_port,
2761 false_node->name());
2762 } else if (input == true_port) {
2763 consumer->set_input(i, true_node->name());
2764 node_map_->UpdateInput(consumer->name(), true_port,
2765 true_node->name());
2766 }
2767 }
2768 }
2769 return true;
2770 }
2771 }
2772 return false;
2773 }
2774
IsReductionWithConstantIndices(const NodeDef & node,bool * indices_is_empty) const2775 bool ConstantFolding::IsReductionWithConstantIndices(
2776 const NodeDef& node, bool* indices_is_empty) const {
2777 // Ensure its an appropriate Reduce node.
2778 if (!IsReduction(node) || node.input_size() < 2) {
2779 return false;
2780 }
2781 // Ensure that the axes to reduce by are constant.
2782 NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
2783 if (!IsReallyConstant(*reductions_indices) ||
2784 !reductions_indices->attr().count("value")) {
2785 return false;
2786 }
2787 const TensorShapeProto& reduction_indices_shape =
2788 reductions_indices->attr().at("value").tensor().tensor_shape();
2789 *indices_is_empty = TensorShape(reduction_indices_shape).num_elements() == 0;
2790 return true;
2791 }
2792
IsReductionCandidateForSimplification(const NodeDef & node,const GraphProperties & properties,TensorShapeProto * input_tensor_shape,TensorShapeProto * output_tensor_shape,bool * is_single_element_op) const2793 bool ConstantFolding::IsReductionCandidateForSimplification(
2794 const NodeDef& node, const GraphProperties& properties,
2795 TensorShapeProto* input_tensor_shape, TensorShapeProto* output_tensor_shape,
2796 bool* is_single_element_op) const {
2797 // Get the properties of the input & output tensors and check if they both
2798 // contain a single element.
2799 if (!properties.HasInputProperties(node.name()) ||
2800 !properties.HasOutputProperties(node.name())) {
2801 return false;
2802 }
2803 const auto& input_props = properties.GetInputProperties(node.name())[0];
2804 const auto& output_props = properties.GetOutputProperties(node.name())[0];
2805 if (!input_props.has_shape() || input_props.shape().unknown_rank() ||
2806 !output_props.has_shape() || output_props.shape().unknown_rank()) {
2807 return false;
2808 }
2809 *input_tensor_shape = input_props.shape();
2810 *output_tensor_shape = output_props.shape();
2811 for (int i = 0; i < input_tensor_shape->dim_size(); ++i) {
2812 if (input_tensor_shape->dim(i).size() < 0) {
2813 return false;
2814 }
2815 }
2816 for (int i = 0; i < output_tensor_shape->dim_size(); ++i) {
2817 if (output_tensor_shape->dim(i).size() < 0) {
2818 return false;
2819 }
2820 }
2821 const int input_num_elements =
2822 TensorShape(*input_tensor_shape).num_elements();
2823 const int output_num_elements =
2824 TensorShape(*output_tensor_shape).num_elements();
2825 *is_single_element_op = input_num_elements == 1 && output_num_elements == 1;
2826
2827 return true;
2828 }
2829
IsReductionSimplifiableToIdentity(const NodeDef & node,const TensorShapeProto & input_shape,bool keep_dims,const TensorVector & reduction_indices_vector) const2830 bool ConstantFolding::IsReductionSimplifiableToIdentity(
2831 const NodeDef& node, const TensorShapeProto& input_shape, bool keep_dims,
2832 const TensorVector& reduction_indices_vector) const {
2833 int output_size = reduction_indices_vector[0]->NumElements();
2834 if (output_size == 0) {
2835 return true;
2836 }
2837
2838 if (!keep_dims) {
2839 return false;
2840 }
2841 bool simplifiable = true;
2842 for (int i = 0; i < output_size; ++i) {
2843 int64_t dim;
2844 if (reduction_indices_vector[0]->dtype() == DT_INT32) {
2845 dim = reduction_indices_vector[0]->flat<int32>()(i);
2846 } else {
2847 dim = reduction_indices_vector[0]->flat<int64_t>()(i);
2848 }
2849 if (dim < 0) {
2850 dim += input_shape.dim_size();
2851 }
2852 if (dim < 0 || dim >= input_shape.dim_size() ||
2853 input_shape.dim(dim).size() != 1) {
2854 simplifiable = false;
2855 break;
2856 }
2857 }
2858 return simplifiable;
2859 }
2860
ReplaceReductionWithIdentity(NodeDef * node) const2861 bool ConstantFolding::ReplaceReductionWithIdentity(NodeDef* node) const {
2862 // Replace the reduction node with an identity node, that can be further
2863 // optimized by other passes.
2864 DataType output_type;
2865 if (node->attr().count("T") != 0) {
2866 output_type = node->attr().at("T").type();
2867 } else if (IsAny(*node) || IsAll(*node)) {
2868 output_type = DT_BOOL;
2869 } else {
2870 return false;
2871 }
2872 node->set_op("Identity");
2873 EraseRegularNodeAttributes(node);
2874 (*node->mutable_attr())["T"].set_type(output_type);
2875 *node->mutable_input(1) = AsControlDependency(node->input(1));
2876 return true;
2877 }
2878
SimplifyReduction(GraphDef * optimized_graph,const GraphProperties & properties,NodeDef * node)2879 bool ConstantFolding::SimplifyReduction(GraphDef* optimized_graph,
2880 const GraphProperties& properties,
2881 NodeDef* node) {
2882 bool indices_is_empty = false;
2883 if (!IsReductionWithConstantIndices(*node, &indices_is_empty)) {
2884 return false;
2885 }
2886 if (indices_is_empty) {
2887 return ReplaceReductionWithIdentity(node);
2888 }
2889 bool is_single_element_op = false;
2890 TensorShapeProto input_tensor_shape, output_tensor_shape;
2891 if (!IsReductionCandidateForSimplification(
2892 *node, properties, &input_tensor_shape, &output_tensor_shape,
2893 &is_single_element_op)) {
2894 return false;
2895 }
2896
2897 // Get the reduction indices.
2898 string reduction_indices_input = node->input(1);
2899 NodeDef* reduction_indices = node_map_->GetNode(reduction_indices_input);
2900 TensorVector reduction_indices_vector;
2901 auto outputs_cleanup = gtl::MakeCleanup([&reduction_indices_vector] {
2902 for (const auto& out : reduction_indices_vector) {
2903 delete out.tensor;
2904 }
2905 });
2906 if (!EvaluateNode(*reduction_indices, TensorVector(),
2907 &reduction_indices_vector)
2908 .ok() ||
2909 reduction_indices_vector.size() != 1) {
2910 return false;
2911 }
2912
2913 bool keep_dims =
2914 node->attr().count("keep_dims") > 0 && node->attr().at("keep_dims").b();
2915 bool simplifiable_to_reshape =
2916 is_single_element_op && !keep_dims && (node->attr().count("T") > 0);
2917 bool simplifiable_to_identity = IsReductionSimplifiableToIdentity(
2918 *node, input_tensor_shape, keep_dims, reduction_indices_vector);
2919
2920 if (simplifiable_to_reshape) {
2921 // Const node to output shape.
2922 const int new_num_dimensions = output_tensor_shape.dim_size();
2923 Tensor tensor(DT_INT32, TensorShape({new_num_dimensions}));
2924 for (int i = 0; i < new_num_dimensions; i++) {
2925 tensor.flat<int>()(i) = 1;
2926 }
2927 TensorValue shape_value(&tensor);
2928 NodeDef* shape_node = optimized_graph->add_node();
2929 if (!CreateNodeDef(OptimizedNodeName(*node, "_shape_const"), shape_value,
2930 shape_node)
2931 .ok()) {
2932 return false;
2933 }
2934 shape_node->set_device(node->device());
2935 node_map_->AddNode(shape_node->name(), shape_node);
2936 // Control dependency to ensure shape_node is in the correct frame.
2937 shape_node->add_input(AsControlDependency(reduction_indices_input));
2938 node_map_->AddOutput(NodeName(reduction_indices_input), shape_node->name());
2939 // Optimize node to Reshape.
2940 node->set_op("Reshape");
2941 node_map_->UpdateInput(node->name(), node->input(1), shape_node->name());
2942 node->set_input(1, shape_node->name());
2943 node->mutable_attr()->erase("keep_dims");
2944 node->mutable_attr()->erase("Tidx");
2945 AttrValue attr_type_indices;
2946 attr_type_indices.set_type(DT_INT32);
2947 (*node->mutable_attr())["Tshape"] = attr_type_indices;
2948 return true;
2949 } else if (simplifiable_to_identity) {
2950 return ReplaceReductionWithIdentity(node);
2951 }
2952 return false;
2953 }
2954
SimplifyReshape(const GraphProperties & properties,bool use_shape_info,NodeDef * node)2955 bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
2956 bool use_shape_info, NodeDef* node) {
2957 if (!use_shape_info || node->attr().count("T") == 0 ||
2958 !IsSimplifiableReshape(*node, properties).ok()) {
2959 return false;
2960 }
2961 DataType output_type = node->attr().at("T").type();
2962 node->set_op("Identity");
2963 EraseRegularNodeAttributes(node);
2964 (*node->mutable_attr())["T"].set_type(output_type);
2965 *node->mutable_input(1) = AsControlDependency(node->input(1));
2966 return true;
2967 }
2968
SimplifyArithmeticOperations(const GraphProperties & properties,bool use_shape_info,GraphDef * optimized_graph,NodeDef * node)2969 Status ConstantFolding::SimplifyArithmeticOperations(
2970 const GraphProperties& properties, bool use_shape_info,
2971 GraphDef* optimized_graph, NodeDef* node) {
2972 const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
2973 const bool is_matmul = IsAnyMatMul(*node);
2974 const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
2975 const bool is_sub = IsSub(*node);
2976 const bool is_any_div = IsAnyDiv(*node) && !IsFloorDiv(*node);
2977 // Simplify arithmetic operations with ones or zeros.
2978 if (use_shape_info &&
2979 (is_mul || is_matmul || is_add || is_sub || is_any_div) &&
2980 properties.HasInputProperties(node->name()) &&
2981 properties.HasOutputProperties(node->name())) {
2982 const NodeDef* x = node_map_->GetNode(node->input(0));
2983 const NodeDef* y = node_map_->GetNode(node->input(1));
2984 if (x == nullptr || y == nullptr) {
2985 return errors::InvalidArgument("Invalid inputs to node: ",
2986 node->DebugString());
2987 }
2988 const TensorShapeProto& output_shape =
2989 properties.GetOutputProperties(node->name())[0].shape();
2990
2991 // Simplify element-wise multiplication by ones or addition/subtraction
2992 // of zeros.
2993 const TensorShapeProto& y_shape =
2994 properties.GetInputProperties(node->name())[1].shape();
2995 const TensorShapeProto& x_shape =
2996 properties.GetInputProperties(node->name())[0].shape();
2997 const bool y_matches_output_shape =
2998 ShapesSymbolicallyEqual(output_shape, y_shape);
2999 const bool x_matches_output_shape =
3000 ShapesSymbolicallyEqual(output_shape, x_shape);
3001
3002 const bool x_is_zero = IsZeros(*x);
3003 const bool x_is_one = x_is_zero ? false : IsOnes(*x);
3004 if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
3005 // 1 * y = y or 0 + y = y.
3006 if (y_matches_output_shape) {
3007 ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
3008 } else if (x_matches_output_shape) {
3009 ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3010 optimized_graph);
3011 }
3012 return OkStatus();
3013 }
3014
3015 if (y_matches_output_shape && (is_sub && x_is_zero)) {
3016 // Replace 0 - y with Neg(y).
3017 ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
3018 return OkStatus();
3019 }
3020
3021 // Replace 1 / y with Reciprocal op.
3022 if (y_matches_output_shape && is_any_div && x_is_one) {
3023 TF_RETURN_IF_ERROR(CheckAttrExists(*node, "T"));
3024 DataType type = node->attr().at("T").type();
3025 if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
3026 ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
3027 return OkStatus();
3028 }
3029 }
3030
3031 const bool y_is_zero = IsZeros(*y);
3032 const bool y_is_one = y_is_zero ? false : IsOnes(*y);
3033 if (((is_mul || is_any_div) && y_is_one) ||
3034 ((is_add || is_sub) && y_is_zero)) {
3035 // x * 1 = x or x / 1 = x or x +/- 0 = x
3036 if (x_matches_output_shape) {
3037 ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
3038 } else if (y_matches_output_shape) {
3039 ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3040 optimized_graph);
3041 }
3042 return OkStatus();
3043 }
3044
3045 // x OR true = true OR y = true.
3046 const PartialTensorShape shp(output_shape);
3047 if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
3048 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3049 1, properties, output_shape, node, optimized_graph));
3050 return OkStatus();
3051 }
3052
3053 // Simplify multiplication and matmul by zeros.
3054 // Also optimize zeros divided by a tensor, but only if we are in
3055 // aggressive mode, since we might get rid of divisions by zero.
3056 const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE;
3057 bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive;
3058 if ((x_is_zero || y_is_zero) &&
3059 (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
3060 if (shp.IsFullyDefined()) {
3061 bool is_quantized = IsQuantizedMatMul(*node);
3062 TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
3063 0, properties, output_shape, node, optimized_graph));
3064 if (is_quantized && graph_modified_) {
3065 TF_RETURN_IF_ERROR(
3066 AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
3067 }
3068 return OkStatus();
3069 }
3070 // Even if an input shape is only partially known, we may known that it
3071 // matches the output shape and thus forward or broadcast the
3072 // corresponding zero input.
3073 if ((is_mul || is_any_div) && x_is_zero) {
3074 if (x_matches_output_shape) {
3075 ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
3076 } else if (y_matches_output_shape) {
3077 ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
3078 optimized_graph);
3079 }
3080 return OkStatus();
3081 } else if (is_mul && y_is_zero) {
3082 if (y_matches_output_shape) {
3083 ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
3084 } else if (x_matches_output_shape) {
3085 ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
3086 optimized_graph);
3087 }
3088 return OkStatus();
3089 }
3090 }
3091 }
3092 return OkStatus();
3093 }
3094
ReduceDivToReciprocalMul(GraphDef * optimized_graph,NodeDef * node)3095 bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph,
3096 NodeDef* node) {
3097 // Strength reduce floating point division by a constant Div(x, const) to
3098 // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn
3099 // will be constant folded to Mul(x, 1.0/const).
3100 if (node->input_size() >= 2 &&
3101 (IsDiv(*node) || IsRealDiv(*node) || IsXdivy(*node))) {
3102 const string& const_input = node->input(1);
3103 const NodeDef* denom = node_map_->GetNode(const_input);
3104 CHECK(denom != nullptr);
3105 if (!IsReallyConstant(*denom)) {
3106 return false;
3107 }
3108 if (node->attr().count("T") == 0) {
3109 return false;
3110 }
3111 DataType type = node->attr().at("T").type();
3112 // Skip integer division.
3113 if (IsDiv(*node) &&
3114 !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) {
3115 return false;
3116 }
3117 // Insert new reciprocal op and change node from Div to Mul.
3118 NodeDef* reciprocal_node = optimized_graph->add_node();
3119 reciprocal_node->set_name(OptimizedNodeName(*node, "_recip"));
3120 reciprocal_node->set_op("Reciprocal");
3121 reciprocal_node->set_device(node->device());
3122 reciprocal_node->add_input(const_input);
3123 (*reciprocal_node->mutable_attr())["T"].set_type(type);
3124
3125 // Re-wire inputs and outputs.
3126 if (IsXdivy(*node)) {
3127 node->set_op("MulNoNan");
3128 node->set_input(1, node->input(0));
3129 node->set_input(0, reciprocal_node->name());
3130 } else {
3131 node->set_op("Mul");
3132 node->set_input(1, reciprocal_node->name());
3133 }
3134 node_map_->AddNode(reciprocal_node->name(), reciprocal_node);
3135 node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name());
3136
3137 return true;
3138 }
3139
3140 return false;
3141 }
3142
PrepareConstantPushDown(const NodeDef & parent,const GraphProperties & properties,bool must_have_properties,ConstantPushDownContext * ctx) const3143 bool ConstantFolding::PrepareConstantPushDown(
3144 const NodeDef& parent, const GraphProperties& properties,
3145 bool must_have_properties, ConstantPushDownContext* ctx) const {
3146 if (ctx == nullptr || !has_fetch_ || NumNonControlInputs(parent) != 2) {
3147 return false;
3148 }
3149 NodeDef* left_child = node_map_->GetNode(parent.input(0));
3150 NodeDef* right_child = node_map_->GetNode(parent.input(1));
3151
3152 // Sanity check for missing children.
3153 if (left_child == nullptr || right_child == nullptr) {
3154 return false;
3155 }
3156
3157 ctx->left_child_is_const = IsReallyConstant(*left_child);
3158 ctx->right_child_is_const = IsReallyConstant(*right_child);
3159 ctx->op_child = ctx->left_child_is_const ? right_child : left_child;
3160 ctx->const_child = ctx->left_child_is_const ? left_child : right_child;
3161
3162 // Nothing to do unless the parent has a constant child node.
3163 if (!ctx->left_child_is_const && !ctx->right_child_is_const) {
3164 return false;
3165 }
3166
3167 // Don't move nodes across devices.
3168 if (parent.device() != ctx->op_child->device() ||
3169 parent.device() != ctx->const_child->device()) {
3170 return false;
3171 }
3172
3173 // Make sure that it is safe to change the value of the child node result.
3174 if (ctx->op_child->input_size() < 2 ||
3175 nodes_to_preserve_.find(ctx->op_child->name()) !=
3176 nodes_to_preserve_.end() ||
3177 NumNonControlOutputs(*ctx->op_child, *node_map_) > 1) {
3178 return false;
3179 }
3180
3181 // Don't apply reassociation to floating point types of low precision.
3182 // The danger of significant numerical changes is too high.
3183 if (!CheckAttrExists(parent, "T").ok()) return false;
3184 DataType dtype = parent.attr().at("T").type();
3185 if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
3186 return false;
3187 }
3188
3189 // Don't rewrite the tree if it might create cycles.
3190 // TODO(rmlarsen): Add back handling of control dependency from op to C.
3191 const auto& child_output = node_map_->GetOutputs(ctx->op_child->name());
3192 if (child_output.find(ctx->const_child) != child_output.end()) {
3193 return false;
3194 }
3195
3196 // Get leaf nodes.
3197 ctx->left_leaf = node_map_->GetNode(ctx->op_child->input(0));
3198 ctx->right_leaf = node_map_->GetNode(ctx->op_child->input(1));
3199 ctx->left_leaf_is_const = IsReallyConstant(*ctx->left_leaf);
3200 ctx->right_leaf_is_const = IsReallyConstant(*ctx->right_leaf);
3201
3202 if (ctx->left_leaf_is_const && ctx->right_leaf_is_const) {
3203 // Child is already foldable, leave it alone.
3204 return false;
3205 }
3206
3207 // Don't move nodes across devices.
3208 if (parent.device() != ctx->left_leaf->device() ||
3209 parent.device() != ctx->right_leaf->device()) {
3210 return false;
3211 }
3212
3213 // Get shape and type information.
3214 ctx->parent_input_props = &properties.GetInputProperties(parent.name());
3215 ctx->op_child_input_props =
3216 &properties.GetInputProperties(ctx->op_child->name());
3217 if (must_have_properties && (ctx->parent_input_props == nullptr ||
3218 ctx->parent_input_props->size() < 2 ||
3219 ctx->op_child_input_props == nullptr ||
3220 ctx->op_child_input_props->size() < 2)) {
3221 return false;
3222 }
3223
3224 VLOG(1) << "\n++++++++ PushDown for node " << parent.name() << ": "
3225 << parent.op() << "(" << left_child->op() << ", " << right_child->op()
3226 << ")";
3227
3228 return true;
3229 }
3230
ConstantPushDownBiasAdd(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3231 bool ConstantFolding::ConstantPushDownBiasAdd(GraphProperties* properties,
3232 GraphDef* optimized_graph,
3233 NodeDef* node) {
3234 // This implements constant push-down for BiasAdd. In the following "CV" is a
3235 // constant vector (tensor of rank 1), "V" is a (possibly) non-constant
3236 // vector, "CM" is a matrix (tensor of rank >= 2), "M" is a (possibly)
3237 // non-constant matrix, and "BA" is BiasAdd.
3238 // For a valid input graph, the following 4 rewrites are legal:
3239 //
3240 // 1) + +
3241 // / \ / \
3242 // BA CV -- > BA V
3243 // / \ / \
3244 // M V M CV
3245 //
3246 // 2) + +
3247 // / \ / \
3248 // BA CM -- > BA M
3249 // / \ / \
3250 // M V CM V
3251 //
3252 // 3) BA BA
3253 // / \ / \
3254 // + CV -- > + V
3255 // / \ / \
3256 // M V M CV
3257 //
3258 // 4) BA BA = parent
3259 // / \ / \
3260 // BA CV -- > BA V = children
3261 // / \ / \
3262 // M V M CV = leaves
3263 //
3264 // Cases 1 through 3 have additional sub-cases due to the symmetry of Add.
3265
3266 const bool parent_is_bias_add = IsBiasAdd(*node);
3267 if (!parent_is_bias_add && !IsAdd(*node)) return false;
3268 ConstantPushDownContext ctx;
3269 if (!PrepareConstantPushDown(*node, *properties,
3270 /*must_have_properties=*/true, &ctx)) {
3271 return false;
3272 }
3273 // Special case for BiasAdd: Since the left argument to BiasAdd must be rank
3274 // >= 2 and the leaves must be vectors, we cannot swap them.
3275 if (ctx.left_child_is_const && parent_is_bias_add) return false;
3276 const bool child_is_bias_add = IsBiasAdd(*ctx.op_child);
3277 if (!child_is_bias_add && !IsAdd(*ctx.op_child)) return false;
3278
3279 // Get properties to validate rank and dtype constraints.
3280 if (ctx.parent_input_props->empty() || ctx.op_child_input_props->empty() ||
3281 (*ctx.parent_input_props)[0].shape().unknown_rank() ||
3282 (*ctx.parent_input_props)[1].shape().unknown_rank() ||
3283 (*ctx.op_child_input_props)[0].shape().unknown_rank() ||
3284 (*ctx.op_child_input_props)[1].shape().unknown_rank()) {
3285 return false;
3286 }
3287
3288 // Now get the ranks and types of the 3 leaf nodes.
3289 const int left_leaf_rank = (*ctx.op_child_input_props)[0].shape().dim_size();
3290 const int right_leaf_rank = (*ctx.op_child_input_props)[1].shape().dim_size();
3291 // At least one leaf must be a vector.
3292 if (left_leaf_rank != 1 && right_leaf_rank != 1) return false;
3293 const int vector_idx = left_leaf_rank == 1 ? 0 : 1;
3294 const int matrix_idx = 1 - vector_idx;
3295
3296 const auto& vector_prop = (*ctx.op_child_input_props)[vector_idx];
3297 const int vector_rank = vector_idx == 0 ? left_leaf_rank : right_leaf_rank;
3298 if (vector_rank != 1) return false; // this should never happen.
3299 const DataType vector_type = vector_prop.dtype();
3300
3301 const auto& matrix_prop = (*ctx.op_child_input_props)[matrix_idx];
3302 const int matrix_rank = matrix_prop.shape().dim_size();
3303 const DataType matrix_type = matrix_prop.dtype();
3304
3305 const int const_idx = ctx.left_child_is_const ? 0 : 1;
3306 const auto& const_prop = (*ctx.parent_input_props)[const_idx];
3307 const int const_rank = const_prop.shape().dim_size();
3308 const DataType const_type = const_prop.dtype();
3309
3310 int input_to_swap = -1;
3311
3312 if (!parent_is_bias_add && child_is_bias_add && const_rank == matrix_rank &&
3313 const_type == matrix_type) {
3314 // Case 2:
3315 input_to_swap = matrix_idx;
3316 } else if (const_rank == 1 && const_type == vector_type) {
3317 // Case 1, 3, and, 4:
3318 input_to_swap = vector_idx;
3319 }
3320 if (input_to_swap == -1) return false;
3321 const NodeDef* leaf_to_swap =
3322 node_map_->GetNode(ctx.op_child->input(input_to_swap));
3323 if (IsConstant(*leaf_to_swap)) return false;
3324
3325 node_map_->UpdateInput(node->name(), node->input(const_idx),
3326 ctx.op_child->input(input_to_swap));
3327 node_map_->AddOutput(node->input(const_idx), ctx.op_child->name());
3328 if (ctx.op_child->input(input_to_swap) !=
3329 ctx.op_child->input(1 - input_to_swap)) {
3330 node_map_->RemoveOutput(ctx.op_child->input(input_to_swap),
3331 ctx.op_child->name());
3332 }
3333 std::swap(*node->mutable_input(const_idx),
3334 *ctx.op_child->mutable_input(input_to_swap));
3335 properties->ClearInputProperties(node->name());
3336 properties->ClearInputProperties(ctx.op_child->name());
3337
3338 return true;
3339 }
3340
ConstantPushDown(GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3341 bool ConstantFolding::ConstantPushDown(GraphProperties* properties,
3342 GraphDef* optimized_graph,
3343 NodeDef* node) {
3344 // Consider the transformation
3345 //
3346 // + + = parent
3347 // / \ / \
3348 // C + -- > X + = children
3349 // / \ / \
3350 // X Y C Y = leaves
3351 //
3352 // where C is constant, X is non-constant, Y may be constant or non-constant,
3353 // and '+' denotes an associative and commutative operator like addition or
3354 // multiplication. This optimization pushes constants down in the tree to
3355 // canonicalize it. Moreover, in cases where the child node has a second
3356 // constant input Y we will create a leaf node that can be folded, e.g.
3357 //
3358 // Add(C1, Add(C2, X)) -> Add(X, Add(C1, C2)) -> Add(X, C1 + C2)
3359 //
3360 // We also handle the non-commutative cases of subtraction and division
3361 // by rotating the tree locally, e.g.
3362 // Sub(C, Add(X, Y)) -> Sub(Sub(C, Y), X)
3363 // Mul(C, Div(X, Y)) -> Mul(X, Div(C, Y)).
3364
3365 // Get parent op type.
3366 const bool is_add = IsAdd(*node);
3367 const bool is_mul = IsMul(*node);
3368 const bool is_sub = IsSub(*node);
3369 const bool is_div = IsDiv(*node);
3370 if (!(is_add || is_sub || is_mul || is_div)) return false;
3371 const bool is_symmetric = is_add || is_mul;
3372
3373 ConstantPushDownContext ctx;
3374 if (!PrepareConstantPushDown(*node, *properties,
3375 /*must_have_properties=*/false, &ctx)) {
3376 return false;
3377 }
3378
3379 // Get child op type.
3380 const bool is_child_add = IsAdd(*ctx.op_child);
3381 const bool is_child_mul = IsMul(*ctx.op_child);
3382 const bool is_child_sub = IsSub(*ctx.op_child);
3383 const bool is_child_div = IsDiv(*ctx.op_child);
3384 const bool is_add_sub = (is_add || is_sub) && (is_child_add || is_child_sub);
3385 const bool is_mul_div = (is_mul || is_div) && (is_child_mul || is_child_div);
3386 if (!is_add_sub && !is_mul_div) {
3387 return false;
3388 }
3389 const bool is_child_symmetric = is_child_add || is_child_mul;
3390
3391 if (!CheckAttrExists(*node, "T").ok()) return false;
3392 DataType dtype = node->attr().at("T").type();
3393 if (!(is_symmetric && is_child_symmetric) &&
3394 !(DataTypeIsFloating(dtype) || DataTypeIsComplex(dtype))) {
3395 return false;
3396 }
3397
3398 const NodeDef* y_node =
3399 ctx.left_leaf_is_const ? ctx.left_leaf : ctx.right_leaf;
3400 if (!IsReallyConstant(*y_node) && !ctx.parent_input_props->empty() &&
3401 !ctx.op_child_input_props->empty()) {
3402 // If we know the shapes of the nodes being swapped, make sure we don't push
3403 // down a larger node and create more work by broadcasting earlier in the
3404 // expressions tree.
3405 const PartialTensorShape c_shape(
3406 (*ctx.parent_input_props)[ctx.left_child_is_const ? 0 : 1].shape());
3407 const PartialTensorShape x_shape(
3408 (*ctx.op_child_input_props)[ctx.left_leaf_is_const ? 0 : 1].shape());
3409
3410 if (c_shape.IsFullyDefined() && x_shape.IsFullyDefined() &&
3411 c_shape.num_elements() > x_shape.num_elements()) {
3412 return false;
3413 } else if (!c_shape.unknown_rank() && !x_shape.unknown_rank() &&
3414 c_shape.dims() > 0) {
3415 for (int idx = 0; idx < std::min(x_shape.dims(), c_shape.dims()); ++idx) {
3416 if (x_shape.dim_size(idx) >= 0 &&
3417 c_shape.dim_size(idx) > x_shape.dim_size(idx)) {
3418 return false;
3419 }
3420 }
3421 }
3422 }
3423
3424 // Get the node names corresponding to X, Y, and C.
3425 const string input_x =
3426 ctx.left_leaf_is_const ? ctx.op_child->input(1) : ctx.op_child->input(0);
3427 const string input_y = input_x == ctx.op_child->input(0)
3428 ? ctx.op_child->input(1)
3429 : ctx.op_child->input(0);
3430 const string input_c =
3431 ctx.left_child_is_const ? node->input(0) : node->input(1);
3432 const string input_op =
3433 ctx.left_child_is_const ? node->input(1) : node->input(0);
3434 VLOG(1) << "input_c = " << input_c << "\ninput_x = " << input_x;
3435
3436 // Now we have identified the nodes to swap, update the nodemap accordingly.
3437 node_map_->UpdateInput(node->name(), input_c, input_x);
3438 node_map_->AddOutput(input_c, ctx.op_child->name());
3439 if (input_x != input_y) {
3440 node_map_->RemoveOutput(input_x, ctx.op_child->name());
3441 }
3442 properties->ClearInputProperties(node->name());
3443 properties->ClearInputProperties(ctx.op_child->name());
3444
3445 if (is_symmetric && is_child_symmetric) {
3446 // Easy case (only commutative ops). We always write this as one of
3447 // +
3448 // / \
3449 // X +
3450 // / \
3451 // C Y
3452 node->set_input(0, input_x);
3453 node->set_input(1, input_op);
3454 ctx.op_child->set_input(0, input_c);
3455 ctx.op_child->set_input(1, input_y);
3456 } else {
3457 // More complicated case: When there are non-commutative operations like
3458 // subtractions or divisions involved, we may have to rotate the tree
3459 // and/or change op types. There are 6 non-trivial cases depending on
3460 // the effective generalized "sign" of each of the three terms C, Y, and X.
3461 // Here are the final trees we want to generate for those 6 cases:
3462 //
3463 // (CYX signs): ++- +-- -+- --+ +-+ -++
3464 //
3465 // - - - - + +
3466 // / \ / \ / \ / \ / \ / \
3467 // + X - X - X X + X - X -
3468 // / \ / \ / \ / \ / \ / \
3469 // C Y C Y Y C Y C C Y Y C
3470 //
3471
3472 // First, let's determine the effective sign of each term in the original
3473 // expression
3474 auto is_leaf_negated = [&](const bool is_right_leaf) -> bool {
3475 bool leaf_negated = !is_child_symmetric && is_right_leaf;
3476 bool child_negated = !is_symmetric && (ctx.left_child_is_const);
3477 return leaf_negated != child_negated;
3478 };
3479 const string symmetric_op = (is_add || is_sub) ? "Add" : "Mul";
3480 const string nonsymmetric_op = (is_add || is_sub) ? "Sub" : "Div";
3481 bool neg_c = !is_symmetric && !ctx.left_child_is_const;
3482 bool neg_x = is_leaf_negated(ctx.left_leaf_is_const);
3483 bool neg_y = is_leaf_negated(!ctx.left_leaf_is_const);
3484 // Rewrite the parent node.
3485 node->set_op((neg_x || (neg_c && neg_y)) ? nonsymmetric_op : symmetric_op);
3486 node->set_input(0, neg_x ? input_op : input_x);
3487 node->set_input(1, neg_x ? input_x : input_op);
3488 // Rewrite the child node.
3489 ctx.op_child->set_op(neg_c != neg_y ? nonsymmetric_op : symmetric_op);
3490 ctx.op_child->set_input(0, neg_c ? input_y : input_c);
3491 ctx.op_child->set_input(1, neg_c ? input_c : input_y);
3492 }
3493 return true;
3494 }
3495
MulConvPushDown(GraphDef * optimized_graph,NodeDef * node,const GraphProperties & properties)3496 bool ConstantFolding::MulConvPushDown(GraphDef* optimized_graph, NodeDef* node,
3497 const GraphProperties& properties) {
3498 // Push down multiplication on ConvND.
3499 // * ConvND
3500 // / \ / \
3501 // ConvND C2 -- > X *
3502 // / \ / \
3503 // X C1 C1 C2
3504 //
3505 // where C1 and C2 are constants and X is non-constant.
3506 //
3507 // TODO(rmlarsen): Use PrepareConstantPushDown() to simplify this code.
3508
3509 if (!IsAnyMul(*node) || NumNonControlInputs(*node) != 2) return false;
3510
3511 NodeDef* mul_left_child = node_map_->GetNode(node->input(0));
3512 NodeDef* mul_right_child = node_map_->GetNode(node->input(1));
3513 if (mul_left_child == nullptr || mul_right_child == nullptr) {
3514 return false;
3515 }
3516 // One child must be constant, and the second must be Conv op.
3517 const bool left_child_is_constant = IsReallyConstant(*mul_left_child);
3518 const bool right_child_is_constant = IsReallyConstant(*mul_right_child);
3519 if (!left_child_is_constant && !right_child_is_constant) {
3520 return false;
3521 }
3522 NodeDef* conv_node =
3523 left_child_is_constant ? mul_right_child : mul_left_child;
3524 if (!IsConv2D(*conv_node) && !IsConv3D(*conv_node)) {
3525 return false;
3526 }
3527 if (node->device() != mul_left_child->device() ||
3528 node->device() != mul_right_child->device()) {
3529 return false;
3530 }
3531
3532 // Make sure that it is safe to change the value of the convolution
3533 // output.
3534 if (conv_node->input_size() < 2 ||
3535 NumNonControlOutputs(*conv_node, *node_map_) > 1 ||
3536 nodes_to_preserve_.find(conv_node->name()) != nodes_to_preserve_.end()) {
3537 return false;
3538 }
3539
3540 // Identify the nodes to swap.
3541 NodeDef* conv_left_child = node_map_->GetNode(conv_node->input(0));
3542 NodeDef* conv_right_child = node_map_->GetNode(conv_node->input(1));
3543 const bool conv_left_is_constant = IsReallyConstant(*conv_left_child);
3544 const bool conv_right_is_constant = IsReallyConstant(*conv_right_child);
3545 if (!conv_left_is_constant && !conv_right_is_constant) {
3546 // At least one of the convolution inputs should be constant.
3547 return false;
3548 }
3549 if (conv_left_is_constant && conv_right_is_constant) {
3550 // Leverage regular constant folding to handle this.
3551 return false;
3552 }
3553 const auto& mul_props = properties.GetOutputProperties(node->name());
3554 const auto& conv_props = properties.GetOutputProperties(conv_node->name());
3555 if (mul_props.empty() || conv_props.empty()) {
3556 return false;
3557 }
3558 const auto& mul_shape = mul_props[0].shape();
3559 const auto& conv_shape = conv_props[0].shape();
3560 if (!ShapesSymbolicallyEqual(mul_shape, conv_shape)) {
3561 return false;
3562 }
3563
3564 const auto& input_props = properties.GetInputProperties(conv_node->name());
3565 if (input_props.size() < 2) {
3566 return false;
3567 }
3568 const auto& filter_shape = input_props[1].shape();
3569
3570 NodeDef* const_node =
3571 left_child_is_constant ? mul_left_child : mul_right_child;
3572 const auto& const_props = properties.GetOutputProperties(const_node->name());
3573 if (const_props.empty()) {
3574 return false;
3575 }
3576 const auto& const_shape = const_props[0].shape();
3577 if (!IsValidConstShapeForMulConvPushDown(
3578 conv_node->attr().at("data_format").s(), filter_shape, const_shape)) {
3579 return false;
3580 }
3581
3582 string mul_new_name = AddPrefixToNodeName("merged_input", conv_node->name());
3583 if (node_map_->NodeExists(mul_new_name)) {
3584 return false;
3585 }
3586 // Make sure we don't introduce loops in the graph by removing control
3587 // dependencies from the conv2d node to c2.
3588 string conv_const_input =
3589 conv_left_is_constant ? conv_node->input(0) : conv_node->input(1);
3590 if (MaybeRemoveControlInput(conv_node->name(), const_node, optimized_graph,
3591 node_map_.get())) {
3592 // Add a control dep from c1 to c2 to ensure c2 is in the right frame
3593 MaybeAddControlInput(conv_const_input, const_node, optimized_graph,
3594 node_map_.get());
3595 }
3596
3597 conv_node->set_name(node->name());
3598 node->set_name(mul_new_name);
3599 if (conv_left_is_constant) {
3600 node_map_->UpdateInput(conv_node->name(), node->input(0), mul_new_name);
3601 conv_node->set_input(0, mul_new_name);
3602 } else {
3603 node_map_->UpdateInput(conv_node->name(), node->input(1), mul_new_name);
3604 conv_node->set_input(1, mul_new_name);
3605 }
3606 NodeDef* conv_const_node =
3607 conv_left_is_constant ? conv_left_child : conv_right_child;
3608 if (left_child_is_constant) {
3609 node->set_input(1, conv_const_node->name());
3610 } else {
3611 node->set_input(0, conv_const_node->name());
3612 }
3613 node_map_->AddNode(mul_new_name, node);
3614
3615 return true;
3616 }
3617
PartialConstPropThroughIdentityN(NodeDef * node)3618 bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) {
3619 // Partial constant propagation through IdentityN.
3620 if (!(IsIdentityN(*node) || IsIdentityNSingleInput(*node)) ||
3621 !HasRegularInputs(*node))
3622 return false;
3623
3624 std::vector<int> inputs_to_forward;
3625 for (int input_idx = 0; input_idx < node->input_size(); ++input_idx) {
3626 const string& input = node->input(input_idx);
3627 if (IsControlInput(input)) {
3628 return false;
3629 }
3630 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3631 if (input_node == nullptr) {
3632 LOG(ERROR) << "Bad input: " << input;
3633 return false;
3634 }
3635 // Forward constant inputs to outputs and add a control dependency on
3636 // the IdentityN node.
3637 if (IsReallyConstant(*input_node)) {
3638 inputs_to_forward.push_back(input_idx);
3639 }
3640 }
3641 return ForwardInputs(node, inputs_to_forward);
3642 }
3643
PartialAssocOpConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3644 bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
3645 GraphProperties* properties,
3646 NodeDef* node) {
3647 // Partial constant folding for associative operators:
3648 // Split AddN/AccumulateNV2 to enable partial
3649 // folding of ops when more than one but not all inputs are constant.
3650 // For AddN and AccumulateNV2, we may furthermore reorder inputs, since
3651 // addition is commutative.
3652 if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
3653
3654 const int num_non_control_inputs = NumNonControlInputs(*node);
3655 if (num_non_control_inputs <= 2) return false;
3656 const int num_control_inputs = node->input_size() - num_non_control_inputs;
3657 std::vector<int> const_inputs;
3658 std::vector<int> nonconst_inputs;
3659 for (int i = 0; i < node->input_size(); ++i) {
3660 const string& input = node->input(i);
3661 const NodeDef* input_node = node_map_->GetNode(NodeName(input));
3662 if (input_node == nullptr) return false;
3663 if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
3664 const_inputs.push_back(i);
3665 } else {
3666 // Non-const and control inputs.
3667 nonconst_inputs.push_back(i);
3668 }
3669 }
3670 // Promote AccumulateNV2 with all constant inputs to AddN, since it is
3671 // a fake node that cannot be constant folded by itself.
3672 int const_inputs_size = const_inputs.size();
3673 if (const_inputs_size == num_non_control_inputs &&
3674 node->op() == "AccumulateNV2") {
3675 node->set_op("AddN");
3676 node->mutable_attr()->erase("shape");
3677 return true;
3678 }
3679 const string new_node_name = OptimizedNodeName(
3680 *node, strings::StrCat("_partial_split_", const_inputs_size));
3681 if (const_inputs_size > 1 && const_inputs_size < num_non_control_inputs &&
3682 !node_map_->NodeExists(new_node_name)) {
3683 NodeDef* added_node = optimized_graph->add_node();
3684 *added_node = *node;
3685 // Always use AddN for the constant node, since AccumulateNV2 is a fake
3686 // node that cannot be constant folded, since it does not have a kernel.
3687 added_node->set_op("AddN");
3688 added_node->mutable_attr()->erase("shape");
3689 added_node->set_name(new_node_name);
3690 node_map_->AddNode(added_node->name(), added_node);
3691 added_node->clear_input();
3692 for (int i : const_inputs) {
3693 added_node->add_input(node->input(i));
3694 node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
3695 added_node->name());
3696 }
3697
3698 // Overwrite the first const input with the added node.
3699 node->set_input(const_inputs[0], added_node->name());
3700 node_map_->AddOutput(added_node->name(), node->name());
3701 nonconst_inputs.push_back(const_inputs[0]);
3702 // Compact the remaining inputs to the original node.
3703 std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
3704 int idx = 0;
3705 for (int i : nonconst_inputs) {
3706 if (idx != i) {
3707 node->set_input(idx, node->input(i));
3708 }
3709 ++idx;
3710 }
3711 node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
3712 const_inputs.size() - 1);
3713 (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
3714 properties->ClearInputProperties(node->name());
3715 (*added_node->mutable_attr())["N"].set_i(const_inputs.size());
3716 return true;
3717 }
3718 return false;
3719 }
3720
PartialConcatConstFolding(GraphDef * optimized_graph,GraphProperties * properties,NodeDef * node)3721 bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
3722 GraphProperties* properties,
3723 NodeDef* node) {
3724 // Partial constant folding for Concat which is not commutative, so
3725 // we have to preserve order and can only push consecutive runs of constant
3726 // inputs into sub-nodes.
3727 if (!IsConcat(*node) ||
3728 node->name().rfind("_partial_split_") != string::npos) {
3729 return false;
3730 }
3731 const int num_non_control_inputs = NumNonControlInputs(*node);
3732 if (num_non_control_inputs <= 3) return false;
3733 int axis_arg = -1;
3734 int begin = 0;
3735 int end = num_non_control_inputs;
3736 if (node->op() == "Concat") {
3737 begin = 1;
3738 axis_arg = 0;
3739 } else if (node->op() == "ConcatV2") {
3740 end = num_non_control_inputs - 1;
3741 axis_arg = num_non_control_inputs - 1;
3742 } else {
3743 return false;
3744 }
3745
3746 // We search for consecutive runs of constant inputs in the range
3747 // [begin:end[ and push then down into child nodes.
3748 std::vector<std::pair<int, int>> constant_input_runs;
3749 int first = begin;
3750 int last = begin;
3751 while (last < end) {
3752 while (first < end && !IsReallyConstant(*node_map_->GetNode(
3753 NodeName(node->input(first))))) {
3754 ++first;
3755 }
3756 // Invariant: node[first] is constant || first >= end.
3757 last = first + 1;
3758 while (last < end &&
3759 IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
3760 ++last;
3761 }
3762 // Invariant: node[last] is not constant || last >= end
3763 // Discard intervals shorter than 2 elements.
3764 if (first < end && (last - first) > 1) {
3765 constant_input_runs.emplace_back(first, last);
3766 }
3767 first = last;
3768 }
3769
3770 // Skip if all inputs are constant, and let constant folding take over.
3771 if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
3772 constant_input_runs[0].first == begin &&
3773 constant_input_runs[0].second == end)) {
3774 return false;
3775 }
3776 std::set<int> inputs_to_delete;
3777 for (auto interval : constant_input_runs) {
3778 // Push the constant inputs in the interval to a child node than can be
3779 // constant folded.
3780 string new_node_name = OptimizedNodeName(*node, "_partial_split");
3781 do {
3782 new_node_name += strings::StrCat("_", interval.first);
3783 } while (node_map_->NodeExists(new_node_name));
3784
3785 NodeDef* added_node = optimized_graph->add_node();
3786 *added_node = *node;
3787 added_node->set_op("ConcatV2");
3788 added_node->set_name(new_node_name);
3789 node_map_->AddNode(added_node->name(), added_node);
3790 added_node->clear_input();
3791 for (int i = interval.first; i < interval.second; ++i) {
3792 added_node->add_input(node->input(i));
3793 node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
3794 if (i != interval.first) {
3795 inputs_to_delete.insert(i);
3796 }
3797 }
3798 added_node->add_input(node->input(axis_arg));
3799 (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
3800 node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
3801
3802 // Overwrite the first constant input with the result of the added
3803 // child node.
3804 node->set_input(interval.first, added_node->name());
3805 }
3806 if (!inputs_to_delete.empty()) {
3807 // Fix up the inputs to the original node.
3808 protobuf::RepeatedPtrField<string> tmp;
3809 tmp.Swap(node->mutable_input());
3810 for (int i = 0; i < tmp.size(); ++i) {
3811 if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
3812 node->add_input(tmp.Get(i));
3813 }
3814 }
3815 (*node->mutable_attr())["N"].set_i(node->input_size() - 1);
3816 properties->ClearInputProperties(node->name());
3817 }
3818 return true;
3819 }
3820
GetConcatAxis(const NodeDef & node,int * axis)3821 bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
3822 if (node.op() != "ConcatV2") {
3823 return false;
3824 }
3825 int axis_idx = node.input_size() - 1;
3826 while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
3827 --axis_idx;
3828 }
3829 if (axis_idx <= 0) {
3830 return false;
3831 }
3832 Tensor axis_tensor;
3833 if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
3834 return false;
3835 }
3836 *axis = axis_tensor.dtype() == DT_INT64
3837 ? static_cast<int>(axis_tensor.scalar<int64_t>()())
3838 : axis_tensor.scalar<int32>()();
3839 return true;
3840 }
3841
MergeConcat(bool use_shape_info,GraphProperties * properties,GraphDef * optimized_graph,NodeDef * node)3842 bool ConstantFolding::MergeConcat(bool use_shape_info,
3843 GraphProperties* properties,
3844 GraphDef* optimized_graph, NodeDef* node) {
3845 // We only optimize for ConcatV2.
3846 int axis;
3847 if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
3848 nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
3849 node_map_->GetOutputs(node->name()).size() != 1) {
3850 return false;
3851 }
3852
3853 // If all inputs are constant, don't merge and let folding take case of it.
3854 const int num_regular_inputs = NumNonControlInputs(*node);
3855 bool all_inputs_are_const = true;
3856 for (int i = 0; i < num_regular_inputs - 1; ++i) {
3857 const NodeDef* input_node = node_map_->GetNode(node->input(i));
3858 if (!IsReallyConstant(*input_node)) {
3859 all_inputs_are_const = false;
3860 break;
3861 }
3862 }
3863 if (all_inputs_are_const) return false;
3864
3865 NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
3866 int parent_axis;
3867 if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
3868 return false;
3869 }
3870
3871 // Make a pass over the parent inputs to see if any of them have explicit
3872 // device() fields set, and if different inputs are on different tasks. If
3873 // so, this concat of concats may have been carefully constructed to be a
3874 // two-stage concat, and we don't want to undo that here.
3875 string task, device;
3876 absl::flat_hash_set<string> unique_input_tasks;
3877 const int n_parent_inputs = NumNonControlInputs(*parent);
3878 // Iterate over the real inputs to concatenate [0..n_parent_inputs - 1). The
3879 // input at n_parent_inputs - 1 is the concat axis argument for a ConcatV2
3880 // node, which we don't want to consider here.
3881 for (int i = 0; i < n_parent_inputs - 1; ++i) {
3882 const NodeDef* input_node = node_map_->GetNode(parent->input(i));
3883 if (!input_node->device().empty() &&
3884 tensorflow::DeviceNameUtils::SplitDeviceName(input_node->device(),
3885 &task, &device)) {
3886 unique_input_tasks.insert(task);
3887 if (unique_input_tasks.size() >= 2) {
3888 // More than one input task represented in the device specifications
3889 // of the parent's input nodes. Don't mess with this.
3890 return false;
3891 }
3892 }
3893 }
3894
3895 protobuf::RepeatedPtrField<string> parent_inputs;
3896 parent_inputs.Swap(parent->mutable_input());
3897 // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
3898 // collapse it into the parent multiple times? Probably not.
3899 for (const auto& input : parent_inputs) {
3900 if (IsSameInput(input, node->name())) {
3901 for (int j = 0; j < num_regular_inputs - 1; ++j) {
3902 // Add tensor inputs to first child concat tensors (except the final
3903 // axis input) to the parent's inputs.
3904 parent->add_input(node->input(j));
3905 node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
3906 }
3907 } else {
3908 parent->add_input(input);
3909 }
3910 }
3911 // Forward Add control inputs
3912 const int num_inputs = node->input_size();
3913 for (int i = num_inputs - 1; i >= num_regular_inputs; --i) {
3914 parent->add_input(node->input(i));
3915 node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
3916 node->mutable_input()->RemoveLast();
3917 }
3918 (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
3919 DedupControlInputs(parent);
3920 ReplaceOperationWithNoOp(node, properties, optimized_graph);
3921
3922 return true;
3923 }
3924
AddQuantizedMatMulMinMaxOutConstNodes(NodeDef * node,GraphDef * optimized_graph)3925 Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
3926 NodeDef* node, GraphDef* optimized_graph) {
3927 auto add_quantized_out = [this, node, optimized_graph](
3928 const string& out_const_name, int index) {
3929 NodeDef* out_node = optimized_graph->add_node();
3930 graph_modified_ = true;
3931 Tensor value(DT_FLOAT, TensorShape({}));
3932 const bool is_min = index == 1;
3933 const DataType type_attr = node->attr().at("dtype").type();
3934
3935 value.flat<float>()(0) = is_min ? QuantizedTypeMinAsFloat(type_attr)
3936 : QuantizedTypeMaxAsFloat(type_attr);
3937 TF_RETURN_IF_ERROR(
3938 CreateNodeDef(out_const_name, TensorValue(&value), out_node));
3939 node_map_->AddNode(out_const_name, out_node);
3940 out_node->set_device(node->device());
3941 // Copy all inputs from node.
3942 out_node->mutable_input()->CopyFrom(node->input());
3943 for (const string& input : out_node->input()) {
3944 node_map_->AddOutput(NodeName(input), out_const_name);
3945 }
3946
3947 // Update output nodes consuming node:index to new const node.
3948 string old_input = absl::StrCat(node->name(), ":", index);
3949 int old_node_count = 0;
3950 // We make a copy since the set might change.
3951 auto outputs = node_map_->GetOutputs(node->name());
3952 for (const auto& output : outputs) {
3953 for (int i = 0; i < output->input_size(); ++i) {
3954 if (output->input(i) == old_input) {
3955 output->set_input(i, out_const_name);
3956 node_map_->AddOutput(out_const_name, output->name());
3957 } else if (NodeName(output->input(i)) == node->name()) {
3958 ++old_node_count;
3959 }
3960 }
3961 if (old_node_count == 0) {
3962 node_map_->RemoveOutput(node->name(), output->name());
3963 }
3964 }
3965
3966 return OkStatus();
3967 };
3968 const string min_out_const_name =
3969 OptimizedNodeName(*node, "-quantized_matmul_min_out");
3970 const string max_out_const_name =
3971 OptimizedNodeName(*node, "-quantized_matmul_max_out");
3972 if (node_map_->GetNode(min_out_const_name) == nullptr &&
3973 node_map_->GetNode(max_out_const_name) == nullptr) {
3974 TF_RETURN_IF_ERROR(add_quantized_out(min_out_const_name, 1));
3975 TF_RETURN_IF_ERROR(add_quantized_out(max_out_const_name, 2));
3976 } else {
3977 return errors::Internal(absl::Substitute(
3978 "Can't create Const for QuantizedMatMul min_out/max_out of "
3979 "node '$0' because of node name conflict",
3980 node->name()));
3981 }
3982 return OkStatus();
3983 }
3984
RunOptimizationPass(Cluster * cluster,GrapplerItem * item,GraphProperties * properties,GraphDef * optimized_graph)3985 Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
3986 GrapplerItem* item,
3987 GraphProperties* properties,
3988 GraphDef* optimized_graph) {
3989 optimized_graph->Clear();
3990 graph_ = &item->graph;
3991 node_map_.reset(new NodeMap(graph_));
3992 nodes_allowlist_.clear();
3993 // Fold fetch nodes iff it has a single fanout. Note that if a fetch node
3994 // has a single fanout, it would be rewritten as a constant with the same
3995 // node name, and therefore users are still able to fetch it. This is not
3996 // the case if the node has multiple fanouts, and constant folding would
3997 // replace the node with multiple constants (each for one fanout) with
3998 // new names, and as a result users would not be able to fetch the node any
3999 // more with the original node name.
4000 for (const auto& fetch : item->fetch) {
4001 const NodeDef* fetch_node = node_map_->GetNode(fetch);
4002 if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
4003 nodes_allowlist_.insert(fetch_node->name());
4004 }
4005 }
4006
4007 absl::flat_hash_set<string> nodes_to_not_simplify;
4008 if (properties->has_properties()) {
4009 TF_RETURN_IF_ERROR(MaterializeShapes(*properties));
4010 TF_RETURN_IF_ERROR(MaterializeConstants(*properties));
4011 TF_RETURN_IF_ERROR(
4012 FoldGraph(*properties, optimized_graph, &nodes_to_not_simplify));
4013 } else {
4014 *optimized_graph = *graph_;
4015 }
4016 node_map_.reset(new NodeMap(optimized_graph));
4017
4018 TF_RETURN_IF_ERROR(
4019 SimplifyGraph(optimized_graph, properties, &nodes_to_not_simplify));
4020
4021 return OkStatus();
4022 }
4023
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)4024 Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
4025 GraphDef* optimized_graph) {
4026 // TensorFlow flushes denormals to zero and rounds to nearest, so we do
4027 // the same here.
4028 port::ScopedFlushDenormal flush;
4029 port::ScopedSetRound round(FE_TONEAREST);
4030 nodes_to_preserve_ = item.NodesToPreserve();
4031 for (const auto& feed : item.feed) {
4032 feed_nodes_.insert(NodeName(feed.first));
4033 }
4034
4035 if (cpu_device_ == nullptr) {
4036 owned_device_.reset(new DeviceSimple());
4037 cpu_device_ = owned_device_.get();
4038 }
4039
4040 graph_contains_assign_or_inplace_op_ = false;
4041 for (const NodeDef& node : item.graph.node()) {
4042 if (ModifiesInputsInPlace(node) || HasRefInput(node)) {
4043 graph_contains_assign_or_inplace_op_ = true;
4044 break;
4045 }
4046 }
4047
4048 has_fetch_ = !item.fetch.empty();
4049 GrapplerItem item_to_optimize = item;
4050 GraphProperties properties(item_to_optimize);
4051 // It's possible to feed a placeholder with a tensor of any shape: make sure
4052 // that the shape inference deals with this conservatively unless we're in
4053 // aggressive mode.
4054 const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
4055 if (!properties
4056 .InferStatically(assume_valid_feeds,
4057 /*aggressive_shape_inference=*/false,
4058 /*include_input_tensor_values=*/false,
4059 /*include_output_tensor_values=*/true)
4060 .ok()) {
4061 properties.Clear();
4062 }
4063
4064 *optimized_graph = GraphDef();
4065 item_to_optimize.graph.Swap(optimized_graph);
4066 int64_t node_count;
4067
4068 do {
4069 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
4070 graph_modified_ = false;
4071 item_to_optimize.graph.Swap(optimized_graph);
4072 node_count = item_to_optimize.graph.node_size();
4073 TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, &item_to_optimize,
4074 &properties, optimized_graph));
4075 } while (graph_modified_ || optimized_graph->node_size() != node_count);
4076 *optimized_graph->mutable_library() = item.graph.library();
4077 *optimized_graph->mutable_versions() = item.graph.versions();
4078
4079 return OkStatus();
4080 }
4081
4082 } // namespace grappler
4083 } // namespace tensorflow
4084