1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/common_shape_fns.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/costs/utils.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/functions.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40
41 namespace tensorflow {
42 namespace grappler {
43
44 namespace {
45
46 using shape_inference::DimensionHandle;
47 using shape_inference::InferenceContext;
48 using shape_inference::ShapeAndType;
49 using shape_inference::ShapeHandle;
50 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
51
52 // A large value for UnknownDim from Const used as a dim value in shape.
53 // Some ops treat "-1" specially, different from UnknownDim:
54 // e.g., shape input to Reshape op.
55 const int64_t kUnknownDimFromConst = INT64_MAX;
56
57 // Skip const value instantiation if the number of elements in a const tensor
58 // is greater than this threshold.
59 const int kThresholdToSkipConstTensorInstantiation = 128;
60
61 template <typename Handle>
62 struct HashHandle {
operator ()tensorflow::grappler::__anonb166909c0111::HashHandle63 std::size_t operator()(const Handle& h) const { return h.Handle(); }
64 };
65 template <typename Handle>
66 struct CompareHandle {
operator ()tensorflow::grappler::__anonb166909c0111::CompareHandle67 bool operator()(const Handle& h1, const Handle& h2) const {
68 return h1.SameHandle(h2);
69 }
70 };
71
72 template <typename Handle>
73 struct HandleToObject {};
74 template <>
75 struct HandleToObject<ShapeHandle> {
76 typedef ShapeHandle Object;
77
Unknowntensorflow::grappler::__anonb166909c0111::HandleToObject78 static ShapeHandle Unknown() { return ShapeHandle(); }
79 };
80
81 template <>
82 struct HandleToObject<DimensionHandle> {
83 typedef int64_t Object;
84
Unknowntensorflow::grappler::__anonb166909c0111::HandleToObject85 static int64_t Unknown() { return -1; }
86 };
87
88 template <typename Handle>
89 struct Processor {};
90
91 template <>
92 struct Processor<ShapeHandle> {
93 // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anonb166909c0111::Processor94 void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
95 // Merge the shapes or dims.
Mergetensorflow::grappler::__anonb166909c0111::Processor96 Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
97 if (InferenceContext::RankKnown(*result)) {
98 // The result was initialized in a previous merge to a shape of known
99 // rank, make sure we preserve that information.
100 return OkStatus();
101 }
102 if (InferenceContext::RankKnown(h1)) {
103 *result = h1;
104 } else {
105 *result = h2;
106 }
107 return OkStatus();
108 }
109 };
110
111 template <>
112 struct Processor<DimensionHandle> {
113 // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
114 // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anonb166909c0111::Processor115 void ExtractValue(DimensionHandle d, int64_t* result) {
116 if (!InferenceContext::ValueKnown(d)) {
117 *result = -counter;
118 counter++;
119 } else {
120 int64_t val = InferenceContext::Value(d);
121 if (val >= 0) {
122 *result = val;
123 } else {
124 // A shape inference function generated an invalid dimension handle.
125 // Use a symbolic dimension to encode this.
126 *result = -counter;
127 counter++;
128 }
129 }
130 }
131
132 // Merge the dimensions d1 and d2. Return the known shape if there is one,
133 // otherwise look for a symbolic shape. If there is no symbolic shape and no
134 // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anonb166909c0111::Processor135 Status Merge(DimensionHandle d1, DimensionHandle d2, int64_t* result) {
136 const int64_t dim1 = InferenceContext::Value(d1);
137 const int64_t dim2 = InferenceContext::Value(d2);
138
139 if (dim1 >= 0 && dim2 >= 0) {
140 CHECK_EQ(dim1, dim2);
141 return RefineDim(dim1, result);
142 } else if (dim1 >= 0 && dim2 < 0) {
143 return RefineDim(dim1, result);
144 } else if (dim1 < 0 && dim2 >= 0) {
145 return RefineDim(dim2, result);
146 } else if (dim1 < -1) {
147 return RefineDim(dim1, result);
148 } else if (dim2 < -1) {
149 return RefineDim(dim2, result);
150 } else {
151 CHECK_EQ(dim1, dim2);
152 CHECK_EQ(-1, dim1);
153 return RefineDim(-1, result);
154 }
155 return OkStatus();
156 }
157
158 private:
RefineDimtensorflow::grappler::__anonb166909c0111::Processor159 Status RefineDim(int64_t dim, int64_t* result) {
160 if (*result >= 0) {
161 if (!(*result == dim || dim < 0)) {
162 return errors::InvalidArgument("Inconsistent dimensions detected");
163 }
164 } else if (dim >= 0) {
165 *result = dim;
166 } else if (dim < *result) {
167 *result = dim;
168 }
169 return OkStatus();
170 }
171
172 int64_t counter = 2;
173 };
174
175 // Traditional Disjoint-Set datastructure with path compression.
176 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
177 template <typename Handle>
178 class DisjointSet {
179 public:
DisjointSet()180 DisjointSet() {}
~DisjointSet()181 ~DisjointSet() {
182 for (auto rep : nodes_) {
183 delete rep.second;
184 }
185 }
186
187 Status Merge(Handle x, Handle y);
188 const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
189
190 private:
191 // All the handles that belong to the same set are part of the same tree, and
192 // utimately represented by the root of that tree.
193 struct Rep {
194 // Parent in the tree used to encode the set.
195 Rep* parent;
196 // Rank in the tree, used to figure out how to compress the path to the root
197 // of the tree.
198 int rank;
199 // The handle.
200 typename HandleToObject<Handle>::Object value;
201 };
202
203 // Create a new set for the value if none exists, or return its representative
204 // node otherwise.
205 Rep* Find(Handle value);
206
207 private:
208 Processor<Handle> processor_;
209 absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
210 nodes_;
211 };
212
213 template <typename Handle>
214 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)215 DisjointSet<Handle>::GetMergedValue(Handle value) {
216 Rep* rep = Find(value);
217 if (!rep) {
218 // We don't know anything about this handle.
219 return HandleToObject<Handle>::Unknown();
220 }
221 return rep->value;
222 }
223
224 template <typename Handle>
Merge(Handle x,Handle y)225 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
226 Rep* x_root = Find(x);
227 Rep* y_root = Find(y);
228
229 // x and y are already in the same set
230 if (x_root == y_root) {
231 return OkStatus();
232 }
233 // x and y are not in same set, so we merge them
234 // Use the occasion to strengthen what we know about the handle by merging the
235 // information about the 2 subsets.
236 if (x_root->rank < y_root->rank) {
237 TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
238 x_root->parent = y_root;
239 } else if (x_root->rank > y_root->rank) {
240 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
241 y_root->parent = x_root;
242 } else {
243 TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
244 // Arbitrarily make one root the new parent
245 y_root->parent = x_root;
246 x_root->rank = x_root->rank + 1;
247 }
248 return OkStatus();
249 }
250
251 template <typename Handle>
Find(Handle value)252 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
253 auto it = nodes_.find(value);
254 if (it == nodes_.end()) {
255 // This is the first time we process this handle, create an entry for it.
256 Rep* node = new Rep;
257 node->parent = node;
258 node->rank = 0;
259 processor_.ExtractValue(value, &node->value);
260 nodes_[value] = node;
261 return node;
262 }
263 // Return the representative for the set, which is the root of the tree. Apply
264 // path compression to speedup future queries.
265 Rep* node = it->second;
266 Rep* root = node->parent;
267 while (root != root->parent) {
268 root = root->parent;
269 }
270 while (node->parent != root) {
271 Rep* next = node->parent;
272 node->parent = root;
273 node = next;
274 }
275 return root;
276 }
277
278 // TODO(dyoon): Move many helper functions in this file (including those within
279 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)280 bool IsEnqueue(const NodeDef& n) {
281 return (n.op().find("Enqueue") != string::npos &&
282 n.op().find("EnqueueMany") == string::npos);
283 }
284
IsDequeue(const NodeDef & n)285 bool IsDequeue(const NodeDef& n) {
286 return (n.op().find("Dequeue") != string::npos &&
287 n.op().find("DequeueMany") == string::npos);
288 }
289
HasAnyUnknownDimensions(const TensorShapeProto & proto)290 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
291 if (proto.unknown_rank()) {
292 return true;
293 }
294 for (const auto& dim : proto.dim()) {
295 if (dim.size() < 0) {
296 return true;
297 }
298 }
299 return false;
300 }
301
302 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)303 void VerboseLogUnknownDimensionSources(
304 const GraphDef& graph,
305 const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
306 input_properties_map,
307 const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
308 output_properties_map) {
309 if (!VLOG_IS_ON(2)) {
310 return;
311 }
312
313 VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
314
315 // Find all nodes in the graph for which we
316 // do not have any unknown dimensions in their inputs, but
317 // we have some unknown dimensions in their outputs.
318 std::map<string, int> op_to_count;
319 for (const NodeDef& node : graph.node()) {
320 const auto& input_properties = input_properties_map.at(node.name());
321 const auto& output_properties = output_properties_map.at(node.name());
322
323 bool has_unknown_inputs = false;
324 for (const auto& input_prop : input_properties) {
325 if (HasAnyUnknownDimensions(input_prop.shape())) {
326 has_unknown_inputs = true;
327 break;
328 }
329 }
330
331 if (has_unknown_inputs) {
332 continue;
333 }
334
335 for (const auto& output_prop : output_properties) {
336 if (HasAnyUnknownDimensions(output_prop.shape())) {
337 string inputs = "input_shapes=[";
338 for (const auto& input_prop : input_properties) {
339 inputs += PartialTensorShape::DebugString(input_prop.shape());
340 }
341 inputs += "]";
342
343 string outputs = "output_shapes=[";
344 for (const auto& output_prop : output_properties) {
345 outputs += PartialTensorShape::DebugString(output_prop.shape());
346 }
347 outputs += "]";
348
349 VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
350 << inputs << ", " << outputs;
351
352 op_to_count[node.op()]++;
353
354 // don't log again for this node
355 break;
356 }
357 }
358 }
359 VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
360 << "(format: <op_type> (<count>)):";
361 for (const auto& p : op_to_count) {
362 VLOG(2) << p.first << " (" << p.second << ")";
363 }
364 }
365
366 // Helper function to convert kUnknownDimFromConst into UnknownDim.
ReplaceUnknownDimFromConstWithUnknownDim(InferenceContext * ic,const std::vector<ShapeHandle> & shapes)367 std::vector<ShapeHandle> ReplaceUnknownDimFromConstWithUnknownDim(
368 InferenceContext* ic, const std::vector<ShapeHandle>& shapes) {
369 std::vector<ShapeHandle> converted_shapes(shapes.size());
370 for (int i = 0, shapes_size = shapes.size(); i < shapes_size; i++) {
371 const auto& shape = shapes[i];
372 if (!ic->RankKnown(shape)) {
373 converted_shapes[i] = shape;
374 continue;
375 }
376 bool just_copy = true;
377 std::vector<DimensionHandle> dims;
378 for (int32_t i = 0; i < ic->Rank(shape); ++i) {
379 DimensionHandle dim = ic->Dim(shape, i);
380 if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
381 just_copy = false;
382 dims.push_back(ic->UnknownDim());
383 } else {
384 dims.push_back(dim);
385 }
386 }
387 if (just_copy) {
388 converted_shapes[i] = shape;
389 continue;
390 }
391 converted_shapes[i] = ic->MakeShape(dims);
392 }
393 return converted_shapes;
394 }
395
396 // Returned tensor's shape is like `shape`, and its values and dtype are from
397 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)398 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
399 const ShapeHandle& shape,
400 const ShapeHandle& tensor_as_shape,
401 const DataType& dtype) {
402 TensorProto tensor_proto;
403 tensor_proto.set_dtype(dtype);
404 auto* shape_proto = tensor_proto.mutable_tensor_shape();
405 if (ic->Rank(shape) == 1) {
406 shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
407 }
408 // For a scalar tensor, tensor_shape field will be left empty; no dim.
409 for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
410 int64_t value = ic->Value(ic->Dim(tensor_as_shape, i));
411 if (dtype == DT_INT32) {
412 tensor_proto.add_int_val(value);
413 } else {
414 tensor_proto.add_int64_val(value);
415 }
416 }
417 return tensor_proto;
418 }
419
420 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)421 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
422 const TensorProto& tensor_proto,
423 const DataType& dtype) {
424 NodeDef const_node;
425 const_node.set_name("const_from_shape");
426 const_node.set_op("Const");
427 auto* attr = const_node.mutable_attr();
428 (*attr)["dtype"].set_type(dtype);
429 auto* tensor = (*attr)["value"].mutable_tensor();
430 *tensor = tensor_proto;
431 return const_node;
432 }
433
434 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
435 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)436 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
437 const ShapeHandle& shape,
438 const ShapeHandle& tensor_as_shape,
439 const DataType& dtype) {
440 return MakeConstNodeDefFromTensorProto(
441 ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
442 }
443
IsNumericType(const DataType dtype)444 bool IsNumericType(const DataType dtype) {
445 static const gtl::FlatSet<DataType>* const kRealNumberTypes =
446 CHECK_NOTNULL((new gtl::FlatSet<DataType>{
447 // Floating point.
448 DT_BFLOAT16,
449 DT_HALF,
450 DT_FLOAT,
451 DT_DOUBLE,
452 // Int / UInt.
453 DT_INT8,
454 DT_INT16,
455 DT_INT32,
456 DT_INT64,
457 DT_UINT8,
458 DT_UINT16,
459 DT_UINT32,
460 DT_UINT64,
461 // Quantized Int.
462 DT_QINT8,
463 DT_QUINT8,
464 DT_QINT16,
465 DT_QUINT16,
466 DT_QINT32,
467 // Bool.
468 DT_BOOL,
469 }));
470 return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
471 }
472
473 // Returns the number of elements in the input (const) tensor.
474 // -1 if the tensor has no shape or unknown rank.
NumElementsFromTensorProto(const TensorProto & tensor_proto)475 uint64 NumElementsFromTensorProto(const TensorProto& tensor_proto) {
476 if (!tensor_proto.has_tensor_shape()) {
477 return -1;
478 }
479 const auto& tensor_shape_proto = tensor_proto.tensor_shape();
480 if (tensor_shape_proto.unknown_rank()) {
481 return -1;
482 }
483 int64_t num_elements = 1;
484 for (const auto& dim : tensor_shape_proto.dim()) {
485 // Note that in some cases, dim.size() can be zero (e.g., empty vector).
486 num_elements *= dim.size();
487 }
488 return num_elements;
489 }
490
491 } // namespace
492
493 // Note that tensor_as_shape input should not include kUnknownDimFromConst.
494 // This function check kUnknownDimFromConst, but will log WARNING.
495 // If checking input_tensors_as_shape_to_propgate or output_tensors_as_shape,
496 // which may include kUnknownDimFromConst, run
497 // convert it using ReplaceUnknownDimFromConstWithUnknownDim() before.
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)498 bool IsShapeFullyDefinedIntegerVectorOrScalar(
499 InferenceContext* ic, const ShapeHandle& shape,
500 const ShapeHandle& tensor_as_shape, const DataType& dtype) {
501 if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
502 !ic->FullyDefined(tensor_as_shape) ||
503 (dtype != DT_INT32 && dtype != DT_INT64)) {
504 return false;
505 }
506 // Also check whether any dim in tensor_as_shape is kUnknownDimFromConst.
507 for (int32_t i = 0; i < ic->Rank(tensor_as_shape); ++i) {
508 DimensionHandle dim = ic->Dim(tensor_as_shape, i);
509 if (ic->Value(dim) == kUnknownDimFromConst) {
510 LOG(WARNING) << "IsShapeFullyDefinedIntegerVectorOrScalar(): "
511 << "tensor_as_shape input includes kUnknownDimFromConst -- "
512 << ic->DebugString(tensor_as_shape);
513 return false;
514 }
515 }
516 return true;
517 }
518
519 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
520 // dequeued in (roughly) topological order. Propagating shapes following a
521 // topological ordering isn't required for correctness but helps speed things up
522 // since it avoids processing the same node multiple times as its inputs
523 // information is refined.
524 class TopoQueue {
525 public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)526 explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
527 : topo_order_(TopoOrder(topo_order)) {}
528
push(const NodeDef * n)529 void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
530
pop()531 const NodeDef* pop() {
532 CHECK(!empty());
533 auto it = queue_.begin();
534 const NodeDef* n = it->first;
535 queue_.erase(it);
536 return n;
537 }
538
empty() const539 bool empty() const { return queue_.empty(); }
size() const540 std::size_t size() const { return queue_.size(); }
541
542 private:
543 using NodeAndId = std::pair<const NodeDef*, int>;
544 // Graph nodes are created in (roughly) topological order. Therefore we can
545 // use their id to ensure they're sorted topologically.
546 struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending547 bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
548 return lhs.second < rhs.second;
549 }
550 };
551
TopoOrder(const std::vector<const NodeDef * > & topo_order) const552 const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
553 const std::vector<const NodeDef*>& topo_order) const {
554 absl::flat_hash_map<const NodeDef*, int> map;
555 map.reserve(topo_order.size());
556 for (int i = 0, topo_order_size = topo_order.size(); i < topo_order_size;
557 ++i) {
558 map.emplace(topo_order[i], i);
559 }
560 return map;
561 }
562
563 const absl::flat_hash_map<const NodeDef*, int> topo_order_;
564 std::set<NodeAndId, OrderByIdAscending> queue_;
565 };
566
567
IsAllowListedOpTypeForEvaluateNode(const string & op_type)568 bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
569 static const gtl::FlatSet<string>* const kOpTpeAllowlist =
570 CHECK_NOTNULL((new gtl::FlatSet<string>{
571 // Unary arithmetic ops
572 "Floor",
573 "Round",
574 "Sqrt",
575 "Square",
576 "Sign",
577 // Binary arithmetic ops
578 "Add",
579 "AddV2",
580 "Div",
581 "FloorDiv",
582 "FloorMod",
583 "Greater",
584 "GreaterEqual",
585 "Less",
586 "LessEqual",
587 "LogicalAnd",
588 "LogicalNot",
589 "LogicalOr",
590 "Maximum",
591 "Minimum",
592 "Mod",
593 "Mul",
594 "NotEqual",
595 "QuantizedAdd",
596 "QuantizedMul",
597 "SquareDifference",
598 "Sub",
599 "TruncateDiv",
600 "TruncateMod",
601 "RealDiv",
602 // N-ary arithmetic ops
603 "AddN",
604 // Others
605 "StridedSlice",
606 "OnesLike",
607 "ZerosLike",
608 "Concat",
609 "ConcatV2",
610 "Split",
611 "Range",
612 "Fill",
613 "Cast",
614 "Prod",
615 "Unpack",
616 "GatherV2",
617 "Pack",
618 // Used in batch_gather_nd: tensorflow/python/ops/array_ops.py
619 "ExpandDims",
620 }));
621 return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
622 }
623
624 // Negative shape size of '-1' represents unknown, while negative shape sizes
625 // less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
626 // -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
627 // we need to normalize them: mark all values <-1 as "unknown" (-1).
NormalizeShapeForOutput(TensorShapeProto * shape)628 static void NormalizeShapeForOutput(TensorShapeProto* shape) {
629 for (int i = 0; i < shape->dim_size(); i++) {
630 if (shape->dim(i).size() < -1) {
631 VLOG(2) << "Normalizing dimension: " << i << " from "
632 << shape->dim(i).size() << " to -1";
633 shape->mutable_dim(i)->set_size(-1);
634 }
635 }
636 }
637
638 // Processes symbolic shapes.
639 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
640 // shape refiner which creates new handles every time it processes an unknown
641 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
642 // unknown shape/dimension of a given node.
643 class SymbolicShapeRefiner {
644 public:
SymbolicShapeRefiner(const GraphView & graph,const absl::flat_hash_map<string,absl::flat_hash_set<int>> & fed_ports,const bool aggressive_shape_inference)645 explicit SymbolicShapeRefiner(
646 const GraphView& graph,
647 const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
648 const bool aggressive_shape_inference)
649 : graph_(graph),
650 function_library_(OpRegistry::Global(), graph.graph()->library()),
651 fed_ports_(fed_ports),
652 aggressive_shape_inference_(aggressive_shape_inference) {
653 graph_def_version_ = graph.graph()->versions().producer();
654 node_to_context_.reserve(graph.graph()->node_size());
655 }
656
graph() const657 const GraphView& graph() const { return graph_; }
658
659 struct NodeContext {
660 const OpRegistrationData* op_data;
661 DataTypeVector input_types;
662 DataTypeVector output_types;
663 std::unique_ptr<InferenceContext> inference_context;
664 // Additional info for propagating tensor values and tensor shapes.
665 std::vector<const TensorProto*> input_tensor_protos;
666 std::vector<const TensorProto*> output_tensor_protos;
667 // This is the same to inference_context->input_tensors_as_shapes, except
668 // that some UnknownDims (-1) can be kUnknownDimFromConst.
669 std::vector<ShapeHandle> input_tensors_as_shapes_to_propagate;
670 std::vector<ShapeHandle> output_tensors_as_shapes;
671
672 // Output shapes incompatible between annotation and shape inference.
673 bool shape_incompatible = false;
674
675 // Similar to DebugString() in InferenceContext, but prints out
676 // kUnknownDimFromConst properly.
StringifyShapeHandletensorflow::grappler::SymbolicShapeRefiner::NodeContext677 std::string StringifyShapeHandle(ShapeHandle s) {
678 auto* ic = inference_context.get();
679 if (ic->RankKnown(s)) {
680 std::vector<std::string> vals;
681 for (int i = 0; i < ic->Rank(s); i++) {
682 DimensionHandle d = ic->Dim(s, i);
683 if (ic->ValueKnown(d) && ic->Value(d) == kUnknownDimFromConst) {
684 vals.push_back("?(Const)");
685 } else {
686 vals.push_back(ic->DebugString(d));
687 }
688 }
689 return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
690 } else {
691 return "?";
692 }
693 }
694
DebugStringtensorflow::grappler::SymbolicShapeRefiner::NodeContext695 std::string DebugString(const NodeDef& node) {
696 std::string output;
697 auto* ic = inference_context.get();
698 absl::StrAppend(
699 &output, node.name(), " [", node.op(), "] has ", ic->num_inputs(),
700 (ic->num_inputs() > 1 ? " inputs and " : " input and "),
701 ic->num_outputs(), (ic->num_outputs() > 1 ? " outputs" : " output"));
702 if (op_data->is_function_op) {
703 absl::StrAppend(&output, " (function op)");
704 }
705 absl::StrAppend(&output, ": \n");
706
707 for (int i = 0; i < ic->num_inputs(); i++) {
708 absl::StrAppend(&output, " input [", i, "] ", node.input(i),
709 " -- type: ", DataTypeString(input_types.at(i)),
710 ", shape: ", ic->DebugString(ic->input(i)),
711 ", tensor: ");
712 Tensor t1;
713 int input_tensor_protos_size = input_tensor_protos.size();
714 if (input_tensor_protos_size > i &&
715 input_tensor_protos.at(i) != nullptr &&
716 t1.FromProto(*input_tensor_protos.at(i))) {
717 absl::StrAppend(&output, t1.DebugString(), ", tensor_as_shape: ");
718 } else {
719 absl::StrAppend(&output, " null, tensor_as_shape: ");
720 }
721 int input_tensors_as_shapes_to_propagate_size =
722 input_tensors_as_shapes_to_propagate.size();
723 if (input_tensors_as_shapes_to_propagate_size > i) {
724 absl::StrAppend(
725 &output,
726 StringifyShapeHandle(input_tensors_as_shapes_to_propagate.at(i)),
727 "\n");
728 } else {
729 absl::StrAppend(&output, " null\n");
730 }
731 }
732 for (int i = 0; i < ic->num_outputs(); i++) {
733 absl::StrAppend(&output, " output [", i,
734 "] -- type: ", DataTypeString(output_types.at(i)),
735 ", shape: ", ic->DebugString(ic->output(i)),
736 ", tensor: ");
737 Tensor t2;
738 int output_tensor_protos_size = output_tensor_protos.size();
739 if (output_tensor_protos_size > i &&
740 output_tensor_protos.at(i) != nullptr &&
741 t2.FromProto(*output_tensor_protos.at(i))) {
742 absl::StrAppend(&output, t2.DebugString(), ", tensor_as_shape: ");
743 } else {
744 absl::StrAppend(&output, " null, tensor_as_shape: ");
745 }
746 int output_tensors_as_shapes_size = output_tensors_as_shapes.size();
747 if (output_tensors_as_shapes_size > i) {
748 absl::StrAppend(&output,
749 StringifyShapeHandle(output_tensors_as_shapes.at(i)),
750 "\n");
751 } else {
752 absl::StrAppend(&output, " null\n");
753 }
754 }
755 return output;
756 }
757 };
758
GetNodeContext(const NodeDef * node)759 NodeContext* GetNodeContext(const NodeDef* node) {
760 auto it = node_to_context_.find(node);
761 if (it == node_to_context_.end()) {
762 return nullptr;
763 }
764 return &it->second;
765 }
766
GetContext(const NodeDef * node)767 InferenceContext* GetContext(const NodeDef* node) {
768 auto it = node_to_context_.find(node);
769 if (it == node_to_context_.end()) {
770 return nullptr;
771 }
772 return it->second.inference_context.get();
773 }
774
775 // Forward the shapes from the function input nodes, PartitionedCalls or
776 // StatefulPartitionedCall to
777 // the argument nodes (which are Placeholder nodes), then
778 // perform shape inference on the function body.
779 //
780 // Propagate shape information of final function body node
781 // to function node `function_node`.
782 //
783 // In the event of an error, UpdateNode will simply set `function_node`'s
784 // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)785 Status UpdateFunction(const NodeDef* function_node) {
786 NameAttrList function;
787 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
788 auto it = fun_to_grappler_function_item_.find(function.name());
789 if (it == fun_to_grappler_function_item_.end()) {
790 return errors::InvalidArgument(
791 function.name(),
792 " was not previously added to SymbolicShapeRefiner.");
793 }
794
795 const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
796 it->second;
797 if (!maybe_grappler_function_item.has_value()) {
798 VLOG(3) << "Skip failed to instantiate function call: function_name="
799 << function.name();
800
801 auto* ctx = GetNodeContext(function_node);
802 auto* ic = ctx->inference_context.get();
803 for (int i = 0; i < ic->num_outputs(); ++i) {
804 TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
805 }
806
807 return OkStatus();
808 }
809
810 // Copy (not reference) so that changes we make here (e.g., replacing
811 // _Arg with Const and _Retval with Identity) don't affect one in
812 // fun_to_grappler_function_item_.
813 GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
814 MutableGraphView gv(&grappler_function_item.graph);
815
816 // Forward shapes from function input nodes to argument nodes.
817 for (int i = 0, end = grappler_function_item.inputs().size(); i < end;
818 ++i) {
819 auto& fun_input = grappler_function_item.input(i);
820 NodeDef* fun_node = gv.GetNode(fun_input.node_name);
821 const TensorId input_tensor = ParseTensorName(function_node->input(i));
822
823 if (IsControlInput(input_tensor)) {
824 return errors::FailedPrecondition(
825 "Function inputs should not contain control nodes.");
826 }
827
828 const NodeDef* input_node = graph_.GetNode(input_tensor.node());
829 if (input_node == nullptr) {
830 return errors::FailedPrecondition(input_tensor.node(),
831 " was not found in the graph.");
832 }
833
834 InferenceContext* input_ic = GetContext(input_node);
835 if (input_ic == nullptr) {
836 return errors::FailedPrecondition(
837 "Inference context has not been created for ", input_tensor.node());
838 }
839
840 int output_port_num = input_tensor.index();
841 AttrValue attr_output_shape;
842 TensorShapeProto proto;
843 const auto handle = input_ic->output(output_port_num);
844 input_ic->ShapeHandleToProto(handle, &proto);
845 // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
846 NormalizeShapeForOutput(&proto);
847 // _Arg op's output shape uses _output_shapes attr.
848 AttrValue output_attr;
849 output_attr.mutable_list()->add_shape()->Swap(&proto);
850 (*fun_node->mutable_attr())["_output_shapes"] = output_attr;
851
852 // If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
853 // _handle_shapes attr for its shapes and dtypes.
854 if (fun_input.data_type == DT_RESOURCE) {
855 auto* shapes_and_types =
856 input_ic->output_handle_shapes_and_types(output_port_num);
857 if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
858 AttrValue dtype_attr;
859 AttrValue shape_attr;
860 for (const auto& shape_and_type : *shapes_and_types) {
861 const auto& dtype = shape_and_type.dtype;
862 const auto& shape_handle = shape_and_type.shape;
863 dtype_attr.mutable_list()->add_type(dtype);
864 input_ic->ShapeHandleToProto(
865 shape_handle, shape_attr.mutable_list()->add_shape());
866 }
867 (*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
868 (*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
869 } else {
870 // Note that we do not return error here, even if the input node does
871 // not have shapes_and_types. Within the function, we cannot infer the
872 // output shape of the DT_RESOURCE input; hence, potentially unknown
873 // shapes/dims in the function output shapes.
874 VLOG(2)
875 << "A function node (" << function_node->name()
876 << ") has input with DT_RESOURCE, but the input node does not "
877 << "have shapes_and_types information: \n"
878 << "function_node: " << function_node->ShortDebugString() << "\n"
879 << "function input: " << i
880 << ", input node's output: " << output_port_num << "\n"
881 << "input node: " << input_node->ShortDebugString();
882 }
883 }
884 }
885
886 // ReplaceInputWithConst() may break GraphView's internal node mapping
887 // structure; hence, we separately build node name to NodeDef* map, for the
888 // output nodes (before GraphView becomes invalid). Note that we use string,
889 // not string_view.
890 absl::flat_hash_map<std::string, NodeDef*> output_nodes;
891 for (const auto& output_arg : grappler_function_item.outputs()) {
892 output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name);
893 }
894
895 // Replace input nodes with Consts, if values are known. Note that
896 // we don't check exceptions here as it's done in the above loop.
897 auto* ctx = GetNodeContext(function_node);
898 auto* ic = ctx->inference_context.get();
899 for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
900 const string& input = function_node->input(i);
901 const string node_name = NodeName(input);
902 const NodeDef* input_node = graph_.GetNode(node_name);
903 if (IsConstant(*input_node)) {
904 TF_CHECK_OK(
905 ReplaceInputWithConst(*input_node, i, &grappler_function_item));
906 } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
907 ctx->input_tensor_protos[i] != nullptr) {
908 NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
909 ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
910 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
911 &grappler_function_item));
912 } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) > i &&
913 IsShapeFullyDefinedIntegerVectorOrScalar(
914 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
915 ctx->input_types[i])) {
916 // We have fully defined input_tensors_as_shapes for this input; use it
917 // as a const input to the function node.
918 NodeDef const_input_node = MakeConstNodeDefFromShape(
919 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
920 ctx->input_types[i]);
921 TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
922 &grappler_function_item));
923 }
924 }
925 // node_name to NodeDef* map in GraphView gv can be broken due to
926 // ReplaceInputWithConst(). gv should not be used after this.
927
928 // Replace output _Retval nodes with Identity nodes. _Retval is a system op
929 // without outputs and registered shape function.
930 for (const auto& output_arg : grappler_function_item.outputs()) {
931 NodeDef* output_node = output_nodes[output_arg.node_name];
932 DCHECK_EQ(output_node->op(), "_Retval");
933 output_node->set_op("Identity");
934 output_node->mutable_attr()->erase("index");
935 }
936
937 // Perform inference on function body.
938 GraphProperties gp(grappler_function_item);
939 TF_RETURN_IF_ERROR(gp.InferStatically(
940 /*assume_valid_feeds=*/true,
941 /*aggressive_shape_inference=*/aggressive_shape_inference_,
942 /*include_tensor_values=*/true));
943
944 // Add return nodes for output shapes.
945 int output = 0;
946 ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
947 ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
948 nullptr);
949 for (auto const& out_arg : grappler_function_item.outputs()) {
950 // It is guaranteed that output_tensors does not contain any control
951 // inputs, so port_id >= 0.
952 TensorId out_tensor = ParseTensorName(out_arg.node_name);
953
954 if (output_nodes.count(out_tensor.node()) <= 0) {
955 return errors::FailedPrecondition(
956 "Unable to find return function_node ", out_tensor.node(), " for ",
957 function_node->name());
958 }
959 const NodeDef* retnode = output_nodes[out_tensor.node()];
960
961 auto output_properties = gp.GetOutputProperties(retnode->name());
962 int output_properties_size = output_properties.size();
963 if (out_tensor.index() >= output_properties_size) {
964 return errors::InvalidArgument(
965 out_tensor.ToString(), " has invalid position ", out_tensor.index(),
966 " (output_properties.size() = ", output_properties.size(), ").");
967 }
968 auto& outprop = output_properties[out_tensor.index()];
969 TensorShapeProto shape = outprop.shape();
970 NormalizeShapeForOutput(&shape);
971 ShapeHandle out;
972 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
973 ic->set_output(output, out);
974 if (outprop.has_value()) {
975 // Forward tensor value to output_tensors_as_shape.
976 MaybeTensorProtoToShape(ic, outprop.value(),
977 &ctx->output_tensors_as_shapes[output]);
978 const_tensors_to_propagate_.push_back(outprop.value());
979 ctx->output_tensor_protos[output] = &const_tensors_to_propagate_.back();
980 }
981 output++;
982 }
983
984 return OkStatus();
985 }
986
987 // Prepares input shapes/values/handles, then runs shape inference, and
988 // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)989 Status UpdateNode(const NodeDef* node, bool* refined) {
990 NodeContext* ctx = GetNodeContext(node);
991 if (ctx == nullptr) {
992 TF_RETURN_IF_ERROR(AddNode(node));
993 ctx = CHECK_NOTNULL(GetNodeContext(node));
994 *refined = true;
995 }
996
997 // Check if the shapes of the nodes in the fan-in of this node have changed,
998 // and if they have, update the node input shapes.
999 InferenceContext* ic = ctx->inference_context.get();
1000 ctx->input_tensors_as_shapes_to_propagate.resize(ic->num_inputs());
1001 ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
1002
1003 for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1004 const GraphView::InputPort port(node, dst_input);
1005 const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
1006 int src_output = fanin.port_id;
1007 const NodeDef* src = fanin.node;
1008 NodeContext* src_ctx = GetNodeContext(src);
1009 if (src_ctx == nullptr) {
1010 return errors::FailedPrecondition(
1011 "Input ", dst_input, " for '", node->name(),
1012 "' was not previously added to SymbolicShapeRefiner.");
1013 }
1014
1015 InferenceContext* src_ic = src_ctx->inference_context.get();
1016 if (src_output >= src_ic->num_outputs()) {
1017 return errors::OutOfRange("src_output = ", src_output,
1018 ", but num_outputs is only ",
1019 src_ic->num_outputs());
1020 }
1021
1022 // Propagate input node's NodeContext info to the current node's
1023 // NodeContext:
1024 // output_tensor_protos to input_tensor_protos and input_tensors, and
1025 // output_tensors_as_shapes to input_tensors_as_shapes.
1026 if (static_cast<int>(src_ctx->output_tensors_as_shapes.size()) >
1027 src_output) {
1028 ctx->input_tensors_as_shapes_to_propagate[dst_input] =
1029 src_ctx->output_tensors_as_shapes[src_output];
1030 }
1031
1032 if (static_cast<int>(src_ctx->output_tensor_protos.size()) > src_output) {
1033 const auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
1034 if (tensor_proto != nullptr) {
1035 ctx->input_tensor_protos[dst_input] = tensor_proto;
1036 }
1037 }
1038
1039 // NOTE: we check only shape is refined; we do not (yet) check whether
1040 // tensor value is refined.
1041 if (!*refined &&
1042 !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
1043 *refined = true;
1044 }
1045 ic->SetInput(dst_input, src_ic->output(src_output));
1046
1047 if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
1048 // The input value may have changed. Since we have no way to know if
1049 // that's indeed the case, err on the safe side.
1050 *refined = true;
1051 }
1052
1053 // Also propagate handle shape and dtype of edges which are carrying
1054 // resource handles.
1055 if (ctx->input_types[dst_input] == DT_RESOURCE) {
1056 auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
1057 if (!outputs) continue;
1058 auto* inputs = ic->input_handle_shapes_and_types(dst_input);
1059
1060 if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
1061 *refined = true;
1062 ic->set_input_handle_shapes_and_types(dst_input, *outputs);
1063 }
1064 }
1065
1066 // Make sure we schedule the fanout of resources (which have no input)
1067 // whenever the resources are updated.
1068 *refined |= ic->num_inputs() == 0;
1069
1070 if (!*refined) {
1071 // No input shape has changed, we're done.
1072 return OkStatus();
1073 }
1074
1075 // Convert all kUnknownDimFromConst to -1 for shape inference.
1076 ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
1077 ic, ctx->input_tensors_as_shapes_to_propagate));
1078 // Note: UpdateFunction uses input_tensors_as_shapes and
1079 // input_tensor_protos (not the Tensor object) for input values.
1080 // so for function nodes, we don't need to convert TensorProtos
1081 // to Tensors here. If the current op is not a function op, we convert
1082 // TensorProtos to Tensors before calling InferShapes.
1083
1084 // Properly handle function nodes.
1085 if (ctx->op_data && ctx->op_data->is_function_op) {
1086 // TODO(jmdecker): Detect if the input shapes have changed for this
1087 // function. Note that when we hit a function call node, refined will be
1088 // true, as the updates to the call node will have changed, even if it's
1089 // the same function being called twice with the same input shapes.
1090 // Example: simple_function.pbtxt
1091 if (aggressive_shape_inference_) {
1092 // If output shapes are annotated, use it and skip UpdateFunction();
1093 // it can be very expensive when a function node has nested function
1094 // nodes internally. One downside with this approach is that we do not
1095 // get output values or output shapes as tensor from function node.
1096 auto s = UpdateOutputShapesUsingAnnotatedInformation(*node, ctx);
1097 if (s.ok() && AllOutputShapesKnown(ctx)) {
1098 return OkStatus();
1099 }
1100 // If shape annotation was not available, incomplete, or incompatible,
1101 // fall through to call UpdateFunction().
1102 }
1103 auto s = UpdateFunction(node);
1104 if (s.ok()) {
1105 return OkStatus();
1106 } else {
1107 VLOG(1) << "UpdateFunction failed for " << node->op()
1108 << ". Defaulting to ShapeUnknown.\n"
1109 << s.ToString();
1110 }
1111 }
1112
1113 // Construct Tensors for constant inputs used by shape functions.
1114 std::vector<Tensor> const_values(ic->num_inputs());
1115 std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
1116 for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1117 const TensorProto* tensor_proto = ctx->input_tensor_protos[dst_input];
1118 if (tensor_proto != nullptr &&
1119 // Skip if the const tensor is too large.
1120 NumElementsFromTensorProto(*tensor_proto) <=
1121 kThresholdToSkipConstTensorInstantiation &&
1122 const_values[dst_input].FromProto(*tensor_proto)) {
1123 input_tensors[dst_input] = &const_values[dst_input];
1124 }
1125 }
1126 ic->set_input_tensors(input_tensors);
1127
1128 // Update the shapes of the outputs.
1129 return InferShapes(*node, ctx);
1130 }
1131
SetUnknownShape(const NodeDef * node,int output_port)1132 Status SetUnknownShape(const NodeDef* node, int output_port) {
1133 shape_inference::ShapeHandle shape =
1134 GetUnknownOutputShape(node, output_port);
1135 InferenceContext* ctx = GetContext(node);
1136 if (ctx == nullptr) {
1137 return errors::InvalidArgument("SetUnknownShape: Missing context");
1138 }
1139 if (output_port < 0 || output_port >= ctx->num_outputs()) {
1140 return errors::InvalidArgument(
1141 "SetUnknownShape: output_port must be in [0, ", ctx->num_outputs(),
1142 ") but was ", output_port);
1143 }
1144 ctx->set_output(output_port, shape);
1145 return OkStatus();
1146 }
1147
1148 struct ShapeId {
1149 const NodeDef* node;
1150 int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId1151 bool operator==(const ShapeId& other) const {
1152 return node == other.node && port_id == other.port_id;
1153 }
1154 };
1155 struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId1156 std::size_t operator()(const ShapeId& shp) const {
1157 return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
1158 }
1159 };
1160
1161 struct DimId {
1162 const NodeDef* node;
1163 int port_id;
1164 int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId1165 bool operator==(const DimId& other) const {
1166 return node == other.node && port_id == other.port_id &&
1167 dim_index == other.dim_index;
1168 }
1169 };
1170
1171 struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId1172 std::size_t operator()(const DimId& dim) const {
1173 return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
1174 dim.dim_index;
1175 }
1176 };
1177
1178 // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)1179 ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
1180 ShapeHandle shape1, ShapeHandle shape2) {
1181 if (shape1.SameHandle(shape2)) {
1182 return shape1;
1183 }
1184 InferenceContext* ctx = GetContext(node);
1185 ShapeHandle relaxed = shape1;
1186 const int rank = ctx->Rank(shape1);
1187 if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
1188 relaxed = GetUnknownOutputShape(node, port_index);
1189 } else {
1190 for (int d = 0; d < rank; ++d) {
1191 if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
1192 int64_t val1 = ctx->Value(ctx->Dim(shape1, d));
1193 int64_t val2 = ctx->Value(ctx->Dim(shape2, d));
1194 if (val1 != val2 || (val1 < 0 && val2 < 0)) {
1195 DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
1196 TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
1197 }
1198 }
1199 }
1200 }
1201 return relaxed;
1202 }
1203
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const1204 bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
1205 if (s1.SameHandle(s2)) {
1206 return true;
1207 }
1208 if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
1209 return false;
1210 }
1211 if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
1212 return true;
1213 }
1214 const int rank = InferenceContext::Rank(s1);
1215 for (int i = 0; i < rank; ++i) {
1216 if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
1217 InferenceContext::DimKnownRank(s2, i))) {
1218 int64_t val1 =
1219 InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
1220 int64_t val2 =
1221 InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
1222 if (val1 >= 0 && val2 >= 0 && val1 == val2) {
1223 continue;
1224 }
1225 return false;
1226 }
1227 }
1228 return true;
1229 }
1230
1231 // Return true if the annotated shape is compatible with shape inference
1232 // result. Examples:
1233 // Inferred shape: ?, annotated shape: [10, 10] -> true;
1234 // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
1235 // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
1236 // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1237 bool CompatibleShapes(ShapeHandle inferred_shape,
1238 ShapeHandle annotated_shape) const {
1239 if (inferred_shape.SameHandle(annotated_shape)) {
1240 return true;
1241 }
1242 if (!InferenceContext::RankKnown(inferred_shape)) {
1243 return true;
1244 }
1245 if (InferenceContext::Rank(inferred_shape) !=
1246 InferenceContext::Rank(annotated_shape)) {
1247 return false;
1248 }
1249 const int rank = InferenceContext::Rank(inferred_shape);
1250 for (int i = 0; i < rank; ++i) {
1251 if (!InferenceContext::DimKnownRank(inferred_shape, i)
1252 .SameHandle(
1253 InferenceContext::DimKnownRank(annotated_shape, i))) {
1254 int64_t val1 = InferenceContext::Value(
1255 InferenceContext::DimKnownRank(inferred_shape, i));
1256 int64_t val2 = InferenceContext::Value(
1257 InferenceContext::DimKnownRank(annotated_shape, i));
1258 if (val1 >= 0 && val1 != val2) {
1259 return false;
1260 }
1261 }
1262 }
1263 return true;
1264 }
1265
SameShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1266 bool SameShapes(ShapeHandle inferred_shape,
1267 ShapeHandle annotated_shape) const {
1268 if (inferred_shape.SameHandle(annotated_shape)) {
1269 return true;
1270 }
1271 if (InferenceContext::Rank(inferred_shape) !=
1272 InferenceContext::Rank(annotated_shape)) {
1273 return false;
1274 }
1275 const int rank = InferenceContext::Rank(inferred_shape);
1276 for (int i = 0; i < rank; ++i) {
1277 int64_t val1 = InferenceContext::Value(
1278 InferenceContext::DimKnownRank(inferred_shape, i));
1279 int64_t val2 = InferenceContext::Value(
1280 InferenceContext::DimKnownRank(annotated_shape, i));
1281 if (val1 != val2) {
1282 return false;
1283 }
1284 }
1285 return true;
1286 }
1287
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1288 bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1289 const std::vector<ShapeAndType>& st2) const {
1290 if (st1.size() != st2.size()) {
1291 return false;
1292 }
1293 for (int i = 0, st1_size = st1.size(); i < st1_size; ++i) {
1294 const ShapeAndType& s1 = st1[i];
1295 const ShapeAndType& s2 = st2[i];
1296 if (s1.dtype != s2.dtype) {
1297 return false;
1298 }
1299 if (!EquivalentShapes(s1.shape, s2.shape)) {
1300 return false;
1301 }
1302 }
1303 return true;
1304 }
1305
AddFunction(const NodeDef * function_node,const std::string & function_name)1306 Status AddFunction(const NodeDef* function_node,
1307 const std::string& function_name) {
1308 auto it = fun_to_grappler_function_item_.find(function_name);
1309 if (it != fun_to_grappler_function_item_.end()) {
1310 return OkStatus();
1311 }
1312
1313 const FunctionDef* function_def =
1314 CHECK_NOTNULL(function_library_.Find(function_name));
1315 GrapplerFunctionItem grappler_function_item;
1316 Status function_instantiated =
1317 MakeGrapplerFunctionItem(*function_def, function_library_,
1318 graph_def_version_, &grappler_function_item);
1319
1320 // If function instantiation failed we will skip it during shape inference.
1321 if (!function_instantiated.ok()) {
1322 VLOG(3) << "Failed to instantiate a function. Error: "
1323 << function_instantiated.error_message();
1324 fun_to_grappler_function_item_[function_def->signature().name()] =
1325 absl::nullopt;
1326 return OkStatus();
1327 }
1328
1329 if (static_cast<int>(grappler_function_item.inputs().size()) >
1330 function_node->input_size()) {
1331 return errors::FailedPrecondition(
1332 "Function input size should be smaller than node input size.");
1333 }
1334
1335 for (int i = grappler_function_item.inputs().size(),
1336 end = function_node->input_size();
1337 i < end; ++i) {
1338 const string& input = function_node->input(i);
1339 if (!IsControlInput(input)) {
1340 return errors::FailedPrecondition(
1341 "Found regular input (", input,
1342 ") instead of control nodes for node ", function_node->name());
1343 }
1344 }
1345
1346 fun_to_grappler_function_item_[function_def->signature().name()] =
1347 grappler_function_item;
1348
1349 return OkStatus();
1350 }
1351
AddNode(const NodeDef * node)1352 Status AddNode(const NodeDef* node) {
1353 NodeContext& node_ctx = node_to_context_[node];
1354 NameAttrList function;
1355 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
1356
1357 // For PartitionedCall, op_data represents the function info.
1358 TF_RETURN_IF_ERROR(
1359 function_library_.LookUp(function.name(), &node_ctx.op_data));
1360
1361 if (node_ctx.op_data->is_function_op) {
1362 TF_RETURN_IF_ERROR(AddFunction(node, function.name()));
1363 }
1364
1365 TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1366 &node_ctx.input_types,
1367 &node_ctx.output_types));
1368
1369 // Create the inference context for this node.
1370 const int num_inputs = node_ctx.input_types.size();
1371 std::vector<ShapeHandle> input_shapes(num_inputs);
1372 std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1373 input_handle_shapes_and_types(num_inputs);
1374 std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1375 std::vector<ShapeHandle> input_tensors_as_shapes;
1376
1377 node_ctx.inference_context.reset(new InferenceContext(
1378 graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
1379 input_tensors, input_tensors_as_shapes,
1380 std::move(input_handle_shapes_and_types)));
1381 const Status s = node_ctx.inference_context->construction_status();
1382 if (!s.ok()) {
1383 node_ctx.inference_context.reset(nullptr);
1384 }
1385 return s;
1386 }
1387
1388 private:
1389 // Return the one ShapeHandle used to denote a fully unknown shape for a node
1390 // output.
GetUnknownOutputShape(const NodeDef * node,int index)1391 ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1392 ShapeId id{node, index};
1393 auto it = unknown_shapes_.find(id);
1394 if (it != unknown_shapes_.end()) {
1395 return it->second;
1396 }
1397 InferenceContext* c = GetContext(node);
1398 ShapeHandle shp = c->UnknownShape();
1399 unknown_shapes_[id] = shp;
1400 return shp;
1401 }
1402 // Return the one ShapeHandle used to denote a fully unknown dimension for a
1403 // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1404 DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1405 int dim_id) {
1406 DimId id{node, index, dim_id};
1407 auto it = unknown_dims_.find(id);
1408 if (it != unknown_dims_.end()) {
1409 return it->second;
1410 }
1411 InferenceContext* c = GetContext(node);
1412 DimensionHandle dim = c->UnknownDim();
1413 unknown_dims_[id] = dim;
1414 return dim;
1415 }
1416
1417 // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1418 bool AllOutputValuesKnown(NodeContext* c) {
1419 InferenceContext* ic = c->inference_context.get();
1420 int c_output_tensors_as_shapes_size = c->output_tensors_as_shapes.size();
1421 int c_output_tensor_protos_size = c->output_tensor_protos.size();
1422 if (c_output_tensors_as_shapes_size < ic->num_outputs() &&
1423 c_output_tensor_protos_size < ic->num_outputs()) {
1424 return false;
1425 } else {
1426 // Checks if we can get output value via either output_tensor_proto or
1427 // output_tensors_as_shapes.
1428 for (int i = 0; i < ic->num_outputs(); i++) {
1429 if (c_output_tensor_protos_size > i &&
1430 c->output_tensor_protos[i] != nullptr) {
1431 continue;
1432 }
1433 if (c_output_tensors_as_shapes_size > i &&
1434 ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1435 bool no_unknown_dim_from_const = true;
1436 for (int32_t j = 0; j < ic->Rank(c->output_tensors_as_shapes[i]);
1437 ++j) {
1438 const auto dim = ic->Dim(c->output_tensors_as_shapes[i], j);
1439 if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
1440 no_unknown_dim_from_const = false;
1441 break;
1442 }
1443 }
1444 if (no_unknown_dim_from_const) {
1445 continue;
1446 }
1447 }
1448 return false;
1449 }
1450 }
1451 return true;
1452 }
1453
1454 // Returns true if all the output shapes are known.
AllOutputShapesKnown(NodeContext * c)1455 bool AllOutputShapesKnown(NodeContext* c) {
1456 InferenceContext* ic = c->inference_context.get();
1457 // Checks if all the output shapes are fully defined.
1458 for (int i = 0; i < ic->num_outputs(); i++) {
1459 if (!ic->FullyDefined(ic->output(i))) {
1460 return false;
1461 }
1462 }
1463 return true;
1464 }
1465
1466 // Returns true if we can infer output tensors' values -- we know values of
1467 // all the input tensors.
AllInputValuesKnown(NodeContext * c)1468 bool AllInputValuesKnown(NodeContext* c) {
1469 InferenceContext* ic = c->inference_context.get();
1470
1471 // Check inputs are fully defined and values are known.
1472 for (int i = 0; i < ic->num_inputs(); i++) {
1473 const Tensor* tensor = ic->input_tensor(i);
1474 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1475 // already converted it to ic->input_tensor(i);
1476 const ShapeHandle& input_tensors_as_shape =
1477 ic->input_tensors_as_shapes()[i];
1478 // Either input_tensor is valid or input_tensors_as_shape, which has
1479 // value of input tensors as shape format, should be fully defined.
1480 if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1481 return false;
1482 }
1483 }
1484 return true;
1485 }
1486
1487 // Returns true if we want to update output shapes and values with running
1488 // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64_t max_size)1489 bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64_t max_size) {
1490 InferenceContext* ic = c->inference_context.get();
1491
1492 // Due to the cost of running EvaluateNode(), we limit only to white listed
1493 // op types.
1494 if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1495 return false;
1496 }
1497
1498 // Check input dtypes are number types.
1499 for (const auto& input_type : c->input_types) {
1500 if (!IsNumericType(input_type)) {
1501 return false;
1502 }
1503 }
1504
1505 // Check output dtypes are number types.
1506 for (const auto& output_type : c->output_types) {
1507 if (!IsNumericType(output_type)) {
1508 return false;
1509 }
1510 }
1511
1512 // Check if the number of elements of each of input tensor is no larger than
1513 // the given max size.
1514 for (int i = 0; i < ic->num_inputs(); i++) {
1515 const Tensor* tensor = ic->input_tensor(i);
1516 const ShapeHandle& input_shape_handle = ic->input(i);
1517 if (tensor != nullptr) {
1518 if (tensor->NumElements() > max_size) {
1519 return false;
1520 }
1521 } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1522 return false;
1523 }
1524 }
1525
1526 // Check if we know the shape of each output tensor, and the number of
1527 // elements is larger than the given max size.
1528 for (int i = 0; i < ic->num_outputs(); i++) {
1529 const ShapeHandle& shape_handle = ic->output(i);
1530 if (!ic->FullyDefined(shape_handle) ||
1531 ic->Value(ic->NumElements(shape_handle)) > max_size) {
1532 return false;
1533 }
1534 }
1535 return true;
1536 }
1537
1538 // Create input tensors from the NodeContext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1539 void CreateInputTensors(NodeContext* c,
1540 std::vector<Tensor>* input_tensor_vector,
1541 TensorVector* inputs) {
1542 InferenceContext* ic = c->inference_context.get();
1543 for (int i = 0; i < ic->num_inputs(); i++) {
1544 if (ic->input_tensor(i)) {
1545 input_tensor_vector->at(i) = *ic->input_tensor(i);
1546 inputs->emplace_back(&input_tensor_vector->at(i));
1547 // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1548 // already converted it to ic->input_tensor(i);
1549 } else {
1550 // Create Tensor from input_tensors_as_shapes, and then emplace it
1551 // back to inputs.
1552 // Note that input_tensors_as_shapes is scalar or vector.
1553 const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1554 const DataType& data_type = c->input_types[i];
1555 int32_t rank = ic->Rank(shape_handle);
1556 if (rank < 1) {
1557 input_tensor_vector->at(i) = Tensor(data_type, {});
1558 } else {
1559 input_tensor_vector->at(i) = Tensor(data_type, {rank});
1560 }
1561 auto* tensor = &input_tensor_vector->at(i);
1562 if (data_type == DT_INT32) {
1563 auto flat = tensor->flat<int32>();
1564 for (int j = 0; j < rank; j++) {
1565 int32_t dim = ic->Value(ic->Dim(shape_handle, j));
1566 flat(j) = dim;
1567 }
1568 } else {
1569 auto flat = tensor->flat<int64_t>();
1570 for (int j = 0; j < rank; j++) {
1571 int64_t dim = ic->Value(ic->Dim(shape_handle, j));
1572 flat(j) = dim;
1573 }
1574 }
1575 inputs->emplace_back(tensor);
1576 }
1577 }
1578 }
1579
1580 // Run a node to infer output shapes and values, and add it to the
1581 // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1582 Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1583 InferenceContext* ic = c->inference_context.get();
1584
1585 // Input to EvaluateNode()
1586 TensorVector inputs;
1587 // Container for temporarily created tensor object.
1588 std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1589 CreateInputTensors(c, &input_tensor_vector, &inputs);
1590
1591 // Output for EvaluateNode() and output tensor clean up object.
1592 TensorVector outputs;
1593 auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1594 for (const auto& output : outputs) {
1595 if (output.tensor) {
1596 delete output.tensor;
1597 }
1598 }
1599 });
1600
1601 TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1602 &resource_mgr_, &outputs));
1603 c->output_tensors_as_shapes.resize(outputs.size());
1604 c->output_tensor_protos.resize(outputs.size(), nullptr);
1605 for (int k = 0, outputs_size = outputs.size(); k < outputs_size; k++) {
1606 const auto& t = outputs[k];
1607 // Override output shape.
1608 ShapeHandle output_shape;
1609 TF_RETURN_IF_ERROR(
1610 ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1611 if (ic->FullyDefined(ic->output(k)) &&
1612 !EquivalentShapes(ic->output(k), output_shape)) {
1613 LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1614 << ", inferred output shape "
1615 << "doesn't match for k=" << k << ": "
1616 << "ic->output(k): " << ic->DebugString(ic->output(k))
1617 << ", output_shape: " << ic->DebugString(output_shape)
1618 << " -- " << node.DebugString();
1619 }
1620 ic->set_output(k, output_shape);
1621 // Set output_tensors_as_shape.
1622 MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1623
1624 // Set output_tensor_protos.
1625 TensorProto tensor_proto;
1626 t->AsProtoTensorContent(&tensor_proto);
1627 const_tensors_to_propagate_.push_back(tensor_proto);
1628 c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1629 }
1630 return OkStatus();
1631 }
1632
1633 // Update output shapes with annotated information.
1634 // Currently only handle nodes with static shapes, i.e. shapes do not change
1635 // during execution.
1636 // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1637 Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1638 NodeContext* c) const {
1639 const auto& attr = node.attr();
1640 if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1641 attr.count(kOutputShapes) == 0)
1642 return OkStatus();
1643
1644 InferenceContext* ic = c->inference_context.get();
1645 int output_size = attr.at(kOutputShapes).list().shape_size();
1646
1647 for (int i = 0; i < ic->num_outputs(); i++) {
1648 // Annotated Switch node has only one output. Propagate the shape to all
1649 // the outputs.
1650 int shape_index = IsSwitch(node) ? 0 : i;
1651 if (shape_index >= output_size) {
1652 LOG(WARNING)
1653 << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1654 << node.name() << ", inferred output shape size "
1655 << ic->num_outputs() << ", annotated output shape size "
1656 << output_size;
1657 break;
1658 }
1659
1660 const TensorShapeProto& shape =
1661 attr.at(kOutputShapes).list().shape(shape_index);
1662 if (shape.dim().empty()) continue;
1663
1664 ShapeHandle output_shape;
1665 TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1666
1667 // Check if annotated shapes are incompatible with inferred shapes.
1668 if ((ic->FullyDefined(ic->output(i)) &&
1669 !SameShapes(ic->output(i), output_shape)) ||
1670 (!ic->FullyDefined(ic->output(i)) &&
1671 !CompatibleShapes(ic->output(i), output_shape))) {
1672 LOG(WARNING)
1673 << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1674 << node.name() << ", inferred output shape "
1675 << "doesn't match for i=" << i << ": "
1676 << "ic->output(k): " << ic->DebugString(ic->output(i))
1677 << ", annotated output shape: " << ic->DebugString(output_shape)
1678 << " -- " << node.DebugString();
1679 c->shape_incompatible = true;
1680 }
1681
1682 // Only use annotated shapes if the inference shape is unknown and
1683 // compatible with annotated shapes.
1684 if (!ic->FullyDefined(ic->output(i)) &&
1685 CompatibleShapes(ic->output(i), output_shape)) {
1686 VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1687 << node.name() << ", inferred output shape " << i << ": "
1688 << "ic->output(i): " << ic->DebugString(ic->output(i))
1689 << ", annotated output shape: " << ic->DebugString(output_shape)
1690 << " -- " << node.ShortDebugString();
1691 ic->set_output(i, output_shape);
1692 }
1693 }
1694
1695 return OkStatus();
1696 }
1697
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1698 Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1699 NodeContext* c) {
1700 // Propagate tensors and shape tensors unless the node is fed.
1701 // TODO(bsteiner) We should still propagate the shapes to the ports that
1702 // aren't fed in the case of a ShapeN node.
1703
1704 // Note that when propagating tensors_as_shapes, we use
1705 // c->input_tensors_as_shapes_to_progate instead of
1706 // ic->input_tensors_as_shapes. The former uses kUnknownDimFromConst if
1707 // UnknownDim is from Const tensor, and it is propagated through shape
1708 // inference. Before calling shape functions, we convert it to UnknownDim,
1709 // but instantiate a new UnknownDim to prevent incorrect symbolic shape
1710 // inference through UnknownDim from Const.
1711 InferenceContext* ic = c->inference_context.get();
1712 if (!is_fed) {
1713 if (IsConstant(node)) {
1714 const TensorProto& tensor_proto = node.attr().at("value").tensor();
1715 c->output_tensor_protos.resize(1);
1716 c->output_tensor_protos[0] = &tensor_proto;
1717 c->output_tensors_as_shapes.resize(1);
1718 MaybeTensorProtoToShape(ic, tensor_proto,
1719 &c->output_tensors_as_shapes[0]);
1720 } else if (IsRank(node)) {
1721 if (ic->RankKnown(ic->input(0))) {
1722 // Propagate rank value.
1723 int32_t rank = ic->Rank(ic->input(0));
1724 const_tensors_to_propagate_.push_back(
1725 MakeIntegerScalarTensorProto(DT_INT32, rank));
1726 c->output_tensor_protos.resize(1);
1727 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1728 }
1729 } else if (IsSize(node)) {
1730 DimensionHandle size = ic->NumElements(ic->input(0));
1731 if (ic->ValueKnown(size)) {
1732 // Propagate size value.
1733 int64_t sz = ic->Value(size);
1734 bool valid = false;
1735 if (node.attr().at("out_type").type() == DT_INT32) {
1736 if (sz < std::numeric_limits<int32>::max()) {
1737 const_tensors_to_propagate_.push_back(
1738 MakeIntegerScalarTensorProto(DT_INT32, sz));
1739 valid = true;
1740 }
1741 } else {
1742 const_tensors_to_propagate_.push_back(
1743 MakeIntegerScalarTensorProto(DT_INT64, sz));
1744 valid = true;
1745 }
1746 if (valid) {
1747 c->output_tensor_protos.resize(1);
1748 c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1749 }
1750 }
1751 } else if (IsShape(node)) {
1752 c->output_tensors_as_shapes.resize(1);
1753 c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1754 } else if (IsShapeN(node)) {
1755 c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1756 for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1757 c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1758 }
1759 } else if (node.op() == "ConcatV2") {
1760 bool valid = true;
1761 ShapeHandle result;
1762 for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1763 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[i];
1764 if (!ic->RankKnown(input)) {
1765 valid = false;
1766 break;
1767 } else if (i == 0) {
1768 result = input;
1769 } else {
1770 TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1771 }
1772 }
1773 if (valid) {
1774 c->output_tensors_as_shapes.resize(1);
1775 c->output_tensors_as_shapes[0] = result;
1776 }
1777 } else if (IsPack(node)) {
1778 // A Pack node concatenating scalars is often used to generate a shape.
1779 std::vector<DimensionHandle> dims;
1780 bool valid = true;
1781 for (int i = 0; i < ic->num_inputs(); ++i) {
1782 const Tensor* t = ic->input_tensor(i);
1783 if (t) {
1784 if (t->dims() != 0 ||
1785 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1786 valid = false;
1787 break;
1788 }
1789 int64_t size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1790 : t->scalar<int64_t>()();
1791 dims.push_back(size < 0 ? ic->MakeDim(kUnknownDimFromConst)
1792 : ic->MakeDim(size));
1793 } else {
1794 // Don't have tensor value, but use input_tensors_as_shapes, if
1795 // possible.
1796 const ShapeHandle& shape_handle =
1797 c->input_tensors_as_shapes_to_propagate[i];
1798 if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1799 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1800 dims.push_back(ic->Dim(shape_handle, 0));
1801 } else {
1802 // This is not from Const, but as it shouldn'be used as symbolic
1803 // unknown dim for different ops, we use kUnknownDimFromConst.
1804 dims.push_back(ic->MakeDim(kUnknownDimFromConst));
1805 }
1806 }
1807 }
1808 if (valid) {
1809 c->output_tensors_as_shapes.resize(1);
1810 c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1811 }
1812 } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1813 c->output_tensors_as_shapes.resize(1);
1814 c->output_tensors_as_shapes[0] =
1815 c->input_tensors_as_shapes_to_propagate[0];
1816 if (c->input_tensor_protos[0] != nullptr) {
1817 c->output_tensor_protos.resize(1);
1818 c->output_tensor_protos[0] = c->input_tensor_protos[0];
1819 }
1820 } else if (IsSlice(node)) {
1821 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1822 bool valid = ic->RankKnown(input);
1823 const Tensor* slice_offset = ic->input_tensor(1);
1824 valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1825 const Tensor* slice_size = ic->input_tensor(2);
1826 valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1827 if (valid) {
1828 int64_t start = slice_offset->dtype() == DT_INT32
1829 ? slice_offset->flat<int32>()(0)
1830 : slice_offset->flat<int64_t>()(0);
1831 int64_t size = (slice_size->dtype() == DT_INT32
1832 ? slice_size->flat<int32>()(0)
1833 : slice_size->flat<int64_t>()(0));
1834 ShapeHandle result;
1835 if (size == -1) {
1836 TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1837 } else {
1838 int64_t end = start + size;
1839 TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1840 }
1841 c->output_tensors_as_shapes.resize(1);
1842 c->output_tensors_as_shapes[0] = result;
1843 }
1844 } else if (IsStridedSlice(node)) {
1845 ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1846 bool valid = ic->RankKnown(input);
1847 const Tensor* slice_begin = ic->input_tensor(1);
1848 valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1849 const Tensor* slice_end = ic->input_tensor(2);
1850 valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1851 const Tensor* slice_stride = ic->input_tensor(3);
1852 valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1853
1854 if (node.attr().count("ellipsis_mask") > 0 &&
1855 node.attr().at("ellipsis_mask").i() != 0) {
1856 valid = false;
1857 }
1858 if (node.attr().count("new_axis_mask") > 0 &&
1859 node.attr().at("new_axis_mask").i() != 0) {
1860 valid = false;
1861 }
1862 if (node.attr().count("shrink_axis_mask") > 0 &&
1863 node.attr().at("shrink_axis_mask").i() != 0) {
1864 valid = false;
1865 }
1866 int begin_mask = 0;
1867 if (node.attr().count("begin_mask") > 0) {
1868 begin_mask = node.attr().at("begin_mask").i();
1869 }
1870 int end_mask = 0;
1871 if (node.attr().count("end_mask") > 0) {
1872 end_mask = node.attr().at("end_mask").i();
1873 }
1874 if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1875 valid = false;
1876 }
1877 if (valid) {
1878 int64_t begin = 0;
1879 if (begin_mask == 0) {
1880 begin = slice_begin->dtype() == DT_INT32
1881 ? slice_begin->flat<int32>()(0)
1882 : slice_begin->flat<int64_t>()(0);
1883 }
1884 int64_t end = std::numeric_limits<int64_t>::max();
1885 if (end_mask == 0) {
1886 end = (slice_end->dtype() == DT_INT32
1887 ? slice_end->flat<int32>()(0)
1888 : slice_end->flat<int64_t>()(0));
1889 }
1890 int64_t stride = slice_stride->dtype() == DT_INT32
1891 ? slice_stride->flat<int32>()(0)
1892 : slice_stride->flat<int64_t>()(0);
1893 ShapeHandle result;
1894 TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1895 c->output_tensors_as_shapes.resize(1);
1896 c->output_tensors_as_shapes[0] = result;
1897 }
1898 }
1899 }
1900
1901 if (aggressive_shape_inference_) {
1902 // Update output shapes with annotated information. This is optional.
1903 UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1904
1905 // Update output tensor values using EvaluateNode() if we can.
1906 // Due to the cost of EvaluateNode(), we run it only for certain op types
1907 // (white listed) and small integer tensors.
1908
1909 const int max_element_size = 17; // Max up to 4x4 matrix or similar.
1910 if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1911 !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1912 return OkStatus();
1913 }
1914 UpdateOutputShapesAndValues(node, c).IgnoreError(); // This is optional.
1915 }
1916 return OkStatus();
1917 }
1918
InferShapes(const NodeDef & node,NodeContext * c)1919 Status InferShapes(const NodeDef& node, NodeContext* c) {
1920 // Infer the shapes of output tensors.
1921 if (!c->op_data || c->op_data->shape_inference_fn == nullptr ||
1922 !c->inference_context->Run(c->op_data->shape_inference_fn).ok()) {
1923 // Annotate outputs with unknown shapes. Update output shapes with
1924 // annotated information later on if available.
1925 // Note that shape inference function may return an error, but we ignore
1926 // it, and use UnknownShape in that case.
1927 TF_RETURN_IF_ERROR(
1928 c->inference_context->Run(shape_inference::UnknownShape));
1929 }
1930 Status status = OkStatus();
1931 auto it = fed_ports_.find(node.name());
1932 const bool is_fed = it != fed_ports_.end();
1933 if (is_fed) {
1934 // It is possible to feed node output ports with tensors of any shape: as
1935 // a result, the shape of a fed port is completely unknown.
1936 for (const int output_port : it->second) {
1937 status.Update(SetUnknownShape(&node, output_port));
1938 }
1939 }
1940
1941 // Update NodeContext output fields after shape inference function runs.
1942 status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1943
1944 return status;
1945 }
1946
1947 private:
IsIntegerVector(const Tensor & tensor)1948 bool IsIntegerVector(const Tensor& tensor) {
1949 if (tensor.dims() == 1 &&
1950 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1951 return true;
1952 }
1953 return false;
1954 }
1955
IsIntegerScalar(const Tensor & tensor)1956 bool IsIntegerScalar(const Tensor& tensor) {
1957 if (tensor.dims() == 0 &&
1958 (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1959 tensor.NumElements() == 1) {
1960 return true;
1961 }
1962 return false;
1963 }
1964
MakeIntegerScalarTensorProto(const DataType dtype,const int64_t val)1965 TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1966 const int64_t val) {
1967 TensorProto tensor_proto;
1968 tensor_proto.set_dtype(dtype);
1969 // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1970 tensor_proto.mutable_tensor_shape();
1971 if (dtype == DT_INT32) {
1972 tensor_proto.add_int_val(val);
1973 } else if (dtype == DT_INT64) {
1974 tensor_proto.add_int64_val(val);
1975 }
1976 return tensor_proto;
1977 }
1978
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1979 bool MaybeTensorProtoToShape(InferenceContext* ic,
1980 const TensorProto& tensor_proto,
1981 ShapeHandle* tensors_as_shapes) {
1982 // Skip if dtype is not integer.
1983 if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1984 return false;
1985 }
1986 // Skip if the const tensor is too large.
1987 if (NumElementsFromTensorProto(tensor_proto) >
1988 kThresholdToSkipConstTensorInstantiation) {
1989 return false;
1990 }
1991 // Skip if shape is neither scalar nor vector.
1992 if (tensor_proto.tensor_shape().unknown_rank() ||
1993 tensor_proto.tensor_shape().dim_size() > 1) {
1994 return false;
1995 }
1996 Tensor tensor;
1997 if (!tensor.FromProto(tensor_proto)) {
1998 return false;
1999 }
2000 return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
2001 }
2002
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)2003 bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
2004 ShapeHandle* tensors_as_shapes) {
2005 // Integer tensors of rank one can also be interpreted as a shape
2006 // provided all their values are >= -1.
2007
2008 if (IsIntegerVector(tensor)) {
2009 bool has_values_smaller_than_minus_1 = false;
2010 std::vector<DimensionHandle> dims;
2011 for (int i = 0; i < tensor.NumElements(); i++) {
2012 int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
2013 : tensor.flat<int64_t>()(i);
2014 has_values_smaller_than_minus_1 |= (value < -1);
2015 // Mark this as UnknownDim from Const.
2016 dims.push_back(value < 0 ? ic->MakeDim(kUnknownDimFromConst)
2017 : ic->MakeDim(value));
2018 }
2019
2020 if (!has_values_smaller_than_minus_1) {
2021 *tensors_as_shapes = ic->MakeShape(dims);
2022 return true;
2023 }
2024 } else if (IsIntegerScalar(tensor)) {
2025 // Scalar constant.
2026 int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
2027 : tensor.flat<int64_t>()(0);
2028 if (value == -1) {
2029 // Scalar value -1 represents an unknown shape. If we would try to
2030 // MakeShape(MakeDim) with it, we would get vector of unknown size.
2031 *tensors_as_shapes = ic->UnknownShape();
2032 return true;
2033 } else if (value >= 0) {
2034 // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
2035 // It's a limitation as we use ShapeHandle as a means to pass values.
2036 *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
2037 return true;
2038 }
2039 }
2040 return false;
2041 }
2042
2043 const GraphView& graph_;
2044 int graph_def_version_;
2045 absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
2046 absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
2047 absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
2048 // Store function instantiations only for valid function. If function
2049 // instantiation failed it will have an `absl::nullopt`.
2050 absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
2051 fun_to_grappler_function_item_;
2052 FunctionLibraryDefinition function_library_;
2053 const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
2054 // Store TensorProtos for tensor value propagation. Note that we use deque,
2055 // not vector, as we use pointers to the TensorProtos in this container.
2056 // Vector may resize and copy the objects into a new buffer, then the existing
2057 // pointers become dangling pointers.
2058 std::deque<TensorProto> const_tensors_to_propagate_;
2059
2060 // For more aggressive shape and value inference.
2061 bool aggressive_shape_inference_;
2062 ResourceMgr resource_mgr_;
2063 };
2064
2065 // Keep track of shapes and dimensions in a graph.
2066 // In particular, use disjoint sets to track equivalence between shapes and
2067 // dims, and consolidate the information globally.
2068 class SymbolicShapeManager {
2069 public:
SymbolicShapeManager()2070 SymbolicShapeManager() {}
2071
Merge(ShapeHandle s1,ShapeHandle s2)2072 Status Merge(ShapeHandle s1, ShapeHandle s2) {
2073 if (!s1.IsSet() || !s2.IsSet()) {
2074 return OkStatus();
2075 }
2076 TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
2077 if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
2078 CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
2079 for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
2080 TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
2081 InferenceContext::DimKnownRank(s2, i)));
2082 }
2083 }
2084 return OkStatus();
2085 }
Merge(DimensionHandle d1,DimensionHandle d2)2086 Status Merge(DimensionHandle d1, DimensionHandle d2) {
2087 if (!d1.IsSet() || !d2.IsSet()) {
2088 return OkStatus();
2089 }
2090 return dims_.Merge(d1, d2);
2091 }
2092
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)2093 void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
2094 OpInfo::TensorProperties* properties) {
2095 properties->set_dtype(type);
2096 ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
2097 if (!InferenceContext::RankKnown(actual_shape)) {
2098 properties->mutable_shape()->set_unknown_rank(true);
2099 } else {
2100 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2101 shape_inference::DimensionHandle dim =
2102 InferenceContext::DimKnownRank(actual_shape, j);
2103 int64_t d = dims_.GetMergedValue(dim);
2104 properties->mutable_shape()->add_dim()->set_size(d);
2105 }
2106 }
2107 }
2108
2109 // Returns merged shape with merged dimensions.
GetMergedShape(InferenceContext * ic,ShapeHandle s)2110 ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
2111 const auto& actual_shape = shapes_.GetMergedValue(s);
2112 if (!InferenceContext::RankKnown(actual_shape)) {
2113 return ic->UnknownShape();
2114 } else {
2115 std::vector<DimensionHandle> dims;
2116 for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2117 shape_inference::DimensionHandle dim =
2118 InferenceContext::DimKnownRank(actual_shape, j);
2119 int64_t d = dims_.GetMergedValue(dim);
2120 // Symbolic shape manager may made some dims < -1, which causes errors
2121 // in creating Dimension.
2122 if (d < -1) {
2123 d = -1;
2124 }
2125 dims.push_back(ic->MakeDim(d));
2126 }
2127 return ic->MakeShape(dims);
2128 }
2129 }
2130
2131 private:
2132 DisjointSet<shape_inference::ShapeHandle> shapes_;
2133 DisjointSet<shape_inference::DimensionHandle> dims_;
2134 };
2135
2136 // Checks whether there is any conflict in merged shapes and dims in
2137 // SymbolicShapeManager.
ValidateSymbolicShapeManager(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2138 Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
2139 SymbolicShapeRefiner* refiner,
2140 SymbolicShapeManager* shape_manager) {
2141 if (!VLOG_IS_ON(1)) {
2142 return OkStatus();
2143 }
2144
2145 VLOG(1) << "Checking any conflicts in shapes and dimensions ...";
2146 int64_t num_incompatible_shapes = 0;
2147 for (const NodeDef& node : graph_def.node()) {
2148 auto ctx = refiner->GetNodeContext(&node);
2149 if (!ctx) {
2150 continue;
2151 }
2152 auto* ic = ctx->inference_context.get();
2153 for (int i = 0; i < ic->num_inputs(); ++i) {
2154 const auto& shape = ic->input(i);
2155 const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2156 if (!refiner->CompatibleShapes(shape, merged_shape)) {
2157 num_incompatible_shapes++;
2158 VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2159 << "for node " << node.name() << " input (" << i << ") "
2160 << ic->DebugString(shape)
2161 << " vs. merged: " << ic->DebugString(merged_shape);
2162 }
2163 }
2164 for (int i = 0; i < ic->num_outputs(); ++i) {
2165 const auto& shape = ic->output(i);
2166 const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2167 if (!refiner->CompatibleShapes(shape, merged_shape)) {
2168 num_incompatible_shapes++;
2169 VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2170 << "for node " << node.name() << " output (" << i << ") "
2171 << ic->DebugString(shape)
2172 << " vs. merged: " << ic->DebugString(merged_shape);
2173 }
2174 }
2175 }
2176 if (num_incompatible_shapes > 0) {
2177 VLOG(1) << "**** WARNING: " << num_incompatible_shapes
2178 << " incompatible shapes from SymbolicShapeManager.";
2179 } else {
2180 VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
2181 }
2182
2183 return OkStatus();
2184 }
2185
2186 // Log shape inference and its merged shapes.
VerboseShapeInferenceLogging(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2187 Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
2188 SymbolicShapeRefiner* refiner,
2189 SymbolicShapeManager* shape_manager) {
2190 // As logging all the nodes would generate too many lines, we by default
2191 // skip this detailed logging. Users may add nodes of interest to
2192 // node_names_for_logging to enable detailed logging.
2193 absl::flat_hash_set<std::string> node_names_for_logging = {};
2194 if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
2195 return OkStatus();
2196 }
2197
2198 auto should_log = [&node_names_for_logging](std::string node_name) {
2199 return node_names_for_logging.find(node_name) !=
2200 node_names_for_logging.end();
2201 };
2202
2203 for (const NodeDef& node : graph_def.node()) {
2204 if (!should_log(node.name())) {
2205 continue;
2206 }
2207 auto ctx = refiner->GetNodeContext(&node);
2208 if (!ctx) {
2209 continue;
2210 }
2211 auto* ic = ctx->inference_context.get();
2212 VLOG(3) << "Shape inference for node : " << node.name();
2213 VLOG(3) << ctx->DebugString(node);
2214 std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
2215 for (int i = 0; i < ic->num_inputs(); ++i) {
2216 absl::StrAppend(
2217 &merged_shapes, " input[", i, "] -- ",
2218 ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
2219 "\n");
2220 }
2221 for (int i = 0; i < ic->num_outputs(); ++i) {
2222 absl::StrAppend(
2223 &merged_shapes, " output[", i, "] -- ",
2224 ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
2225 "\n");
2226 }
2227 VLOG(3) << merged_shapes;
2228 VLOG(3) << "--------------------------------";
2229 VLOG(3) << "";
2230 }
2231
2232 return OkStatus();
2233 }
2234
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)2235 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
2236 SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
2237 const std::vector<ShapeAndType>& shapes_and_types,
2238 std::vector<ShapeAndType>* queue_shapes_and_types) {
2239 if (shapes_and_types.size() != queue_shapes_and_types->size()) {
2240 return errors::InvalidArgument(
2241 "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
2242 " vs ", queue_shapes_and_types->size());
2243 }
2244 for (size_t i = 0; i < shapes_and_types.size(); ++i) {
2245 const ShapeAndType& a = shapes_and_types[i];
2246 ShapeAndType& b = (*queue_shapes_and_types)[i];
2247 if (a.dtype != b.dtype) {
2248 return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
2249 i, ": ", DataTypeString(a.dtype), " vs ",
2250 DataTypeString(b.dtype));
2251 }
2252
2253 b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
2254 }
2255 return OkStatus();
2256 }
2257
2258 // Compute the output shape of the merge node as the union of the available
2259 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const2260 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
2261 const NodeDef* node,
2262 bool* new_shapes) const {
2263 InferenceContext* ic = shape_refiner->GetContext(node);
2264 if (!ic) {
2265 // Now we can run shape inference
2266 TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
2267 ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
2268 *new_shapes = true;
2269
2270 // Infer the shape of the second output once and for all since it never
2271 // changes.
2272 ShapeHandle out1 = ic->Scalar();
2273 if (ic->num_outputs() >= 2) ic->set_output(1, out1);
2274 }
2275
2276 ShapeHandle out;
2277 const std::vector<ShapeAndType>* out_handle = nullptr;
2278 bool out_initialized = false;
2279 for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
2280 *node, /*include_controlling_edges=*/false)) {
2281 InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
2282 if (!src_ic) {
2283 // Handling a loop for the first time, the back edge won't have any shape
2284 // info.
2285 continue;
2286 }
2287 ShapeHandle input = src_ic->output(fanin.src.port_id);
2288 ic->SetInput(fanin.dst.port_id, input);
2289 auto* input_handle =
2290 src_ic->output_handle_shapes_and_types(fanin.src.port_id);
2291 if (input_handle)
2292 ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
2293 if (!out_initialized) {
2294 out_initialized = true;
2295 out = input;
2296 out_handle = input_handle;
2297 } else {
2298 // Note here only out, not out_handle, is modified.
2299 out = shape_refiner->OutputAsUnion(node, 0, input, out);
2300 }
2301 }
2302
2303 if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
2304 ic->set_output(0, out);
2305 if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
2306 *new_shapes = true;
2307 }
2308
2309 return OkStatus();
2310 }
2311
2312 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)2313 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
2314 const NodeDef* node, bool* new_shapes) {
2315 InferenceContext* ic = shape_refiner->GetContext(node);
2316 if (!ic) {
2317 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
2318 ic = shape_refiner->GetContext(node);
2319 }
2320
2321 GraphView::InputPort port(node, 0);
2322 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
2323
2324 InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
2325 ShapeHandle input = src_ic->output(fanin.port_id);
2326 if (!ic->output(0).SameHandle(input)) {
2327 ic->SetInput(0, input);
2328 ic->set_output(0, input);
2329 *new_shapes = true;
2330 }
2331 auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
2332 if (outputs) {
2333 ic->set_input_handle_shapes_and_types(0, *outputs);
2334 ic->set_output_handle_shapes_and_types(0, *outputs);
2335 *new_shapes = true;
2336 }
2337 return OkStatus();
2338 }
2339
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const2340 Status GraphProperties::UpdateShapes(
2341 SymbolicShapeRefiner* shape_refiner,
2342 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2343 const NodeDef* n, bool* new_shapes) const {
2344 if (IsEnter(*n)) {
2345 // The Enter shape function always forwards an UnknownShape, so do the right
2346 // thing here.
2347 TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
2348 } else if (IsMerge(*n)) {
2349 // Properly handle merge nodes.
2350 TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
2351 } else if (IsEnqueue(*n)) {
2352 // Make sure the shapes of enqueued tensors are propagated to the queue
2353 // itself.
2354 TF_RETURN_IF_ERROR(
2355 UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
2356 } else if (IsQueue(*n)) {
2357 // Set shapes and types of Queue ops, if needed.
2358 TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
2359 } else {
2360 // Rely on regular TF shape refinement for all the other nodes.
2361 // UpdateNode calls UpdateFunction if a function node is detected.
2362 TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
2363 }
2364
2365 return OkStatus();
2366 }
2367
2368 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const2369 Status GraphProperties::PropagateShapes(
2370 SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
2371 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2372 int num_loops) const {
2373 // Limit the number of iterations to prevent infinite loops in the presence of
2374 // incorrect shape functions. The algorithm should converge in at most
2375 // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
2376 // The same applies to resources.
2377 VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
2378 << num_loops << " loops and " << resource_handles.size()
2379 << " resources" << std::endl;
2380
2381 const int64_t max_loop_length = item_.graph.node_size();
2382 const int64_t max_rank = 4;
2383 const int64_t max_loop_iterations =
2384 max_rank * max_loop_length * std::max<int64_t>(1, num_loops * num_loops);
2385 const int64_t num_queues = resource_handles.size();
2386 const int64_t max_resource_iterations = num_queues * num_queues * max_rank;
2387
2388 int64_t num_resource_iterations = 0;
2389 do {
2390 int64_t num_loop_iterations = 0;
2391 while (!new_shapes->empty() &&
2392 num_loop_iterations++ < max_loop_iterations) {
2393 const NodeDef* n = new_shapes->pop();
2394 bool updated = false;
2395 TF_RETURN_IF_ERROR(
2396 UpdateShapes(shape_refiner, resource_handles, n, &updated));
2397 if (updated) {
2398 for (const auto& fanout : shape_refiner->graph().GetFanouts(
2399 *n, /*include_controlled_nodes=*/false)) {
2400 new_shapes->push(fanout.node);
2401 }
2402 // Make sure the corresponding queue nodes are (re)processed.
2403 if (IsEnqueue(*n)) {
2404 auto it = resource_handles.find(n);
2405 if (it != resource_handles.end()) {
2406 new_shapes->push(it->second);
2407 }
2408 }
2409 }
2410 }
2411 } while (!new_shapes->empty() &&
2412 num_resource_iterations++ < max_resource_iterations);
2413
2414 if (!new_shapes->empty()) {
2415 return errors::Internal("Shape inference failed to converge");
2416 }
2417
2418 return OkStatus();
2419 }
2420
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2421 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
2422 SymbolicShapeRefiner* shape_refiner,
2423 bool* new_shapes) {
2424 auto* ctx = shape_refiner->GetNodeContext(queue_node);
2425 if (!ctx) {
2426 TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
2427 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
2428 }
2429 auto* ic = ctx->inference_context.get();
2430
2431 auto* outputs = ic->output_handle_shapes_and_types(0);
2432 if (outputs) {
2433 // Shapes and types are already set, presumably by Enqueue ops.
2434 return shape_refiner->UpdateNode(queue_node, new_shapes);
2435 }
2436
2437 if (queue_node->attr().count("shapes") <= 0 ||
2438 queue_node->attr().count("component_types") <= 0 ||
2439 queue_node->attr().at("shapes").list().shape_size() !=
2440 queue_node->attr().at("component_types").list().type_size()) {
2441 // Errors in shapes and component_types attr.
2442 return shape_refiner->UpdateNode(queue_node, new_shapes);
2443 }
2444
2445 // Extract types and shapes from Queue attr.
2446 const auto& shapes = queue_node->attr().at("shapes").list().shape();
2447 const auto& types = queue_node->attr().at("component_types").list().type();
2448 std::vector<ShapeAndType> shapes_and_types;
2449 for (int i = 0; i < types.size(); i++) {
2450 const auto& shape = shapes[i];
2451 ShapeHandle shape_handle;
2452 TF_RETURN_IF_ERROR(
2453 ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
2454 DataType data_type =
2455 queue_node->attr().at("component_types").list().type(i);
2456 ShapeAndType shape_and_type(shape_handle, data_type);
2457 shapes_and_types.push_back(shape_and_type);
2458 }
2459 ic->set_output_handle_shapes_and_types(0, shapes_and_types);
2460
2461 // Queue node is updated with output_handle_shapes_and_types, so set
2462 // new_shapes and ignore it from UpdateNoe().
2463 *new_shapes = true;
2464 bool dummy_new_shapes = false;
2465 return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
2466 }
2467
UpdateEnqueue(const NodeDef * enqueue_node,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2468 Status GraphProperties::UpdateEnqueue(
2469 const NodeDef* enqueue_node,
2470 const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2471 SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
2472 auto ctx = shape_refiner->GetNodeContext(enqueue_node);
2473 if (!ctx) {
2474 TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
2475 ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
2476 }
2477
2478 auto it = resource_handles.find(enqueue_node);
2479 if (it == resource_handles.end()) {
2480 // The corresponding queue was not found, there isn't much we can do.
2481 return OkStatus();
2482 }
2483 const NodeDef* qnode = it->second;
2484 auto qctx = shape_refiner->GetContext(qnode);
2485 if (!qctx) {
2486 return OkStatus();
2487 }
2488 auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2489
2490 // TODO(bsteiner): handle EnqueueMany as well.
2491 std::vector<ShapeAndType> shapes_and_types;
2492 for (int i = 1, end = ctx->input_types.size(); i < end; ++i) {
2493 GraphView::InputPort inp(enqueue_node, i);
2494 GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2495 InferenceContext* in = shape_refiner->GetContext(fanin.node);
2496 ShapeHandle input = in->output(fanin.port_id);
2497 ctx->inference_context->SetInput(i, input);
2498 shapes_and_types.push_back({input, ctx->input_types[i]});
2499 }
2500
2501 if (queue_handle_data == nullptr) {
2502 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2503 *new_shapes = true;
2504 } else {
2505 TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2506 shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2507 *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2508 shapes_and_types);
2509 qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2510 }
2511
2512 return OkStatus();
2513 }
2514
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference,bool include_input_tensor_values,bool include_output_tensor_values)2515 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2516 bool aggressive_shape_inference,
2517 bool include_input_tensor_values,
2518 bool include_output_tensor_values) {
2519 FunctionLibraryDefinition function_library(OpRegistry::Global(),
2520 item_.graph.library());
2521 absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
2522 if (!assume_valid_feeds) {
2523 for (const auto& feed : item_.feed) {
2524 SafeTensorId tensor_id = ParseTensorName(feed.first);
2525 fed_ports[tensor_id.node()].insert(tensor_id.index());
2526 }
2527 }
2528
2529 GraphView graph_view(&item_.graph);
2530
2531 // List the resources and the nodes using them. Also collect the Merge nodes,
2532 // fed nodes, and primary inputs.
2533 absl::flat_hash_map<const NodeDef*,
2534 std::pair<absl::flat_hash_set<const NodeDef*>,
2535 absl::flat_hash_set<const NodeDef*>>>
2536 resources;
2537 absl::flat_hash_set<const NodeDef*> merge_nodes;
2538 absl::flat_hash_set<const NodeDef*> fed_nodes;
2539 absl::flat_hash_set<const NodeDef*> primary_inputs;
2540 int num_loops = 0;
2541 for (const NodeDef& node : item_.graph.node()) {
2542 if (IsQueue(node)) {
2543 for (const GraphView::InputPort& fanout :
2544 graph_view.GetFanouts(node, false)) {
2545 if (IsEnter(*fanout.node)) {
2546 const NodeDef& enter = *fanout.node;
2547 for (const GraphView::InputPort& fanout :
2548 graph_view.GetFanouts(enter, false)) {
2549 if (IsEnqueue(*fanout.node)) {
2550 resources[&node].first.insert(fanout.node);
2551 } else if (IsDequeue(*fanout.node)) {
2552 resources[&node].second.insert(fanout.node);
2553 }
2554 }
2555 } else {
2556 if (IsEnqueue(*fanout.node)) {
2557 resources[&node].first.insert(fanout.node);
2558 } else if (IsDequeue(*fanout.node)) {
2559 resources[&node].second.insert(fanout.node);
2560 }
2561 }
2562 }
2563 }
2564 if (!HasRegularInputs(node)) {
2565 primary_inputs.insert(&node);
2566 } else if (IsMerge(node)) {
2567 merge_nodes.insert(&node);
2568 } else if (IsNextIteration(node)) {
2569 ++num_loops;
2570 }
2571 if (fed_ports.find(node.name()) != fed_ports.end()) {
2572 fed_nodes.insert(&node);
2573 }
2574 }
2575
2576 absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
2577 std::vector<TopologicalDependency> extra_deps;
2578 for (const auto& resource : resources) {
2579 for (const NodeDef* src : resource.second.first) {
2580 resource_handles[src] = resource.first;
2581 for (const NodeDef* dst : resource.second.second) {
2582 // Add control edges from enqueue to dequeue nodes to ensure they are
2583 // processed in their logical order.
2584 extra_deps.emplace_back(src, dst);
2585 }
2586 }
2587 }
2588
2589 std::vector<const NodeDef*> topo_order;
2590 Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2591 if (!s.ok()) {
2592 if (extra_deps.empty()) {
2593 return s;
2594 } else {
2595 // There is a loop between queues: we'll just use the graph topological
2596 // order. This will make the shape inference less precise but since this
2597 // isn't common it's not worth to figure out where to break the loop and
2598 // do a proper relaxation.
2599 TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2600 }
2601 }
2602
2603 // Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
2604 // of stack space.
2605 auto refiner = std::make_unique<SymbolicShapeRefiner>(
2606 graph_view, fed_ports, aggressive_shape_inference);
2607
2608 TopoQueue new_shapes(topo_order);
2609 // Also seed the propagation of shapes in the fanout of primary inputs.
2610 for (const NodeDef* node : primary_inputs) {
2611 new_shapes.push(node);
2612 }
2613 // Also seed the propagation of shapes in the fanout of fed nodes.
2614 for (const NodeDef* node : fed_nodes) {
2615 new_shapes.push(node);
2616 }
2617 // Propagate shapes normally.
2618 TF_RETURN_IF_ERROR(
2619 PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));
2620
2621 // Track shapes globally across the graph.
2622 std::unique_ptr<SymbolicShapeManager> shape_manager =
2623 std::make_unique<SymbolicShapeManager>();
2624 bool found_error = false;
2625 for (const NodeDef& node : item_.graph.node()) {
2626 auto node_ctx = refiner->GetContext(&node);
2627 if (!node_ctx) {
2628 continue;
2629 }
2630 // Skip any information that comes from fed nodes.
2631 if (fed_ports.find(node.name()) != fed_ports.end()) {
2632 VLOG(2) << "Skipping feed node shape: " << node.name();
2633 continue;
2634 }
2635 for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2636 if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2637 .ok()) {
2638 found_error = true;
2639 break;
2640 }
2641 }
2642 for (const auto& merged_dims : node_ctx->MergedDims()) {
2643 if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2644 found_error = true;
2645 break;
2646 }
2647 }
2648 if (found_error) {
2649 // The shapes aren't consistent, we can't infer safely: discard all the
2650 // information discovered so far.
2651 shape_manager = std::make_unique<SymbolicShapeManager>();
2652 break;
2653 }
2654 }
2655
2656 TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
2657 shape_manager.get()));
2658
2659 for (const NodeDef& node : item_.graph.node()) {
2660 VLOG(4) << "Filling in graph properties for node: " << node.name();
2661 auto ctx = refiner->GetNodeContext(&node);
2662 if (!ctx) {
2663 continue;
2664 }
2665
2666 auto* ic = ctx->inference_context.get();
2667
2668 // Fill input properties.
2669 {
2670 auto& input_properties = input_properties_[node.name()];
2671
2672 // Should always be empty, node names in graph are supposed to be unique.
2673 CHECK_EQ(input_properties.size(), 0);
2674
2675 input_properties.resize(ic->num_inputs());
2676 GraphView::InputPort input(&node, -1);
2677 for (int i = 0; i < ic->num_inputs(); ++i) {
2678 shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2679 &input_properties[i]);
2680 input.port_id = i;
2681 GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2682 if (include_input_tensor_values) {
2683 // Export tensor value to input_properties.value.
2684 if (IsConstant(*fanin.node)) {
2685 const TensorProto& raw_val =
2686 fanin.node->attr().at("value").tensor();
2687 *input_properties[i].mutable_value() = raw_val;
2688 } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
2689 ctx->input_tensor_protos[i] != nullptr) {
2690 *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2691 } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) >
2692 i &&
2693 IsShapeFullyDefinedIntegerVectorOrScalar(
2694 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2695 ctx->input_types[i])) {
2696 *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2697 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2698 ctx->input_types[i]);
2699 }
2700 }
2701 }
2702 }
2703
2704 // Fill output properties.
2705 {
2706 auto& output_properties = output_properties_[node.name()];
2707
2708 // Should always be empty, node names in graph are supposed to be unique.
2709 CHECK_EQ(output_properties.size(), 0);
2710
2711 output_properties.resize(ic->num_outputs());
2712 for (int i = 0; i < ic->num_outputs(); ++i) {
2713 shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2714 &output_properties[i]);
2715 auto converted_output_tensors_as_shapes =
2716 ReplaceUnknownDimFromConstWithUnknownDim(
2717 ic, ctx->output_tensors_as_shapes);
2718 if (include_output_tensor_values) {
2719 // Export tensor value to output_properties.value.
2720 if (IsConstant(node)) {
2721 // TODO(rmlarsen): Eliminate this copy.
2722 const TensorProto& raw_val = node.attr().at("value").tensor();
2723 *output_properties[i].mutable_value() = raw_val;
2724 } else if (static_cast<int>(ctx->output_tensor_protos.size()) > i &&
2725 ctx->output_tensor_protos[i] != nullptr) {
2726 *output_properties[i].mutable_value() =
2727 *ctx->output_tensor_protos[i];
2728 } else if (static_cast<int>(
2729 converted_output_tensors_as_shapes.size()) > i &&
2730 IsShapeFullyDefinedIntegerVectorOrScalar(
2731 ic, ic->output(i),
2732 converted_output_tensors_as_shapes[i],
2733 ctx->output_types[i])) {
2734 *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2735 ic, ic->output(i), converted_output_tensors_as_shapes[i],
2736 ctx->output_types[i]);
2737 }
2738 }
2739 }
2740 }
2741
2742 if (aggressive_shape_inference && ctx->shape_incompatible)
2743 incompatible_shape_nodes_.insert(node.name());
2744 }
2745
2746 if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
2747 LOG(WARNING) << incompatible_shape_nodes_.size()
2748 << " nodes have incompatible output shapes.";
2749
2750 // Help trace the unknown dimensions to their origins.
2751 VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2752 output_properties_);
2753
2754 TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
2755 shape_manager.get()));
2756
2757 return OkStatus();
2758 }
2759
InferDynamically(Cluster * cluster)2760 Status GraphProperties::InferDynamically(Cluster* cluster) {
2761 TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2762
2763 // Runs the model once to collect the shapes in the cost model.
2764 RunMetadata metadata;
2765 TF_RETURN_IF_ERROR(
2766 cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2767
2768 return InferFromCostGraph(metadata.cost_graph());
2769 }
2770
AnnotateOutputShapes(GraphDef * output_graph_def) const2771 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const {
2772 *output_graph_def = item_.graph;
2773 for (int i = 0; i < output_graph_def->node_size(); i++) {
2774 auto node = output_graph_def->mutable_node(i);
2775 AttrValue attr_output_shape;
2776 auto tensor_properties = GetOutputProperties(node->name());
2777 for (const auto& tensor_property : tensor_properties) {
2778 TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
2779 *proto = tensor_property.shape();
2780 NormalizeShapeForOutput(proto);
2781 }
2782 (*node->mutable_attr())["_output_shapes"] = std::move(attr_output_shape);
2783 }
2784 return OkStatus();
2785 }
2786
InferFromCostGraph(const CostGraphDef & cost_graph)2787 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2788 if (cost_graph.node_size() == 0) {
2789 LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2790 }
2791 std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2792 std::unordered_map<string, const NodeDef*> name_to_node; // Empty
2793 for (auto& node : cost_graph.node()) {
2794 name_to_cost[node.name()] = &node;
2795
2796 std::vector<OpInfo::TensorProperties> output_properties;
2797 for (const auto& out : node.output_info()) {
2798 OpInfo::TensorProperties properties;
2799 properties.set_dtype(out.dtype());
2800 *properties.mutable_shape() = out.shape();
2801 output_properties.push_back(properties);
2802 }
2803 output_properties_[node.name()] = output_properties;
2804 }
2805
2806 for (const auto& node : item_.graph.node()) {
2807 // Skip the nodes that are not in the cost graph: these are nodes that
2808 // aren't run, because they aren't in the intersection of transitive fan-in
2809 // of a fetch node and the transitive fan-out of an input, or nodes that
2810 // were optimized away by the optimizer.
2811 auto it = name_to_cost.find(node.name());
2812 if (it == name_to_cost.end()) {
2813 continue;
2814 }
2815 std::vector<OpInfo::TensorProperties> inputs =
2816 FindInputFeatures(node, name_to_cost, name_to_node);
2817
2818 input_properties_[node.name()] = inputs;
2819 }
2820 return OkStatus();
2821 }
2822
HasInputProperties(const string & node_name) const2823 bool GraphProperties::HasInputProperties(const string& node_name) const {
2824 return input_properties_.find(node_name) != input_properties_.end();
2825 }
2826
HasOutputProperties(const string & node_name) const2827 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2828 return output_properties_.find(node_name) != output_properties_.end();
2829 }
2830
2831 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2832 GraphProperties::GetInputProperties(const string& node_name) const {
2833 auto it = input_properties_.find(node_name);
2834 if (it != input_properties_.end()) {
2835 return it->second;
2836 }
2837 return missing_properties_;
2838 }
2839
2840 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2841 GraphProperties::GetOutputProperties(const string& node_name) const {
2842 auto it = output_properties_.find(node_name);
2843 if (it != output_properties_.end()) {
2844 return it->second;
2845 }
2846 return missing_properties_;
2847 }
2848
ClearInputProperties(const string & node_name)2849 void GraphProperties::ClearInputProperties(const string& node_name) {
2850 input_properties_.erase(node_name);
2851 }
ClearOutputProperties(const string & node_name)2852 void GraphProperties::ClearOutputProperties(const string& node_name) {
2853 output_properties_.erase(node_name);
2854 }
2855
2856 } // end namespace grappler
2857 } // end namespace tensorflow
2858