xref: /aosp_15_r20/external/tensorflow/tensorflow/core/graph/costmodel.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/graph/costmodel.h"
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "tensorflow/core/framework/allocation_description.pb.h"
22 #include "tensorflow/core/framework/cost_graph.pb.h"
23 #include "tensorflow/core/framework/step_stats.pb.h"
24 #include "tensorflow/core/framework/tensor_description.pb.h"
25 #include "tensorflow/core/graph/graph.h"
26 #include "tensorflow/core/platform/logging.h"
27 
28 namespace tensorflow {
29 namespace {
30 const Microseconds kDefaultTimeEstimate(1);
31 const Microseconds kMinTimeEstimate(1);
32 }  // namespace
33 
SuppressInfrequent()34 void CostModel::SuppressInfrequent() {
35   // Find the median of the non-zero counts, and use half of its value
36   // as the cutoff for a "normal" execution mode node.
37   if (count_.empty()) return;
38   std::vector<int32> non_zero;
39   for (auto v : count_) {
40     if (v > 0) non_zero.push_back(v);
41   }
42   const size_t sz = non_zero.size();
43   if (sz > 0) {
44     std::nth_element(non_zero.begin(), non_zero.begin() + sz / 2,
45                      non_zero.end());
46     int32_t median_value = non_zero[sz / 2];
47     min_count_ = median_value / 2;
48     VLOG(1) << "num non_zero vals: " << non_zero.size() << " median_value "
49             << median_value;
50   } else {
51     min_count_ = 1;
52   }
53 }
54 
MergeFromLocal(const Graph & g,const CostModel & cm)55 void CostModel::MergeFromLocal(const Graph& g, const CostModel& cm) {
56   CHECK(is_global_);
57   CHECK(!cm.is_global());
58   for (const Node* n : g.nodes()) {
59     const int local_id = cm.Id(n);
60     const int global_id = Id(n);
61     if (local_id < 0 || global_id < 0) continue;
62     int num_slots = cm.slot_bytes_[local_id].size();
63     Ensure(global_id, num_slots);
64     count_[global_id] += cm.count_[local_id];
65     time_[global_id] += cm.time_[local_id];
66     if (num_slots > 0) {
67       if (slot_bytes_[global_id].empty()) {
68         slot_bytes_[global_id].resize(num_slots);
69       } else {
70         CHECK_EQ(num_slots, slot_bytes_[global_id].size());
71       }
72       for (int s = 0; s < num_slots; ++s) {
73         slot_bytes_[global_id][s] += cm.slot_bytes_[local_id][s];
74       }
75     }
76   }
77 }
78 
MergeFromGlobal(const CostModel & cm)79 void CostModel::MergeFromGlobal(const CostModel& cm) {
80   CHECK(is_global_);
81   CHECK_EQ(true, cm.is_global());
82   const int num_nodes = cm.count_.size();
83   for (int i = num_nodes - 1; i >= 0; --i) {
84     count_[i] += cm.count_[i];
85     time_[i] += cm.time_[i];
86     int num_slots = cm.slot_bytes_[i].size();
87     Ensure(i, num_slots);
88     if (num_slots > 0) {
89       if (slot_bytes_[i].empty()) {
90         slot_bytes_[i].resize(num_slots);
91       } else {
92         CHECK_EQ(num_slots, slot_bytes_[i].size());
93       }
94       for (int s = 0; s < num_slots; ++s) {
95         slot_bytes_[i][s] += cm.slot_bytes_[i][s];
96       }
97     }
98   }
99 }
100 
MergeFromStats(const NodeNameToCostIdMap & map,const StepStats & ss)101 void CostModel::MergeFromStats(const NodeNameToCostIdMap& map,
102                                const StepStats& ss) {
103   CHECK(is_global_);
104   for (auto& ds : ss.dev_stats()) {
105     for (auto& ns : ds.node_stats()) {
106       NodeNameToCostIdMap::const_iterator iter = map.find(ns.node_name());
107       // We don't keep stats for nodes not in the global graph, i.e.
108       // copy/send/recv nodes, feed/fetch, etc.
109       if (iter == map.end()) continue;
110       int32_t global_id = iter->second;
111       Ensure(global_id, ns.output_size());
112       int64_t elapsed_micros =
113           ns.op_end_rel_micros() - ns.op_start_rel_micros();
114       count_[global_id]++;
115       time_[global_id] += elapsed_micros;
116       for (auto& no : ns.output()) {
117         int si = no.slot();
118         if (static_cast<size_t>(si) >= slot_bytes_[global_id].size()) {
119           slot_bytes_[global_id].resize(1 + si);
120         }
121         slot_bytes_[global_id][si] +=
122             no.tensor_description().allocation_description().requested_bytes();
123       }
124     }
125   }
126 }
127 
Ensure(int id,int num_outputs)128 void CostModel::Ensure(int id, int num_outputs) {
129   if (slot_bytes_.size() <= static_cast<size_t>(id)) {
130     slot_bytes_.resize(id + 1);
131     count_.resize(id + 1);
132     time_.resize(id + 1);
133     max_mem_usage_.resize(id + 1);
134     max_exec_time_.resize(id + 1);
135     output_port_alloc_ids_.resize(id + 1);
136   }
137   if (num_outputs > 0) {
138     auto perslot = &slot_bytes_[id];
139     auto output_port_alloc_ids = &output_port_alloc_ids_[id];
140     auto max_mem_usage = &max_mem_usage_[id];
141 
142     CHECK_LE(perslot->size(), num_outputs);
143     DCHECK_EQ(output_port_alloc_ids->size(), perslot->size());
144     DCHECK_EQ(max_mem_usage->output_port_mem.size(), perslot->size());
145     DCHECK_EQ(max_mem_usage->output_port_shape.size(), perslot->size());
146     DCHECK_EQ(max_mem_usage->output_port_type.size(), perslot->size());
147 
148     perslot->resize(num_outputs, Bytes(-1));
149     output_port_alloc_ids->resize(num_outputs, -1);
150     max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
151     max_mem_usage->output_port_shape.resize(num_outputs, unknown_shape_);
152     max_mem_usage->output_port_type.resize(num_outputs, DT_INVALID);
153   }
154 }
155 
SetNumOutputs(const Node * node,int num_outputs)156 void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
157   const int id = Id(node);
158   if (id < 0) return;
159   // Do not resize the number of slots before checking its existing number of
160   // slots.
161   Ensure(id, 0);
162   auto perslot = &slot_bytes_[id];
163   if (!perslot->empty()) {
164     CHECK_EQ(num_outputs, perslot->size())
165         << "Cannot resize slot_bytes, node=" << node->name();
166   }
167   Ensure(id, num_outputs);
168 }
169 
RecordCount(const Node * node,int count)170 void CostModel::RecordCount(const Node* node, int count) {
171   const int id = Id(node);
172   if (id < 0) return;
173   CHECK_LT(id, slot_bytes_.size());
174   count_[id] += count;
175 }
176 
TotalCount(const Node * node) const177 int32 CostModel::TotalCount(const Node* node) const {
178   const int id = Id(node);
179   if (id < 0) return 0;
180   return (static_cast<size_t>(id) < slot_bytes_.size()) ? count_[id] : 0;
181 }
182 
RecordSize(const Node * node,int slot,Bytes bytes)183 void CostModel::RecordSize(const Node* node, int slot, Bytes bytes) {
184   const int id = Id(node);
185   if (id < 0) return;
186   CHECK_LT(id, slot_bytes_.size());
187   auto perslot = &slot_bytes_[id];
188   CHECK_LT(slot, perslot->size());
189   auto v = &(*perslot)[slot];
190   if (*v >= 0) {
191     *v += bytes;
192   } else {
193     *v = bytes;
194   }
195 }
196 
TotalBytes(const Node * node,int slot) const197 Bytes CostModel::TotalBytes(const Node* node, int slot) const {
198   const int id = Id(node);
199   if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() ||
200       slot_bytes_[id].size() <= static_cast<size_t>(slot)) {
201     return Bytes(0);
202   }
203   return slot_bytes_[id][slot];
204 }
205 
SizeEstimate(const Node * node,int slot) const206 Bytes CostModel::SizeEstimate(const Node* node, int slot) const {
207   int32_t count = TotalCount(node);
208   if (count < min_count_) return Bytes(0);
209   return TotalBytes(node, slot) / std::max(1, TotalCount(node));
210 }
211 
RecordTime(const Node * node,Microseconds time)212 void CostModel::RecordTime(const Node* node, Microseconds time) {
213   const int id = Id(node);
214   if (id < 0) return;
215   DCHECK(node->IsOp()) << node->DebugString();
216   Ensure(id, node->num_outputs());
217   time_[id] += time;
218 }
219 
TotalTime(const Node * node) const220 Microseconds CostModel::TotalTime(const Node* node) const {
221   DCHECK(node->IsOp()) << node->DebugString();
222   const int id = Id(node);
223   if (id < 0 || static_cast<size_t>(id) >= time_.size() ||
224       time_[id] < Microseconds(0)) {
225     return Microseconds(0);
226   }
227   return time_[id];
228 }
229 
TimeEstimate(const Node * node) const230 Microseconds CostModel::TimeEstimate(const Node* node) const {
231   int32_t count = TotalCount(node);
232   if (count <= min_count_) return kMinTimeEstimate;
233   return std::max(kMinTimeEstimate, TotalTime(node) / std::max(1, count));
234 }
235 
CheckInitialized(const Graph & graph) const236 void CostModel::CheckInitialized(const Graph& graph) const {
237   for (const Node* n : graph.op_nodes()) {
238     CHECK(static_cast<size_t>(n->id()) < time_.size() &&
239           time_[n->id()] >= Microseconds(0))
240         << ": no time estimate for " << n->DebugString();
241 
242     CHECK(static_cast<size_t>(n->id()) < slot_bytes_.size())
243         << ": no size estimate for " << n->DebugString();
244     const auto& perslot = slot_bytes_[n->id()];
245     for (size_t i = 0; i < perslot.size(); i++) {
246       CHECK_GE(perslot[i], Bytes(0)) << ": no size estimate for output# " << i
247                                      << " of " << n->DebugString();
248     }
249   }
250 }
251 
RecordMaxMemorySize(const Node * node,int output_slot,Bytes bytes,const TensorShapeProto & tensor_shape,const DataType & dtype)252 void CostModel::RecordMaxMemorySize(const Node* node, int output_slot,
253                                     Bytes bytes,
254                                     const TensorShapeProto& tensor_shape,
255                                     const DataType& dtype) {
256   const int id = Id(node);
257   if (id < 0) return;
258   if (output_slot >= node->num_outputs()) {
259     LOG(ERROR) << "Unexpected output slot for node " << node->DebugString()
260                << ". Got " << output_slot << " but its num_outputs is "
261                << node->num_outputs();
262     return;
263   }
264   Ensure(id, node->num_outputs());
265   auto& current_max = max_mem_usage_[id].output_port_mem[output_slot];
266   // If the memory allocator doesn't track memory usage, let's infer a lower
267   // bound from the tensor shape and its data type.
268   if (bytes.value() < 0) {
269     bytes = MinTensorMemoryUsage(tensor_shape, dtype);
270   }
271   if (bytes.value() > current_max.value()) {
272     current_max = bytes.value();
273     max_mem_usage_[id].output_port_shape[output_slot] = tensor_shape;
274     max_mem_usage_[id].output_port_type[output_slot] = dtype;
275   }
276 }
277 
MaxMemorySize(const Node * node,int slot) const278 Bytes CostModel::MaxMemorySize(const Node* node, int slot) const {
279   const int id = Id(node);
280   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
281       max_mem_usage_[id].output_port_mem.size() <= static_cast<size_t>(slot)) {
282     return Bytes(0);
283   }
284   return max_mem_usage_[id].output_port_mem[slot];
285 }
286 
MaxMemoryShape(const Node * node,int slot) const287 const TensorShapeProto& CostModel::MaxMemoryShape(const Node* node,
288                                                   int slot) const {
289   const int id = Id(node);
290   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
291       max_mem_usage_[id].output_port_shape.size() <=
292           static_cast<size_t>(slot)) {
293     return unknown_shape_;
294   }
295   return max_mem_usage_[id].output_port_shape[slot];
296 }
297 
MaxMemoryType(const Node * node,int slot) const298 DataType CostModel::MaxMemoryType(const Node* node, int slot) const {
299   const int id = Id(node);
300   if (id < 0 || static_cast<size_t>(id) >= max_mem_usage_.size() ||
301       max_mem_usage_[id].output_port_type.size() <= static_cast<size_t>(slot)) {
302     return DT_INVALID;
303   }
304   return max_mem_usage_[id].output_port_type[slot];
305 }
306 
TempMemorySize(const Node * node) const307 Bytes CostModel::TempMemorySize(const Node* node) const {
308   const int id = Id(node);
309   if (id < 0) {
310     return Bytes(0);
311   }
312   return max_mem_usage_[id].temp_memory_size;
313 }
314 
PersistentMemorySize(const Node * node) const315 Bytes CostModel::PersistentMemorySize(const Node* node) const {
316   const int id = Id(node);
317   if (id < 0) {
318     return Bytes(0);
319   }
320   return max_mem_usage_[id].persistent_memory_size;
321 }
322 
RecordMemoryStats(const Node * node,const MemoryStats & memory_stats)323 void CostModel::RecordMemoryStats(const Node* node,
324                                   const MemoryStats& memory_stats) {
325   const int id = Id(node);
326   if (id < 0) return;
327   max_mem_usage_[id].temp_memory_size = memory_stats.temp_memory_size();
328   max_mem_usage_[id].persistent_memory_size =
329       memory_stats.persistent_memory_size();
330   for (int64_t alloc_id : memory_stats.persistent_tensor_alloc_ids()) {
331     if (alloc_id > 0) {
332       persistent_alloc_ids_.insert(alloc_id);
333     }
334   }
335 }
336 
RecordMaxExecutionTime(const Node * node,Microseconds time)337 void CostModel::RecordMaxExecutionTime(const Node* node, Microseconds time) {
338   const int id = Id(node);
339   if (id < 0) return;
340   Ensure(id, node->num_outputs());
341   max_exec_time_[id] = std::max(max_exec_time_[id], time);
342 }
343 
MaxExecutionTime(const Node * node) const344 Microseconds CostModel::MaxExecutionTime(const Node* node) const {
345   const int id = Id(node);
346   if (id < 0 || static_cast<size_t>(id) >= max_exec_time_.size()) {
347     return Microseconds(0);
348   }
349   return max_exec_time_[id];
350 }
351 
RecordAllocationId(const Node * node,int output_slot,int64_t alloc_id)352 void CostModel::RecordAllocationId(const Node* node, int output_slot,
353                                    int64_t alloc_id) {
354   const int id = Id(node);
355   if (id < 0) return;
356   Ensure(id, node->num_outputs());
357   output_port_alloc_ids_[id][output_slot] = alloc_id;
358 }
359 
AllocationId(const Node * node,int slot) const360 int64_t CostModel::AllocationId(const Node* node, int slot) const {
361   const int id = Id(node);
362   if (id < 0 || static_cast<size_t>(id) >= output_port_alloc_ids_.size() ||
363       output_port_alloc_ids_[id].size() <= static_cast<size_t>(slot)) {
364     return -1;
365   }
366   return output_port_alloc_ids_[id][slot];
367 }
368 
IsPersistentTensor(const Node * node,int64_t alloc_id) const369 bool CostModel::IsPersistentTensor(const Node* node, int64_t alloc_id) const {
370   if (persistent_alloc_ids_.count(alloc_id) > 0) {
371     return true;
372   }
373   if (persistent_alloc_ids_by_devices_.find(node->assigned_device_name()) ==
374       persistent_alloc_ids_by_devices_.end()) {
375     return false;
376   }
377   return persistent_alloc_ids_by_devices_.at(node->assigned_device_name())
378       .count(alloc_id);
379 }
380 
CopyTimeEstimate(Bytes b,double network_latency_millis,double estimated_gbps)381 Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis,
382                                          double estimated_gbps) {
383   // TODO(jeff,sanjay): estimate cost based on bandwidth along the
384   // communication path and the type of transport we are using between
385   // devices.
386   //
387   // We assume the copy time follows a linear model:
388   //    copy_time = copy_bytes / rate + min_time
389   int64_t copy_bytes = b.value();
390   const double bytes_per_usec = estimated_gbps * 1000.0 / 8;
391   const double min_micros = network_latency_millis * 1000.0;
392   return Microseconds(
393       static_cast<int64_t>(copy_bytes / bytes_per_usec + min_micros));
394 }
395 
ComputationTimeEstimate(int64_t math_ops)396 Microseconds CostModel::ComputationTimeEstimate(int64_t math_ops) {
397   // TODO(jeff,sanjay): Eventually we should pass in the type of device
398   // (GPU vs. CPU) and use that to affect the estimate.
399 
400   // We estimate the microseconds using that value.  We divide
401   // by 1000 to convert the madd number into microseconds (assuming
402   // roughly 1000 madds per microsecond (~1 GHz for one core)).
403   return Microseconds(math_ops / 1000);
404 }
405 
IncrementUpdateTimes()406 void CostModel::IncrementUpdateTimes() { update_times_++; }
407 
GetUpdateTimes() const408 int32 CostModel::GetUpdateTimes() const { return update_times_; }
409 
410 // ----------------------------------------------------------------------------
411 // InitCostModel
412 // ----------------------------------------------------------------------------
413 
414 namespace {
415 
AddNodesToCostModel(const Graph & g,CostModel * cost_model)416 static void AddNodesToCostModel(const Graph& g, CostModel* cost_model) {
417   for (Node* n : g.nodes()) {
418     const int num_outputs = n->num_outputs();
419     cost_model->SetNumOutputs(n, num_outputs);
420     for (int output = 0; output < num_outputs; output++) {
421       // Set up an initial bogus estimate for the node's outputs
422       cost_model->RecordSize(n, output, Bytes(1));
423     }
424   }
425 }
426 
AssignSizes(const Graph & g,CostModel * cost_model)427 static void AssignSizes(const Graph& g, CostModel* cost_model) {
428   for (const Edge* e : g.edges()) {
429     // Skip if it is a control edge.
430     if (e->IsControlEdge()) {
431       continue;
432     }
433     const Node* src = e->src();
434 
435     // TODO(josh11b): Get an estimate from the Op
436     Bytes size(1);
437     cost_model->RecordSize(src, e->src_output(), size);
438   }
439 }
440 
441 // This generates an extremely simple initial guess for the
442 // computation cost of each node. For ordinary Ops, its value should quickly
443 // be wiped out by the real runtime measurements.  For other Ops we don't
444 // actually generate measurements, so suppression of infrequent Ops ends up
445 // giving them 0 costs.  So, this is not of much consequence except perhaps
446 // in tests.
TimeEstimateForNode(CostModel * cost_model,Node * n)447 static Microseconds TimeEstimateForNode(CostModel* cost_model, Node* n) {
448   CHECK(n->IsOp());
449   VLOG(2) << "Node " << n->id() << ": " << n->name()
450           << " type_string: " << n->type_string();
451   if (IsConstant(n) || IsVariable(n)) {
452     return Microseconds(0);
453   }
454   return kDefaultTimeEstimate;
455 }
456 
EstimateComputationCosts(const Graph & g,CostModel * cost_model)457 static void EstimateComputationCosts(const Graph& g, CostModel* cost_model) {
458   for (Node* n : g.nodes()) {
459     if (!n->IsOp()) continue;
460     cost_model->RecordTime(n, TimeEstimateForNode(cost_model, n));
461   }
462 }
463 
464 }  // namespace
465 
InitFromGraph(const Graph & g)466 void CostModel::InitFromGraph(const Graph& g) {
467   const int num_node_ids = g.num_node_ids();
468   slot_bytes_.reserve(num_node_ids);
469   count_.reserve(num_node_ids);
470   time_.reserve(num_node_ids);
471   max_mem_usage_.reserve(num_node_ids);
472   max_exec_time_.reserve(num_node_ids);
473   output_port_alloc_ids_.reserve(num_node_ids);
474 
475   AddNodesToCostModel(g, this);
476   AssignSizes(g, this);
477   EstimateComputationCosts(g, this);
478   CheckInitialized(g);
479 }
480 
AddToCostGraphDef(const Graph * graph,CostGraphDef * cost_graph) const481 void CostModel::AddToCostGraphDef(const Graph* graph,
482                                   CostGraphDef* cost_graph) const {
483   std::vector<const Edge*> inputs;
484   std::vector<const Edge*> control_inputs;
485   int offset = cost_graph->node_size();
486   for (const Node* n : graph->nodes()) {
487     CostGraphDef::Node* cnode = cost_graph->add_node();
488     cnode->set_name(n->name());
489     cnode->set_device(n->assigned_device_name());
490     cnode->set_id(GlobalId(n, offset));
491 
492     inputs.clear();
493     inputs.resize(n->num_inputs(), nullptr);
494     control_inputs.clear();
495     for (const Edge* e : n->in_edges()) {
496       if (e->IsControlEdge()) {
497         control_inputs.push_back(e);
498       } else {
499         inputs[e->dst_input()] = e;
500       }
501     }
502     std::sort(control_inputs.begin(), control_inputs.end(),
503               [this](Edge const* a, Edge const* b) {
504                 return Id(a->src()) < Id(b->src());
505               });
506 
507     for (const Edge* e : inputs) {
508       CostGraphDef::Node::InputInfo* input_info = cnode->add_input_info();
509       input_info->set_preceding_node(GlobalId(e->src(), offset));
510       input_info->set_preceding_port(e->src_output());
511     }
512 
513     for (int i = 0; i < n->num_outputs(); i++) {
514       CostGraphDef::Node::OutputInfo* output_info = cnode->add_output_info();
515       int64_t alloc_id = AllocationId(n, i);
516       int64_t alias_to_input = -1;
517       for (const Edge* e : inputs) {
518         int64_t input_alloc_id = AllocationId(e->src(), e->src_output());
519         if (input_alloc_id == alloc_id) {
520           alias_to_input = e->dst_input();
521           break;
522         }
523       }
524       output_info->set_alias_input_port(alias_to_input);
525       output_info->set_dtype(MaxMemoryType(n, i));
526       *output_info->mutable_shape() = MaxMemoryShape(n, i);
527       if (alias_to_input < 0 && IsPersistentTensor(n, alloc_id)) {
528         output_info->set_size(0);
529       } else {
530         output_info->set_size(MaxMemorySize(n, i).value());
531       }
532     }
533 
534     for (const Edge* e : control_inputs) {
535       cnode->add_control_input(GlobalId(e->src(), offset));
536     }
537 
538     cnode->set_temporary_memory_size(TempMemorySize(n).value());
539     cnode->set_persistent_memory_size(PersistentMemorySize(n).value());
540 
541     cnode->set_compute_cost(MaxExecutionTime(n).value());
542 
543     // For now we treat all send nodes as final.
544     // TODO(yuanbyu): Send nodes for fetches shouldn't be treated as final.
545     cnode->set_is_final(n->IsSend());
546   }
547 }
548 
WriteSummaryToLog() const549 void CostModel::WriteSummaryToLog() const {
550   LOG(INFO) << " min_count_=" << min_count_;
551   for (size_t i = 0; i < count_.size(); ++i) {
552     LOG(INFO) << "Node " << i << " count " << count_[i] << " total time "
553               << time_[i] << " avg time "
554               << (time_[i] / (std::max(1, count_[i])));
555   }
556 }
557 
MinTensorMemoryUsage(const TensorShapeProto & tensor_shape,const DataType & dtype)558 Bytes CostModel::MinTensorMemoryUsage(const TensorShapeProto& tensor_shape,
559                                       const DataType& dtype) {
560   if (tensor_shape.unknown_rank()) {
561     return Bytes(-1);
562   }
563 
564   size_t num_coefficients = 1;
565   for (const TensorShapeProto::Dim& dim : tensor_shape.dim()) {
566     // If the dimension is unknown, it has to be at least 1
567     num_coefficients *= std::max<size_t>(dim.size(), 1);
568   }
569   return Bytes(num_coefficients * DataTypeSize(dtype));
570 }
571 
572 }  // namespace tensorflow
573