xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/costs/graph_properties.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/graph_properties.h"
17 
18 #include "absl/types/optional.h"
19 #include "tensorflow/core/common_runtime/function.h"
20 #include "tensorflow/core/common_runtime/graph_constructor.h"
21 #include "tensorflow/core/framework/common_shape_fns.h"
22 #include "tensorflow/core/framework/function.pb.h"
23 #include "tensorflow/core/framework/node_def_util.h"
24 #include "tensorflow/core/framework/tensor.pb.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/graph/tensor_id.h"
30 #include "tensorflow/core/grappler/costs/utils.h"
31 #include "tensorflow/core/grappler/mutable_graph_view.h"
32 #include "tensorflow/core/grappler/op_types.h"
33 #include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
34 #include "tensorflow/core/grappler/utils.h"
35 #include "tensorflow/core/grappler/utils/functions.h"
36 #include "tensorflow/core/grappler/utils/topological_sort.h"
37 #include "tensorflow/core/lib/gtl/cleanup.h"
38 #include "tensorflow/core/lib/gtl/flatset.h"
39 #include "tensorflow/core/lib/strings/str_util.h"
40 
41 namespace tensorflow {
42 namespace grappler {
43 
44 namespace {
45 
46 using shape_inference::DimensionHandle;
47 using shape_inference::InferenceContext;
48 using shape_inference::ShapeAndType;
49 using shape_inference::ShapeHandle;
50 using TensorVector = gtl::InlinedVector<TensorValue, 4>;
51 
52 // A large value for UnknownDim from Const used as a dim value in shape.
53 // Some ops treat "-1" specially, different from UnknownDim:
54 // e.g., shape input to Reshape op.
55 const int64_t kUnknownDimFromConst = INT64_MAX;
56 
57 // Skip const value instantiation if the number of elements in a const tensor
58 // is greater than this threshold.
59 const int kThresholdToSkipConstTensorInstantiation = 128;
60 
61 template <typename Handle>
62 struct HashHandle {
operator ()tensorflow::grappler::__anonb166909c0111::HashHandle63   std::size_t operator()(const Handle& h) const { return h.Handle(); }
64 };
65 template <typename Handle>
66 struct CompareHandle {
operator ()tensorflow::grappler::__anonb166909c0111::CompareHandle67   bool operator()(const Handle& h1, const Handle& h2) const {
68     return h1.SameHandle(h2);
69   }
70 };
71 
72 template <typename Handle>
73 struct HandleToObject {};
74 template <>
75 struct HandleToObject<ShapeHandle> {
76   typedef ShapeHandle Object;
77 
Unknowntensorflow::grappler::__anonb166909c0111::HandleToObject78   static ShapeHandle Unknown() { return ShapeHandle(); }
79 };
80 
81 template <>
82 struct HandleToObject<DimensionHandle> {
83   typedef int64_t Object;
84 
Unknowntensorflow::grappler::__anonb166909c0111::HandleToObject85   static int64_t Unknown() { return -1; }
86 };
87 
88 template <typename Handle>
89 struct Processor {};
90 
91 template <>
92 struct Processor<ShapeHandle> {
93   // Extract the shape or dim denoted by the handle.
ExtractValuetensorflow::grappler::__anonb166909c0111::Processor94   void ExtractValue(ShapeHandle h, ShapeHandle* result) { *result = h; }
95   // Merge the shapes or dims.
Mergetensorflow::grappler::__anonb166909c0111::Processor96   Status Merge(ShapeHandle h1, ShapeHandle h2, ShapeHandle* result) {
97     if (InferenceContext::RankKnown(*result)) {
98       // The result was initialized in a previous merge to a shape of known
99       // rank, make sure we preserve that information.
100       return OkStatus();
101     }
102     if (InferenceContext::RankKnown(h1)) {
103       *result = h1;
104     } else {
105       *result = h2;
106     }
107     return OkStatus();
108   }
109 };
110 
111 template <>
112 struct Processor<DimensionHandle> {
113   // Assign a negative id to unknown dimensions, starting at -2 (the -1 id
114   // reserved by TensorFlow).
ExtractValuetensorflow::grappler::__anonb166909c0111::Processor115   void ExtractValue(DimensionHandle d, int64_t* result) {
116     if (!InferenceContext::ValueKnown(d)) {
117       *result = -counter;
118       counter++;
119     } else {
120       int64_t val = InferenceContext::Value(d);
121       if (val >= 0) {
122         *result = val;
123       } else {
124         // A shape inference function generated an invalid dimension handle.
125         // Use a symbolic dimension to encode this.
126         *result = -counter;
127         counter++;
128       }
129     }
130   }
131 
132   // Merge the dimensions d1 and d2. Return the known shape if there is one,
133   // otherwise look for a symbolic shape. If there is no symbolic shape and no
134   // known shape, the shape if fully unknown so return -1.
Mergetensorflow::grappler::__anonb166909c0111::Processor135   Status Merge(DimensionHandle d1, DimensionHandle d2, int64_t* result) {
136     const int64_t dim1 = InferenceContext::Value(d1);
137     const int64_t dim2 = InferenceContext::Value(d2);
138 
139     if (dim1 >= 0 && dim2 >= 0) {
140       CHECK_EQ(dim1, dim2);
141       return RefineDim(dim1, result);
142     } else if (dim1 >= 0 && dim2 < 0) {
143       return RefineDim(dim1, result);
144     } else if (dim1 < 0 && dim2 >= 0) {
145       return RefineDim(dim2, result);
146     } else if (dim1 < -1) {
147       return RefineDim(dim1, result);
148     } else if (dim2 < -1) {
149       return RefineDim(dim2, result);
150     } else {
151       CHECK_EQ(dim1, dim2);
152       CHECK_EQ(-1, dim1);
153       return RefineDim(-1, result);
154     }
155     return OkStatus();
156   }
157 
158  private:
RefineDimtensorflow::grappler::__anonb166909c0111::Processor159   Status RefineDim(int64_t dim, int64_t* result) {
160     if (*result >= 0) {
161       if (!(*result == dim || dim < 0)) {
162         return errors::InvalidArgument("Inconsistent dimensions detected");
163       }
164     } else if (dim >= 0) {
165       *result = dim;
166     } else if (dim < *result) {
167       *result = dim;
168     }
169     return OkStatus();
170   }
171 
172   int64_t counter = 2;
173 };
174 
175 // Traditional Disjoint-Set datastructure with path compression.
176 // (https://en.wikipedia.org/wiki/Disjoint-set_data_structure)
177 template <typename Handle>
178 class DisjointSet {
179  public:
DisjointSet()180   DisjointSet() {}
~DisjointSet()181   ~DisjointSet() {
182     for (auto rep : nodes_) {
183       delete rep.second;
184     }
185   }
186 
187   Status Merge(Handle x, Handle y);
188   const typename HandleToObject<Handle>::Object GetMergedValue(Handle value);
189 
190  private:
191   // All the handles that belong to the same set are part of the same tree, and
192   // utimately represented by the root of that tree.
193   struct Rep {
194     // Parent in the tree used to encode the set.
195     Rep* parent;
196     // Rank in the tree, used to figure out how to compress the path to the root
197     // of the tree.
198     int rank;
199     // The handle.
200     typename HandleToObject<Handle>::Object value;
201   };
202 
203   // Create a new set for the value if none exists, or return its representative
204   // node otherwise.
205   Rep* Find(Handle value);
206 
207  private:
208   Processor<Handle> processor_;
209   absl::flat_hash_map<Handle, Rep*, HashHandle<Handle>, CompareHandle<Handle>>
210       nodes_;
211 };
212 
213 template <typename Handle>
214 const typename HandleToObject<Handle>::Object
GetMergedValue(Handle value)215 DisjointSet<Handle>::GetMergedValue(Handle value) {
216   Rep* rep = Find(value);
217   if (!rep) {
218     // We don't know anything about this handle.
219     return HandleToObject<Handle>::Unknown();
220   }
221   return rep->value;
222 }
223 
224 template <typename Handle>
Merge(Handle x,Handle y)225 Status DisjointSet<Handle>::Merge(Handle x, Handle y) {
226   Rep* x_root = Find(x);
227   Rep* y_root = Find(y);
228 
229   // x and y are already in the same set
230   if (x_root == y_root) {
231     return OkStatus();
232   }
233   // x and y are not in same set, so we merge them
234   // Use the occasion to strengthen what we know about the handle by merging the
235   // information about the 2 subsets.
236   if (x_root->rank < y_root->rank) {
237     TF_RETURN_IF_ERROR(processor_.Merge(y, x, &y_root->value));
238     x_root->parent = y_root;
239   } else if (x_root->rank > y_root->rank) {
240     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
241     y_root->parent = x_root;
242   } else {
243     TF_RETURN_IF_ERROR(processor_.Merge(x, y, &x_root->value));
244     // Arbitrarily make one root the new parent
245     y_root->parent = x_root;
246     x_root->rank = x_root->rank + 1;
247   }
248   return OkStatus();
249 }
250 
251 template <typename Handle>
Find(Handle value)252 typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
253   auto it = nodes_.find(value);
254   if (it == nodes_.end()) {
255     // This is the first time we process this handle, create an entry for it.
256     Rep* node = new Rep;
257     node->parent = node;
258     node->rank = 0;
259     processor_.ExtractValue(value, &node->value);
260     nodes_[value] = node;
261     return node;
262   }
263   // Return the representative for the set, which is the root of the tree. Apply
264   // path compression to speedup future queries.
265   Rep* node = it->second;
266   Rep* root = node->parent;
267   while (root != root->parent) {
268     root = root->parent;
269   }
270   while (node->parent != root) {
271     Rep* next = node->parent;
272     node->parent = root;
273     node = next;
274   }
275   return root;
276 }
277 
278 // TODO(dyoon): Move many helper functions in this file (including those within
279 // SymbolicShapeRefiner class) to shared utils.
IsEnqueue(const NodeDef & n)280 bool IsEnqueue(const NodeDef& n) {
281   return (n.op().find("Enqueue") != string::npos &&
282           n.op().find("EnqueueMany") == string::npos);
283 }
284 
IsDequeue(const NodeDef & n)285 bool IsDequeue(const NodeDef& n) {
286   return (n.op().find("Dequeue") != string::npos &&
287           n.op().find("DequeueMany") == string::npos);
288 }
289 
HasAnyUnknownDimensions(const TensorShapeProto & proto)290 bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
291   if (proto.unknown_rank()) {
292     return true;
293   }
294   for (const auto& dim : proto.dim()) {
295     if (dim.size() < 0) {
296       return true;
297     }
298   }
299   return false;
300 }
301 
302 // This really should be done in an external debugging tool
VerboseLogUnknownDimensionSources(const GraphDef & graph,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & input_properties_map,const absl::flat_hash_map<string,std::vector<OpInfo::TensorProperties>> & output_properties_map)303 void VerboseLogUnknownDimensionSources(
304     const GraphDef& graph,
305     const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
306         input_properties_map,
307     const absl::flat_hash_map<string, std::vector<OpInfo::TensorProperties>>&
308         output_properties_map) {
309   if (!VLOG_IS_ON(2)) {
310     return;
311   }
312 
313   VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
314 
315   // Find all nodes in the graph for which we
316   // do not have any unknown dimensions in their inputs, but
317   // we have some unknown dimensions in their outputs.
318   std::map<string, int> op_to_count;
319   for (const NodeDef& node : graph.node()) {
320     const auto& input_properties = input_properties_map.at(node.name());
321     const auto& output_properties = output_properties_map.at(node.name());
322 
323     bool has_unknown_inputs = false;
324     for (const auto& input_prop : input_properties) {
325       if (HasAnyUnknownDimensions(input_prop.shape())) {
326         has_unknown_inputs = true;
327         break;
328       }
329     }
330 
331     if (has_unknown_inputs) {
332       continue;
333     }
334 
335     for (const auto& output_prop : output_properties) {
336       if (HasAnyUnknownDimensions(output_prop.shape())) {
337         string inputs = "input_shapes=[";
338         for (const auto& input_prop : input_properties) {
339           inputs += PartialTensorShape::DebugString(input_prop.shape());
340         }
341         inputs += "]";
342 
343         string outputs = "output_shapes=[";
344         for (const auto& output_prop : output_properties) {
345           outputs += PartialTensorShape::DebugString(output_prop.shape());
346         }
347         outputs += "]";
348 
349         VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
350                 << inputs << ", " << outputs;
351 
352         op_to_count[node.op()]++;
353 
354         // don't log again for this node
355         break;
356       }
357     }
358   }
359   VLOG(2) << "Op types with known inputs, but with unknown output dimensions "
360           << "(format: <op_type> (<count>)):";
361   for (const auto& p : op_to_count) {
362     VLOG(2) << p.first << " (" << p.second << ")";
363   }
364 }
365 
366 // Helper function to convert kUnknownDimFromConst into UnknownDim.
ReplaceUnknownDimFromConstWithUnknownDim(InferenceContext * ic,const std::vector<ShapeHandle> & shapes)367 std::vector<ShapeHandle> ReplaceUnknownDimFromConstWithUnknownDim(
368     InferenceContext* ic, const std::vector<ShapeHandle>& shapes) {
369   std::vector<ShapeHandle> converted_shapes(shapes.size());
370   for (int i = 0, shapes_size = shapes.size(); i < shapes_size; i++) {
371     const auto& shape = shapes[i];
372     if (!ic->RankKnown(shape)) {
373       converted_shapes[i] = shape;
374       continue;
375     }
376     bool just_copy = true;
377     std::vector<DimensionHandle> dims;
378     for (int32_t i = 0; i < ic->Rank(shape); ++i) {
379       DimensionHandle dim = ic->Dim(shape, i);
380       if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
381         just_copy = false;
382         dims.push_back(ic->UnknownDim());
383       } else {
384         dims.push_back(dim);
385       }
386     }
387     if (just_copy) {
388       converted_shapes[i] = shape;
389       continue;
390     }
391     converted_shapes[i] = ic->MakeShape(dims);
392   }
393   return converted_shapes;
394 }
395 
396 // Returned tensor's shape is like `shape`, and its values and dtype are from
397 // `tensor_as_shape` and `dtype`.
MakeTensorProtoFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)398 TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
399                                      const ShapeHandle& shape,
400                                      const ShapeHandle& tensor_as_shape,
401                                      const DataType& dtype) {
402   TensorProto tensor_proto;
403   tensor_proto.set_dtype(dtype);
404   auto* shape_proto = tensor_proto.mutable_tensor_shape();
405   if (ic->Rank(shape) == 1) {
406     shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
407   }
408   // For a scalar tensor, tensor_shape field will be left empty; no dim.
409   for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
410     int64_t value = ic->Value(ic->Dim(tensor_as_shape, i));
411     if (dtype == DT_INT32) {
412       tensor_proto.add_int_val(value);
413     } else {
414       tensor_proto.add_int64_val(value);
415     }
416   }
417   return tensor_proto;
418 }
419 
420 // Returns a Const NodeDef with tensor `tensor_proto` and dtype = `dtype`.
MakeConstNodeDefFromTensorProto(InferenceContext * ic,const TensorProto & tensor_proto,const DataType & dtype)421 NodeDef MakeConstNodeDefFromTensorProto(InferenceContext* ic,
422                                         const TensorProto& tensor_proto,
423                                         const DataType& dtype) {
424   NodeDef const_node;
425   const_node.set_name("const_from_shape");
426   const_node.set_op("Const");
427   auto* attr = const_node.mutable_attr();
428   (*attr)["dtype"].set_type(dtype);
429   auto* tensor = (*attr)["value"].mutable_tensor();
430   *tensor = tensor_proto;
431   return const_node;
432 }
433 
434 // Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
435 // and dtype = `dtype`.
MakeConstNodeDefFromShape(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)436 NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
437                                   const ShapeHandle& shape,
438                                   const ShapeHandle& tensor_as_shape,
439                                   const DataType& dtype) {
440   return MakeConstNodeDefFromTensorProto(
441       ic, MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype), dtype);
442 }
443 
IsNumericType(const DataType dtype)444 bool IsNumericType(const DataType dtype) {
445   static const gtl::FlatSet<DataType>* const kRealNumberTypes =
446       CHECK_NOTNULL((new gtl::FlatSet<DataType>{
447           // Floating point.
448           DT_BFLOAT16,
449           DT_HALF,
450           DT_FLOAT,
451           DT_DOUBLE,
452           // Int / UInt.
453           DT_INT8,
454           DT_INT16,
455           DT_INT32,
456           DT_INT64,
457           DT_UINT8,
458           DT_UINT16,
459           DT_UINT32,
460           DT_UINT64,
461           // Quantized Int.
462           DT_QINT8,
463           DT_QUINT8,
464           DT_QINT16,
465           DT_QUINT16,
466           DT_QINT32,
467           // Bool.
468           DT_BOOL,
469       }));
470   return kRealNumberTypes->find(dtype) != kRealNumberTypes->end();
471 }
472 
473 // Returns the number of elements in the input (const) tensor.
474 // -1 if the tensor has no shape or unknown rank.
NumElementsFromTensorProto(const TensorProto & tensor_proto)475 uint64 NumElementsFromTensorProto(const TensorProto& tensor_proto) {
476   if (!tensor_proto.has_tensor_shape()) {
477     return -1;
478   }
479   const auto& tensor_shape_proto = tensor_proto.tensor_shape();
480   if (tensor_shape_proto.unknown_rank()) {
481     return -1;
482   }
483   int64_t num_elements = 1;
484   for (const auto& dim : tensor_shape_proto.dim()) {
485     // Note that in some cases, dim.size() can be zero (e.g., empty vector).
486     num_elements *= dim.size();
487   }
488   return num_elements;
489 }
490 
491 }  // namespace
492 
493 // Note that tensor_as_shape input should not include kUnknownDimFromConst.
494 // This function check kUnknownDimFromConst, but will log WARNING.
495 // If checking input_tensors_as_shape_to_propgate or output_tensors_as_shape,
496 // which may include kUnknownDimFromConst, run
497 // convert it using ReplaceUnknownDimFromConstWithUnknownDim() before.
IsShapeFullyDefinedIntegerVectorOrScalar(InferenceContext * ic,const ShapeHandle & shape,const ShapeHandle & tensor_as_shape,const DataType & dtype)498 bool IsShapeFullyDefinedIntegerVectorOrScalar(
499     InferenceContext* ic, const ShapeHandle& shape,
500     const ShapeHandle& tensor_as_shape, const DataType& dtype) {
501   if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
502       !ic->FullyDefined(tensor_as_shape) ||
503       (dtype != DT_INT32 && dtype != DT_INT64)) {
504     return false;
505   }
506   // Also check whether any dim in tensor_as_shape is kUnknownDimFromConst.
507   for (int32_t i = 0; i < ic->Rank(tensor_as_shape); ++i) {
508     DimensionHandle dim = ic->Dim(tensor_as_shape, i);
509     if (ic->Value(dim) == kUnknownDimFromConst) {
510       LOG(WARNING) << "IsShapeFullyDefinedIntegerVectorOrScalar(): "
511                    << "tensor_as_shape input includes kUnknownDimFromConst -- "
512                    << ic->DebugString(tensor_as_shape);
513       return false;
514     }
515   }
516   return true;
517 }
518 
519 // Queue of nodes to process. Nodes can be enqueued in any order, but will be
520 // dequeued in (roughly) topological order. Propagating shapes following a
521 // topological ordering isn't required for correctness but helps speed things up
522 // since it avoids processing the same node multiple times as its inputs
523 // information is refined.
524 class TopoQueue {
525  public:
TopoQueue(const std::vector<const NodeDef * > & topo_order)526   explicit TopoQueue(const std::vector<const NodeDef*>& topo_order)
527       : topo_order_(TopoOrder(topo_order)) {}
528 
push(const NodeDef * n)529   void push(const NodeDef* n) { queue_.emplace(n, topo_order_.at(n)); }
530 
pop()531   const NodeDef* pop() {
532     CHECK(!empty());
533     auto it = queue_.begin();
534     const NodeDef* n = it->first;
535     queue_.erase(it);
536     return n;
537   }
538 
empty() const539   bool empty() const { return queue_.empty(); }
size() const540   std::size_t size() const { return queue_.size(); }
541 
542  private:
543   using NodeAndId = std::pair<const NodeDef*, int>;
544   // Graph nodes are created in (roughly) topological order. Therefore we can
545   // use their id to ensure they're sorted topologically.
546   struct OrderByIdAscending {
operator ()tensorflow::grappler::TopoQueue::OrderByIdAscending547     bool operator()(const NodeAndId& lhs, const NodeAndId& rhs) const {
548       return lhs.second < rhs.second;
549     }
550   };
551 
TopoOrder(const std::vector<const NodeDef * > & topo_order) const552   const absl::flat_hash_map<const NodeDef*, int> TopoOrder(
553       const std::vector<const NodeDef*>& topo_order) const {
554     absl::flat_hash_map<const NodeDef*, int> map;
555     map.reserve(topo_order.size());
556     for (int i = 0, topo_order_size = topo_order.size(); i < topo_order_size;
557          ++i) {
558       map.emplace(topo_order[i], i);
559     }
560     return map;
561   }
562 
563   const absl::flat_hash_map<const NodeDef*, int> topo_order_;
564   std::set<NodeAndId, OrderByIdAscending> queue_;
565 };
566 
567 
IsAllowListedOpTypeForEvaluateNode(const string & op_type)568 bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) {
569   static const gtl::FlatSet<string>* const kOpTpeAllowlist =
570       CHECK_NOTNULL((new gtl::FlatSet<string>{
571           // Unary arithmetic ops
572           "Floor",
573           "Round",
574           "Sqrt",
575           "Square",
576           "Sign",
577           // Binary arithmetic ops
578           "Add",
579           "AddV2",
580           "Div",
581           "FloorDiv",
582           "FloorMod",
583           "Greater",
584           "GreaterEqual",
585           "Less",
586           "LessEqual",
587           "LogicalAnd",
588           "LogicalNot",
589           "LogicalOr",
590           "Maximum",
591           "Minimum",
592           "Mod",
593           "Mul",
594           "NotEqual",
595           "QuantizedAdd",
596           "QuantizedMul",
597           "SquareDifference",
598           "Sub",
599           "TruncateDiv",
600           "TruncateMod",
601           "RealDiv",
602           // N-ary arithmetic ops
603           "AddN",
604           // Others
605           "StridedSlice",
606           "OnesLike",
607           "ZerosLike",
608           "Concat",
609           "ConcatV2",
610           "Split",
611           "Range",
612           "Fill",
613           "Cast",
614           "Prod",
615           "Unpack",
616           "GatherV2",
617           "Pack",
618           // Used in batch_gather_nd: tensorflow/python/ops/array_ops.py
619           "ExpandDims",
620       }));
621   return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end();
622 }
623 
624 // Negative shape size of '-1' represents unknown, while negative shape sizes
625 // less than -1 represent unknown symbolic shapes (e.g. the shape of [-5, 5, -1,
626 // -5] really means [x, 5, ?, x]). Before we can output the tensors as shapes,
627 // we need to normalize them: mark all values <-1 as "unknown" (-1).
NormalizeShapeForOutput(TensorShapeProto * shape)628 static void NormalizeShapeForOutput(TensorShapeProto* shape) {
629   for (int i = 0; i < shape->dim_size(); i++) {
630     if (shape->dim(i).size() < -1) {
631       VLOG(2) << "Normalizing dimension: " << i << " from "
632               << shape->dim(i).size() << " to -1";
633       shape->mutable_dim(i)->set_size(-1);
634     }
635   }
636 }
637 
638 // Processes symbolic shapes.
639 // Each symbolic shape or dimension is represented by a handle. Unlike the TF
640 // shape refiner which creates new handles every time it processes an unknown
641 // shape/dimension, the symbolic shape refiner assigns a specific handle to each
642 // unknown shape/dimension of a given node.
643 class SymbolicShapeRefiner {
644  public:
SymbolicShapeRefiner(const GraphView & graph,const absl::flat_hash_map<string,absl::flat_hash_set<int>> & fed_ports,const bool aggressive_shape_inference)645   explicit SymbolicShapeRefiner(
646       const GraphView& graph,
647       const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports,
648       const bool aggressive_shape_inference)
649       : graph_(graph),
650         function_library_(OpRegistry::Global(), graph.graph()->library()),
651         fed_ports_(fed_ports),
652         aggressive_shape_inference_(aggressive_shape_inference) {
653     graph_def_version_ = graph.graph()->versions().producer();
654     node_to_context_.reserve(graph.graph()->node_size());
655   }
656 
graph() const657   const GraphView& graph() const { return graph_; }
658 
659   struct NodeContext {
660     const OpRegistrationData* op_data;
661     DataTypeVector input_types;
662     DataTypeVector output_types;
663     std::unique_ptr<InferenceContext> inference_context;
664     // Additional info for propagating tensor values and tensor shapes.
665     std::vector<const TensorProto*> input_tensor_protos;
666     std::vector<const TensorProto*> output_tensor_protos;
667     // This is the same to inference_context->input_tensors_as_shapes, except
668     // that some UnknownDims (-1) can be kUnknownDimFromConst.
669     std::vector<ShapeHandle> input_tensors_as_shapes_to_propagate;
670     std::vector<ShapeHandle> output_tensors_as_shapes;
671 
672     // Output shapes incompatible between annotation and shape inference.
673     bool shape_incompatible = false;
674 
675     // Similar to DebugString() in InferenceContext, but prints out
676     // kUnknownDimFromConst properly.
StringifyShapeHandletensorflow::grappler::SymbolicShapeRefiner::NodeContext677     std::string StringifyShapeHandle(ShapeHandle s) {
678       auto* ic = inference_context.get();
679       if (ic->RankKnown(s)) {
680         std::vector<std::string> vals;
681         for (int i = 0; i < ic->Rank(s); i++) {
682           DimensionHandle d = ic->Dim(s, i);
683           if (ic->ValueKnown(d) && ic->Value(d) == kUnknownDimFromConst) {
684             vals.push_back("?(Const)");
685           } else {
686             vals.push_back(ic->DebugString(d));
687           }
688         }
689         return strings::StrCat("[", absl::StrJoin(vals, ","), "]");
690       } else {
691         return "?";
692       }
693     }
694 
DebugStringtensorflow::grappler::SymbolicShapeRefiner::NodeContext695     std::string DebugString(const NodeDef& node) {
696       std::string output;
697       auto* ic = inference_context.get();
698       absl::StrAppend(
699           &output, node.name(), " [", node.op(), "] has ", ic->num_inputs(),
700           (ic->num_inputs() > 1 ? " inputs and " : " input and "),
701           ic->num_outputs(), (ic->num_outputs() > 1 ? " outputs" : " output"));
702       if (op_data->is_function_op) {
703         absl::StrAppend(&output, " (function op)");
704       }
705       absl::StrAppend(&output, ": \n");
706 
707       for (int i = 0; i < ic->num_inputs(); i++) {
708         absl::StrAppend(&output, " input [", i, "] ", node.input(i),
709                         " -- type: ", DataTypeString(input_types.at(i)),
710                         ", shape: ", ic->DebugString(ic->input(i)),
711                         ", tensor: ");
712         Tensor t1;
713         int input_tensor_protos_size = input_tensor_protos.size();
714         if (input_tensor_protos_size > i &&
715             input_tensor_protos.at(i) != nullptr &&
716             t1.FromProto(*input_tensor_protos.at(i))) {
717           absl::StrAppend(&output, t1.DebugString(), ", tensor_as_shape: ");
718         } else {
719           absl::StrAppend(&output, " null, tensor_as_shape: ");
720         }
721         int input_tensors_as_shapes_to_propagate_size =
722             input_tensors_as_shapes_to_propagate.size();
723         if (input_tensors_as_shapes_to_propagate_size > i) {
724           absl::StrAppend(
725               &output,
726               StringifyShapeHandle(input_tensors_as_shapes_to_propagate.at(i)),
727               "\n");
728         } else {
729           absl::StrAppend(&output, " null\n");
730         }
731       }
732       for (int i = 0; i < ic->num_outputs(); i++) {
733         absl::StrAppend(&output, " output [", i,
734                         "] -- type: ", DataTypeString(output_types.at(i)),
735                         ", shape: ", ic->DebugString(ic->output(i)),
736                         ", tensor: ");
737         Tensor t2;
738         int output_tensor_protos_size = output_tensor_protos.size();
739         if (output_tensor_protos_size > i &&
740             output_tensor_protos.at(i) != nullptr &&
741             t2.FromProto(*output_tensor_protos.at(i))) {
742           absl::StrAppend(&output, t2.DebugString(), ", tensor_as_shape: ");
743         } else {
744           absl::StrAppend(&output, " null, tensor_as_shape: ");
745         }
746         int output_tensors_as_shapes_size = output_tensors_as_shapes.size();
747         if (output_tensors_as_shapes_size > i) {
748           absl::StrAppend(&output,
749                           StringifyShapeHandle(output_tensors_as_shapes.at(i)),
750                           "\n");
751         } else {
752           absl::StrAppend(&output, " null\n");
753         }
754       }
755       return output;
756     }
757   };
758 
GetNodeContext(const NodeDef * node)759   NodeContext* GetNodeContext(const NodeDef* node) {
760     auto it = node_to_context_.find(node);
761     if (it == node_to_context_.end()) {
762       return nullptr;
763     }
764     return &it->second;
765   }
766 
GetContext(const NodeDef * node)767   InferenceContext* GetContext(const NodeDef* node) {
768     auto it = node_to_context_.find(node);
769     if (it == node_to_context_.end()) {
770       return nullptr;
771     }
772     return it->second.inference_context.get();
773   }
774 
775   // Forward the shapes from the function input nodes, PartitionedCalls or
776   // StatefulPartitionedCall to
777   // the argument nodes (which are Placeholder nodes), then
778   // perform shape inference on the function body.
779   //
780   // Propagate shape information of final function body node
781   // to function node `function_node`.
782   //
783   // In the event of an error, UpdateNode will simply set `function_node`'s
784   // output shape to be Unknown.
UpdateFunction(const NodeDef * function_node)785   Status UpdateFunction(const NodeDef* function_node) {
786     NameAttrList function;
787     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*function_node, &function));
788     auto it = fun_to_grappler_function_item_.find(function.name());
789     if (it == fun_to_grappler_function_item_.end()) {
790       return errors::InvalidArgument(
791           function.name(),
792           " was not previously added to SymbolicShapeRefiner.");
793     }
794 
795     const absl::optional<GrapplerFunctionItem>& maybe_grappler_function_item =
796         it->second;
797     if (!maybe_grappler_function_item.has_value()) {
798       VLOG(3) << "Skip failed to instantiate function call: function_name="
799               << function.name();
800 
801       auto* ctx = GetNodeContext(function_node);
802       auto* ic = ctx->inference_context.get();
803       for (int i = 0; i < ic->num_outputs(); ++i) {
804         TF_RETURN_IF_ERROR(SetUnknownShape(function_node, i));
805       }
806 
807       return OkStatus();
808     }
809 
810     // Copy (not reference) so that changes we make here (e.g., replacing
811     // _Arg with Const and _Retval with Identity) don't affect one in
812     // fun_to_grappler_function_item_.
813     GrapplerFunctionItem grappler_function_item = *maybe_grappler_function_item;
814     MutableGraphView gv(&grappler_function_item.graph);
815 
816     // Forward shapes from function input nodes to argument nodes.
817     for (int i = 0, end = grappler_function_item.inputs().size(); i < end;
818          ++i) {
819       auto& fun_input = grappler_function_item.input(i);
820       NodeDef* fun_node = gv.GetNode(fun_input.node_name);
821       const TensorId input_tensor = ParseTensorName(function_node->input(i));
822 
823       if (IsControlInput(input_tensor)) {
824         return errors::FailedPrecondition(
825             "Function inputs should not contain control nodes.");
826       }
827 
828       const NodeDef* input_node = graph_.GetNode(input_tensor.node());
829       if (input_node == nullptr) {
830         return errors::FailedPrecondition(input_tensor.node(),
831                                           " was not found in the graph.");
832       }
833 
834       InferenceContext* input_ic = GetContext(input_node);
835       if (input_ic == nullptr) {
836         return errors::FailedPrecondition(
837             "Inference context has not been created for ", input_tensor.node());
838       }
839 
840       int output_port_num = input_tensor.index();
841       AttrValue attr_output_shape;
842       TensorShapeProto proto;
843       const auto handle = input_ic->output(output_port_num);
844       input_ic->ShapeHandleToProto(handle, &proto);
845       // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
846       NormalizeShapeForOutput(&proto);
847       // _Arg op's output shape uses _output_shapes attr.
848       AttrValue output_attr;
849       output_attr.mutable_list()->add_shape()->Swap(&proto);
850       (*fun_node->mutable_attr())["_output_shapes"] = output_attr;
851 
852       // If dtype is DT_RESOURCE, ops that read _Arg op use _handle_dtypes and
853       // _handle_shapes attr for its shapes and dtypes.
854       if (fun_input.data_type == DT_RESOURCE) {
855         auto* shapes_and_types =
856             input_ic->output_handle_shapes_and_types(output_port_num);
857         if (shapes_and_types != nullptr && !shapes_and_types->empty()) {
858           AttrValue dtype_attr;
859           AttrValue shape_attr;
860           for (const auto& shape_and_type : *shapes_and_types) {
861             const auto& dtype = shape_and_type.dtype;
862             const auto& shape_handle = shape_and_type.shape;
863             dtype_attr.mutable_list()->add_type(dtype);
864             input_ic->ShapeHandleToProto(
865                 shape_handle, shape_attr.mutable_list()->add_shape());
866           }
867           (*fun_node->mutable_attr())["_handle_dtypes"] = dtype_attr;
868           (*fun_node->mutable_attr())["_handle_shapes"] = shape_attr;
869         } else {
870           // Note that we do not return error here, even if the input node does
871           // not have shapes_and_types. Within the function, we cannot infer the
872           // output shape of the DT_RESOURCE input; hence, potentially unknown
873           // shapes/dims in the function output shapes.
874           VLOG(2)
875               << "A function node (" << function_node->name()
876               << ") has input with DT_RESOURCE, but the input node does not "
877               << "have shapes_and_types information: \n"
878               << "function_node: " << function_node->ShortDebugString() << "\n"
879               << "function input: " << i
880               << ", input node's output: " << output_port_num << "\n"
881               << "input node: " << input_node->ShortDebugString();
882         }
883       }
884     }
885 
886     // ReplaceInputWithConst() may break GraphView's internal node mapping
887     // structure; hence, we separately build node name to NodeDef* map, for the
888     // output nodes (before GraphView becomes invalid). Note that we use string,
889     // not string_view.
890     absl::flat_hash_map<std::string, NodeDef*> output_nodes;
891     for (const auto& output_arg : grappler_function_item.outputs()) {
892       output_nodes[output_arg.node_name] = gv.GetNode(output_arg.node_name);
893     }
894 
895     // Replace input nodes with Consts, if values are known. Note that
896     // we don't check exceptions here as it's done in the above loop.
897     auto* ctx = GetNodeContext(function_node);
898     auto* ic = ctx->inference_context.get();
899     for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
900       const string& input = function_node->input(i);
901       const string node_name = NodeName(input);
902       const NodeDef* input_node = graph_.GetNode(node_name);
903       if (IsConstant(*input_node)) {
904         TF_CHECK_OK(
905             ReplaceInputWithConst(*input_node, i, &grappler_function_item));
906       } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
907                  ctx->input_tensor_protos[i] != nullptr) {
908         NodeDef const_input_node = MakeConstNodeDefFromTensorProto(
909             ic, *ctx->input_tensor_protos[i], ctx->input_types[i]);
910         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
911                                           &grappler_function_item));
912       } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) > i &&
913                  IsShapeFullyDefinedIntegerVectorOrScalar(
914                      ic, ic->input(i), ic->input_tensors_as_shapes()[i],
915                      ctx->input_types[i])) {
916         // We have fully defined input_tensors_as_shapes for this input; use it
917         // as a const input to the function node.
918         NodeDef const_input_node = MakeConstNodeDefFromShape(
919             ic, ic->input(i), ic->input_tensors_as_shapes()[i],
920             ctx->input_types[i]);
921         TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
922                                           &grappler_function_item));
923       }
924     }
925     // node_name to NodeDef* map in GraphView gv can be broken due to
926     // ReplaceInputWithConst(). gv should not be used after this.
927 
928     // Replace output _Retval nodes with Identity nodes. _Retval is a system op
929     // without outputs and registered shape function.
930     for (const auto& output_arg : grappler_function_item.outputs()) {
931       NodeDef* output_node = output_nodes[output_arg.node_name];
932       DCHECK_EQ(output_node->op(), "_Retval");
933       output_node->set_op("Identity");
934       output_node->mutable_attr()->erase("index");
935     }
936 
937     // Perform inference on function body.
938     GraphProperties gp(grappler_function_item);
939     TF_RETURN_IF_ERROR(gp.InferStatically(
940         /*assume_valid_feeds=*/true,
941         /*aggressive_shape_inference=*/aggressive_shape_inference_,
942         /*include_tensor_values=*/true));
943 
944     // Add return nodes for output shapes.
945     int output = 0;
946     ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
947     ctx->output_tensor_protos.resize(grappler_function_item.output_size(),
948                                      nullptr);
949     for (auto const& out_arg : grappler_function_item.outputs()) {
950       // It is guaranteed that output_tensors does not contain any control
951       // inputs, so port_id >= 0.
952       TensorId out_tensor = ParseTensorName(out_arg.node_name);
953 
954       if (output_nodes.count(out_tensor.node()) <= 0) {
955         return errors::FailedPrecondition(
956             "Unable to find return function_node ", out_tensor.node(), " for ",
957             function_node->name());
958       }
959       const NodeDef* retnode = output_nodes[out_tensor.node()];
960 
961       auto output_properties = gp.GetOutputProperties(retnode->name());
962       int output_properties_size = output_properties.size();
963       if (out_tensor.index() >= output_properties_size) {
964         return errors::InvalidArgument(
965             out_tensor.ToString(), " has invalid position ", out_tensor.index(),
966             " (output_properties.size() = ", output_properties.size(), ").");
967       }
968       auto& outprop = output_properties[out_tensor.index()];
969       TensorShapeProto shape = outprop.shape();
970       NormalizeShapeForOutput(&shape);
971       ShapeHandle out;
972       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
973       ic->set_output(output, out);
974       if (outprop.has_value()) {
975         // Forward tensor value to output_tensors_as_shape.
976         MaybeTensorProtoToShape(ic, outprop.value(),
977                                 &ctx->output_tensors_as_shapes[output]);
978         const_tensors_to_propagate_.push_back(outprop.value());
979         ctx->output_tensor_protos[output] = &const_tensors_to_propagate_.back();
980       }
981       output++;
982     }
983 
984     return OkStatus();
985   }
986 
987   // Prepares input shapes/values/handles, then runs shape inference, and
988   // finally sets output shapes/values/handles.
UpdateNode(const NodeDef * node,bool * refined)989   Status UpdateNode(const NodeDef* node, bool* refined) {
990     NodeContext* ctx = GetNodeContext(node);
991     if (ctx == nullptr) {
992       TF_RETURN_IF_ERROR(AddNode(node));
993       ctx = CHECK_NOTNULL(GetNodeContext(node));
994       *refined = true;
995     }
996 
997     // Check if the shapes of the nodes in the fan-in of this node have changed,
998     // and if they have, update the node input shapes.
999     InferenceContext* ic = ctx->inference_context.get();
1000     ctx->input_tensors_as_shapes_to_propagate.resize(ic->num_inputs());
1001     ctx->input_tensor_protos.resize(ic->num_inputs(), nullptr);
1002 
1003     for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1004       const GraphView::InputPort port(node, dst_input);
1005       const GraphView::OutputPort fanin = graph_.GetRegularFanin(port);
1006       int src_output = fanin.port_id;
1007       const NodeDef* src = fanin.node;
1008       NodeContext* src_ctx = GetNodeContext(src);
1009       if (src_ctx == nullptr) {
1010         return errors::FailedPrecondition(
1011             "Input ", dst_input, " for '", node->name(),
1012             "' was not previously added to SymbolicShapeRefiner.");
1013       }
1014 
1015       InferenceContext* src_ic = src_ctx->inference_context.get();
1016       if (src_output >= src_ic->num_outputs()) {
1017         return errors::OutOfRange("src_output = ", src_output,
1018                                   ", but num_outputs is only ",
1019                                   src_ic->num_outputs());
1020       }
1021 
1022       // Propagate input node's NodeContext info to the current node's
1023       // NodeContext:
1024       // output_tensor_protos to input_tensor_protos and input_tensors, and
1025       // output_tensors_as_shapes to input_tensors_as_shapes.
1026       if (static_cast<int>(src_ctx->output_tensors_as_shapes.size()) >
1027           src_output) {
1028         ctx->input_tensors_as_shapes_to_propagate[dst_input] =
1029             src_ctx->output_tensors_as_shapes[src_output];
1030       }
1031 
1032       if (static_cast<int>(src_ctx->output_tensor_protos.size()) > src_output) {
1033         const auto* tensor_proto = src_ctx->output_tensor_protos[src_output];
1034         if (tensor_proto != nullptr) {
1035           ctx->input_tensor_protos[dst_input] = tensor_proto;
1036         }
1037       }
1038 
1039       // NOTE: we check only shape is refined; we do not (yet) check whether
1040       // tensor value is refined.
1041       if (!*refined &&
1042           !ic->input(dst_input).SameHandle(src_ic->output(src_output))) {
1043         *refined = true;
1044       }
1045       ic->SetInput(dst_input, src_ic->output(src_output));
1046 
1047       if (!*refined && ic->requested_input_tensor_as_partial_shape(dst_input)) {
1048         // The input value may have changed. Since we have no way to know if
1049         // that's indeed the case, err on the safe side.
1050         *refined = true;
1051       }
1052 
1053       // Also propagate handle shape and dtype of edges which are carrying
1054       // resource handles.
1055       if (ctx->input_types[dst_input] == DT_RESOURCE) {
1056         auto* outputs = src_ic->output_handle_shapes_and_types(src_output);
1057         if (!outputs) continue;
1058         auto* inputs = ic->input_handle_shapes_and_types(dst_input);
1059 
1060         if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs))
1061           *refined = true;
1062         ic->set_input_handle_shapes_and_types(dst_input, *outputs);
1063       }
1064     }
1065 
1066     // Make sure we schedule the fanout of resources (which have no input)
1067     // whenever the resources are updated.
1068     *refined |= ic->num_inputs() == 0;
1069 
1070     if (!*refined) {
1071       // No input shape has changed, we're done.
1072       return OkStatus();
1073     }
1074 
1075     // Convert all kUnknownDimFromConst to -1 for shape inference.
1076     ic->set_input_tensors_as_shapes(ReplaceUnknownDimFromConstWithUnknownDim(
1077         ic, ctx->input_tensors_as_shapes_to_propagate));
1078     // Note: UpdateFunction uses input_tensors_as_shapes and
1079     // input_tensor_protos (not the Tensor object) for input values.
1080     // so for function nodes, we don't need to convert TensorProtos
1081     // to Tensors here. If the current op is not a function op, we convert
1082     // TensorProtos to Tensors before calling InferShapes.
1083 
1084     // Properly handle function nodes.
1085     if (ctx->op_data && ctx->op_data->is_function_op) {
1086       // TODO(jmdecker): Detect if the input shapes have changed for this
1087       // function. Note that when we hit a function call node, refined will be
1088       // true, as the updates to the call node will have changed, even if it's
1089       // the same function being called twice with the same input shapes.
1090       // Example: simple_function.pbtxt
1091       if (aggressive_shape_inference_) {
1092         // If output shapes are annotated, use it and skip UpdateFunction();
1093         // it can be very expensive when a function node has nested function
1094         // nodes internally. One downside with this approach is that we do not
1095         // get output values or output shapes as tensor from function node.
1096         auto s = UpdateOutputShapesUsingAnnotatedInformation(*node, ctx);
1097         if (s.ok() && AllOutputShapesKnown(ctx)) {
1098           return OkStatus();
1099         }
1100         // If shape annotation was not available, incomplete, or incompatible,
1101         // fall through to call UpdateFunction().
1102       }
1103       auto s = UpdateFunction(node);
1104       if (s.ok()) {
1105         return OkStatus();
1106       } else {
1107         VLOG(1) << "UpdateFunction failed for " << node->op()
1108                 << ". Defaulting to ShapeUnknown.\n"
1109                 << s.ToString();
1110       }
1111     }
1112 
1113     //  Construct Tensors for constant inputs used by shape functions.
1114     std::vector<Tensor> const_values(ic->num_inputs());
1115     std::vector<const Tensor*> input_tensors(ic->num_inputs(), nullptr);
1116     for (int dst_input = 0; dst_input < ic->num_inputs(); ++dst_input) {
1117       const TensorProto* tensor_proto = ctx->input_tensor_protos[dst_input];
1118       if (tensor_proto != nullptr &&
1119           // Skip if the const tensor is too large.
1120           NumElementsFromTensorProto(*tensor_proto) <=
1121               kThresholdToSkipConstTensorInstantiation &&
1122           const_values[dst_input].FromProto(*tensor_proto)) {
1123         input_tensors[dst_input] = &const_values[dst_input];
1124       }
1125     }
1126     ic->set_input_tensors(input_tensors);
1127 
1128     // Update the shapes of the outputs.
1129     return InferShapes(*node, ctx);
1130   }
1131 
SetUnknownShape(const NodeDef * node,int output_port)1132   Status SetUnknownShape(const NodeDef* node, int output_port) {
1133     shape_inference::ShapeHandle shape =
1134         GetUnknownOutputShape(node, output_port);
1135     InferenceContext* ctx = GetContext(node);
1136     if (ctx == nullptr) {
1137       return errors::InvalidArgument("SetUnknownShape: Missing context");
1138     }
1139     if (output_port < 0 || output_port >= ctx->num_outputs()) {
1140       return errors::InvalidArgument(
1141           "SetUnknownShape: output_port must be in [0, ", ctx->num_outputs(),
1142           ") but was ", output_port);
1143     }
1144     ctx->set_output(output_port, shape);
1145     return OkStatus();
1146   }
1147 
1148   struct ShapeId {
1149     const NodeDef* node;
1150     int port_id;
operator ==tensorflow::grappler::SymbolicShapeRefiner::ShapeId1151     bool operator==(const ShapeId& other) const {
1152       return node == other.node && port_id == other.port_id;
1153     }
1154   };
1155   struct HashShapeId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashShapeId1156     std::size_t operator()(const ShapeId& shp) const {
1157       return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
1158     }
1159   };
1160 
1161   struct DimId {
1162     const NodeDef* node;
1163     int port_id;
1164     int dim_index;
operator ==tensorflow::grappler::SymbolicShapeRefiner::DimId1165     bool operator==(const DimId& other) const {
1166       return node == other.node && port_id == other.port_id &&
1167              dim_index == other.dim_index;
1168     }
1169   };
1170 
1171   struct HashDimId {
operator ()tensorflow::grappler::SymbolicShapeRefiner::HashDimId1172     std::size_t operator()(const DimId& dim) const {
1173       return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
1174              dim.dim_index;
1175     }
1176   };
1177 
1178   // 'port_index' as the union of shape1 and shape2.
OutputAsUnion(const NodeDef * node,int port_index,ShapeHandle shape1,ShapeHandle shape2)1179   ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
1180                             ShapeHandle shape1, ShapeHandle shape2) {
1181     if (shape1.SameHandle(shape2)) {
1182       return shape1;
1183     }
1184     InferenceContext* ctx = GetContext(node);
1185     ShapeHandle relaxed = shape1;
1186     const int rank = ctx->Rank(shape1);
1187     if (!ctx->RankKnown(shape2) || ctx->Rank(shape2) != rank) {
1188       relaxed = GetUnknownOutputShape(node, port_index);
1189     } else {
1190       for (int d = 0; d < rank; ++d) {
1191         if (!ctx->Dim(shape1, d).SameHandle(ctx->Dim(shape2, d))) {
1192           int64_t val1 = ctx->Value(ctx->Dim(shape1, d));
1193           int64_t val2 = ctx->Value(ctx->Dim(shape2, d));
1194           if (val1 != val2 || (val1 < 0 && val2 < 0)) {
1195             DimensionHandle new_dim = GetUnknownOutputDim(node, port_index, d);
1196             TF_CHECK_OK(ctx->ReplaceDim(relaxed, d, new_dim, &relaxed));
1197           }
1198         }
1199       }
1200     }
1201     return relaxed;
1202   }
1203 
EquivalentShapes(ShapeHandle s1,ShapeHandle s2) const1204   bool EquivalentShapes(ShapeHandle s1, ShapeHandle s2) const {
1205     if (s1.SameHandle(s2)) {
1206       return true;
1207     }
1208     if (InferenceContext::Rank(s1) != InferenceContext::Rank(s2)) {
1209       return false;
1210     }
1211     if (!InferenceContext::RankKnown(s1) && !InferenceContext::RankKnown(s2)) {
1212       return true;
1213     }
1214     const int rank = InferenceContext::Rank(s1);
1215     for (int i = 0; i < rank; ++i) {
1216       if (!InferenceContext::DimKnownRank(s1, i).SameHandle(
1217               InferenceContext::DimKnownRank(s2, i))) {
1218         int64_t val1 =
1219             InferenceContext::Value(InferenceContext::DimKnownRank(s1, i));
1220         int64_t val2 =
1221             InferenceContext::Value(InferenceContext::DimKnownRank(s2, i));
1222         if (val1 >= 0 && val2 >= 0 && val1 == val2) {
1223           continue;
1224         }
1225         return false;
1226       }
1227     }
1228     return true;
1229   }
1230 
1231   // Return true if the annotated shape is compatible with shape inference
1232   // result. Examples:
1233   // Inferred shape: ?, annotated shape: [10, 10] -> true;
1234   // Inferred shape: [-1, 10], annotated shape: [10, 10] -> true;
1235   // Inferred shape: [-1, 100], annotated shape: [10, 10] -> false;
1236   // Inferred shape: [-1, 10, 10], annotated shape: [10, 10] -> false.
CompatibleShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1237   bool CompatibleShapes(ShapeHandle inferred_shape,
1238                         ShapeHandle annotated_shape) const {
1239     if (inferred_shape.SameHandle(annotated_shape)) {
1240       return true;
1241     }
1242     if (!InferenceContext::RankKnown(inferred_shape)) {
1243       return true;
1244     }
1245     if (InferenceContext::Rank(inferred_shape) !=
1246         InferenceContext::Rank(annotated_shape)) {
1247       return false;
1248     }
1249     const int rank = InferenceContext::Rank(inferred_shape);
1250     for (int i = 0; i < rank; ++i) {
1251       if (!InferenceContext::DimKnownRank(inferred_shape, i)
1252                .SameHandle(
1253                    InferenceContext::DimKnownRank(annotated_shape, i))) {
1254         int64_t val1 = InferenceContext::Value(
1255             InferenceContext::DimKnownRank(inferred_shape, i));
1256         int64_t val2 = InferenceContext::Value(
1257             InferenceContext::DimKnownRank(annotated_shape, i));
1258         if (val1 >= 0 && val1 != val2) {
1259           return false;
1260         }
1261       }
1262     }
1263     return true;
1264   }
1265 
SameShapes(ShapeHandle inferred_shape,ShapeHandle annotated_shape) const1266   bool SameShapes(ShapeHandle inferred_shape,
1267                   ShapeHandle annotated_shape) const {
1268     if (inferred_shape.SameHandle(annotated_shape)) {
1269       return true;
1270     }
1271     if (InferenceContext::Rank(inferred_shape) !=
1272         InferenceContext::Rank(annotated_shape)) {
1273       return false;
1274     }
1275     const int rank = InferenceContext::Rank(inferred_shape);
1276     for (int i = 0; i < rank; ++i) {
1277       int64_t val1 = InferenceContext::Value(
1278           InferenceContext::DimKnownRank(inferred_shape, i));
1279       int64_t val2 = InferenceContext::Value(
1280           InferenceContext::DimKnownRank(annotated_shape, i));
1281       if (val1 != val2) {
1282         return false;
1283       }
1284     }
1285     return true;
1286   }
1287 
EquivalentShapesAndTypes(const std::vector<ShapeAndType> & st1,const std::vector<ShapeAndType> & st2) const1288   bool EquivalentShapesAndTypes(const std::vector<ShapeAndType>& st1,
1289                                 const std::vector<ShapeAndType>& st2) const {
1290     if (st1.size() != st2.size()) {
1291       return false;
1292     }
1293     for (int i = 0, st1_size = st1.size(); i < st1_size; ++i) {
1294       const ShapeAndType& s1 = st1[i];
1295       const ShapeAndType& s2 = st2[i];
1296       if (s1.dtype != s2.dtype) {
1297         return false;
1298       }
1299       if (!EquivalentShapes(s1.shape, s2.shape)) {
1300         return false;
1301       }
1302     }
1303     return true;
1304   }
1305 
AddFunction(const NodeDef * function_node,const std::string & function_name)1306   Status AddFunction(const NodeDef* function_node,
1307                      const std::string& function_name) {
1308     auto it = fun_to_grappler_function_item_.find(function_name);
1309     if (it != fun_to_grappler_function_item_.end()) {
1310       return OkStatus();
1311     }
1312 
1313     const FunctionDef* function_def =
1314         CHECK_NOTNULL(function_library_.Find(function_name));
1315     GrapplerFunctionItem grappler_function_item;
1316     Status function_instantiated =
1317         MakeGrapplerFunctionItem(*function_def, function_library_,
1318                                  graph_def_version_, &grappler_function_item);
1319 
1320     // If function instantiation failed we will skip it during shape inference.
1321     if (!function_instantiated.ok()) {
1322       VLOG(3) << "Failed to instantiate a function. Error: "
1323               << function_instantiated.error_message();
1324       fun_to_grappler_function_item_[function_def->signature().name()] =
1325           absl::nullopt;
1326       return OkStatus();
1327     }
1328 
1329     if (static_cast<int>(grappler_function_item.inputs().size()) >
1330         function_node->input_size()) {
1331       return errors::FailedPrecondition(
1332           "Function input size should be smaller than node input size.");
1333     }
1334 
1335     for (int i = grappler_function_item.inputs().size(),
1336              end = function_node->input_size();
1337          i < end; ++i) {
1338       const string& input = function_node->input(i);
1339       if (!IsControlInput(input)) {
1340         return errors::FailedPrecondition(
1341             "Found regular input (", input,
1342             ") instead of control nodes for node ", function_node->name());
1343       }
1344     }
1345 
1346     fun_to_grappler_function_item_[function_def->signature().name()] =
1347         grappler_function_item;
1348 
1349     return OkStatus();
1350   }
1351 
AddNode(const NodeDef * node)1352   Status AddNode(const NodeDef* node) {
1353     NodeContext& node_ctx = node_to_context_[node];
1354     NameAttrList function;
1355     TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(*node, &function));
1356 
1357     // For PartitionedCall, op_data represents the function info.
1358     TF_RETURN_IF_ERROR(
1359         function_library_.LookUp(function.name(), &node_ctx.op_data));
1360 
1361     if (node_ctx.op_data->is_function_op) {
1362       TF_RETURN_IF_ERROR(AddFunction(node, function.name()));
1363     }
1364 
1365     TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
1366                                          &node_ctx.input_types,
1367                                          &node_ctx.output_types));
1368 
1369     // Create the inference context for this node.
1370     const int num_inputs = node_ctx.input_types.size();
1371     std::vector<ShapeHandle> input_shapes(num_inputs);
1372     std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
1373         input_handle_shapes_and_types(num_inputs);
1374     std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
1375     std::vector<ShapeHandle> input_tensors_as_shapes;
1376 
1377     node_ctx.inference_context.reset(new InferenceContext(
1378         graph_def_version_, *node, node_ctx.op_data->op_def, input_shapes,
1379         input_tensors, input_tensors_as_shapes,
1380         std::move(input_handle_shapes_and_types)));
1381     const Status s = node_ctx.inference_context->construction_status();
1382     if (!s.ok()) {
1383       node_ctx.inference_context.reset(nullptr);
1384     }
1385     return s;
1386   }
1387 
1388  private:
1389   // Return the one ShapeHandle used to denote a fully unknown shape for a node
1390   // output.
GetUnknownOutputShape(const NodeDef * node,int index)1391   ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
1392     ShapeId id{node, index};
1393     auto it = unknown_shapes_.find(id);
1394     if (it != unknown_shapes_.end()) {
1395       return it->second;
1396     }
1397     InferenceContext* c = GetContext(node);
1398     ShapeHandle shp = c->UnknownShape();
1399     unknown_shapes_[id] = shp;
1400     return shp;
1401   }
1402   // Return the one ShapeHandle used to denote a fully unknown dimension for a
1403   // node output.
GetUnknownOutputDim(const NodeDef * node,int index,int dim_id)1404   DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
1405                                       int dim_id) {
1406     DimId id{node, index, dim_id};
1407     auto it = unknown_dims_.find(id);
1408     if (it != unknown_dims_.end()) {
1409       return it->second;
1410     }
1411     InferenceContext* c = GetContext(node);
1412     DimensionHandle dim = c->UnknownDim();
1413     unknown_dims_[id] = dim;
1414     return dim;
1415   }
1416 
1417   // Returns true if all the output tensors have known values.
AllOutputValuesKnown(NodeContext * c)1418   bool AllOutputValuesKnown(NodeContext* c) {
1419     InferenceContext* ic = c->inference_context.get();
1420     int c_output_tensors_as_shapes_size = c->output_tensors_as_shapes.size();
1421     int c_output_tensor_protos_size = c->output_tensor_protos.size();
1422     if (c_output_tensors_as_shapes_size < ic->num_outputs() &&
1423         c_output_tensor_protos_size < ic->num_outputs()) {
1424       return false;
1425     } else {
1426       // Checks if we can get output value via either output_tensor_proto or
1427       // output_tensors_as_shapes.
1428       for (int i = 0; i < ic->num_outputs(); i++) {
1429         if (c_output_tensor_protos_size > i &&
1430             c->output_tensor_protos[i] != nullptr) {
1431           continue;
1432         }
1433         if (c_output_tensors_as_shapes_size > i &&
1434             ic->FullyDefined(c->output_tensors_as_shapes[i])) {
1435           bool no_unknown_dim_from_const = true;
1436           for (int32_t j = 0; j < ic->Rank(c->output_tensors_as_shapes[i]);
1437                ++j) {
1438             const auto dim = ic->Dim(c->output_tensors_as_shapes[i], j);
1439             if (ic->ValueKnown(dim) && ic->Value(dim) == kUnknownDimFromConst) {
1440               no_unknown_dim_from_const = false;
1441               break;
1442             }
1443           }
1444           if (no_unknown_dim_from_const) {
1445             continue;
1446           }
1447         }
1448         return false;
1449       }
1450     }
1451     return true;
1452   }
1453 
1454   // Returns true if all the output shapes are known.
AllOutputShapesKnown(NodeContext * c)1455   bool AllOutputShapesKnown(NodeContext* c) {
1456     InferenceContext* ic = c->inference_context.get();
1457     // Checks if all the output shapes are fully defined.
1458     for (int i = 0; i < ic->num_outputs(); i++) {
1459       if (!ic->FullyDefined(ic->output(i))) {
1460         return false;
1461       }
1462     }
1463     return true;
1464   }
1465 
1466   // Returns true if we can infer output tensors' values -- we know values of
1467   // all the input tensors.
AllInputValuesKnown(NodeContext * c)1468   bool AllInputValuesKnown(NodeContext* c) {
1469     InferenceContext* ic = c->inference_context.get();
1470 
1471     // Check inputs are fully defined and values are known.
1472     for (int i = 0; i < ic->num_inputs(); i++) {
1473       const Tensor* tensor = ic->input_tensor(i);
1474       // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1475       // already converted it to ic->input_tensor(i);
1476       const ShapeHandle& input_tensors_as_shape =
1477           ic->input_tensors_as_shapes()[i];
1478       // Either input_tensor is valid or input_tensors_as_shape, which has
1479       // value of input tensors as shape format, should be fully defined.
1480       if (tensor == nullptr && !ic->FullyDefined(input_tensors_as_shape)) {
1481         return false;
1482       }
1483     }
1484     return true;
1485   }
1486 
1487   // Returns true if we want to update output shapes and values with running
1488   // EvaluateNode() for this op, based on op type, data type, and size.
ShouldUpdateOutputShapesAndValues(NodeContext * c,int64_t max_size)1489   bool ShouldUpdateOutputShapesAndValues(NodeContext* c, int64_t max_size) {
1490     InferenceContext* ic = c->inference_context.get();
1491 
1492     // Due to the cost of running EvaluateNode(), we limit only to white listed
1493     // op types.
1494     if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) {
1495       return false;
1496     }
1497 
1498     // Check input dtypes are number types.
1499     for (const auto& input_type : c->input_types) {
1500       if (!IsNumericType(input_type)) {
1501         return false;
1502       }
1503     }
1504 
1505     // Check output dtypes are number types.
1506     for (const auto& output_type : c->output_types) {
1507       if (!IsNumericType(output_type)) {
1508         return false;
1509       }
1510     }
1511 
1512     // Check if the number of elements of each of input tensor is no larger than
1513     // the given max size.
1514     for (int i = 0; i < ic->num_inputs(); i++) {
1515       const Tensor* tensor = ic->input_tensor(i);
1516       const ShapeHandle& input_shape_handle = ic->input(i);
1517       if (tensor != nullptr) {
1518         if (tensor->NumElements() > max_size) {
1519           return false;
1520         }
1521       } else if (ic->Value(ic->NumElements(input_shape_handle)) > max_size) {
1522         return false;
1523       }
1524     }
1525 
1526     // Check if we know the shape of each output tensor, and the number of
1527     // elements is larger than the given max size.
1528     for (int i = 0; i < ic->num_outputs(); i++) {
1529       const ShapeHandle& shape_handle = ic->output(i);
1530       if (!ic->FullyDefined(shape_handle) ||
1531           ic->Value(ic->NumElements(shape_handle)) > max_size) {
1532         return false;
1533       }
1534     }
1535     return true;
1536   }
1537 
1538   // Create input tensors from the NodeContext.
CreateInputTensors(NodeContext * c,std::vector<Tensor> * input_tensor_vector,TensorVector * inputs)1539   void CreateInputTensors(NodeContext* c,
1540                           std::vector<Tensor>* input_tensor_vector,
1541                           TensorVector* inputs) {
1542     InferenceContext* ic = c->inference_context.get();
1543     for (int i = 0; i < ic->num_inputs(); i++) {
1544       if (ic->input_tensor(i)) {
1545         input_tensor_vector->at(i) = *ic->input_tensor(i);
1546         inputs->emplace_back(&input_tensor_vector->at(i));
1547         // Note that we don't check c->input_tensor_protos[i], as UpdateNode()
1548         // already converted it to ic->input_tensor(i);
1549       } else {
1550         // Create Tensor from input_tensors_as_shapes, and then emplace it
1551         // back to inputs.
1552         // Note that input_tensors_as_shapes is scalar or vector.
1553         const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
1554         const DataType& data_type = c->input_types[i];
1555         int32_t rank = ic->Rank(shape_handle);
1556         if (rank < 1) {
1557           input_tensor_vector->at(i) = Tensor(data_type, {});
1558         } else {
1559           input_tensor_vector->at(i) = Tensor(data_type, {rank});
1560         }
1561         auto* tensor = &input_tensor_vector->at(i);
1562         if (data_type == DT_INT32) {
1563           auto flat = tensor->flat<int32>();
1564           for (int j = 0; j < rank; j++) {
1565             int32_t dim = ic->Value(ic->Dim(shape_handle, j));
1566             flat(j) = dim;
1567           }
1568         } else {
1569           auto flat = tensor->flat<int64_t>();
1570           for (int j = 0; j < rank; j++) {
1571             int64_t dim = ic->Value(ic->Dim(shape_handle, j));
1572             flat(j) = dim;
1573           }
1574         }
1575         inputs->emplace_back(tensor);
1576       }
1577     }
1578   }
1579 
1580   // Run a node to infer output shapes and values, and add it to the
1581   // NodeContext.
UpdateOutputShapesAndValues(const NodeDef & node,NodeContext * c)1582   Status UpdateOutputShapesAndValues(const NodeDef& node, NodeContext* c) {
1583     InferenceContext* ic = c->inference_context.get();
1584 
1585     // Input to EvaluateNode()
1586     TensorVector inputs;
1587     // Container for temporarily created tensor object.
1588     std::vector<Tensor> input_tensor_vector(ic->num_inputs());
1589     CreateInputTensors(c, &input_tensor_vector, &inputs);
1590 
1591     // Output for EvaluateNode() and output tensor clean up object.
1592     TensorVector outputs;
1593     auto outputs_cleanup = gtl::MakeCleanup([&outputs] {
1594       for (const auto& output : outputs) {
1595         if (output.tensor) {
1596           delete output.tensor;
1597         }
1598       }
1599     });
1600 
1601     TF_RETURN_IF_ERROR(EvaluateNode(node, inputs, /*cpu_device=*/nullptr,
1602                                     &resource_mgr_, &outputs));
1603     c->output_tensors_as_shapes.resize(outputs.size());
1604     c->output_tensor_protos.resize(outputs.size(), nullptr);
1605     for (int k = 0, outputs_size = outputs.size(); k < outputs_size; k++) {
1606       const auto& t = outputs[k];
1607       // Override output shape.
1608       ShapeHandle output_shape;
1609       TF_RETURN_IF_ERROR(
1610           ic->MakeShapeFromTensorShape(t->shape(), &output_shape));
1611       if (ic->FullyDefined(ic->output(k)) &&
1612           !EquivalentShapes(ic->output(k), output_shape)) {
1613         LOG(WARNING) << "UpdateOutputShapesAndValues() -- node: " << node.name()
1614                      << ", inferred output shape "
1615                      << "doesn't match for k=" << k << ": "
1616                      << "ic->output(k): " << ic->DebugString(ic->output(k))
1617                      << ", output_shape: " << ic->DebugString(output_shape)
1618                      << " -- " << node.DebugString();
1619       }
1620       ic->set_output(k, output_shape);
1621       // Set output_tensors_as_shape.
1622       MaybeTensorValueToShape(ic, *t.tensor, &c->output_tensors_as_shapes[k]);
1623 
1624       // Set output_tensor_protos.
1625       TensorProto tensor_proto;
1626       t->AsProtoTensorContent(&tensor_proto);
1627       const_tensors_to_propagate_.push_back(tensor_proto);
1628       c->output_tensor_protos[k] = &const_tensors_to_propagate_.back();
1629     }
1630     return OkStatus();
1631   }
1632 
1633   // Update output shapes with annotated information.
1634   // Currently only handle nodes with static shapes, i.e. shapes do not change
1635   // during execution.
1636   // TODO(andiryxu): Use annotated shapes in Enter/Merge etc as well.
UpdateOutputShapesUsingAnnotatedInformation(const NodeDef & node,NodeContext * c) const1637   Status UpdateOutputShapesUsingAnnotatedInformation(const NodeDef& node,
1638                                                      NodeContext* c) const {
1639     const auto& attr = node.attr();
1640     if (attr.count(kOutputSame) == 0 || !attr.at(kOutputSame).b() ||
1641         attr.count(kOutputShapes) == 0)
1642       return OkStatus();
1643 
1644     InferenceContext* ic = c->inference_context.get();
1645     int output_size = attr.at(kOutputShapes).list().shape_size();
1646 
1647     for (int i = 0; i < ic->num_outputs(); i++) {
1648       // Annotated Switch node has only one output. Propagate the shape to all
1649       // the outputs.
1650       int shape_index = IsSwitch(node) ? 0 : i;
1651       if (shape_index >= output_size) {
1652         LOG(WARNING)
1653             << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1654             << node.name() << ", inferred output shape size "
1655             << ic->num_outputs() << ", annotated output shape size "
1656             << output_size;
1657         break;
1658       }
1659 
1660       const TensorShapeProto& shape =
1661           attr.at(kOutputShapes).list().shape(shape_index);
1662       if (shape.dim().empty()) continue;
1663 
1664       ShapeHandle output_shape;
1665       TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &output_shape));
1666 
1667       // Check if annotated shapes are incompatible with inferred shapes.
1668       if ((ic->FullyDefined(ic->output(i)) &&
1669            !SameShapes(ic->output(i), output_shape)) ||
1670           (!ic->FullyDefined(ic->output(i)) &&
1671            !CompatibleShapes(ic->output(i), output_shape))) {
1672         LOG(WARNING)
1673             << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1674             << node.name() << ", inferred output shape "
1675             << "doesn't match for i=" << i << ": "
1676             << "ic->output(k): " << ic->DebugString(ic->output(i))
1677             << ", annotated output shape: " << ic->DebugString(output_shape)
1678             << " -- " << node.DebugString();
1679         c->shape_incompatible = true;
1680       }
1681 
1682       // Only use annotated shapes if the inference shape is unknown and
1683       // compatible with annotated shapes.
1684       if (!ic->FullyDefined(ic->output(i)) &&
1685           CompatibleShapes(ic->output(i), output_shape)) {
1686         VLOG(3) << "UpdateOutputShapesUsingAnnotatedInformation() -- node: "
1687                 << node.name() << ", inferred output shape " << i << ": "
1688                 << "ic->output(i): " << ic->DebugString(ic->output(i))
1689                 << ", annotated output shape: " << ic->DebugString(output_shape)
1690                 << " -- " << node.ShortDebugString();
1691         ic->set_output(i, output_shape);
1692       }
1693     }
1694 
1695     return OkStatus();
1696   }
1697 
MaybeUpdateNodeContextOutput(const NodeDef & node,const bool is_fed,NodeContext * c)1698   Status MaybeUpdateNodeContextOutput(const NodeDef& node, const bool is_fed,
1699                                       NodeContext* c) {
1700     // Propagate tensors and shape tensors unless the node is fed.
1701     // TODO(bsteiner) We should still propagate the shapes to the ports that
1702     // aren't fed in the case of a ShapeN node.
1703 
1704     // Note that when propagating tensors_as_shapes, we use
1705     // c->input_tensors_as_shapes_to_progate instead of
1706     // ic->input_tensors_as_shapes. The former uses kUnknownDimFromConst if
1707     // UnknownDim is from Const tensor, and it is propagated through shape
1708     // inference. Before calling shape functions, we convert it to UnknownDim,
1709     // but instantiate a new UnknownDim to prevent incorrect symbolic shape
1710     // inference through UnknownDim from Const.
1711     InferenceContext* ic = c->inference_context.get();
1712     if (!is_fed) {
1713       if (IsConstant(node)) {
1714         const TensorProto& tensor_proto = node.attr().at("value").tensor();
1715         c->output_tensor_protos.resize(1);
1716         c->output_tensor_protos[0] = &tensor_proto;
1717         c->output_tensors_as_shapes.resize(1);
1718         MaybeTensorProtoToShape(ic, tensor_proto,
1719                                 &c->output_tensors_as_shapes[0]);
1720       } else if (IsRank(node)) {
1721         if (ic->RankKnown(ic->input(0))) {
1722           // Propagate rank value.
1723           int32_t rank = ic->Rank(ic->input(0));
1724           const_tensors_to_propagate_.push_back(
1725               MakeIntegerScalarTensorProto(DT_INT32, rank));
1726           c->output_tensor_protos.resize(1);
1727           c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1728         }
1729       } else if (IsSize(node)) {
1730         DimensionHandle size = ic->NumElements(ic->input(0));
1731         if (ic->ValueKnown(size)) {
1732           // Propagate size value.
1733           int64_t sz = ic->Value(size);
1734           bool valid = false;
1735           if (node.attr().at("out_type").type() == DT_INT32) {
1736             if (sz < std::numeric_limits<int32>::max()) {
1737               const_tensors_to_propagate_.push_back(
1738                   MakeIntegerScalarTensorProto(DT_INT32, sz));
1739               valid = true;
1740             }
1741           } else {
1742             const_tensors_to_propagate_.push_back(
1743                 MakeIntegerScalarTensorProto(DT_INT64, sz));
1744             valid = true;
1745           }
1746           if (valid) {
1747             c->output_tensor_protos.resize(1);
1748             c->output_tensor_protos[0] = &const_tensors_to_propagate_.back();
1749           }
1750         }
1751       } else if (IsShape(node)) {
1752         c->output_tensors_as_shapes.resize(1);
1753         c->output_tensors_as_shapes[0] = c->inference_context->input(0);
1754       } else if (IsShapeN(node)) {
1755         c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
1756         for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
1757           c->output_tensors_as_shapes[i] = c->inference_context->input(i);
1758         }
1759       } else if (node.op() == "ConcatV2") {
1760         bool valid = true;
1761         ShapeHandle result;
1762         for (int i = 0; i < ic->num_inputs() - 1; ++i) {
1763           ShapeHandle input = c->input_tensors_as_shapes_to_propagate[i];
1764           if (!ic->RankKnown(input)) {
1765             valid = false;
1766             break;
1767           } else if (i == 0) {
1768             result = input;
1769           } else {
1770             TF_RETURN_IF_ERROR(ic->Concatenate(result, input, &result));
1771           }
1772         }
1773         if (valid) {
1774           c->output_tensors_as_shapes.resize(1);
1775           c->output_tensors_as_shapes[0] = result;
1776         }
1777       } else if (IsPack(node)) {
1778         // A Pack node concatenating scalars is often used to generate a shape.
1779         std::vector<DimensionHandle> dims;
1780         bool valid = true;
1781         for (int i = 0; i < ic->num_inputs(); ++i) {
1782           const Tensor* t = ic->input_tensor(i);
1783           if (t) {
1784             if (t->dims() != 0 ||
1785                 (t->dtype() != DT_INT32 && t->dtype() != DT_INT64)) {
1786               valid = false;
1787               break;
1788             }
1789             int64_t size = t->dtype() == DT_INT32 ? t->scalar<int32>()()
1790                                                   : t->scalar<int64_t>()();
1791             dims.push_back(size < 0 ? ic->MakeDim(kUnknownDimFromConst)
1792                                     : ic->MakeDim(size));
1793           } else {
1794             // Don't have tensor value, but use input_tensors_as_shapes, if
1795             // possible.
1796             const ShapeHandle& shape_handle =
1797                 c->input_tensors_as_shapes_to_propagate[i];
1798             if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
1799                 ic->ValueKnown(ic->Dim(shape_handle, 0))) {
1800               dims.push_back(ic->Dim(shape_handle, 0));
1801             } else {
1802               // This is not from Const, but as it shouldn'be used as symbolic
1803               // unknown dim for different ops, we use kUnknownDimFromConst.
1804               dims.push_back(ic->MakeDim(kUnknownDimFromConst));
1805             }
1806           }
1807         }
1808         if (valid) {
1809           c->output_tensors_as_shapes.resize(1);
1810           c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
1811         }
1812       } else if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
1813         c->output_tensors_as_shapes.resize(1);
1814         c->output_tensors_as_shapes[0] =
1815             c->input_tensors_as_shapes_to_propagate[0];
1816         if (c->input_tensor_protos[0] != nullptr) {
1817           c->output_tensor_protos.resize(1);
1818           c->output_tensor_protos[0] = c->input_tensor_protos[0];
1819         }
1820       } else if (IsSlice(node)) {
1821         ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1822         bool valid = ic->RankKnown(input);
1823         const Tensor* slice_offset = ic->input_tensor(1);
1824         valid &= slice_offset != nullptr && slice_offset->NumElements() == 1;
1825         const Tensor* slice_size = ic->input_tensor(2);
1826         valid &= slice_size != nullptr && slice_size->NumElements() == 1;
1827         if (valid) {
1828           int64_t start = slice_offset->dtype() == DT_INT32
1829                               ? slice_offset->flat<int32>()(0)
1830                               : slice_offset->flat<int64_t>()(0);
1831           int64_t size = (slice_size->dtype() == DT_INT32
1832                               ? slice_size->flat<int32>()(0)
1833                               : slice_size->flat<int64_t>()(0));
1834           ShapeHandle result;
1835           if (size == -1) {
1836             TF_RETURN_IF_ERROR(ic->Subshape(input, start, &result));
1837           } else {
1838             int64_t end = start + size;
1839             TF_RETURN_IF_ERROR(ic->Subshape(input, start, end, &result));
1840           }
1841           c->output_tensors_as_shapes.resize(1);
1842           c->output_tensors_as_shapes[0] = result;
1843         }
1844       } else if (IsStridedSlice(node)) {
1845         ShapeHandle input = c->input_tensors_as_shapes_to_propagate[0];
1846         bool valid = ic->RankKnown(input);
1847         const Tensor* slice_begin = ic->input_tensor(1);
1848         valid &= slice_begin != nullptr && slice_begin->NumElements() == 1;
1849         const Tensor* slice_end = ic->input_tensor(2);
1850         valid &= slice_end != nullptr && slice_end->NumElements() == 1;
1851         const Tensor* slice_stride = ic->input_tensor(3);
1852         valid &= slice_stride != nullptr && slice_stride->NumElements() == 1;
1853 
1854         if (node.attr().count("ellipsis_mask") > 0 &&
1855             node.attr().at("ellipsis_mask").i() != 0) {
1856           valid = false;
1857         }
1858         if (node.attr().count("new_axis_mask") > 0 &&
1859             node.attr().at("new_axis_mask").i() != 0) {
1860           valid = false;
1861         }
1862         if (node.attr().count("shrink_axis_mask") > 0 &&
1863             node.attr().at("shrink_axis_mask").i() != 0) {
1864           valid = false;
1865         }
1866         int begin_mask = 0;
1867         if (node.attr().count("begin_mask") > 0) {
1868           begin_mask = node.attr().at("begin_mask").i();
1869         }
1870         int end_mask = 0;
1871         if (node.attr().count("end_mask") > 0) {
1872           end_mask = node.attr().at("end_mask").i();
1873         }
1874         if (begin_mask < 0 || begin_mask > 1 || end_mask < 0 || end_mask > 1) {
1875           valid = false;
1876         }
1877         if (valid) {
1878           int64_t begin = 0;
1879           if (begin_mask == 0) {
1880             begin = slice_begin->dtype() == DT_INT32
1881                         ? slice_begin->flat<int32>()(0)
1882                         : slice_begin->flat<int64_t>()(0);
1883           }
1884           int64_t end = std::numeric_limits<int64_t>::max();
1885           if (end_mask == 0) {
1886             end = (slice_end->dtype() == DT_INT32
1887                        ? slice_end->flat<int32>()(0)
1888                        : slice_end->flat<int64_t>()(0));
1889           }
1890           int64_t stride = slice_stride->dtype() == DT_INT32
1891                                ? slice_stride->flat<int32>()(0)
1892                                : slice_stride->flat<int64_t>()(0);
1893           ShapeHandle result;
1894           TF_RETURN_IF_ERROR(ic->Subshape(input, begin, end, stride, &result));
1895           c->output_tensors_as_shapes.resize(1);
1896           c->output_tensors_as_shapes[0] = result;
1897         }
1898       }
1899     }
1900 
1901     if (aggressive_shape_inference_) {
1902       // Update output shapes with annotated information. This is optional.
1903       UpdateOutputShapesUsingAnnotatedInformation(node, c).IgnoreError();
1904 
1905       // Update output tensor values using EvaluateNode() if we can.
1906       // Due to the cost of EvaluateNode(), we run it only for certain op types
1907       // (white listed) and small integer tensors.
1908 
1909       const int max_element_size = 17;  // Max up to 4x4 matrix or similar.
1910       if (AllOutputValuesKnown(c) || !AllInputValuesKnown(c) ||
1911           !ShouldUpdateOutputShapesAndValues(c, max_element_size)) {
1912         return OkStatus();
1913       }
1914       UpdateOutputShapesAndValues(node, c).IgnoreError();  // This is optional.
1915     }
1916     return OkStatus();
1917   }
1918 
InferShapes(const NodeDef & node,NodeContext * c)1919   Status InferShapes(const NodeDef& node, NodeContext* c) {
1920     // Infer the shapes of output tensors.
1921     if (!c->op_data || c->op_data->shape_inference_fn == nullptr ||
1922         !c->inference_context->Run(c->op_data->shape_inference_fn).ok()) {
1923       // Annotate outputs with unknown shapes. Update output shapes with
1924       // annotated information later on if available.
1925       // Note that shape inference function may return an error, but we ignore
1926       // it, and use UnknownShape in that case.
1927       TF_RETURN_IF_ERROR(
1928           c->inference_context->Run(shape_inference::UnknownShape));
1929     }
1930     Status status = OkStatus();
1931     auto it = fed_ports_.find(node.name());
1932     const bool is_fed = it != fed_ports_.end();
1933     if (is_fed) {
1934       // It is possible to feed node output ports with tensors of any shape: as
1935       // a result, the shape of a fed port is completely unknown.
1936       for (const int output_port : it->second) {
1937         status.Update(SetUnknownShape(&node, output_port));
1938       }
1939     }
1940 
1941     // Update NodeContext output fields after shape inference function runs.
1942     status.Update(MaybeUpdateNodeContextOutput(node, is_fed, c));
1943 
1944     return status;
1945   }
1946 
1947  private:
IsIntegerVector(const Tensor & tensor)1948   bool IsIntegerVector(const Tensor& tensor) {
1949     if (tensor.dims() == 1 &&
1950         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
1951       return true;
1952     }
1953     return false;
1954   }
1955 
IsIntegerScalar(const Tensor & tensor)1956   bool IsIntegerScalar(const Tensor& tensor) {
1957     if (tensor.dims() == 0 &&
1958         (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
1959         tensor.NumElements() == 1) {
1960       return true;
1961     }
1962     return false;
1963   }
1964 
MakeIntegerScalarTensorProto(const DataType dtype,const int64_t val)1965   TensorProto MakeIntegerScalarTensorProto(const DataType dtype,
1966                                            const int64_t val) {
1967     TensorProto tensor_proto;
1968     tensor_proto.set_dtype(dtype);
1969     // Scalar TensorProto has an empty tensor_shape; no dim, no dim.size.
1970     tensor_proto.mutable_tensor_shape();
1971     if (dtype == DT_INT32) {
1972       tensor_proto.add_int_val(val);
1973     } else if (dtype == DT_INT64) {
1974       tensor_proto.add_int64_val(val);
1975     }
1976     return tensor_proto;
1977   }
1978 
MaybeTensorProtoToShape(InferenceContext * ic,const TensorProto & tensor_proto,ShapeHandle * tensors_as_shapes)1979   bool MaybeTensorProtoToShape(InferenceContext* ic,
1980                                const TensorProto& tensor_proto,
1981                                ShapeHandle* tensors_as_shapes) {
1982     // Skip if dtype is not integer.
1983     if (tensor_proto.dtype() != DT_INT32 && tensor_proto.dtype() != DT_INT64) {
1984       return false;
1985     }
1986     // Skip if the const tensor is too large.
1987     if (NumElementsFromTensorProto(tensor_proto) >
1988         kThresholdToSkipConstTensorInstantiation) {
1989       return false;
1990     }
1991     // Skip if shape is neither scalar nor vector.
1992     if (tensor_proto.tensor_shape().unknown_rank() ||
1993         tensor_proto.tensor_shape().dim_size() > 1) {
1994       return false;
1995     }
1996     Tensor tensor;
1997     if (!tensor.FromProto(tensor_proto)) {
1998       return false;
1999     }
2000     return MaybeTensorValueToShape(ic, tensor, tensors_as_shapes);
2001   }
2002 
MaybeTensorValueToShape(InferenceContext * ic,const Tensor & tensor,ShapeHandle * tensors_as_shapes)2003   bool MaybeTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
2004                                ShapeHandle* tensors_as_shapes) {
2005     // Integer tensors of rank one can also be interpreted as a shape
2006     // provided all their values are >= -1.
2007 
2008     if (IsIntegerVector(tensor)) {
2009       bool has_values_smaller_than_minus_1 = false;
2010       std::vector<DimensionHandle> dims;
2011       for (int i = 0; i < tensor.NumElements(); i++) {
2012         int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(i)
2013                                                    : tensor.flat<int64_t>()(i);
2014         has_values_smaller_than_minus_1 |= (value < -1);
2015         // Mark this as UnknownDim from Const.
2016         dims.push_back(value < 0 ? ic->MakeDim(kUnknownDimFromConst)
2017                                  : ic->MakeDim(value));
2018       }
2019 
2020       if (!has_values_smaller_than_minus_1) {
2021         *tensors_as_shapes = ic->MakeShape(dims);
2022         return true;
2023       }
2024     } else if (IsIntegerScalar(tensor)) {
2025       // Scalar constant.
2026       int64_t value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
2027                                                  : tensor.flat<int64_t>()(0);
2028       if (value == -1) {
2029         // Scalar value -1 represents an unknown shape. If we would try to
2030         // MakeShape(MakeDim) with it, we would get vector of unknown size.
2031         *tensors_as_shapes = ic->UnknownShape();
2032         return true;
2033       } else if (value >= 0) {
2034         // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
2035         // It's a limitation as we use ShapeHandle as a means to pass values.
2036         *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
2037         return true;
2038       }
2039     }
2040     return false;
2041   }
2042 
2043   const GraphView& graph_;
2044   int graph_def_version_;
2045   absl::flat_hash_map<const NodeDef*, NodeContext> node_to_context_;
2046   absl::flat_hash_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
2047   absl::flat_hash_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
2048   // Store function instantiations only for valid function. If function
2049   // instantiation failed it will have an `absl::nullopt`.
2050   absl::flat_hash_map<string, absl::optional<GrapplerFunctionItem>>
2051       fun_to_grappler_function_item_;
2052   FunctionLibraryDefinition function_library_;
2053   const absl::flat_hash_map<string, absl::flat_hash_set<int>>& fed_ports_;
2054   // Store TensorProtos for tensor value propagation. Note that we use deque,
2055   // not vector, as we use pointers to the TensorProtos in this container.
2056   // Vector may resize and copy the objects into a new buffer, then the existing
2057   // pointers become dangling pointers.
2058   std::deque<TensorProto> const_tensors_to_propagate_;
2059 
2060   // For more aggressive shape and value inference.
2061   bool aggressive_shape_inference_;
2062   ResourceMgr resource_mgr_;
2063 };
2064 
2065 // Keep track of shapes and dimensions in a graph.
2066 // In particular, use disjoint sets to track equivalence between shapes and
2067 // dims, and consolidate the information globally.
2068 class SymbolicShapeManager {
2069  public:
SymbolicShapeManager()2070   SymbolicShapeManager() {}
2071 
Merge(ShapeHandle s1,ShapeHandle s2)2072   Status Merge(ShapeHandle s1, ShapeHandle s2) {
2073     if (!s1.IsSet() || !s2.IsSet()) {
2074       return OkStatus();
2075     }
2076     TF_RETURN_IF_ERROR(shapes_.Merge(s1, s2));
2077     if (InferenceContext::Rank(s1) > 0 && InferenceContext::Rank(s2) > 0) {
2078       CHECK_EQ(InferenceContext::Rank(s1), InferenceContext::Rank(s2));
2079       for (int i = 0; i < InferenceContext::Rank(s1); ++i) {
2080         TF_RETURN_IF_ERROR(dims_.Merge(InferenceContext::DimKnownRank(s1, i),
2081                                        InferenceContext::DimKnownRank(s2, i)));
2082       }
2083     }
2084     return OkStatus();
2085   }
Merge(DimensionHandle d1,DimensionHandle d2)2086   Status Merge(DimensionHandle d1, DimensionHandle d2) {
2087     if (!d1.IsSet() || !d2.IsSet()) {
2088       return OkStatus();
2089     }
2090     return dims_.Merge(d1, d2);
2091   }
2092 
AsTensorProperties(const ShapeHandle & shape,const DataType & type,OpInfo::TensorProperties * properties)2093   void AsTensorProperties(const ShapeHandle& shape, const DataType& type,
2094                           OpInfo::TensorProperties* properties) {
2095     properties->set_dtype(type);
2096     ShapeHandle actual_shape = shapes_.GetMergedValue(shape);
2097     if (!InferenceContext::RankKnown(actual_shape)) {
2098       properties->mutable_shape()->set_unknown_rank(true);
2099     } else {
2100       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2101         shape_inference::DimensionHandle dim =
2102             InferenceContext::DimKnownRank(actual_shape, j);
2103         int64_t d = dims_.GetMergedValue(dim);
2104         properties->mutable_shape()->add_dim()->set_size(d);
2105       }
2106     }
2107   }
2108 
2109   // Returns merged shape with merged dimensions.
GetMergedShape(InferenceContext * ic,ShapeHandle s)2110   ShapeHandle GetMergedShape(InferenceContext* ic, ShapeHandle s) {
2111     const auto& actual_shape = shapes_.GetMergedValue(s);
2112     if (!InferenceContext::RankKnown(actual_shape)) {
2113       return ic->UnknownShape();
2114     } else {
2115       std::vector<DimensionHandle> dims;
2116       for (int j = 0; j < InferenceContext::Rank(actual_shape); ++j) {
2117         shape_inference::DimensionHandle dim =
2118             InferenceContext::DimKnownRank(actual_shape, j);
2119         int64_t d = dims_.GetMergedValue(dim);
2120         // Symbolic shape manager may made some dims < -1, which causes errors
2121         // in creating Dimension.
2122         if (d < -1) {
2123           d = -1;
2124         }
2125         dims.push_back(ic->MakeDim(d));
2126       }
2127       return ic->MakeShape(dims);
2128     }
2129   }
2130 
2131  private:
2132   DisjointSet<shape_inference::ShapeHandle> shapes_;
2133   DisjointSet<shape_inference::DimensionHandle> dims_;
2134 };
2135 
2136 // Checks whether there is any conflict in merged shapes and dims in
2137 // SymbolicShapeManager.
ValidateSymbolicShapeManager(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2138 Status ValidateSymbolicShapeManager(const GraphDef& graph_def,
2139                                     SymbolicShapeRefiner* refiner,
2140                                     SymbolicShapeManager* shape_manager) {
2141   if (!VLOG_IS_ON(1)) {
2142     return OkStatus();
2143   }
2144 
2145   VLOG(1) << "Checking any conflicts in shapes and dimensions ...";
2146   int64_t num_incompatible_shapes = 0;
2147   for (const NodeDef& node : graph_def.node()) {
2148     auto ctx = refiner->GetNodeContext(&node);
2149     if (!ctx) {
2150       continue;
2151     }
2152     auto* ic = ctx->inference_context.get();
2153     for (int i = 0; i < ic->num_inputs(); ++i) {
2154       const auto& shape = ic->input(i);
2155       const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2156       if (!refiner->CompatibleShapes(shape, merged_shape)) {
2157         num_incompatible_shapes++;
2158         VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2159                 << "for node " << node.name() << " input (" << i << ") "
2160                 << ic->DebugString(shape)
2161                 << " vs. merged: " << ic->DebugString(merged_shape);
2162       }
2163     }
2164     for (int i = 0; i < ic->num_outputs(); ++i) {
2165       const auto& shape = ic->output(i);
2166       const auto& merged_shape = shape_manager->GetMergedShape(ic, shape);
2167       if (!refiner->CompatibleShapes(shape, merged_shape)) {
2168         num_incompatible_shapes++;
2169         VLOG(1) << "**** Incompatible shape from SymbolicShapeManager "
2170                 << "for node " << node.name() << " output (" << i << ") "
2171                 << ic->DebugString(shape)
2172                 << " vs. merged: " << ic->DebugString(merged_shape);
2173       }
2174     }
2175   }
2176   if (num_incompatible_shapes > 0) {
2177     VLOG(1) << "**** WARNING: " << num_incompatible_shapes
2178             << " incompatible shapes from SymbolicShapeManager.";
2179   } else {
2180     VLOG(1) << "**** No incompatible shape found from SymbolicShapeManager.";
2181   }
2182 
2183   return OkStatus();
2184 }
2185 
2186 // Log shape inference and its merged shapes.
VerboseShapeInferenceLogging(const GraphDef & graph_def,SymbolicShapeRefiner * refiner,SymbolicShapeManager * shape_manager)2187 Status VerboseShapeInferenceLogging(const GraphDef& graph_def,
2188                                     SymbolicShapeRefiner* refiner,
2189                                     SymbolicShapeManager* shape_manager) {
2190   // As logging all the nodes would generate too many lines, we by default
2191   // skip this detailed logging. Users may add nodes of interest to
2192   // node_names_for_logging to enable detailed logging.
2193   absl::flat_hash_set<std::string> node_names_for_logging = {};
2194   if (!VLOG_IS_ON(3) || node_names_for_logging.empty()) {
2195     return OkStatus();
2196   }
2197 
2198   auto should_log = [&node_names_for_logging](std::string node_name) {
2199     return node_names_for_logging.find(node_name) !=
2200            node_names_for_logging.end();
2201   };
2202 
2203   for (const NodeDef& node : graph_def.node()) {
2204     if (!should_log(node.name())) {
2205       continue;
2206     }
2207     auto ctx = refiner->GetNodeContext(&node);
2208     if (!ctx) {
2209       continue;
2210     }
2211     auto* ic = ctx->inference_context.get();
2212     VLOG(3) << "Shape inference for node : " << node.name();
2213     VLOG(3) << ctx->DebugString(node);
2214     std::string merged_shapes = "Merged shapes from SymbolicShapManager:\n";
2215     for (int i = 0; i < ic->num_inputs(); ++i) {
2216       absl::StrAppend(
2217           &merged_shapes, " input[", i, "] -- ",
2218           ic->DebugString(shape_manager->GetMergedShape(ic, ic->input(i))),
2219           "\n");
2220     }
2221     for (int i = 0; i < ic->num_outputs(); ++i) {
2222       absl::StrAppend(
2223           &merged_shapes, " output[", i, "] -- ",
2224           ic->DebugString(shape_manager->GetMergedShape(ic, ic->output(i))),
2225           "\n");
2226     }
2227     VLOG(3) << merged_shapes;
2228     VLOG(3) << "--------------------------------";
2229     VLOG(3) << "";
2230   }
2231 
2232   return OkStatus();
2233 }
2234 
RelaxEnqueueShapesAndMergeTypes(SymbolicShapeRefiner * shape_refiner,const NodeDef * qnode,const std::vector<ShapeAndType> & shapes_and_types,std::vector<ShapeAndType> * queue_shapes_and_types)2235 Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
2236     SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
2237     const std::vector<ShapeAndType>& shapes_and_types,
2238     std::vector<ShapeAndType>* queue_shapes_and_types) {
2239   if (shapes_and_types.size() != queue_shapes_and_types->size()) {
2240     return errors::InvalidArgument(
2241         "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(),
2242         "  vs ", queue_shapes_and_types->size());
2243   }
2244   for (size_t i = 0; i < shapes_and_types.size(); ++i) {
2245     const ShapeAndType& a = shapes_and_types[i];
2246     ShapeAndType& b = (*queue_shapes_and_types)[i];
2247     if (a.dtype != b.dtype) {
2248       return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ",
2249                                      i, ": ", DataTypeString(a.dtype), " vs ",
2250                                      DataTypeString(b.dtype));
2251     }
2252 
2253     b.shape = shape_refiner->OutputAsUnion(qnode, i, a.shape, b.shape);
2254   }
2255   return OkStatus();
2256 }
2257 
2258 // Compute the output shape of the merge node as the union of the available
2259 // input shapes.
UpdateMerge(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes) const2260 Status GraphProperties::UpdateMerge(SymbolicShapeRefiner* shape_refiner,
2261                                     const NodeDef* node,
2262                                     bool* new_shapes) const {
2263   InferenceContext* ic = shape_refiner->GetContext(node);
2264   if (!ic) {
2265     // Now we can run shape inference
2266     TF_RETURN_IF_ERROR(shape_refiner->AddNode(node));
2267     ic = CHECK_NOTNULL(shape_refiner->GetContext(node));
2268     *new_shapes = true;
2269 
2270     // Infer the shape of the second output once and for all since it never
2271     // changes.
2272     ShapeHandle out1 = ic->Scalar();
2273     if (ic->num_outputs() >= 2) ic->set_output(1, out1);
2274   }
2275 
2276   ShapeHandle out;
2277   const std::vector<ShapeAndType>* out_handle = nullptr;
2278   bool out_initialized = false;
2279   for (const GraphView::Edge fanin : shape_refiner->graph().GetFaninEdges(
2280            *node, /*include_controlling_edges=*/false)) {
2281     InferenceContext* src_ic = shape_refiner->GetContext(fanin.src.node);
2282     if (!src_ic) {
2283       // Handling a loop for the first time, the back edge won't have any shape
2284       // info.
2285       continue;
2286     }
2287     ShapeHandle input = src_ic->output(fanin.src.port_id);
2288     ic->SetInput(fanin.dst.port_id, input);
2289     auto* input_handle =
2290         src_ic->output_handle_shapes_and_types(fanin.src.port_id);
2291     if (input_handle)
2292       ic->set_input_handle_shapes_and_types(fanin.dst.port_id, *input_handle);
2293     if (!out_initialized) {
2294       out_initialized = true;
2295       out = input;
2296       out_handle = input_handle;
2297     } else {
2298       // Note here only out, not out_handle, is modified.
2299       out = shape_refiner->OutputAsUnion(node, 0, input, out);
2300     }
2301   }
2302 
2303   if (*new_shapes || !shape_refiner->EquivalentShapes(out, ic->output(0))) {
2304     ic->set_output(0, out);
2305     if (out_handle) ic->set_output_handle_shapes_and_types(0, *out_handle);
2306     *new_shapes = true;
2307   }
2308 
2309   return OkStatus();
2310 }
2311 
2312 // Manually propagate the input shape for Enter nodes.
UpdateEnter(SymbolicShapeRefiner * shape_refiner,const NodeDef * node,bool * new_shapes)2313 Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
2314                                     const NodeDef* node, bool* new_shapes) {
2315   InferenceContext* ic = shape_refiner->GetContext(node);
2316   if (!ic) {
2317     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(node, new_shapes));
2318     ic = shape_refiner->GetContext(node);
2319   }
2320 
2321   GraphView::InputPort port(node, 0);
2322   GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(port);
2323 
2324   InferenceContext* src_ic = shape_refiner->GetContext(fanin.node);
2325   ShapeHandle input = src_ic->output(fanin.port_id);
2326   if (!ic->output(0).SameHandle(input)) {
2327     ic->SetInput(0, input);
2328     ic->set_output(0, input);
2329     *new_shapes = true;
2330   }
2331   auto* outputs = src_ic->output_handle_shapes_and_types(fanin.port_id);
2332   if (outputs) {
2333     ic->set_input_handle_shapes_and_types(0, *outputs);
2334     ic->set_output_handle_shapes_and_types(0, *outputs);
2335     *new_shapes = true;
2336   }
2337   return OkStatus();
2338 }
2339 
UpdateShapes(SymbolicShapeRefiner * shape_refiner,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,const NodeDef * n,bool * new_shapes) const2340 Status GraphProperties::UpdateShapes(
2341     SymbolicShapeRefiner* shape_refiner,
2342     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2343     const NodeDef* n, bool* new_shapes) const {
2344   if (IsEnter(*n)) {
2345     // The Enter shape function always forwards an UnknownShape, so do the right
2346     // thing here.
2347     TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, new_shapes));
2348   } else if (IsMerge(*n)) {
2349     // Properly handle merge nodes.
2350     TF_RETURN_IF_ERROR(UpdateMerge(shape_refiner, n, new_shapes));
2351   } else if (IsEnqueue(*n)) {
2352     // Make sure the shapes of enqueued tensors are propagated to the queue
2353     // itself.
2354     TF_RETURN_IF_ERROR(
2355         UpdateEnqueue(n, resource_handles, shape_refiner, new_shapes));
2356   } else if (IsQueue(*n)) {
2357     // Set shapes and types of Queue ops, if needed.
2358     TF_RETURN_IF_ERROR(UpdateQueue(n, shape_refiner, new_shapes));
2359   } else {
2360     // Rely on regular TF shape refinement for all the other nodes.
2361     // UpdateNode calls UpdateFunction if a function node is detected.
2362     TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(n, new_shapes));
2363   }
2364 
2365   return OkStatus();
2366 }
2367 
2368 // Propagates the shapes in the transitive fan-out of <new_shapes>.
PropagateShapes(SymbolicShapeRefiner * shape_refiner,TopoQueue * new_shapes,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,int num_loops) const2369 Status GraphProperties::PropagateShapes(
2370     SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes,
2371     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2372     int num_loops) const {
2373   // Limit the number of iterations to prevent infinite loops in the presence of
2374   // incorrect shape functions. The algorithm should converge in at most
2375   // num_nested_loops^2 * max_rank. We approximate max_rank with the constant 4.
2376   // The same applies to resources.
2377   VLOG(1) << "Propagating " << new_shapes->size() << " new shapes through "
2378           << num_loops << " loops and " << resource_handles.size()
2379           << " resources" << std::endl;
2380 
2381   const int64_t max_loop_length = item_.graph.node_size();
2382   const int64_t max_rank = 4;
2383   const int64_t max_loop_iterations =
2384       max_rank * max_loop_length * std::max<int64_t>(1, num_loops * num_loops);
2385   const int64_t num_queues = resource_handles.size();
2386   const int64_t max_resource_iterations = num_queues * num_queues * max_rank;
2387 
2388   int64_t num_resource_iterations = 0;
2389   do {
2390     int64_t num_loop_iterations = 0;
2391     while (!new_shapes->empty() &&
2392            num_loop_iterations++ < max_loop_iterations) {
2393       const NodeDef* n = new_shapes->pop();
2394       bool updated = false;
2395       TF_RETURN_IF_ERROR(
2396           UpdateShapes(shape_refiner, resource_handles, n, &updated));
2397       if (updated) {
2398         for (const auto& fanout : shape_refiner->graph().GetFanouts(
2399                  *n, /*include_controlled_nodes=*/false)) {
2400           new_shapes->push(fanout.node);
2401         }
2402         // Make sure the corresponding queue nodes are (re)processed.
2403         if (IsEnqueue(*n)) {
2404           auto it = resource_handles.find(n);
2405           if (it != resource_handles.end()) {
2406             new_shapes->push(it->second);
2407           }
2408         }
2409       }
2410     }
2411   } while (!new_shapes->empty() &&
2412            num_resource_iterations++ < max_resource_iterations);
2413 
2414   if (!new_shapes->empty()) {
2415     return errors::Internal("Shape inference failed to converge");
2416   }
2417 
2418   return OkStatus();
2419 }
2420 
UpdateQueue(const NodeDef * queue_node,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2421 Status GraphProperties::UpdateQueue(const NodeDef* queue_node,
2422                                     SymbolicShapeRefiner* shape_refiner,
2423                                     bool* new_shapes) {
2424   auto* ctx = shape_refiner->GetNodeContext(queue_node);
2425   if (!ctx) {
2426     TF_RETURN_IF_ERROR(shape_refiner->AddNode(queue_node));
2427     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(queue_node));
2428   }
2429   auto* ic = ctx->inference_context.get();
2430 
2431   auto* outputs = ic->output_handle_shapes_and_types(0);
2432   if (outputs) {
2433     // Shapes and types are already set, presumably by Enqueue ops.
2434     return shape_refiner->UpdateNode(queue_node, new_shapes);
2435   }
2436 
2437   if (queue_node->attr().count("shapes") <= 0 ||
2438       queue_node->attr().count("component_types") <= 0 ||
2439       queue_node->attr().at("shapes").list().shape_size() !=
2440           queue_node->attr().at("component_types").list().type_size()) {
2441     // Errors in shapes and component_types attr.
2442     return shape_refiner->UpdateNode(queue_node, new_shapes);
2443   }
2444 
2445   // Extract types and shapes from Queue attr.
2446   const auto& shapes = queue_node->attr().at("shapes").list().shape();
2447   const auto& types = queue_node->attr().at("component_types").list().type();
2448   std::vector<ShapeAndType> shapes_and_types;
2449   for (int i = 0; i < types.size(); i++) {
2450     const auto& shape = shapes[i];
2451     ShapeHandle shape_handle;
2452     TF_RETURN_IF_ERROR(
2453         ic->MakeShapeFromPartialTensorShape(shape, &shape_handle));
2454     DataType data_type =
2455         queue_node->attr().at("component_types").list().type(i);
2456     ShapeAndType shape_and_type(shape_handle, data_type);
2457     shapes_and_types.push_back(shape_and_type);
2458   }
2459   ic->set_output_handle_shapes_and_types(0, shapes_and_types);
2460 
2461   // Queue node is updated with output_handle_shapes_and_types, so set
2462   // new_shapes and ignore it from UpdateNoe().
2463   *new_shapes = true;
2464   bool dummy_new_shapes = false;
2465   return shape_refiner->UpdateNode(queue_node, &dummy_new_shapes);
2466 }
2467 
UpdateEnqueue(const NodeDef * enqueue_node,const absl::flat_hash_map<const NodeDef *,const NodeDef * > & resource_handles,SymbolicShapeRefiner * shape_refiner,bool * new_shapes)2468 Status GraphProperties::UpdateEnqueue(
2469     const NodeDef* enqueue_node,
2470     const absl::flat_hash_map<const NodeDef*, const NodeDef*>& resource_handles,
2471     SymbolicShapeRefiner* shape_refiner, bool* new_shapes) {
2472   auto ctx = shape_refiner->GetNodeContext(enqueue_node);
2473   if (!ctx) {
2474     TF_RETURN_IF_ERROR(shape_refiner->AddNode(enqueue_node));
2475     ctx = CHECK_NOTNULL(shape_refiner->GetNodeContext(enqueue_node));
2476   }
2477 
2478   auto it = resource_handles.find(enqueue_node);
2479   if (it == resource_handles.end()) {
2480     // The corresponding queue was not found, there isn't much we can do.
2481     return OkStatus();
2482   }
2483   const NodeDef* qnode = it->second;
2484   auto qctx = shape_refiner->GetContext(qnode);
2485   if (!qctx) {
2486     return OkStatus();
2487   }
2488   auto* queue_handle_data = qctx->output_handle_shapes_and_types(0);
2489 
2490   // TODO(bsteiner): handle EnqueueMany as well.
2491   std::vector<ShapeAndType> shapes_and_types;
2492   for (int i = 1, end = ctx->input_types.size(); i < end; ++i) {
2493     GraphView::InputPort inp(enqueue_node, i);
2494     GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
2495     InferenceContext* in = shape_refiner->GetContext(fanin.node);
2496     ShapeHandle input = in->output(fanin.port_id);
2497     ctx->inference_context->SetInput(i, input);
2498     shapes_and_types.push_back({input, ctx->input_types[i]});
2499   }
2500 
2501   if (queue_handle_data == nullptr) {
2502     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2503     *new_shapes = true;
2504   } else {
2505     TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
2506         shape_refiner, qnode, *queue_handle_data, &shapes_and_types));
2507     *new_shapes |= !shape_refiner->EquivalentShapesAndTypes(*queue_handle_data,
2508                                                             shapes_and_types);
2509     qctx->set_output_handle_shapes_and_types(0, shapes_and_types);
2510   }
2511 
2512   return OkStatus();
2513 }
2514 
InferStatically(bool assume_valid_feeds,bool aggressive_shape_inference,bool include_input_tensor_values,bool include_output_tensor_values)2515 Status GraphProperties::InferStatically(bool assume_valid_feeds,
2516                                         bool aggressive_shape_inference,
2517                                         bool include_input_tensor_values,
2518                                         bool include_output_tensor_values) {
2519   FunctionLibraryDefinition function_library(OpRegistry::Global(),
2520                                              item_.graph.library());
2521   absl::flat_hash_map<string, absl::flat_hash_set<int>> fed_ports;
2522   if (!assume_valid_feeds) {
2523     for (const auto& feed : item_.feed) {
2524       SafeTensorId tensor_id = ParseTensorName(feed.first);
2525       fed_ports[tensor_id.node()].insert(tensor_id.index());
2526     }
2527   }
2528 
2529   GraphView graph_view(&item_.graph);
2530 
2531   // List the resources and the nodes using them. Also collect the Merge nodes,
2532   // fed nodes, and primary inputs.
2533   absl::flat_hash_map<const NodeDef*,
2534                       std::pair<absl::flat_hash_set<const NodeDef*>,
2535                                 absl::flat_hash_set<const NodeDef*>>>
2536       resources;
2537   absl::flat_hash_set<const NodeDef*> merge_nodes;
2538   absl::flat_hash_set<const NodeDef*> fed_nodes;
2539   absl::flat_hash_set<const NodeDef*> primary_inputs;
2540   int num_loops = 0;
2541   for (const NodeDef& node : item_.graph.node()) {
2542     if (IsQueue(node)) {
2543       for (const GraphView::InputPort& fanout :
2544            graph_view.GetFanouts(node, false)) {
2545         if (IsEnter(*fanout.node)) {
2546           const NodeDef& enter = *fanout.node;
2547           for (const GraphView::InputPort& fanout :
2548                graph_view.GetFanouts(enter, false)) {
2549             if (IsEnqueue(*fanout.node)) {
2550               resources[&node].first.insert(fanout.node);
2551             } else if (IsDequeue(*fanout.node)) {
2552               resources[&node].second.insert(fanout.node);
2553             }
2554           }
2555         } else {
2556           if (IsEnqueue(*fanout.node)) {
2557             resources[&node].first.insert(fanout.node);
2558           } else if (IsDequeue(*fanout.node)) {
2559             resources[&node].second.insert(fanout.node);
2560           }
2561         }
2562       }
2563     }
2564     if (!HasRegularInputs(node)) {
2565       primary_inputs.insert(&node);
2566     } else if (IsMerge(node)) {
2567       merge_nodes.insert(&node);
2568     } else if (IsNextIteration(node)) {
2569       ++num_loops;
2570     }
2571     if (fed_ports.find(node.name()) != fed_ports.end()) {
2572       fed_nodes.insert(&node);
2573     }
2574   }
2575 
2576   absl::flat_hash_map<const NodeDef*, const NodeDef*> resource_handles;
2577   std::vector<TopologicalDependency> extra_deps;
2578   for (const auto& resource : resources) {
2579     for (const NodeDef* src : resource.second.first) {
2580       resource_handles[src] = resource.first;
2581       for (const NodeDef* dst : resource.second.second) {
2582         // Add control edges from enqueue to dequeue nodes to ensure they are
2583         // processed in their logical order.
2584         extra_deps.emplace_back(src, dst);
2585       }
2586     }
2587   }
2588 
2589   std::vector<const NodeDef*> topo_order;
2590   Status s = ComputeTopologicalOrder(item_.graph, extra_deps, &topo_order);
2591   if (!s.ok()) {
2592     if (extra_deps.empty()) {
2593       return s;
2594     } else {
2595       // There is a loop between queues: we'll just use the graph topological
2596       // order. This will make the shape inference less precise but since this
2597       // isn't common it's not worth to figure out where to break the loop and
2598       // do a proper relaxation.
2599       TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
2600     }
2601   }
2602 
2603   // Heap-allocate SymbolicShapeRefiner in order to not consume a large amount
2604   // of stack space.
2605   auto refiner = std::make_unique<SymbolicShapeRefiner>(
2606       graph_view, fed_ports, aggressive_shape_inference);
2607 
2608   TopoQueue new_shapes(topo_order);
2609   // Also seed the propagation of shapes in the fanout of primary inputs.
2610   for (const NodeDef* node : primary_inputs) {
2611     new_shapes.push(node);
2612   }
2613   // Also seed the propagation of shapes in the fanout of fed nodes.
2614   for (const NodeDef* node : fed_nodes) {
2615     new_shapes.push(node);
2616   }
2617   // Propagate shapes normally.
2618   TF_RETURN_IF_ERROR(
2619       PropagateShapes(refiner.get(), &new_shapes, resource_handles, num_loops));
2620 
2621   // Track shapes globally across the graph.
2622   std::unique_ptr<SymbolicShapeManager> shape_manager =
2623       std::make_unique<SymbolicShapeManager>();
2624   bool found_error = false;
2625   for (const NodeDef& node : item_.graph.node()) {
2626     auto node_ctx = refiner->GetContext(&node);
2627     if (!node_ctx) {
2628       continue;
2629     }
2630     // Skip any information that comes from fed nodes.
2631     if (fed_ports.find(node.name()) != fed_ports.end()) {
2632       VLOG(2) << "Skipping feed node shape: " << node.name();
2633       continue;
2634     }
2635     for (const auto& merged_shapes : node_ctx->MergedShapes()) {
2636       if (!shape_manager->Merge(merged_shapes.first, merged_shapes.second)
2637                .ok()) {
2638         found_error = true;
2639         break;
2640       }
2641     }
2642     for (const auto& merged_dims : node_ctx->MergedDims()) {
2643       if (!shape_manager->Merge(merged_dims.first, merged_dims.second).ok()) {
2644         found_error = true;
2645         break;
2646       }
2647     }
2648     if (found_error) {
2649       // The shapes aren't consistent, we can't infer safely: discard all the
2650       // information discovered so far.
2651       shape_manager = std::make_unique<SymbolicShapeManager>();
2652       break;
2653     }
2654   }
2655 
2656   TF_RETURN_IF_ERROR(ValidateSymbolicShapeManager(item_.graph, refiner.get(),
2657                                                   shape_manager.get()));
2658 
2659   for (const NodeDef& node : item_.graph.node()) {
2660     VLOG(4) << "Filling in graph properties for node: " << node.name();
2661     auto ctx = refiner->GetNodeContext(&node);
2662     if (!ctx) {
2663       continue;
2664     }
2665 
2666     auto* ic = ctx->inference_context.get();
2667 
2668     // Fill input properties.
2669     {
2670       auto& input_properties = input_properties_[node.name()];
2671 
2672       // Should always be empty, node names in graph are supposed to be unique.
2673       CHECK_EQ(input_properties.size(), 0);
2674 
2675       input_properties.resize(ic->num_inputs());
2676       GraphView::InputPort input(&node, -1);
2677       for (int i = 0; i < ic->num_inputs(); ++i) {
2678         shape_manager->AsTensorProperties(ic->input(i), ctx->input_types[i],
2679                                           &input_properties[i]);
2680         input.port_id = i;
2681         GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
2682         if (include_input_tensor_values) {
2683           // Export tensor value to input_properties.value.
2684           if (IsConstant(*fanin.node)) {
2685             const TensorProto& raw_val =
2686                 fanin.node->attr().at("value").tensor();
2687             *input_properties[i].mutable_value() = raw_val;
2688           } else if (static_cast<int>(ctx->input_tensor_protos.size()) > i &&
2689                      ctx->input_tensor_protos[i] != nullptr) {
2690             *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
2691           } else if (static_cast<int>(ic->input_tensors_as_shapes().size()) >
2692                          i &&
2693                      IsShapeFullyDefinedIntegerVectorOrScalar(
2694                          ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2695                          ctx->input_types[i])) {
2696             *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
2697                 ic, ic->input(i), ic->input_tensors_as_shapes()[i],
2698                 ctx->input_types[i]);
2699           }
2700         }
2701       }
2702     }
2703 
2704     // Fill output properties.
2705     {
2706       auto& output_properties = output_properties_[node.name()];
2707 
2708       // Should always be empty, node names in graph are supposed to be unique.
2709       CHECK_EQ(output_properties.size(), 0);
2710 
2711       output_properties.resize(ic->num_outputs());
2712       for (int i = 0; i < ic->num_outputs(); ++i) {
2713         shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
2714                                           &output_properties[i]);
2715         auto converted_output_tensors_as_shapes =
2716             ReplaceUnknownDimFromConstWithUnknownDim(
2717                 ic, ctx->output_tensors_as_shapes);
2718         if (include_output_tensor_values) {
2719           // Export tensor value to output_properties.value.
2720           if (IsConstant(node)) {
2721             // TODO(rmlarsen): Eliminate this copy.
2722             const TensorProto& raw_val = node.attr().at("value").tensor();
2723             *output_properties[i].mutable_value() = raw_val;
2724           } else if (static_cast<int>(ctx->output_tensor_protos.size()) > i &&
2725                      ctx->output_tensor_protos[i] != nullptr) {
2726             *output_properties[i].mutable_value() =
2727                 *ctx->output_tensor_protos[i];
2728           } else if (static_cast<int>(
2729                          converted_output_tensors_as_shapes.size()) > i &&
2730                      IsShapeFullyDefinedIntegerVectorOrScalar(
2731                          ic, ic->output(i),
2732                          converted_output_tensors_as_shapes[i],
2733                          ctx->output_types[i])) {
2734             *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
2735                 ic, ic->output(i), converted_output_tensors_as_shapes[i],
2736                 ctx->output_types[i]);
2737           }
2738         }
2739       }
2740     }
2741 
2742     if (aggressive_shape_inference && ctx->shape_incompatible)
2743       incompatible_shape_nodes_.insert(node.name());
2744   }
2745 
2746   if (aggressive_shape_inference && !incompatible_shape_nodes_.empty())
2747     LOG(WARNING) << incompatible_shape_nodes_.size()
2748                  << " nodes have incompatible output shapes.";
2749 
2750   // Help trace the unknown dimensions to their origins.
2751   VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
2752                                     output_properties_);
2753 
2754   TF_RETURN_IF_ERROR(VerboseShapeInferenceLogging(item_.graph, refiner.get(),
2755                                                   shape_manager.get()));
2756 
2757   return OkStatus();
2758 }
2759 
InferDynamically(Cluster * cluster)2760 Status GraphProperties::InferDynamically(Cluster* cluster) {
2761   TF_RETURN_IF_ERROR(cluster->Initialize(item_));
2762 
2763   // Runs the model once to collect the shapes in the cost model.
2764   RunMetadata metadata;
2765   TF_RETURN_IF_ERROR(
2766       cluster->Run(item_.graph, item_.feed, item_.fetch, &metadata));
2767 
2768   return InferFromCostGraph(metadata.cost_graph());
2769 }
2770 
AnnotateOutputShapes(GraphDef * output_graph_def) const2771 Status GraphProperties::AnnotateOutputShapes(GraphDef* output_graph_def) const {
2772   *output_graph_def = item_.graph;
2773   for (int i = 0; i < output_graph_def->node_size(); i++) {
2774     auto node = output_graph_def->mutable_node(i);
2775     AttrValue attr_output_shape;
2776     auto tensor_properties = GetOutputProperties(node->name());
2777     for (const auto& tensor_property : tensor_properties) {
2778       TensorShapeProto* proto = attr_output_shape.mutable_list()->add_shape();
2779       *proto = tensor_property.shape();
2780       NormalizeShapeForOutput(proto);
2781     }
2782     (*node->mutable_attr())["_output_shapes"] = std::move(attr_output_shape);
2783   }
2784   return OkStatus();
2785 }
2786 
InferFromCostGraph(const CostGraphDef & cost_graph)2787 Status GraphProperties::InferFromCostGraph(const CostGraphDef& cost_graph) {
2788   if (cost_graph.node_size() == 0) {
2789     LOG(WARNING) << "cost_graph is empty: nothing can be inferred!";
2790   }
2791   std::unordered_map<string, const CostGraphDef::Node*> name_to_cost;
2792   std::unordered_map<string, const NodeDef*> name_to_node;  // Empty
2793   for (auto& node : cost_graph.node()) {
2794     name_to_cost[node.name()] = &node;
2795 
2796     std::vector<OpInfo::TensorProperties> output_properties;
2797     for (const auto& out : node.output_info()) {
2798       OpInfo::TensorProperties properties;
2799       properties.set_dtype(out.dtype());
2800       *properties.mutable_shape() = out.shape();
2801       output_properties.push_back(properties);
2802     }
2803     output_properties_[node.name()] = output_properties;
2804   }
2805 
2806   for (const auto& node : item_.graph.node()) {
2807     // Skip the nodes that are not in the cost graph: these are nodes that
2808     // aren't run, because they aren't in the intersection of transitive fan-in
2809     // of a fetch node and the transitive fan-out of an input, or nodes that
2810     // were optimized away by the optimizer.
2811     auto it = name_to_cost.find(node.name());
2812     if (it == name_to_cost.end()) {
2813       continue;
2814     }
2815     std::vector<OpInfo::TensorProperties> inputs =
2816         FindInputFeatures(node, name_to_cost, name_to_node);
2817 
2818     input_properties_[node.name()] = inputs;
2819   }
2820   return OkStatus();
2821 }
2822 
HasInputProperties(const string & node_name) const2823 bool GraphProperties::HasInputProperties(const string& node_name) const {
2824   return input_properties_.find(node_name) != input_properties_.end();
2825 }
2826 
HasOutputProperties(const string & node_name) const2827 bool GraphProperties::HasOutputProperties(const string& node_name) const {
2828   return output_properties_.find(node_name) != output_properties_.end();
2829 }
2830 
2831 const std::vector<OpInfo::TensorProperties>&
GetInputProperties(const string & node_name) const2832 GraphProperties::GetInputProperties(const string& node_name) const {
2833   auto it = input_properties_.find(node_name);
2834   if (it != input_properties_.end()) {
2835     return it->second;
2836   }
2837   return missing_properties_;
2838 }
2839 
2840 const std::vector<OpInfo::TensorProperties>&
GetOutputProperties(const string & node_name) const2841 GraphProperties::GetOutputProperties(const string& node_name) const {
2842   auto it = output_properties_.find(node_name);
2843   if (it != output_properties_.end()) {
2844     return it->second;
2845   }
2846   return missing_properties_;
2847 }
2848 
ClearInputProperties(const string & node_name)2849 void GraphProperties::ClearInputProperties(const string& node_name) {
2850   input_properties_.erase(node_name);
2851 }
ClearOutputProperties(const string & node_name)2852 void GraphProperties::ClearOutputProperties(const string& node_name) {
2853   output_properties_.erase(node_name);
2854 }
2855 
2856 }  // end namespace grappler
2857 }  // end namespace tensorflow
2858