xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/costs/op_level_cost_estimator.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 
2 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3 
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7 
8     http://www.apache.org/licenses/LICENSE-2.0
9 
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16 
17 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
18 
19 #include "absl/strings/match.h"
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/costs/op_context.h"
28 #include "tensorflow/core/grappler/costs/utils.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/util/overflow.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 
35 // TODO(dyoon): update op to Predict method map for TF ops with V2 or V3 suffix.
36 constexpr int kOpsPerMac = 2;
37 constexpr char kGuaranteeConst[] = "GuaranteeConst";
38 constexpr char kAddN[] = "AddN";
39 constexpr char kBitCast[] = "BitCast";
40 constexpr char kConcatV2[] = "ConcatV2";
41 constexpr char kConv2d[] = "Conv2D";
42 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
43 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
44 constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
45 constexpr char kDataFormatVecPermute[] = "DataFormatVecPermute";
46 constexpr char kDepthToSpace[] = "DepthToSpace";
47 constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
48 constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
49     "DepthwiseConv2dNativeBackpropFilter";
50 constexpr char kDepthwiseConv2dNativeBackpropInput[] =
51     "DepthwiseConv2dNativeBackpropInput";
52 constexpr char kMatMul[] = "MatMul";
53 constexpr char kXlaEinsum[] = "XlaEinsum";
54 constexpr char kEinsum[] = "Einsum";
55 constexpr char kExpandDims[] = "ExpandDims";
56 constexpr char kFill[] = "Fill";
57 constexpr char kSparseMatMul[] = "SparseMatMul";
58 constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
59 constexpr char kPlaceholder[] = "Placeholder";
60 constexpr char kIdentity[] = "Identity";
61 constexpr char kIdentityN[] = "IdentityN";
62 constexpr char kRefIdentity[] = "RefIdentity";
63 constexpr char kNoOp[] = "NoOp";
64 constexpr char kReshape[] = "Reshape";
65 constexpr char kSplit[] = "Split";
66 constexpr char kSqueeze[] = "Squeeze";
67 constexpr char kRecv[] = "_Recv";
68 constexpr char kSend[] = "_Send";
69 constexpr char kBatchMatMul[] = "BatchMatMul";
70 constexpr char kBatchMatMulV2[] = "BatchMatMulV2";
71 constexpr char kOneHot[] = "OneHot";
72 constexpr char kPack[] = "Pack";
73 constexpr char kRank[] = "Rank";
74 constexpr char kRange[] = "Range";
75 constexpr char kShape[] = "Shape";
76 constexpr char kShapeN[] = "ShapeN";
77 constexpr char kSize[] = "Size";
78 constexpr char kStopGradient[] = "StopGradient";
79 constexpr char kPreventGradient[] = "PreventGradient";
80 constexpr char kGather[] = "Gather";
81 constexpr char kGatherNd[] = "GatherNd";
82 constexpr char kGatherV2[] = "GatherV2";
83 constexpr char kScatterAdd[] = "ScatterAdd";
84 constexpr char kScatterDiv[] = "ScatterDiv";
85 constexpr char kScatterMax[] = "ScatterMax";
86 constexpr char kScatterMin[] = "ScatterMin";
87 constexpr char kScatterMul[] = "ScatterMul";
88 constexpr char kScatterSub[] = "ScatterSub";
89 constexpr char kScatterUpdate[] = "ScatterUpdate";
90 constexpr char kSlice[] = "Slice";
91 constexpr char kStridedSlice[] = "StridedSlice";
92 constexpr char kSpaceToDepth[] = "SpaceToDepth";
93 constexpr char kTranspose[] = "Transpose";
94 constexpr char kTile[] = "Tile";
95 constexpr char kMaxPool[] = "MaxPool";
96 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
97 constexpr char kAvgPool[] = "AvgPool";
98 constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
99 constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
100 constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
101 constexpr char kQuantizedMatMul[] = "QuantizedMatMul";
102 constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
103 constexpr char kUnpack[] = "Unpack";
104 constexpr char kSoftmax[] = "Softmax";
105 constexpr char kResizeBilinear[] = "ResizeBilinear";
106 constexpr char kCropAndResize[] = "CropAndResize";
107 // Dynamic control flow ops.
108 constexpr char kSwitch[] = "Switch";
109 constexpr char kMerge[] = "Merge";
110 constexpr char kEnter[] = "Enter";
111 constexpr char kExit[] = "Exit";
112 constexpr char kNextIteration[] = "NextIteration";
113 // Persistent ops.
114 constexpr char kConst[] = "Const";
115 constexpr char kVariable[] = "Variable";
116 constexpr char kVariableV2[] = "VariableV2";
117 constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
118 constexpr char kVarHandleOp[] = "VarHandleOp";
119 constexpr char kVarHandlesOp[] = "_VarHandlesOp";
120 constexpr char kReadVariableOp[] = "ReadVariableOp";
121 constexpr char kReadVariablesOp[] = "_ReadVariablesOp";
122 constexpr char kAssignVariableOp[] = "AssignVariableOp";
123 constexpr char kAssignAddVariableOp[] = "AssignAddVariableOp";
124 constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp";
125 
126 static const Costs::Duration kMinComputeTime(1);
127 static const int64_t kMinComputeOp = 1;
128 
129 namespace {
130 
GetDataFormat(const OpInfo & op_info)131 std::string GetDataFormat(const OpInfo& op_info) {
132   std::string data_format = "NHWC";  // Default format.
133   if (op_info.attr().find("data_format") != op_info.attr().end()) {
134     data_format = op_info.attr().at("data_format").s();
135   }
136   return data_format;
137 }
138 
GetFilterFormat(const OpInfo & op_info)139 std::string GetFilterFormat(const OpInfo& op_info) {
140   std::string filter_format = "HWIO";  // Default format.
141   if (op_info.attr().find("filter_format") != op_info.attr().end()) {
142     filter_format = op_info.attr().at("filter_format").s();
143   }
144   return filter_format;
145 }
146 
GetPadding(const OpInfo & op_info)147 Padding GetPadding(const OpInfo& op_info) {
148   if (op_info.attr().find("padding") != op_info.attr().end() &&
149       op_info.attr().at("padding").s() == "VALID") {
150     return Padding::VALID;
151   }
152   return Padding::SAME;  // Default padding.
153 }
154 
IsTraining(const OpInfo & op_info)155 bool IsTraining(const OpInfo& op_info) {
156   if (op_info.attr().find("is_training") != op_info.attr().end() &&
157       op_info.attr().at("is_training").b()) {
158     return true;
159   }
160   return false;
161 }
162 
163 // TODO(dyoon): support non-4D tensors in the cost functions of convolution
164 // related ops (Conv, Pool, BatchNorm, and their backprops) and the related
165 // helper functions.
GetStrides(const OpInfo & op_info)166 std::vector<int64_t> GetStrides(const OpInfo& op_info) {
167   if (op_info.attr().find("strides") != op_info.attr().end()) {
168     const auto strides = op_info.attr().at("strides").list().i();
169     DCHECK(strides.size() == 4)
170         << "Attr strides is not a length-4 vector: " << op_info.DebugString();
171     if (strides.size() != 4) return {1, 1, 1, 1};
172     return {strides[0], strides[1], strides[2], strides[3]};
173   }
174   return {1, 1, 1, 1};
175 }
176 
GetKernelSize(const OpInfo & op_info)177 std::vector<int64_t> GetKernelSize(const OpInfo& op_info) {
178   if (op_info.attr().find("ksize") != op_info.attr().end()) {
179     const auto ksize = op_info.attr().at("ksize").list().i();
180     DCHECK(ksize.size() == 4)
181         << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
182     if (ksize.size() != 4) return {1, 1, 1, 1};
183     return {ksize[0], ksize[1], ksize[2], ksize[3]};
184   }
185   // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
186   // {1, 1, 1, 1} in that case.
187   return {1, 1, 1, 1};
188 }
189 
GetOutputSize(const int64_t input,const int64_t filter,const int64_t stride,const Padding & padding)190 int64_t GetOutputSize(const int64_t input, const int64_t filter,
191                       const int64_t stride, const Padding& padding) {
192   // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
193   // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
194   if (padding == Padding::VALID) {
195     return (input - filter + stride) / stride;
196   } else {  // SAME.
197     return (input + stride - 1) / stride;
198   }
199 }
200 
201 // Return the output element count of a multi-input element-wise op considering
202 // broadcasting.
CwiseOutputElementCount(const OpInfo & op_info)203 int64_t CwiseOutputElementCount(const OpInfo& op_info) {
204   int max_rank = 1;
205   for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
206     max_rank = std::max(max_rank, input_properties.shape().dim_size());
207   }
208 
209   TensorShapeProto output_shape;
210   output_shape.mutable_dim()->Reserve(max_rank);
211   for (int i = 0; i < max_rank; ++i) {
212     output_shape.add_dim();
213   }
214 
215   // Expand the shape of the output to follow the numpy-style broadcast rule
216   // which matches each input starting with the trailing dimensions and working
217   // its way forward. To do this, iterate through each input shape's dimensions
218   // in reverse order, and potentially increase the corresponding output
219   // dimension.
220   for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
221     const TensorShapeProto& input_shape = input_properties.shape();
222     for (int i = input_shape.dim_size() - 1; i >= 0; --i) {
223       int output_shape_dim_index =
224           i + output_shape.dim_size() - input_shape.dim_size();
225       output_shape.mutable_dim(output_shape_dim_index)
226           ->set_size(std::max(output_shape.dim(output_shape_dim_index).size(),
227                               input_shape.dim(i).size()));
228     }
229   }
230 
231   int64_t count = 1;
232   for (int i = 0; i < output_shape.dim_size(); i++) {
233     count *= output_shape.dim(i).size();
234   }
235   return count;
236 }
237 
238 // Helper function for determining whether there are repeated indices in the
239 // input Einsum equation.
CheckRepeatedDimensions(const absl::string_view dim_str)240 bool CheckRepeatedDimensions(const absl::string_view dim_str) {
241   int str_size = dim_str.size();
242   for (int idx = 0; idx < str_size - 1; idx++) {
243     if (dim_str.find(dim_str[idx], idx + 1) != std::string::npos) {
244       return true;
245     }
246   }
247   return false;
248 }
249 
250 // Auxiliary function for determining whether OpLevelCostEstimator is compatible
251 // with a given Einsum.
IsEinsumCorrectlyFormed(const OpContext & einsum_context)252 bool IsEinsumCorrectlyFormed(const OpContext& einsum_context) {
253   const auto& op_info = einsum_context.op_info;
254 
255   auto it = op_info.attr().find("equation");
256   if (it == op_info.attr().end()) return false;
257   const absl::string_view equation = it->second.s();
258   std::vector<std::string> equation_split = absl::StrSplit(equation, "->");
259 
260   if (equation_split.empty()) {
261     LOG(WARNING) << "Einsum with malformed equation";
262     return false;
263   }
264   std::vector<absl::string_view> input_split =
265       absl::StrSplit(equation_split[0], ',');
266 
267   // The current model covers Einsum operations with two operands and a RHS
268   if (op_info.inputs_size() != 2 || equation_split.size() != 2) {
269     VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
270     return false;
271   }
272   const auto& a_input = op_info.inputs(0);
273   const auto& b_input = op_info.inputs(1);
274   absl::string_view rhs_str = equation_split[1];
275   absl::string_view a_input_str = input_split[0];
276   absl::string_view b_input_str = input_split[1];
277 
278   // Ellipsis are not currently supported
279   if (absl::StrContains(a_input_str, "...") ||
280       absl::StrContains(b_input_str, "...")) {
281     VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
282             << ", ellipsis not supported";
283     return false;
284   }
285 
286   constexpr int kMatrixRank = 2;
287 
288   bool a_input_shape_unknown = false;
289   bool b_input_shape_unknown = false;
290 
291   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
292       a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
293       &a_input_shape_unknown);
294   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
295       b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
296       &b_input_shape_unknown);
297 
298   if (a_input_str.size() != static_cast<size_t>(a_input_shape.dim_size()) ||
299       b_input_str.size() != static_cast<size_t>(b_input_shape.dim_size())) {
300     VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
301             << ", equation subscripts don't match tensor rank.";
302     return false;
303   }
304 
305   // Subscripts where axis appears more than once for a single input are not yet
306   // supported
307   if (CheckRepeatedDimensions(a_input_str) ||
308       CheckRepeatedDimensions(b_input_str) ||
309       CheckRepeatedDimensions(rhs_str)) {
310     VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
311             << ", Subscripts where axis appears more than once for a single "
312                "input are not yet supported";
313     return false;
314   }
315 
316   return true;
317 }
318 
319 }  // namespace
320 
321 // Return a minimum shape if the shape is unknown. If known, return the original
322 // shape.
MaybeGetMinimumShape(const TensorShapeProto & original_shape,int rank,bool * found_unknown_shapes)323 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
324                                       int rank, bool* found_unknown_shapes) {
325   auto shape = original_shape;
326   bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
327 
328   if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
329     *found_unknown_shapes = true;
330     VLOG(2) << "Use minimum shape because the rank is unknown.";
331     // The size of each dimension is at least 1, if unknown.
332     for (int i = shape.dim_size(); i < rank; i++) {
333       shape.add_dim()->set_size(1);
334     }
335   } else if (is_scalar) {
336     for (int i = 0; i < rank; i++) {
337       shape.add_dim()->set_size(1);
338     }
339   } else if (shape.dim_size() > rank) {
340     *found_unknown_shapes = true;
341     shape.clear_dim();
342     for (int i = 0; i < rank; i++) {
343       shape.add_dim()->set_size(original_shape.dim(i).size());
344     }
345   } else {
346     for (int i = 0; i < shape.dim_size(); i++) {
347       if (shape.dim(i).size() < 0) {
348         *found_unknown_shapes = true;
349         VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
350         // The size of each dimension is at least 1, if unknown.
351         shape.mutable_dim(i)->set_size(1);
352       }
353     }
354   }
355   return shape;
356 }
357 
OpLevelCostEstimator()358 OpLevelCostEstimator::OpLevelCostEstimator() {
359   // Syntactic sugar to build and return a lambda that takes an OpInfo and
360   // returns a cost.
361   typedef Status (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context,
362                                                    NodeCosts*) const;
363   auto wrap = [this](CostImpl impl)
364       -> std::function<Status(const OpContext&, NodeCosts*)> {
365     return [this, impl](const OpContext& op_context, NodeCosts* node_costs) {
366       return (this->*impl)(op_context, node_costs);
367     };
368   };
369 
370   device_cost_impl_.emplace(kConv2d,
371                             wrap(&OpLevelCostEstimator::PredictConv2D));
372   device_cost_impl_.emplace(
373       kConv2dBackpropFilter,
374       wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
375   device_cost_impl_.emplace(
376       kConv2dBackpropInput,
377       wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
378   device_cost_impl_.emplace(
379       kFusedConv2dBiasActivation,
380       wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation));
381   // reuse Conv2D for DepthwiseConv2dNative because the calculation is the
382   // same although the actual meaning of the parameters are different. See
383   // comments in PredictConv2D and related functions
384   device_cost_impl_.emplace(kDepthwiseConv2dNative,
385                             wrap(&OpLevelCostEstimator::PredictConv2D));
386   device_cost_impl_.emplace(
387       kDepthwiseConv2dNativeBackpropFilter,
388       wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
389   device_cost_impl_.emplace(
390       kDepthwiseConv2dNativeBackpropInput,
391       wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
392   device_cost_impl_.emplace(kMatMul,
393                             wrap(&OpLevelCostEstimator::PredictMatMul));
394   device_cost_impl_.emplace(kSparseMatMul,
395                             wrap(&OpLevelCostEstimator::PredictMatMul));
396   device_cost_impl_.emplace(
397       kSparseTensorDenseMatMul,
398       wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul));
399   device_cost_impl_.emplace(kBatchMatMul,
400                             wrap(&OpLevelCostEstimator::PredictBatchMatMul));
401   device_cost_impl_.emplace(kBatchMatMulV2,
402                             wrap(&OpLevelCostEstimator::PredictBatchMatMul));
403   device_cost_impl_.emplace(kQuantizedMatMul,
404                             wrap(&OpLevelCostEstimator::PredictMatMul));
405   device_cost_impl_.emplace(kQuantizedMatMulV2,
406                             wrap(&OpLevelCostEstimator::PredictMatMul));
407   device_cost_impl_.emplace(kXlaEinsum,
408                             wrap(&OpLevelCostEstimator::PredictEinsum));
409   device_cost_impl_.emplace(kEinsum,
410                             wrap(&OpLevelCostEstimator::PredictEinsum));
411 
412   device_cost_impl_.emplace(kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp));
413   device_cost_impl_.emplace(kGuaranteeConst,
414                             wrap(&OpLevelCostEstimator::PredictNoOp));
415 
416   device_cost_impl_.emplace(kGather,
417                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
418   device_cost_impl_.emplace(kGatherNd,
419                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
420   device_cost_impl_.emplace(kGatherV2,
421                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
422   device_cost_impl_.emplace(kScatterAdd,
423                             wrap(&OpLevelCostEstimator::PredictScatter));
424   device_cost_impl_.emplace(kScatterDiv,
425                             wrap(&OpLevelCostEstimator::PredictScatter));
426   device_cost_impl_.emplace(kScatterMax,
427                             wrap(&OpLevelCostEstimator::PredictScatter));
428   device_cost_impl_.emplace(kScatterMin,
429                             wrap(&OpLevelCostEstimator::PredictScatter));
430   device_cost_impl_.emplace(kScatterMul,
431                             wrap(&OpLevelCostEstimator::PredictScatter));
432   device_cost_impl_.emplace(kScatterSub,
433                             wrap(&OpLevelCostEstimator::PredictScatter));
434   device_cost_impl_.emplace(kScatterUpdate,
435                             wrap(&OpLevelCostEstimator::PredictScatter));
436 
437   device_cost_impl_.emplace(kSlice,
438                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
439   device_cost_impl_.emplace(kStridedSlice,
440                             wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
441 
442   device_cost_impl_.emplace(kPlaceholder,
443                             wrap(&OpLevelCostEstimator::PredictIdentity));
444   device_cost_impl_.emplace(kIdentity,
445                             wrap(&OpLevelCostEstimator::PredictIdentity));
446   device_cost_impl_.emplace(kIdentityN,
447                             wrap(&OpLevelCostEstimator::PredictIdentity));
448   device_cost_impl_.emplace(kRefIdentity,
449                             wrap(&OpLevelCostEstimator::PredictIdentity));
450   device_cost_impl_.emplace(kStopGradient,
451                             wrap(&OpLevelCostEstimator::PredictIdentity));
452   device_cost_impl_.emplace(kPreventGradient,
453                             wrap(&OpLevelCostEstimator::PredictIdentity));
454   device_cost_impl_.emplace(kReshape,
455                             wrap(&OpLevelCostEstimator::PredictIdentity));
456   device_cost_impl_.emplace(kRecv,
457                             wrap(&OpLevelCostEstimator::PredictIdentity));
458   device_cost_impl_.emplace(kSend,
459                             wrap(&OpLevelCostEstimator::PredictIdentity));
460   device_cost_impl_.emplace(kSwitch,
461                             wrap(&OpLevelCostEstimator::PredictIdentity));
462   device_cost_impl_.emplace(kMerge,
463                             wrap(&OpLevelCostEstimator::PredictIdentity));
464   device_cost_impl_.emplace(kEnter,
465                             wrap(&OpLevelCostEstimator::PredictIdentity));
466   device_cost_impl_.emplace(kExit,
467                             wrap(&OpLevelCostEstimator::PredictIdentity));
468   device_cost_impl_.emplace(kNextIteration,
469                             wrap(&OpLevelCostEstimator::PredictIdentity));
470   device_cost_impl_.emplace(kBitCast,
471                             wrap(&OpLevelCostEstimator::PredictIdentity));
472 
473   device_cost_impl_.emplace(kConcatV2,
474                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
475   device_cost_impl_.emplace(kDataFormatVecPermute,
476                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
477   device_cost_impl_.emplace(kDepthToSpace,
478                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
479   device_cost_impl_.emplace(kExpandDims,
480                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
481   device_cost_impl_.emplace(kFill,
482                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
483   device_cost_impl_.emplace(kOneHot,
484                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
485   device_cost_impl_.emplace(kPack,
486                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
487   device_cost_impl_.emplace(kRange,
488                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
489   device_cost_impl_.emplace(kSpaceToDepth,
490                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
491   device_cost_impl_.emplace(kSplit,
492                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
493   device_cost_impl_.emplace(kSqueeze,
494                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
495   device_cost_impl_.emplace(kTranspose,
496                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
497   device_cost_impl_.emplace(kTile,
498                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
499   device_cost_impl_.emplace(kUnpack,
500                             wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
501 
502   device_cost_impl_.emplace(kRank,
503                             wrap(&OpLevelCostEstimator::PredictMetadata));
504   device_cost_impl_.emplace(kShape,
505                             wrap(&OpLevelCostEstimator::PredictMetadata));
506   device_cost_impl_.emplace(kShapeN,
507                             wrap(&OpLevelCostEstimator::PredictMetadata));
508   device_cost_impl_.emplace(kSize,
509                             wrap(&OpLevelCostEstimator::PredictMetadata));
510   device_cost_impl_.emplace(kMaxPool,
511                             wrap(&OpLevelCostEstimator::PredictMaxPool));
512   device_cost_impl_.emplace(kMaxPoolGrad,
513                             wrap(&OpLevelCostEstimator::PredictMaxPoolGrad));
514   device_cost_impl_.emplace(kAvgPool,
515                             wrap(&OpLevelCostEstimator::PredictAvgPool));
516   device_cost_impl_.emplace(kAvgPoolGrad,
517                             wrap(&OpLevelCostEstimator::PredictAvgPoolGrad));
518   device_cost_impl_.emplace(kFusedBatchNorm,
519                             wrap(&OpLevelCostEstimator::PredictFusedBatchNorm));
520   device_cost_impl_.emplace(
521       kFusedBatchNormGrad,
522       wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad));
523   device_cost_impl_.emplace(kSoftmax,
524                             wrap(&OpLevelCostEstimator::PredictSoftmax));
525   device_cost_impl_.emplace(kResizeBilinear,
526                             wrap(&OpLevelCostEstimator::PredictResizeBilinear));
527   device_cost_impl_.emplace(kCropAndResize,
528                             wrap(&OpLevelCostEstimator::PredictCropAndResize));
529   device_cost_impl_.emplace(
530       kAssignVariableOp, wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
531   device_cost_impl_.emplace(
532       kAssignAddVariableOp,
533       wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
534   device_cost_impl_.emplace(
535       kAssignSubVariableOp,
536       wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
537   device_cost_impl_.emplace(kAddN, wrap(&OpLevelCostEstimator::PredictNaryOp));
538 
539   persistent_ops_ = {
540       kConst,       kVariable,       kVariableV2,   kAutoReloadVariable,
541       kVarHandleOp, kReadVariableOp, kVarHandlesOp, kReadVariablesOp};
542 
543 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
544 
545   // Quantize = apply min and max bounds, multiply by scale factor and round.
546   const int quantize_v2_cost =
547       EIGEN_COST(scalar_product_op<float>) + EIGEN_COST(scalar_max_op<float>) +
548       EIGEN_COST(scalar_min_op<float>) + EIGEN_COST(scalar_round_op<float>);
549   const int quantize_and_dequantize_v2_cost =
550       quantize_v2_cost + EIGEN_COST(scalar_product_op<float>);
551 
552   // Unary ops alphabetically sorted
553   elementwise_ops_.emplace("Acos", EIGEN_COST(scalar_acos_op<float>));
554   elementwise_ops_.emplace("All", EIGEN_COST(scalar_boolean_and_op));
555   elementwise_ops_.emplace("ArgMax", EIGEN_COST(scalar_max_op<float>));
556   elementwise_ops_.emplace("Asin", EIGEN_COST(scalar_asin_op<float>));
557   elementwise_ops_.emplace("Atan", EIGEN_COST(scalar_atan_op<float>));
558   elementwise_ops_.emplace("Atan2", EIGEN_COST(scalar_quotient_op<float>) +
559                                         EIGEN_COST(scalar_atan_op<float>));
560   // For now, we use Eigen cost model for float to int16 cast as an example
561   // case; Eigen cost model is zero when src and dst types are identical,
562   // and it uses AddCost (1) when different. We may implement a separate
563   // cost functions for cast ops, using the actual input and output types.
564   elementwise_ops_.emplace(
565       "Cast", Eigen::internal::functor_traits<
566                   Eigen::internal::scalar_cast_op<float, int16>>::Cost);
567   elementwise_ops_.emplace("Ceil", EIGEN_COST(scalar_ceil_op<float>));
568   elementwise_ops_.emplace("Cos", EIGEN_COST(scalar_cos_op<float>));
569   elementwise_ops_.emplace("Dequantize", EIGEN_COST(scalar_product_op<float>));
570   elementwise_ops_.emplace("Erf", 1);
571   elementwise_ops_.emplace("Erfc", 1);
572   elementwise_ops_.emplace("Exp", EIGEN_COST(scalar_exp_op<float>));
573   elementwise_ops_.emplace("Expm1", EIGEN_COST(scalar_expm1_op<float>));
574   elementwise_ops_.emplace("Floor", EIGEN_COST(scalar_floor_op<float>));
575   elementwise_ops_.emplace("Inv", EIGEN_COST(scalar_inverse_op<float>));
576   elementwise_ops_.emplace("InvGrad", 1);
577   elementwise_ops_.emplace("Lgamma", 1);
578   elementwise_ops_.emplace("Log", EIGEN_COST(scalar_log_op<float>));
579   elementwise_ops_.emplace("Log1p", EIGEN_COST(scalar_log1p_op<float>));
580   elementwise_ops_.emplace("Max", EIGEN_COST(scalar_max_op<float>));
581   elementwise_ops_.emplace("Min", EIGEN_COST(scalar_min_op<float>));
582   elementwise_ops_.emplace("Neg", EIGEN_COST(scalar_opposite_op<float>));
583   elementwise_ops_.emplace("Prod", EIGEN_COST(scalar_product_op<float>));
584   elementwise_ops_.emplace("QuantizeAndDequantizeV2",
585                            quantize_and_dequantize_v2_cost);
586   elementwise_ops_.emplace("QuantizeAndDequantizeV4",
587                            quantize_and_dequantize_v2_cost);
588   elementwise_ops_.emplace("QuantizedSigmoid",
589                            EIGEN_COST(scalar_logistic_op<float>));
590   elementwise_ops_.emplace("QuantizeV2", quantize_v2_cost);
591   elementwise_ops_.emplace("Reciprocal", EIGEN_COST(scalar_inverse_op<float>));
592   elementwise_ops_.emplace("Relu", EIGEN_COST(scalar_max_op<float>));
593   elementwise_ops_.emplace("Relu6", EIGEN_COST(scalar_max_op<float>));
594   elementwise_ops_.emplace("Rint", 1);
595   elementwise_ops_.emplace("Round", EIGEN_COST(scalar_round_op<float>));
596   elementwise_ops_.emplace("Rsqrt", EIGEN_COST(scalar_rsqrt_op<float>));
597   elementwise_ops_.emplace("Sigmoid", EIGEN_COST(scalar_logistic_op<float>));
598   elementwise_ops_.emplace("Sign", EIGEN_COST(scalar_sign_op<float>));
599   elementwise_ops_.emplace("Sin", EIGEN_COST(scalar_sin_op<float>));
600   elementwise_ops_.emplace("Sqrt", EIGEN_COST(scalar_sqrt_op<float>));
601   elementwise_ops_.emplace("Square", EIGEN_COST(scalar_square_op<float>));
602   elementwise_ops_.emplace("Sum", EIGEN_COST(scalar_sum_op<float>));
603   elementwise_ops_.emplace("Tan", EIGEN_COST(scalar_tan_op<float>));
604   elementwise_ops_.emplace("Tanh", EIGEN_COST(scalar_tanh_op<float>));
605   elementwise_ops_.emplace("TopKV2", EIGEN_COST(scalar_max_op<float>));
606   // Binary ops alphabetically sorted
607   elementwise_ops_.emplace("Add", EIGEN_COST(scalar_sum_op<float>));
608   elementwise_ops_.emplace("AddV2", EIGEN_COST(scalar_sum_op<float>));
609   elementwise_ops_.emplace("ApproximateEqual", 1);
610   elementwise_ops_.emplace("BiasAdd", EIGEN_COST(scalar_sum_op<float>));
611   elementwise_ops_.emplace("QuantizedBiasAdd",
612                            EIGEN_COST(scalar_sum_op<float>));
613   elementwise_ops_.emplace("Div", EIGEN_COST(scalar_quotient_op<float>));
614   elementwise_ops_.emplace("Equal", 1);
615   elementwise_ops_.emplace("FloorDiv", EIGEN_COST(scalar_quotient_op<float>));
616   elementwise_ops_.emplace("FloorMod", EIGEN_COST(scalar_mod_op<float>));
617   elementwise_ops_.emplace("Greater", 1);
618   elementwise_ops_.emplace("GreaterEqual", 1);
619   elementwise_ops_.emplace("Less", 1);
620   elementwise_ops_.emplace("LessEqual", 1);
621   elementwise_ops_.emplace("LogicalAnd", EIGEN_COST(scalar_boolean_and_op));
622   elementwise_ops_.emplace("LogicalNot", 1);
623   elementwise_ops_.emplace("LogicalOr", EIGEN_COST(scalar_boolean_or_op));
624   elementwise_ops_.emplace("Maximum", EIGEN_COST(scalar_max_op<float>));
625   elementwise_ops_.emplace("Minimum", EIGEN_COST(scalar_min_op<float>));
626   elementwise_ops_.emplace("Mod", EIGEN_COST(scalar_mod_op<float>));
627   elementwise_ops_.emplace("Mul", EIGEN_COST(scalar_product_op<float>));
628   elementwise_ops_.emplace("NotEqual", 1);
629   elementwise_ops_.emplace("QuantizedAdd", EIGEN_COST(scalar_sum_op<float>));
630   elementwise_ops_.emplace("QuantizedMul",
631                            EIGEN_COST(scalar_product_op<float>));
632   elementwise_ops_.emplace("RealDiv", EIGEN_COST(scalar_quotient_op<float>));
633   elementwise_ops_.emplace("ReluGrad", EIGEN_COST(scalar_max_op<float>));
634   elementwise_ops_.emplace("Select", EIGEN_COST(scalar_boolean_or_op));
635   elementwise_ops_.emplace("SelectV2", EIGEN_COST(scalar_boolean_or_op));
636   elementwise_ops_.emplace("SquaredDifference",
637                            EIGEN_COST(scalar_square_op<float>) +
638                                EIGEN_COST(scalar_difference_op<float>));
639   elementwise_ops_.emplace("Sub", EIGEN_COST(scalar_difference_op<float>));
640   elementwise_ops_.emplace("TruncateDiv",
641                            EIGEN_COST(scalar_quotient_op<float>));
642   elementwise_ops_.emplace("TruncateMod", EIGEN_COST(scalar_mod_op<float>));
643   elementwise_ops_.emplace("Where", 1);
644 
645 #undef EIGEN_COST
646 
647   // By default, use sum of memory_time and compute_time for execution_time.
648   compute_memory_overlap_ = false;
649 }
650 
PredictCosts(const OpContext & op_context) const651 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
652   Costs costs;
653   NodeCosts node_costs;
654   if (PredictNodeCosts(op_context, &node_costs).ok()) {
655     if (node_costs.has_costs) {
656       return node_costs.costs;
657     }
658     // Convert NodeCosts to Costs.
659     if (node_costs.minimum_cost_op) {
660       // Override to minimum cost; Note that some ops with minimum cost may have
661       // non-typical device (e.g., channel for _Send), which may fail with
662       // GetDeviceInfo(), called from PredictOpCountBasedCost(). Make sure we
663       // directly set minimum values to Costs here, not calling
664       // PredictOpCountBasedCost().
665       costs.compute_time = kMinComputeTime;
666       costs.execution_time = kMinComputeTime;
667       costs.memory_time = 0;
668       costs.intermediate_memory_time = 0;
669       costs.intermediate_memory_read_time = 0;
670       costs.intermediate_memory_write_time = 0;
671     } else {
672       // Convert NodeCosts to Costs.
673       costs = PredictOpCountBasedCost(
674           node_costs.num_compute_ops, node_costs.num_total_read_bytes(),
675           node_costs.num_total_write_bytes(), op_context.op_info);
676     }
677     VLOG(1) << "Operation " << op_context.op_info.op() << " takes "
678             << costs.execution_time.count() << " ns.";
679     // Copy additional stats from NodeCosts to Costs.
680     costs.max_memory = node_costs.max_memory;
681     costs.persistent_memory = node_costs.persistent_memory;
682     costs.temporary_memory = node_costs.temporary_memory;
683     costs.inaccurate = node_costs.inaccurate;
684     costs.num_ops_with_unknown_shapes =
685         node_costs.num_nodes_with_unknown_shapes;
686     costs.num_ops_total = node_costs.num_nodes;
687     return costs;
688   }
689   // Errors during node cost estimate.
690   LOG(WARNING) << "Error in PredictCost() for the op: "
691                << op_context.op_info.ShortDebugString();
692   costs = Costs::ZeroCosts(/*inaccurate=*/true);
693   costs.num_ops_with_unknown_shapes = node_costs.num_nodes_with_unknown_shapes;
694   return costs;
695 }
696 
PredictNodeCosts(const OpContext & op_context,NodeCosts * node_costs) const697 Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
698                                               NodeCosts* node_costs) const {
699   const auto& op_info = op_context.op_info;
700   auto it = device_cost_impl_.find(op_info.op());
701   if (it != device_cost_impl_.end()) {
702     std::function<Status(const OpContext&, NodeCosts*)> estimator = it->second;
703     return estimator(op_context, node_costs);
704   }
705 
706   if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
707     return PredictVariable(op_context, node_costs);
708   }
709 
710   if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
711     return PredictCwiseOp(op_context, node_costs);
712   }
713 
714   VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
715 
716   node_costs->num_nodes_with_unknown_op_type = 1;
717   return PredictCostOfAnUnknownOp(op_context, node_costs);
718 }
719 
720 // This method assumes a typical system composed of CPUs and GPUs, connected
721 // through PCIe. To define device info more precisely, override this method.
GetDeviceInfo(const DeviceProperties & device) const722 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
723     const DeviceProperties& device) const {
724   double gflops = -1;
725   double gb_per_sec = -1;
726 
727   if (device.type() == "CPU") {
728     // Check if vector instructions are available, and refine performance
729     // prediction based on this.
730     // Frequencies are stored in MHz in the DeviceProperties.
731     gflops = device.num_cores() * device.frequency() * 1e-3;
732     if (gb_per_sec < 0) {
733       if (device.bandwidth() > 0) {
734         gb_per_sec = device.bandwidth() / 1e6;
735       } else {
736         gb_per_sec = 32;
737       }
738     }
739   } else if (device.type() == "GPU") {
740     const auto& device_env = device.environment();
741     auto it = device_env.find("architecture");
742     if (it != device_env.end()) {
743       const std::string architecture = device_env.at("architecture");
744       int cores_per_multiprocessor;
745       if (architecture < "3") {
746         // Fermi
747         cores_per_multiprocessor = 32;
748       } else if (architecture < "4") {
749         // Kepler
750         cores_per_multiprocessor = 192;
751       } else if (architecture < "6") {
752         // Maxwell
753         cores_per_multiprocessor = 128;
754       } else {
755         // Pascal (compute capability version 6) and Volta (compute capability
756         // version 7)
757         cores_per_multiprocessor = 64;
758       }
759       gflops = device.num_cores() * device.frequency() * 1e-3 *
760                cores_per_multiprocessor * kOpsPerMac;
761       if (device.bandwidth() > 0) {
762         gb_per_sec = device.bandwidth() / 1e6;
763       } else {
764         gb_per_sec = 100;
765       }
766     } else {
767       // Architecture is not available (ex: pluggable device), return default
768       // value.
769       gflops = 100;     // Dummy value;
770       gb_per_sec = 12;  // default PCIe x16 gen3.
771     }
772   } else {
773     LOG_EVERY_N(WARNING, 1000) << "Unknown device type: " << device.type()
774                                << ", assuming PCIe between CPU and GPU.";
775     gflops = 1;  // Dummy value; data transfer ops would not have compute ops.
776     gb_per_sec = 12;  // default PCIe x16 gen3.
777   }
778   VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
779           << " gb_per_sec: " << gb_per_sec;
780 
781   return DeviceInfo(gflops, gb_per_sec);
782 }
783 
PredictCwiseOp(const OpContext & op_context,NodeCosts * node_costs) const784 Status OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context,
785                                             NodeCosts* node_costs) const {
786   const auto& op_info = op_context.op_info;
787   bool found_unknown_shapes = false;
788   // For element-wise operations, op count is the element count of any input. We
789   // use the count for the largest input here to be more robust in case that the
790   // shape is unknown or partially known for other input.
791   int64_t op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
792   // If output shape is available, try to use the element count calculated from
793   // that.
794   if (op_info.outputs_size() > 0) {
795     op_count = std::max(
796         op_count,
797         CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
798   }
799   // Calculate the output shape possibly resulting from broadcasting.
800   if (op_info.inputs_size() >= 2) {
801     op_count = std::max(op_count, CwiseOutputElementCount(op_info));
802   }
803 
804   int op_cost = 1;
805   auto it = elementwise_ops_.find(op_info.op());
806   if (it != elementwise_ops_.end()) {
807     op_cost = it->second;
808   } else {
809     return errors::InvalidArgument("Not a cwise op: ", op_info.op());
810   }
811 
812   return PredictDefaultNodeCosts(op_count * op_cost, op_context,
813                                  &found_unknown_shapes, node_costs);
814 }
815 
PredictCostOfAnUnknownOp(const OpContext & op_context,NodeCosts * node_costs) const816 Status OpLevelCostEstimator::PredictCostOfAnUnknownOp(
817     const OpContext& op_context, NodeCosts* node_costs) const {
818   // Don't assume the operation is cwise, return cost based on input/output size
819   // and admit that it is inaccurate...
820   bool found_unknown_shapes = false;
821   node_costs->inaccurate = true;
822   return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
823                                  node_costs);
824 }
825 
PredictOpCountBasedCost(double operations,const OpInfo & op_info) const826 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
827     double operations, const OpInfo& op_info) const {
828   bool unknown_shapes = false;
829   const double input_size = CalculateInputSize(op_info, &unknown_shapes);
830   const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
831   Costs costs =
832       PredictOpCountBasedCost(operations, input_size, output_size, op_info);
833   costs.inaccurate = unknown_shapes;
834   costs.num_ops_with_unknown_shapes = unknown_shapes;
835   costs.max_memory = output_size;
836   return costs;
837 }
838 
PredictOpCountBasedCost(double operations,double input_io_bytes,double output_io_bytes,const OpInfo & op_info) const839 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
840     double operations, double input_io_bytes, double output_io_bytes,
841     const OpInfo& op_info) const {
842   double total_io_bytes = input_io_bytes + output_io_bytes;
843   const DeviceInfo device_info = GetDeviceInfo(op_info.device());
844   if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
845       device_info.intermediate_read_gb_per_sec <= 0 ||
846       device_info.intermediate_write_gb_per_sec <= 0) {
847     VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
848             << " device type:" << op_info.device().type()
849             << " device model:" << op_info.device().model();
850   }
851 
852   Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
853   VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
854           << " Compute Time (ns):" << compute_cost.count();
855 
856   Costs::NanoSeconds memory_cost(
857       std::ceil(total_io_bytes / device_info.gb_per_sec));
858   VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
859           << " Memory Time (ns):" << memory_cost.count();
860 
861   // Check if bytes > 0.  If it's not and the bandwidth is set to infinity
862   // then the result would be undefined.
863   double intermediate_read_time =
864       (input_io_bytes > 0)
865           ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
866           : 0;
867 
868   double intermediate_write_time =
869       (output_io_bytes > 0)
870           ? std::ceil(output_io_bytes /
871                       device_info.intermediate_write_gb_per_sec)
872           : 0;
873 
874   Costs::NanoSeconds intermediate_memory_cost =
875       compute_memory_overlap_
876           ? std::max(intermediate_read_time, intermediate_write_time)
877           : (intermediate_read_time + intermediate_write_time);
878   VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
879           << " Intermediate Memory Time (ns):"
880           << intermediate_memory_cost.count();
881 
882   Costs costs = Costs::ZeroCosts();
883   costs.compute_time = compute_cost;
884   costs.memory_time = memory_cost;
885   costs.intermediate_memory_time = intermediate_memory_cost;
886   costs.intermediate_memory_read_time =
887       Costs::NanoSeconds(intermediate_read_time);
888   costs.intermediate_memory_write_time =
889       Costs::NanoSeconds(intermediate_write_time);
890   CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
891   return costs;
892 }
893 
CountConv2DOperations(const OpInfo & op_info,bool * found_unknown_shapes)894 int64_t OpLevelCostEstimator::CountConv2DOperations(
895     const OpInfo& op_info, bool* found_unknown_shapes) {
896   return CountConv2DOperations(op_info, nullptr, found_unknown_shapes);
897 }
898 
899 // Helper to translate the positional arguments into named fields.
900 /* static */
901 OpLevelCostEstimator::ConvolutionDimensions
ConvolutionDimensionsFromInputs(const TensorShapeProto & original_image_shape,const TensorShapeProto & original_filter_shape,const OpInfo & op_info,bool * found_unknown_shapes)902 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
903     const TensorShapeProto& original_image_shape,
904     const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
905     bool* found_unknown_shapes) {
906   VLOG(2) << "op features: " << op_info.DebugString();
907   VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
908   VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
909 
910   int x_index, y_index, major_channel_index, minor_channel_index = -1;
911   const std::string& data_format = GetDataFormat(op_info);
912   if (data_format == "NCHW") {
913     major_channel_index = 1;
914     y_index = 2;
915     x_index = 3;
916   } else if (data_format == "NCHW_VECT_C") {
917     // Use NCHW_VECT_C
918     minor_channel_index = 1;
919     y_index = 2;
920     x_index = 3;
921     major_channel_index = 4;
922   } else {
923     // Use NHWC.
924     y_index = 1;
925     x_index = 2;
926     major_channel_index = 3;
927   }
928   const std::string& filter_format = GetFilterFormat(op_info);
929   int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
930       in_minor_channel_index = -1;
931   if (filter_format == "HWIO") {
932     filter_y_index = 0;
933     filter_x_index = 1;
934     in_major_channel_index = 2;
935     out_channel_index = 3;
936   } else if (filter_format == "OIHW_VECT_I") {
937     out_channel_index = 0;
938     in_minor_channel_index = 1;
939     filter_y_index = 2;
940     filter_x_index = 3;
941     in_major_channel_index = 4;
942   } else {
943     // Use OIHW
944     out_channel_index = 0;
945     in_major_channel_index = 1;
946     filter_y_index = 2;
947     filter_x_index = 3;
948   }
949 
950   auto image_shape = MaybeGetMinimumShape(original_image_shape,
951                                           minor_channel_index >= 0 ? 5 : 4,
952                                           found_unknown_shapes);
953   auto filter_shape = MaybeGetMinimumShape(original_filter_shape,
954                                            in_minor_channel_index >= 0 ? 5 : 4,
955                                            found_unknown_shapes);
956   VLOG(2) << "Image shape: " << image_shape.DebugString();
957   VLOG(2) << "Filter shape: " << filter_shape.DebugString();
958 
959   int64_t batch = image_shape.dim(0).size();
960   int64_t ix = image_shape.dim(x_index).size();
961   int64_t iy = image_shape.dim(y_index).size();
962   int64_t iz = minor_channel_index >= 0
963                    ? image_shape.dim(minor_channel_index).size() *
964                          image_shape.dim(major_channel_index).size()
965                    : image_shape.dim(major_channel_index).size();
966   int64_t kx = filter_shape.dim(filter_x_index).size();
967   int64_t ky = filter_shape.dim(filter_y_index).size();
968   int64_t kz = in_minor_channel_index >= 0
969                    ? filter_shape.dim(in_major_channel_index).size() *
970                          filter_shape.dim(in_minor_channel_index).size()
971                    : filter_shape.dim(in_major_channel_index).size();
972   std::vector<int64_t> strides = GetStrides(op_info);
973   const auto padding = GetPadding(op_info);
974   int64_t sx = strides[x_index];
975   int64_t sy = strides[y_index];
976   int64_t ox = GetOutputSize(ix, kx, sx, padding);
977   int64_t oy = GetOutputSize(iy, ky, sy, padding);
978   int64_t oz = filter_shape.dim(out_channel_index).size();
979   // Only check equality when both sizes are known (in other words, when
980   // neither is set to a minimum dimension size of 1).
981   if (iz != 1 && kz != 1) {
982     DCHECK_EQ(iz % kz, 0) << "Input channel " << iz
983                           << " is not a multiple of filter channel " << kz
984                           << ".";
985     if (iz % kz) {
986       *found_unknown_shapes = true;
987     }
988   } else {
989     iz = kz = std::max<int64_t>(iz, kz);
990   }
991   OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
992       batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
993 
994   VLOG(1) << "Batch Size:" << batch;
995   VLOG(1) << "Image Dims:" << ix << "," << iy;
996   VLOG(1) << "Input Depth:" << iz;
997   VLOG(1) << "Kernel Dims:" << kx << "," << ky;
998   VLOG(1) << "Kernel Depth:" << kz;
999   VLOG(1) << "Output Dims:" << ox << "," << oy;
1000   VLOG(1) << "Output Depth:" << oz;
1001   VLOG(1) << "Strides:" << sx << "," << sy;
1002   VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
1003   return conv_dims;
1004 }
1005 
CountConv2DOperations(const OpInfo & op_info,ConvolutionDimensions * conv_info,bool * found_unknown_shapes)1006 int64_t OpLevelCostEstimator::CountConv2DOperations(
1007     const OpInfo& op_info, ConvolutionDimensions* conv_info,
1008     bool* found_unknown_shapes) {
1009   DCHECK(op_info.op() == kConv2d || op_info.op() == kDepthwiseConv2dNative)
1010       << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
1011 
1012   if (op_info.inputs_size() < 2) {  // Unexpected inputs.
1013     *found_unknown_shapes = true;
1014     return 0;
1015   }
1016 
1017   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1018       op_info.inputs(0).shape(), op_info.inputs(1).shape(), op_info,
1019       found_unknown_shapes);
1020 
1021   //  in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
1022   //  multiplier; The effective output channel depth oz_effective is
1023   //  conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
1024   //  Compare to Conv2D where # ops =  N x H x W x kz x oz x 2RS,
1025   //  oz = oz_effective,  then Conv2D_ops / Depthwise_conv2d_native_ops = kz.
1026   int64_t ops = conv_dims.batch;
1027   ops *= conv_dims.ox * conv_dims.oy;
1028   ops *= conv_dims.kx * conv_dims.ky;
1029   if (op_info.op() == kConv2d) {
1030     ops *= conv_dims.kz * conv_dims.oz;
1031   } else {
1032     // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
1033     // although ops are the same as Conv2D.
1034     conv_dims.oz *= conv_dims.iz;
1035     ops *= conv_dims.oz;
1036   }
1037   ops *= kOpsPerMac;
1038 
1039   if (conv_info != nullptr) {
1040     *conv_info = conv_dims;
1041   }
1042   return ops;
1043 }
1044 
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1045 int64_t OpLevelCostEstimator::CountMatMulOperations(
1046     const OpInfo& op_info, bool* found_unknown_shapes) {
1047   return CountMatMulOperations(op_info, nullptr, found_unknown_shapes);
1048 }
1049 
1050 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
CountMatMulOperations(const OpInfo & op_info,MatMulDimensions * mat_mul,bool * found_unknown_shapes)1051 int64_t OpLevelCostEstimator::CountMatMulOperations(
1052     const OpInfo& op_info, MatMulDimensions* mat_mul,
1053     bool* found_unknown_shapes) {
1054   double ops = 0;
1055 
1056   if (op_info.inputs_size() < 2) {
1057     LOG(ERROR) << "Need 2 inputs but got " << op_info.inputs_size();
1058     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1059     *found_unknown_shapes = true;
1060     return 0;
1061   }
1062 
1063   auto& a_matrix = op_info.inputs(0);
1064   auto& b_matrix = op_info.inputs(1);
1065 
1066   bool transpose_a = false;
1067   bool transpose_b = false;
1068 
1069   double m_dim, n_dim, k_dim, k_dim_b = 0;
1070 
1071   for (const auto& item : op_info.attr()) {
1072     VLOG(1) << "Key:" << item.first
1073             << " Value:" << SummarizeAttrValue(item.second);
1074     if (item.first == "transpose_a" && item.second.b() == true)
1075       transpose_a = true;
1076     if (item.first == "transpose_b" && item.second.b() == true)
1077       transpose_b = true;
1078   }
1079   VLOG(1) << "transpose_a:" << transpose_a;
1080   VLOG(1) << "transpose_b:" << transpose_b;
1081   auto a_matrix_shape =
1082       MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
1083   auto b_matrix_shape =
1084       MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
1085   if (transpose_a) {
1086     m_dim = a_matrix_shape.dim(1).size();
1087     k_dim = a_matrix_shape.dim(0).size();
1088   } else {
1089     m_dim = a_matrix_shape.dim(0).size();
1090     k_dim = a_matrix_shape.dim(1).size();
1091   }
1092   if (transpose_b) {
1093     k_dim_b = b_matrix_shape.dim(1).size();
1094     n_dim = b_matrix_shape.dim(0).size();
1095   } else {
1096     k_dim_b = b_matrix_shape.dim(0).size();
1097     n_dim = b_matrix_shape.dim(1).size();
1098   }
1099 
1100   VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
1101   // Only check equality when both sizes are known (in other words, when
1102   // neither is set to a minimum dimension size of 1).
1103   if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
1104     LOG(ERROR) << "Incompatible Matrix dimensions";
1105     return ops;
1106   } else {
1107     // One of k_dim and k_dim_b might be 1 (minimum dimension size).
1108     k_dim = std::max(k_dim, k_dim_b);
1109   }
1110 
1111   ops = m_dim * n_dim * k_dim * 2;
1112   VLOG(1) << "Operations for Matmul: " << ops;
1113 
1114   if (mat_mul != nullptr) {
1115     mat_mul->m = m_dim;
1116     mat_mul->n = n_dim;
1117     mat_mul->k = k_dim;
1118   }
1119   return ops;
1120 }
1121 
GenerateBatchMatmulContextFromEinsum(const OpContext & einsum_context,OpContext * batch_matmul_context,bool * found_unknown_shapes) const1122 bool OpLevelCostEstimator::GenerateBatchMatmulContextFromEinsum(
1123     const OpContext& einsum_context, OpContext* batch_matmul_context,
1124     bool* found_unknown_shapes) const {
1125   // This auxiliary function transforms an einsum OpContext into its equivalent
1126   // Batch Matmul OpContext. The function returns a boolean, which determines
1127   // whether it was successful in generating the output OpContext or not.
1128 
1129   // Einsum computes a generalized contraction between tensors of arbitrary
1130   // dimension as defined by the equation written in the Einstein summation
1131   // convention. The number of tensors in the computation and the number of
1132   // contractions can be arbitrarily long. The current model only contemplates
1133   // Einsum equations, which can be translated into a single BatchMatMul
1134   // operation. Einsum operations with more than two operands are not currently
1135   // supported. Subscripts where an axis appears more than once for a single
1136   // input and ellipsis are currently also excluded. See:
1137   // https://www.tensorflow.org/api_docs/python/tf/einsum
1138   // We distinguish four kinds of dimensions, depending on their placement in
1139   // the equation:
1140   // + B: Batch dimensions: Dimensions which appear in both operands and RHS.
1141   // + K: Contracting dimensions: These appear in both inputs but not RHS.
1142   // + M: Operand A dimensions: These appear in the first operand and the RHS.
1143   // + N: Operand B dimensions: These appear in the second operand and the RHS.
1144   // Then, the operation to estimate is BatchMatMul([B,M,K],[B,K,N])
1145 
1146   if (batch_matmul_context == nullptr) {
1147     VLOG(1) << "Output context should not be a nullptr.";
1148     return false;
1149   }
1150   if (!IsEinsumCorrectlyFormed(einsum_context)) return false;
1151   const auto& op_info = einsum_context.op_info;
1152   std::vector<std::string> equation_split =
1153       absl::StrSplit(op_info.attr().find("equation")->second.s(), "->");
1154   std::vector<absl::string_view> input_split =
1155       absl::StrSplit(equation_split[0], ',');
1156   const auto& a_input = op_info.inputs(0);
1157   const auto& b_input = op_info.inputs(1);
1158   absl::string_view rhs_str = equation_split[1];
1159   absl::string_view a_input_str = input_split[0];
1160   absl::string_view b_input_str = input_split[1];
1161 
1162   constexpr int kMatrixRank = 2;
1163 
1164   bool a_input_shape_unknown = false;
1165   bool b_input_shape_unknown = false;
1166 
1167   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1168       a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
1169       &a_input_shape_unknown);
1170   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1171       b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
1172       &b_input_shape_unknown);
1173 
1174   *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1175                           (a_input.shape().dim_size() < kMatrixRank) ||
1176                           (b_input.shape().dim_size() < kMatrixRank);
1177 
1178   OpInfo batch_matmul_op_info = op_info;
1179   batch_matmul_op_info.mutable_inputs()->Clear();
1180   batch_matmul_op_info.set_op("BatchMatMul");
1181 
1182   AttrValue transpose_attribute;
1183   transpose_attribute.set_b(false);
1184   (*batch_matmul_op_info.mutable_attr())["transpose_a"] = transpose_attribute;
1185   (*batch_matmul_op_info.mutable_attr())["transpose_b"] = transpose_attribute;
1186 
1187   OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs();
1188   TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1189   a_matrix->set_dtype(a_input.dtype());
1190 
1191   OpInfo::TensorProperties* b_matrix = batch_matmul_op_info.add_inputs();
1192   b_matrix->set_dtype(b_input.dtype());
1193   TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1194 
1195   TensorShapeProto_Dim m_dim;
1196   TensorShapeProto_Dim n_dim;
1197   TensorShapeProto_Dim k_dim;
1198 
1199   m_dim.set_size(1);
1200   n_dim.set_size(1);
1201   k_dim.set_size(1);
1202 
1203   for (int i_idx = 0, a_input_str_size = a_input_str.size();
1204        i_idx < a_input_str_size; ++i_idx) {
1205     if (b_input_str.find(a_input_str[i_idx]) == std::string::npos) {
1206       if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1207         VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1208         return false;
1209       }
1210 
1211       m_dim.set_size(m_dim.size() * a_input_shape.dim(i_idx).size());
1212       continue;
1213     } else if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1214       // The dimension does not appear in the RHS, therefore it is a contracting
1215       // dimension.
1216       k_dim.set_size(k_dim.size() * a_input_shape.dim(i_idx).size());
1217       continue;
1218     }
1219     // It appears in both input operands, therefore we place it as an outer
1220     // dimension for the Batch Matmul.
1221     *(a_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1222     *(b_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1223   }
1224   for (int i_idx = 0, b_input_str_size = b_input_str.size();
1225        i_idx < b_input_str_size; ++i_idx) {
1226     if (a_input_str.find(b_input_str[i_idx]) == std::string::npos) {
1227       if (rhs_str.find(b_input_str[i_idx]) == std::string::npos) {
1228         VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1229         return false;
1230       }
1231       n_dim.set_size(n_dim.size() * b_input_shape.dim(i_idx).size());
1232     }
1233   }
1234 
1235   // The two inner-most dimensions of the Batch Matmul are added.
1236   *(a_matrix_shape->add_dim()) = m_dim;
1237   *(a_matrix_shape->add_dim()) = k_dim;
1238   *(b_matrix_shape->add_dim()) = k_dim;
1239   *(b_matrix_shape->add_dim()) = n_dim;
1240 
1241   *batch_matmul_context = einsum_context;
1242   batch_matmul_context->op_info = batch_matmul_op_info;
1243   return true;
1244 }
1245 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1246 int64_t OpLevelCostEstimator::CountBatchMatMulOperations(
1247     const OpInfo& op_info, bool* found_unknown_shapes) {
1248   return CountBatchMatMulOperations(op_info, nullptr, found_unknown_shapes);
1249 }
1250 
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes)1251 int64_t OpLevelCostEstimator::CountBatchMatMulOperations(
1252     const OpInfo& op_info, BatchMatMulDimensions* batch_mat_mul,
1253     bool* found_unknown_shapes) {
1254   if (op_info.op() != kBatchMatMul && op_info.op() != kBatchMatMulV2) {
1255     LOG(ERROR) << "Invalid Operation: " << op_info.op();
1256     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1257     *found_unknown_shapes = true;
1258     return 0;
1259   }
1260   if (op_info.inputs_size() != 2) {
1261     LOG(ERROR) << "Expected 2 inputs but got " << op_info.inputs_size();
1262     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1263     *found_unknown_shapes = true;
1264     return 0;
1265   }
1266 
1267   double ops = 0;
1268   const auto& a_input = op_info.inputs(0);
1269   const auto& b_input = op_info.inputs(1);
1270 
1271   // BatchMatMul requires inputs of at least matrix shape (rank 2).
1272   // The two most minor dimensions of each input are matrices that
1273   // need to be multiplied together. The other dimensions determine
1274   // the number of such MatMuls.  For example, if the BatchMatMul has
1275   // inputs of shape:
1276   //   a_input_shape = [2, 3, 4, 5]
1277   //   b_input_shape = [2, 3, 5, 6]
1278   // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
1279   // in this BatchMatMul.
1280   const int matrix_rank = 2;
1281 
1282   bool a_input_shape_unknown = false;
1283   bool b_input_shape_unknown = false;
1284 
1285   TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1286       a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
1287       &a_input_shape_unknown);
1288   TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1289       b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
1290       &b_input_shape_unknown);
1291 
1292   *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1293                           (a_input.shape().dim_size() < matrix_rank) ||
1294                           (b_input.shape().dim_size() < matrix_rank);
1295 
1296   // Compute the number of matmuls as the max indicated at each dimension
1297   // by either input. Note that the shapes do not have to have
1298   // the same rank due to incompleteness.
1299   TensorShapeProto* bigger_rank_shape = &a_input_shape;
1300   TensorShapeProto* smaller_rank_shape = &b_input_shape;
1301   if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
1302     bigger_rank_shape = &b_input_shape;
1303     smaller_rank_shape = &a_input_shape;
1304   }
1305   int num_matmuls = 1;
1306   for (int b_i = 0,
1307            s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
1308        b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
1309     int b_dim = bigger_rank_shape->dim(b_i).size();
1310     int s_dim = 1;
1311     if (s_i >= 0) {
1312       s_dim = smaller_rank_shape->dim(s_i).size();
1313     }
1314     if (batch_mat_mul != nullptr) {
1315       batch_mat_mul->batch_dims.push_back(s_dim);
1316     }
1317     num_matmuls *= std::max(b_dim, s_dim);
1318   }
1319 
1320   // Build the MatMul. Note that values are ignored here since we are just
1321   // counting ops (e.g. only shapes matter).
1322   OpInfo matmul_op_info;
1323   matmul_op_info.set_op("MatMul");
1324 
1325   AttrValue transpose_a;
1326   transpose_a.set_b(false);
1327   if (op_info.attr().find("adj_x") != op_info.attr().end()) {
1328     transpose_a.set_b(op_info.attr().at("adj_x").b());
1329   }
1330   (*matmul_op_info.mutable_attr())["transpose_a"] = transpose_a;
1331 
1332   AttrValue transpose_b;
1333   transpose_b.set_b(false);
1334   if (op_info.attr().find("adj_y") != op_info.attr().end()) {
1335     transpose_b.set_b(op_info.attr().at("adj_y").b());
1336   }
1337   (*matmul_op_info.mutable_attr())["transpose_b"] = transpose_b;
1338 
1339   OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs();
1340   a_matrix->set_dtype(a_input.dtype());
1341   TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1342   for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
1343        i < a_input_shape.dim_size(); ++i) {
1344     *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
1345   }
1346 
1347   OpInfo::TensorProperties* b_matrix = matmul_op_info.add_inputs();
1348   b_matrix->set_dtype(b_input.dtype());
1349   TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1350   for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
1351        i < b_input_shape.dim_size(); ++i) {
1352     *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
1353   }
1354   if (batch_mat_mul != nullptr) {
1355     batch_mat_mul->matmul_dims.m = (transpose_a.b())
1356                                        ? a_matrix_shape->dim(1).size()
1357                                        : a_matrix_shape->dim(0).size();
1358     batch_mat_mul->matmul_dims.k = (transpose_a.b())
1359                                        ? a_matrix_shape->dim(0).size()
1360                                        : a_matrix_shape->dim(1).size();
1361     batch_mat_mul->matmul_dims.n = (transpose_b.b())
1362                                        ? b_matrix_shape->dim(0).size()
1363                                        : b_matrix_shape->dim(1).size();
1364   }
1365 
1366   for (int i = 0; i < num_matmuls; ++i) {
1367     bool matmul_unknown_shapes = false;
1368     ops += CountMatMulOperations(matmul_op_info, &matmul_unknown_shapes);
1369     *found_unknown_shapes |= matmul_unknown_shapes;
1370   }
1371   return ops;
1372 }
1373 
GetTensorShapeProtoFromTensorProto(const TensorProto & tensor_proto,TensorShapeProto * tensor_shape_proto)1374 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
1375                                         TensorShapeProto* tensor_shape_proto) {
1376   tensor_shape_proto->Clear();
1377   // First convert TensorProto into Tensor class so that it correctly parses
1378   // data values within TensorProto (whether it's in int_val, int64_val,
1379   // tensor_content, or anything.
1380   Tensor tensor(tensor_proto.dtype());
1381   if (!tensor.FromProto(tensor_proto)) {
1382     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1383                  << "failed to parse TensorProto: "
1384                  << tensor_proto.DebugString();
1385     return false;
1386   }
1387   if (tensor.dims() != 1) {
1388     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1389                  << "tensor is not 1D: " << tensor.dims();
1390     return false;
1391   }
1392   // Then, convert it back to TensorProto using AsProtoField, which makes sure
1393   // the data is in int_val, int64_val, or such repeated data fields, not in
1394   // tensor_content.
1395   TensorProto temp_tensor;
1396   tensor.AsProtoField(&temp_tensor);
1397 
1398 #define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type)        \
1399   do {                                                   \
1400     for (const auto& value : temp_tensor.type##_val()) { \
1401       tensor_shape_proto->add_dim()->set_size(value);    \
1402     }                                                    \
1403   } while (0)
1404 
1405   if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
1406       tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
1407     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
1408   } else if (tensor.dtype() == DT_INT64) {
1409     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
1410   } else if (tensor.dtype() == DT_UINT32) {
1411     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
1412   } else if (tensor.dtype() == DT_UINT64) {
1413     TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
1414   } else {
1415     LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1416                  << "Unsupported dtype: " << tensor.dtype();
1417     return false;
1418   }
1419 #undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
1420 
1421   return true;
1422 }
1423 
1424 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
CountConv2DBackpropInputOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1425 int64_t OpLevelCostEstimator::CountConv2DBackpropInputOperations(
1426     const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1427     bool* found_unknown_shapes) {
1428   int64_t ops = 0;
1429 
1430   DCHECK(op_info.op() == kConv2dBackpropInput ||
1431          op_info.op() == kDepthwiseConv2dNativeBackpropInput)
1432       << "Invalid Operation: not kConv2dBackpropInput nor"
1433          "kDepthwiseConv2dNativeBackpropInput";
1434 
1435   if (op_info.inputs_size() < 2) {
1436     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1437     *found_unknown_shapes = true;
1438     return ops;
1439   }
1440 
1441   TensorShapeProto input_shape;
1442   bool shape_found = false;
1443   if (op_info.inputs(0).has_value()) {
1444     const TensorProto& value = op_info.inputs(0).value();
1445     shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
1446   }
1447   if (!shape_found && op_info.outputs_size() == 1) {
1448     input_shape = op_info.outputs(0).shape();
1449     shape_found = true;
1450   }
1451   if (!shape_found) {
1452     // Set the minimum filter size that's feasible.
1453     input_shape.Clear();
1454     for (int i = 0; i < 4; ++i) {
1455       input_shape.add_dim()->set_size(1);
1456     }
1457     *found_unknown_shapes = true;
1458   }
1459 
1460   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1461       input_shape, op_info.inputs(1).shape(), op_info, found_unknown_shapes);
1462 
1463   ops = conv_dims.batch;
1464   ops *= conv_dims.ox * conv_dims.oy;
1465   ops *= conv_dims.kx * conv_dims.ky;
1466   if (op_info.op() == kConv2dBackpropInput) {
1467     ops *= conv_dims.kz * conv_dims.oz;
1468   } else {
1469     // conv_dims always use forward path definition regardless
1470     conv_dims.oz *= conv_dims.iz;
1471     ops *= conv_dims.oz;
1472   }
1473   ops *= kOpsPerMac;
1474 
1475   VLOG(1) << "Operations for" << op_info.op() << "  " << ops;
1476 
1477   if (returned_conv_dims != nullptr) {
1478     *returned_conv_dims = conv_dims;
1479   }
1480   return ops;
1481 }
1482 
CountConv2DBackpropFilterOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1483 int64_t OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
1484     const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1485     bool* found_unknown_shapes) {
1486   int64_t ops = 0;
1487 
1488   DCHECK(op_info.op() == kConv2dBackpropFilter ||
1489          op_info.op() == kDepthwiseConv2dNativeBackpropFilter)
1490       << "Invalid Operation: not kConv2dBackpropFilter nor"
1491          "kDepthwiseConv2dNativeBackpropFilter";
1492 
1493   TensorShapeProto filter_shape;
1494   bool shape_found = false;
1495   if (op_info.inputs_size() >= 2 && op_info.inputs(1).has_value()) {
1496     const TensorProto& value = op_info.inputs(1).value();
1497     shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
1498   }
1499   if (!shape_found && op_info.outputs_size() == 1) {
1500     filter_shape = op_info.outputs(0).shape();
1501     shape_found = true;
1502   }
1503   if (!shape_found) {
1504     // Set the minimum filter size that's feasible.
1505     filter_shape.Clear();
1506     for (int i = 0; i < 4; ++i) {
1507       filter_shape.add_dim()->set_size(1);
1508     }
1509     *found_unknown_shapes = true;
1510   }
1511 
1512   if (op_info.inputs_size() < 1) {
1513     // TODO(pcma): Try to separate invalid inputs from unknown shapes
1514     *found_unknown_shapes = true;
1515     return ops;
1516   }
1517   ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1518       op_info.inputs(0).shape(), filter_shape, op_info, found_unknown_shapes);
1519 
1520   ops = conv_dims.batch;
1521   ops *= conv_dims.ox * conv_dims.oy;
1522   ops *= conv_dims.kx * conv_dims.ky;
1523   if (op_info.op() == kConv2dBackpropFilter) {
1524     ops *= conv_dims.kz * conv_dims.oz;
1525   } else {
1526     // conv_dims always use forward path definition regardless
1527     conv_dims.oz *= conv_dims.iz;
1528     ops *= conv_dims.oz;
1529   }
1530   ops *= kOpsPerMac;
1531   VLOG(1) << "Operations for" << op_info.op() << "  " << ops;
1532 
1533   if (returned_conv_dims != nullptr) {
1534     *returned_conv_dims = conv_dims;
1535   }
1536   return ops;
1537 }
1538 
CalculateTensorElementCount(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1539 int64_t OpLevelCostEstimator::CalculateTensorElementCount(
1540     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1541   VLOG(2) << "   with " << DataTypeString(tensor.dtype()) << " tensor of shape "
1542           << tensor.shape().DebugString();
1543   int64_t tensor_size = 1;
1544   int num_dims = std::max(1, tensor.shape().dim_size());
1545   auto tensor_shape =
1546       MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
1547   for (const auto& dim : tensor_shape.dim()) {
1548     int64_t new_tensor_size = MultiplyWithoutOverflow(tensor_size, dim.size());
1549     if (new_tensor_size < 0) {
1550       VLOG(1) << "Overflow encountered when computing element count of a "
1551                  "tensor, multiplying "
1552               << tensor_size << " with " << dim.size();
1553       return -1;
1554     }
1555     tensor_size = new_tensor_size;
1556   }
1557   return tensor_size;
1558 }
1559 
CalculateTensorSize(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1560 int64_t OpLevelCostEstimator::CalculateTensorSize(
1561     const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1562   int64_t count = CalculateTensorElementCount(tensor, found_unknown_shapes);
1563   int size = DataTypeSize(BaseType(tensor.dtype()));
1564   VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
1565   int64_t tensor_size = MultiplyWithoutOverflow(count, size);
1566   if (tensor_size < 0) {
1567     VLOG(1) << "Overflow encountered when computing tensor size, multiplying "
1568             << count << " with " << size;
1569     return -1;
1570   }
1571   return tensor_size;
1572 }
1573 
CalculateInputSize(const OpInfo & op_info,bool * found_unknown_shapes)1574 int64_t OpLevelCostEstimator::CalculateInputSize(const OpInfo& op_info,
1575                                                  bool* found_unknown_shapes) {
1576   int64_t total_input_size = 0;
1577   for (auto& input : op_info.inputs()) {
1578     int64_t input_size = CalculateTensorSize(input, found_unknown_shapes);
1579     total_input_size += input_size;
1580     VLOG(1) << "Input Size: " << input_size
1581             << " Total Input Size:" << total_input_size;
1582   }
1583   return total_input_size;
1584 }
1585 
CalculateInputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1586 std::vector<int64_t> OpLevelCostEstimator::CalculateInputTensorSize(
1587     const OpInfo& op_info, bool* found_unknown_shapes) {
1588   std::vector<int64_t> input_tensor_size;
1589   input_tensor_size.reserve(op_info.inputs().size());
1590   for (auto& input : op_info.inputs()) {
1591     input_tensor_size.push_back(
1592         CalculateTensorSize(input, found_unknown_shapes));
1593   }
1594   return input_tensor_size;
1595 }
1596 
CalculateLargestInputCount(const OpInfo & op_info,bool * found_unknown_shapes)1597 int64_t OpLevelCostEstimator::CalculateLargestInputCount(
1598     const OpInfo& op_info, bool* found_unknown_shapes) {
1599   int64_t largest_input_count = 0;
1600   for (auto& input : op_info.inputs()) {
1601     int64_t input_count =
1602         CalculateTensorElementCount(input, found_unknown_shapes);
1603     if (input_count > largest_input_count) {
1604       largest_input_count = input_count;
1605     }
1606     VLOG(1) << "Input Count: " << input_count
1607             << " Largest Input Count:" << largest_input_count;
1608   }
1609   return largest_input_count;
1610 }
1611 
CalculateOutputSize(const OpInfo & op_info,bool * found_unknown_shapes)1612 int64_t OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info,
1613                                                   bool* found_unknown_shapes) {
1614   int64_t total_output_size = 0;
1615   // Use float as default for calculations.
1616   for (const auto& output : op_info.outputs()) {
1617     DataType dt = output.dtype();
1618     const auto& original_output_shape = output.shape();
1619     int64_t output_size = DataTypeSize(BaseType(dt));
1620     int num_dims = std::max(1, original_output_shape.dim_size());
1621     auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1622                                              found_unknown_shapes);
1623     for (const auto& dim : output_shape.dim()) {
1624       int64_t new_output_size =
1625           MultiplyWithoutOverflow(output_size, dim.size());
1626       if (new_output_size < 0) {
1627         VLOG(1) << "Overflow encountered when estimating cost, multiplying "
1628                 << output_size << " with " << dim.size();
1629         return -1;
1630       }
1631       output_size = new_output_size;
1632     }
1633     total_output_size += output_size;
1634     VLOG(1) << "Output Size: " << output_size
1635             << " Total Output Size:" << total_output_size;
1636   }
1637   return total_output_size;
1638 }
1639 
CalculateOutputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1640 std::vector<int64_t> OpLevelCostEstimator::CalculateOutputTensorSize(
1641     const OpInfo& op_info, bool* found_unknown_shapes) {
1642   std::vector<int64_t> output_tensor_size;
1643   output_tensor_size.reserve(op_info.outputs().size());
1644   // Use float as default for calculations.
1645   for (const auto& output : op_info.outputs()) {
1646     DataType dt = output.dtype();
1647     const auto& original_output_shape = output.shape();
1648     int64_t output_size = DataTypeSize(BaseType(dt));
1649     int num_dims = std::max(1, original_output_shape.dim_size());
1650     auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1651                                              found_unknown_shapes);
1652     for (const auto& dim : output_shape.dim()) {
1653       output_size *= dim.size();
1654     }
1655     output_tensor_size.push_back(output_size);
1656   }
1657   return output_tensor_size;
1658 }
1659 
PredictDefaultNodeCosts(const int64_t num_compute_ops,const OpContext & op_context,bool * found_unknown_shapes,NodeCosts * node_costs)1660 Status OpLevelCostEstimator::PredictDefaultNodeCosts(
1661     const int64_t num_compute_ops, const OpContext& op_context,
1662     bool* found_unknown_shapes, NodeCosts* node_costs) {
1663   const auto& op_info = op_context.op_info;
1664   node_costs->num_compute_ops = num_compute_ops;
1665   node_costs->num_input_bytes_accessed =
1666       CalculateInputTensorSize(op_info, found_unknown_shapes);
1667   node_costs->num_output_bytes_accessed =
1668       CalculateOutputTensorSize(op_info, found_unknown_shapes);
1669   node_costs->max_memory = node_costs->num_total_output_bytes();
1670   if (*found_unknown_shapes) {
1671     node_costs->inaccurate = true;
1672     node_costs->num_nodes_with_unknown_shapes = 1;
1673   }
1674   return OkStatus();
1675 }
1676 
HasZeroDim(const OpInfo & op_info)1677 bool HasZeroDim(const OpInfo& op_info) {
1678   for (int i = 0; i < op_info.inputs_size(); ++i) {
1679     const auto& input = op_info.inputs(i);
1680     for (int j = 0; j < input.shape().dim_size(); ++j) {
1681       const auto& dim = input.shape().dim(j);
1682       if (dim.size() == 0) {
1683         VLOG(1) << "Convolution config has zero dim "
1684                 << op_info.ShortDebugString();
1685         return true;
1686       }
1687     }
1688   }
1689   return false;
1690 }
1691 
PredictConv2D(const OpContext & op_context,NodeCosts * node_costs) const1692 Status OpLevelCostEstimator::PredictConv2D(const OpContext& op_context,
1693                                            NodeCosts* node_costs) const {
1694   const auto& op_info = op_context.op_info;
1695   if (HasZeroDim(op_info)) {
1696     node_costs->num_nodes_with_unknown_shapes = 1;
1697     return errors::InvalidArgument("Conv2D op includes zero dimension: ",
1698                                    op_info.ShortDebugString());
1699   }
1700   bool found_unknown_shapes = false;
1701   int64_t num_compute_ops =
1702       CountConv2DOperations(op_info, &found_unknown_shapes);
1703   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1704                                  &found_unknown_shapes, node_costs);
1705 }
1706 
PredictConv2DBackpropInput(const OpContext & op_context,NodeCosts * node_costs) const1707 Status OpLevelCostEstimator::PredictConv2DBackpropInput(
1708     const OpContext& op_context, NodeCosts* node_costs) const {
1709   const auto& op_info = op_context.op_info;
1710   if (HasZeroDim(op_info)) {
1711     node_costs->num_nodes_with_unknown_shapes = 1;
1712     return errors::InvalidArgument(
1713         "Conv2DBackpropInput op includes zero dimension",
1714         op_info.ShortDebugString());
1715   }
1716   bool found_unknown_shapes = false;
1717   int64_t num_compute_ops = CountConv2DBackpropInputOperations(
1718       op_info, nullptr, &found_unknown_shapes);
1719   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1720                                  &found_unknown_shapes, node_costs);
1721 }
1722 
PredictConv2DBackpropFilter(const OpContext & op_context,NodeCosts * node_costs) const1723 Status OpLevelCostEstimator::PredictConv2DBackpropFilter(
1724     const OpContext& op_context, NodeCosts* node_costs) const {
1725   const auto& op_info = op_context.op_info;
1726   if (HasZeroDim(op_info)) {
1727     node_costs->num_nodes_with_unknown_shapes = 1;
1728     return errors::InvalidArgument(
1729         "Conv2DBackpropFilter op includes zero dimension",
1730         op_info.ShortDebugString());
1731   }
1732   bool found_unknown_shapes = false;
1733   int64_t num_compute_ops = CountConv2DBackpropFilterOperations(
1734       op_info, nullptr, &found_unknown_shapes);
1735   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1736                                  &found_unknown_shapes, node_costs);
1737 }
1738 
PredictFusedConv2DBiasActivation(const OpContext & op_context,NodeCosts * node_costs) const1739 Status OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
1740     const OpContext& op_context, NodeCosts* node_costs) const {
1741   // FusedConv2DBiasActivation computes a fused kernel which implements:
1742   // 2D convolution, adds side input with separate scaling on convolution and
1743   // side inputs, then adds bias, and finally applies the ReLU activation
1744   // function to the result:
1745   //
1746   // Input -> Conv2D  ->  Add  -> BiasAdd  -> ReLU
1747   //            ^          ^         ^
1748   //          Filter   Side Input   Bias
1749   //
1750   // Note that when adding the side input, the operation multiplies the output
1751   // of Conv2D by conv_input_scale, confusingly, and the side_input by
1752   // side_input_scale.
1753   //
1754   // Note that in the special case that side_input_scale is 0, which we infer
1755   // from side_input having dimensions [], we skip that addition operation.
1756   //
1757   // For more information, see
1758   // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
1759 
1760   // TODO(yaozhang): Support NHWC_VECT_W.
1761   std::string data_format = GetDataFormat(op_context.op_info);
1762   if (data_format != "NCHW" && data_format != "NHWC" &&
1763       data_format != "NCHW_VECT_C") {
1764     return errors::InvalidArgument(
1765         "Unsupported data format (", data_format,
1766         ") for op: ", op_context.op_info.ShortDebugString());
1767   }
1768   std::string filter_format = GetFilterFormat(op_context.op_info);
1769   if (filter_format != "HWIO" && filter_format != "OIHW" &&
1770       filter_format != "OIHW_VECT_I") {
1771     return errors::InvalidArgument(
1772         "Unsupported filter format (", filter_format,
1773         ") for op: ", op_context.op_info.ShortDebugString());
1774   }
1775 
1776   auto& conv_input = op_context.op_info.inputs(0);
1777   auto& filter = op_context.op_info.inputs(1);
1778   auto& side_input = op_context.op_info.inputs(3);
1779   auto& conv_input_scale = op_context.op_info.inputs(4);
1780   auto& side_input_scale = op_context.op_info.inputs(5);
1781 
1782   // Manually compute our convolution dimensions.
1783   bool found_unknown_shapes = false;
1784   auto dims = ConvolutionDimensionsFromInputs(
1785       conv_input.shape(), filter.shape(), op_context.op_info,
1786       &found_unknown_shapes);
1787   OpInfo::TensorProperties output;
1788   if (data_format == "NCHW" || data_format == "NCHW_VECT_C") {
1789     output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox});
1790   } else if (data_format == "NHWC") {
1791     output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz});
1792   }
1793 
1794   // Add the operations the fused op always computes.
1795   std::vector<OpContext> component_ops = {
1796       FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
1797       FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
1798       FusedChildContext(
1799           op_context, "BiasAdd", output,
1800           {output, output}),  // Note we're no longer using bias at all
1801       FusedChildContext(op_context, "Relu", output, {output})};
1802 
1803   // Add our side_input iff it's non-empty.
1804   if (side_input.shape().dim_size() > 0) {
1805     component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
1806                                               {side_input, side_input_scale}));
1807     component_ops.push_back(FusedChildContext(
1808         op_context, "Add", output,
1809         {output, output}));  // Note that we're not using side_input here
1810   }
1811 
1812   // Construct an op_context which definitely has our output shape.
1813   auto op_context_with_output = op_context;
1814   op_context_with_output.op_info.mutable_outputs()->Clear();
1815   *op_context_with_output.op_info.mutable_outputs()->Add() = output;
1816 
1817   // Construct component operations and run the cost computation.
1818   if (found_unknown_shapes) {
1819     node_costs->inaccurate = true;
1820     node_costs->num_nodes_with_unknown_shapes = 1;
1821   }
1822   return PredictFusedOp(op_context_with_output, component_ops, node_costs);
1823 }
1824 
PredictMatMul(const OpContext & op_context,NodeCosts * node_costs) const1825 Status OpLevelCostEstimator::PredictMatMul(const OpContext& op_context,
1826                                            NodeCosts* node_costs) const {
1827   const auto& op_info = op_context.op_info;
1828   bool found_unknown_shapes = false;
1829   int64_t num_compute_ops =
1830       CountMatMulOperations(op_info, &found_unknown_shapes);
1831   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1832                                  &found_unknown_shapes, node_costs);
1833 }
1834 
PredictEinsum(const OpContext & op_context,NodeCosts * node_costs) const1835 Status OpLevelCostEstimator::PredictEinsum(const OpContext& op_context,
1836                                            NodeCosts* node_costs) const {
1837   const auto& op_info = op_context.op_info;
1838 
1839   auto it = op_info.attr().find("equation");
1840   if (it == op_info.attr().end()) {
1841     return errors::InvalidArgument("Einsum op doesn't have equation attr: ",
1842                                    op_info.ShortDebugString());
1843   }
1844 
1845   OpContext batch_matmul_op_context;
1846   bool found_unknown_shapes = false;
1847   bool success = GenerateBatchMatmulContextFromEinsum(
1848       op_context, &batch_matmul_op_context, &found_unknown_shapes);
1849   if (found_unknown_shapes) {
1850     node_costs->inaccurate = true;
1851     node_costs->num_nodes_with_unknown_shapes = 1;
1852   }
1853   if (!success) {
1854     return PredictCostOfAnUnknownOp(op_context, node_costs);
1855   }
1856   return PredictNodeCosts(batch_matmul_op_context, node_costs);
1857 }
1858 
PredictSparseTensorDenseMatMul(const OpContext & op_context,NodeCosts * node_costs) const1859 Status OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
1860     const OpContext& op_context, NodeCosts* node_costs) const {
1861   const auto& op_info = op_context.op_info;
1862   bool found_unknown_shapes = false;
1863   // input[0]: indices in sparse matrix a
1864   // input[1]: values in sparse matrix a
1865   // input[2]: shape of matrix a
1866   // input[3]: matrix b
1867   // See
1868   // https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
1869   int64_t num_elems_in_a =
1870       CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1871   auto b_matrix = op_info.inputs(3);
1872   auto b_matrix_shape =
1873       MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
1874   int64_t n_dim = b_matrix_shape.dim(1).size();
1875 
1876   // Each element in A is multiplied and added with an element from each column
1877   // in b.
1878   const int64_t op_count = kOpsPerMac * num_elems_in_a * n_dim;
1879 
1880   int64_t a_indices_input_size =
1881       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1882   int64_t a_values_input_size =
1883       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1884   int64_t a_shape_input_size =
1885       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1886   int64_t b_input_size =
1887       num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
1888   int64_t output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
1889 
1890   node_costs->num_compute_ops = op_count;
1891   node_costs->num_input_bytes_accessed = {a_indices_input_size,
1892                                           a_values_input_size,
1893                                           a_shape_input_size, b_input_size};
1894   node_costs->num_output_bytes_accessed = {output_size};
1895   if (found_unknown_shapes) {
1896     node_costs->inaccurate = true;
1897     node_costs->num_nodes_with_unknown_shapes = 1;
1898   }
1899   return OkStatus();
1900 }
1901 
PredictNoOp(const OpContext & op_context,NodeCosts * node_costs) const1902 Status OpLevelCostEstimator::PredictNoOp(const OpContext& op_context,
1903                                          NodeCosts* node_costs) const {
1904   const auto& op_info = op_context.op_info;
1905   VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1906   // By default, NodeCosts is initialized to zero ops and bytes.
1907   return OkStatus();
1908 }
1909 
PredictPureMemoryOp(const OpContext & op_context,NodeCosts * node_costs) const1910 Status OpLevelCostEstimator::PredictPureMemoryOp(const OpContext& op_context,
1911                                                  NodeCosts* node_costs) const {
1912   // Each output element is a copy of some element from input, with no required
1913   // computation, so just compute memory costs.
1914   bool found_unknown_shapes = false;
1915   node_costs->num_nodes_with_pure_memory_op = 1;
1916   return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
1917                                  node_costs);
1918 }
1919 
PredictIdentity(const OpContext & op_context,NodeCosts * node_costs) const1920 Status OpLevelCostEstimator::PredictIdentity(const OpContext& op_context,
1921                                              NodeCosts* node_costs) const {
1922   const auto& op_info = op_context.op_info;
1923   VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Identity";
1924   node_costs->minimum_cost_op = true;
1925   node_costs->num_compute_ops = kMinComputeOp;
1926   // Identity op internally pass input tensor buffer's pointer to the output
1927   // tensor buffer; no actual memory operation.
1928   node_costs->num_input_bytes_accessed = {0};
1929   node_costs->num_output_bytes_accessed = {0};
1930   bool inaccurate = false;
1931   node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1932   if (inaccurate) {
1933     node_costs->inaccurate = true;
1934     node_costs->num_nodes_with_unknown_shapes = 1;
1935   }
1936   return OkStatus();
1937 }
1938 
PredictVariable(const OpContext & op_context,NodeCosts * node_costs) const1939 Status OpLevelCostEstimator::PredictVariable(const OpContext& op_context,
1940                                              NodeCosts* node_costs) const {
1941   const auto& op_info = op_context.op_info;
1942   VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Variable";
1943   node_costs->minimum_cost_op = true;
1944   node_costs->num_compute_ops = kMinComputeOp;
1945   // Variables are persistent ops; initialized before step; hence, no memory
1946   // cost.
1947   node_costs->num_input_bytes_accessed = {0};
1948   node_costs->num_output_bytes_accessed = {0};
1949   bool inaccurate = false;
1950   node_costs->persistent_memory = CalculateOutputSize(op_info, &inaccurate);
1951   if (inaccurate) {
1952     node_costs->inaccurate = true;
1953     node_costs->num_nodes_with_unknown_shapes = 1;
1954   }
1955   return OkStatus();
1956 }
1957 
PredictBatchMatMul(const OpContext & op_context,NodeCosts * node_costs) const1958 Status OpLevelCostEstimator::PredictBatchMatMul(const OpContext& op_context,
1959                                                 NodeCosts* node_costs) const {
1960   const auto& op_info = op_context.op_info;
1961   bool found_unknown_shapes = false;
1962   int64_t num_compute_ops =
1963       CountBatchMatMulOperations(op_info, &found_unknown_shapes);
1964   return PredictDefaultNodeCosts(num_compute_ops, op_context,
1965                                  &found_unknown_shapes, node_costs);
1966 }
1967 
PredictMetadata(const OpContext & op_context,NodeCosts * node_costs) const1968 Status OpLevelCostEstimator::PredictMetadata(const OpContext& op_context,
1969                                              NodeCosts* node_costs) const {
1970   const auto& op_info = op_context.op_info;
1971   node_costs->minimum_cost_op = true;
1972   node_costs->num_compute_ops = kMinComputeOp;
1973   node_costs->num_input_bytes_accessed = {0};
1974   node_costs->num_output_bytes_accessed = {0};
1975   bool inaccurate = false;
1976   node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1977   if (inaccurate) {
1978     node_costs->inaccurate = true;
1979     node_costs->num_nodes_with_unknown_shapes = 1;
1980   }
1981   return OkStatus();
1982 }
1983 
PredictGatherOrSlice(const OpContext & op_context,NodeCosts * node_costs) const1984 Status OpLevelCostEstimator::PredictGatherOrSlice(const OpContext& op_context,
1985                                                   NodeCosts* node_costs) const {
1986   // Gather & Slice ops can have a very large input, but only access a small
1987   // part of it. For these op the size of the output determines the memory cost.
1988   const auto& op_info = op_context.op_info;
1989 
1990   const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
1991   if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
1992     return errors::InvalidArgument(
1993         op_info.op(),
1994         " Op doesn't have valid input / output: ", op_info.ShortDebugString());
1995   }
1996 
1997   bool unknown_shapes = false;
1998 
1999   // Each output element is a copy of some element from input.
2000   // For roofline estimate we assume each copy has a unit cost.
2001   const int64_t op_count =
2002       CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
2003   node_costs->num_compute_ops = op_count;
2004 
2005   const int64_t output_size = CalculateOutputSize(op_info, &unknown_shapes);
2006   node_costs->num_output_bytes_accessed = {output_size};
2007 
2008   node_costs->num_input_bytes_accessed.reserve(op_info.inputs().size());
2009   int64_t input_size = output_size;
2010   // Note that input(0) byte accessed is not equal to input(0) tensor size.
2011   // It's equal to the output size; though, input access is indexed gather or
2012   // slice (ignore duplicate indices).
2013   node_costs->num_input_bytes_accessed.push_back(input_size);
2014   int begin_input_index = 1;
2015   int end_input_index;
2016   if (op_info.op() == "Slice") {
2017     // Slice: 'input' (omitted), 'begin', 'size'
2018     end_input_index = 3;
2019   } else if (op_info.op() == "StridedSlice") {
2020     // StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
2021     end_input_index = 4;
2022   } else {
2023     // Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
2024     end_input_index = 2;
2025   }
2026   for (int i = begin_input_index; i < end_input_index; ++i) {
2027     node_costs->num_input_bytes_accessed.push_back(
2028         CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes));
2029   }
2030   if (unknown_shapes) {
2031     node_costs->inaccurate = true;
2032     node_costs->num_nodes_with_unknown_shapes = 1;
2033   }
2034   return OkStatus();
2035 }
2036 
PredictScatter(const OpContext & op_context,NodeCosts * node_costs) const2037 Status OpLevelCostEstimator::PredictScatter(const OpContext& op_context,
2038                                             NodeCosts* node_costs) const {
2039   // Scatter ops sparsely access a reference input and output tensor.
2040   const auto& op_info = op_context.op_info;
2041   bool found_unknown_shapes = false;
2042 
2043   // input[0]: ref tensor that will be sparsely accessed
2044   // input[1]: indices - A tensor of indices into the first dimension of ref.
2045   // input[2]: updates where updates.shape = indices.shape + ref.shape[1:]
2046   // See
2047   // https://www.tensorflow.org/api_docs/python/tf/scatter_add and
2048   // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/state_ops.cc#L146
2049 
2050   const int64_t num_indices =
2051       CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
2052 
2053   int64_t num_elems_in_ref_per_index = 1;
2054   auto ref_tensor_shape = MaybeGetMinimumShape(
2055       op_info.inputs(0).shape(), op_info.inputs(0).shape().dim_size(),
2056       &found_unknown_shapes);
2057   for (int i = 1; i < ref_tensor_shape.dim().size(); ++i) {
2058     num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size();
2059   }
2060   const int64_t op_count = num_indices * num_elems_in_ref_per_index;
2061   node_costs->num_compute_ops = op_count;
2062 
2063   // Sparsely access ref so input size depends on the number of operations
2064   int64_t ref_input_size =
2065       op_count * DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2066   int64_t indices_input_size =
2067       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2068   int64_t updates_input_size =
2069       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2070   node_costs->num_input_bytes_accessed = {ref_input_size, indices_input_size,
2071                                           updates_input_size};
2072 
2073   // Sparsely access ref so output size depends on the number of operations
2074   int64_t output_size =
2075       op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype()));
2076   node_costs->num_output_bytes_accessed = {output_size};
2077 
2078   if (found_unknown_shapes) {
2079     node_costs->inaccurate = true;
2080     node_costs->num_nodes_with_unknown_shapes = 1;
2081   }
2082   return OkStatus();
2083 }
2084 
PredictFusedOp(const OpContext & op_context,const std::vector<OpContext> & fused_op_contexts,NodeCosts * node_costs) const2085 Status OpLevelCostEstimator::PredictFusedOp(
2086     const OpContext& op_context,
2087     const std::vector<OpContext>& fused_op_contexts,
2088     NodeCosts* node_costs) const {
2089   // Note that PredictDefaultNodeCosts will get the correct memory costs from
2090   // the node's inputs and outputs; but we don't want to have to re-implement
2091   // the logic for computing the operation count of each of our component
2092   // operations here; so we simply add the compute times of each component
2093   // operation, then update the cost.
2094   bool found_unknown_shapes = false;
2095   Status s =
2096       PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, node_costs);
2097 
2098   for (auto& fused_op : fused_op_contexts) {
2099     NodeCosts fused_node_costs;
2100     s.Update(PredictNodeCosts(fused_op, &fused_node_costs));
2101     node_costs->num_compute_ops += fused_node_costs.num_compute_ops;
2102     node_costs->inaccurate |= fused_node_costs.inaccurate;
2103     // Set, not increment. Note that we are predicting the cost of one fused
2104     // node, not a function node composed of many nodes.
2105     node_costs->num_nodes_with_unknown_shapes |=
2106         fused_node_costs.num_nodes_with_unknown_shapes;
2107     node_costs->num_nodes_with_unknown_op_type |=
2108         fused_node_costs.num_nodes_with_unknown_op_type;
2109     node_costs->num_nodes_with_pure_memory_op |=
2110         fused_node_costs.num_nodes_with_pure_memory_op;
2111   }
2112 
2113   return OkStatus();
2114 }
2115 
2116 /* static */
FusedChildContext(const OpContext & parent,const std::string & op_name,const OpInfo::TensorProperties & output,const std::vector<OpInfo::TensorProperties> & inputs)2117 OpContext OpLevelCostEstimator::FusedChildContext(
2118     const OpContext& parent, const std::string& op_name,
2119     const OpInfo::TensorProperties& output,
2120     const std::vector<OpInfo::TensorProperties>& inputs) {
2121   // Setup the base parameters of our new context.
2122   OpContext new_context;
2123   new_context.name = op_name;
2124   new_context.device_name = parent.device_name;
2125   new_context.op_info = parent.op_info;
2126   new_context.op_info.set_op(op_name);
2127 
2128   // Setup the inputs of our new context.
2129   new_context.op_info.mutable_inputs()->Clear();
2130   for (const auto& input : inputs) {
2131     *new_context.op_info.mutable_inputs()->Add() = input;
2132   }
2133 
2134   // Setup the output of our new context.
2135   new_context.op_info.mutable_outputs()->Clear();
2136   *new_context.op_info.mutable_outputs()->Add() = output;
2137 
2138   return new_context;
2139 }
2140 
2141 /* static */
DescribeTensor(DataType type,const std::vector<int64_t> & dims)2142 OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
2143     DataType type, const std::vector<int64_t>& dims) {
2144   OpInfo::TensorProperties ret;
2145   ret.set_dtype(type);
2146 
2147   auto shape = ret.mutable_shape();
2148   for (const int dim : dims) {
2149     shape->add_dim()->set_size(dim);
2150   }
2151 
2152   return ret;
2153 }
2154 
2155 /* static */
2156 StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
OpDimensionsFromInputs(const TensorShapeProto & original_image_shape,const OpInfo & op_info,bool * found_unknown_shapes)2157 OpLevelCostEstimator::OpDimensionsFromInputs(
2158     const TensorShapeProto& original_image_shape, const OpInfo& op_info,
2159     bool* found_unknown_shapes) {
2160   VLOG(2) << "op features: " << op_info.DebugString();
2161   VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
2162   auto image_shape =
2163       MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
2164   VLOG(2) << "Image shape: " << image_shape.DebugString();
2165 
2166   int x_index, y_index, channel_index;
2167   const std::string& data_format = GetDataFormat(op_info);
2168   if (data_format == "NCHW") {
2169     channel_index = 1;
2170     y_index = 2;
2171     x_index = 3;
2172   } else {
2173     y_index = 1;
2174     x_index = 2;
2175     channel_index = 3;
2176   }
2177   int64_t batch = image_shape.dim(0).size();
2178   int64_t ix = image_shape.dim(x_index).size();
2179   int64_t iy = image_shape.dim(y_index).size();
2180   int64_t iz = image_shape.dim(channel_index).size();
2181 
2182   // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
2183   // {1, 1, 1, 1} in that case.
2184   std::vector<int64_t> ksize = GetKernelSize(op_info);
2185   int64_t kx = ksize[x_index];
2186   int64_t ky = ksize[y_index];
2187   // These ops don't support groupwise operation, therefore kz == iz.
2188   int64_t kz = iz;
2189 
2190   std::vector<int64_t> strides = GetStrides(op_info);
2191   int64_t sx = strides[x_index];
2192   int64_t sy = strides[y_index];
2193   if (sx == 0 || sy == 0) {
2194     return errors::InvalidArgument(
2195         "Stride must be > 0 for Height and Width, but got (", sy, ", ", sx,
2196         ")");
2197   }
2198   const auto padding = GetPadding(op_info);
2199 
2200   int64_t ox = GetOutputSize(ix, kx, sx, padding);
2201   int64_t oy = GetOutputSize(iy, ky, sy, padding);
2202   int64_t oz = iz;
2203 
2204   OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
2205       batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
2206   return conv_dims;
2207 }
2208 
PredictMaxPool(const OpContext & op_context,NodeCosts * node_costs) const2209 Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
2210                                             NodeCosts* node_costs) const {
2211   bool found_unknown_shapes = false;
2212   const auto& op_info = op_context.op_info;
2213   // x: op_info.inputs(0)
2214   TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2215                       OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2216                                              &found_unknown_shapes));
2217   // kx * ky - 1 comparisons per output (kx * xy > 1)
2218   // or 1 copy per output (kx * k1 = 1).
2219   int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
2220   int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
2221   node_costs->num_compute_ops = ops;
2222 
2223   int64_t input_size = 0;
2224   if (dims.ky >= dims.sy) {
2225     input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2226   } else {  // dims.ky < dims.sy
2227     // Vertical stride is larger than vertical kernel; assuming row-major
2228     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2229     // others are not used for output).
2230     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2231     input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2232   }
2233   node_costs->num_input_bytes_accessed = {input_size};
2234   const int64_t output_size =
2235       CalculateOutputSize(op_info, &found_unknown_shapes);
2236   node_costs->num_output_bytes_accessed = {output_size};
2237   node_costs->max_memory = output_size;
2238   if (found_unknown_shapes) {
2239     node_costs->inaccurate = true;
2240     node_costs->num_nodes_with_unknown_shapes = 1;
2241   }
2242   return OkStatus();
2243 }
2244 
PredictMaxPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2245 Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
2246                                                 NodeCosts* node_costs) const {
2247   bool found_unknown_shapes = false;
2248   const auto& op_info = op_context.op_info;
2249   // x: op_info.inputs(0)
2250   // y: op_info.inputs(1)
2251   // y_grad: op_info.inputs(2)
2252   if (op_info.inputs_size() < 3) {
2253     return errors::InvalidArgument("MaxPoolGrad op has invalid inputs: ",
2254                                    op_info.ShortDebugString());
2255   }
2256 
2257   TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2258                       OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2259                                              &found_unknown_shapes));
2260 
2261   int64_t ops = 0;
2262   if (dims.kx == 1 && dims.ky == 1) {
2263     // 1x1 window. No need to know which input was max.
2264     ops = dims.batch * dims.ix * dims.iy * dims.iz;
2265   } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2266     // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
2267     ops = dims.batch * dims.iz *
2268           (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
2269   } else {
2270     // Overlapping window: initialize with zeros, re-run maxpool, then
2271     // accumulate y_gad to proper x_grad locations.
2272     ops = dims.batch * dims.iz *
2273           (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
2274   }
2275   node_costs->num_compute_ops = ops;
2276 
2277   // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
2278   // MaxPool internally.
2279   const int64_t input0_size =
2280       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2281   const int64_t input2_size =
2282       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2283   node_costs->num_input_bytes_accessed = {input0_size, 0, input2_size};
2284   // Write x_grad; size equal to x.
2285   const int64_t output_size =
2286       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2287   node_costs->num_output_bytes_accessed = {output_size};
2288   node_costs->max_memory = output_size;
2289 
2290   if (found_unknown_shapes) {
2291     node_costs->inaccurate = true;
2292     node_costs->num_nodes_with_unknown_shapes = 1;
2293   }
2294   return OkStatus();
2295 }
2296 
2297 /* This predict function handles three types of tensorflow ops
2298  * AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting
2299  * was not possible for these ops, therefore the input tensor's shapes is
2300  * enough to compute the cost */
PredictAssignVariableOps(const OpContext & op_context,NodeCosts * node_costs) const2301 Status OpLevelCostEstimator::PredictAssignVariableOps(
2302     const OpContext& op_context, NodeCosts* node_costs) const {
2303   bool found_unknown_shapes = false;
2304   const auto& op_info = op_context.op_info;
2305   /* First input of these ops are reference to the assignee. */
2306   if (op_info.inputs_size() != 2) {
2307     return errors::InvalidArgument("AssignVariable op has invalid input: ",
2308                                    op_info.ShortDebugString());
2309   }
2310 
2311   const int64_t ops = op_info.op() == kAssignVariableOp
2312                           ? 0
2313                           : CalculateTensorElementCount(op_info.inputs(1),
2314                                                         &found_unknown_shapes);
2315   node_costs->num_compute_ops = ops;
2316   const int64_t input_size = CalculateInputSize(op_info, &found_unknown_shapes);
2317   node_costs->num_input_bytes_accessed = {input_size};
2318   // TODO(dyoon): check these ops' behavior whether it writes data;
2319   // Op itself doesn't have output tensor, but it may modify the input (ref or
2320   // resource). Maybe use node_costs->internal_write_bytes.
2321   node_costs->num_output_bytes_accessed = {0};
2322   if (found_unknown_shapes) {
2323     node_costs->inaccurate = true;
2324     node_costs->num_nodes_with_unknown_shapes = 1;
2325   }
2326   return OkStatus();
2327 }
2328 
PredictAvgPool(const OpContext & op_context,NodeCosts * node_costs) const2329 Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
2330                                             NodeCosts* node_costs) const {
2331   bool found_unknown_shapes = false;
2332   const auto& op_info = op_context.op_info;
2333   // x: op_info.inputs(0)
2334   TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2335                       OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2336                                              &found_unknown_shapes));
2337 
2338   // kx * ky - 1 additions and 1 multiplication per output.
2339   int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
2340   node_costs->num_compute_ops = ops;
2341 
2342   int64_t input_size;
2343   if (dims.ky >= dims.sy) {
2344     input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2345   } else {  // dims.ky < dims.sy
2346     // vertical stride is larger than vertical kernel; assuming row-major
2347     // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2348     // others are not used for output).
2349     const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2350     input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2351   }
2352   node_costs->num_input_bytes_accessed = {input_size};
2353 
2354   const int64_t output_size =
2355       CalculateOutputSize(op_info, &found_unknown_shapes);
2356   node_costs->num_output_bytes_accessed = {output_size};
2357   node_costs->max_memory = output_size;
2358 
2359   if (found_unknown_shapes) {
2360     node_costs->inaccurate = true;
2361     node_costs->num_nodes_with_unknown_shapes = 1;
2362   }
2363   return OkStatus();
2364 }
2365 
PredictAvgPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2366 Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
2367                                                 NodeCosts* node_costs) const {
2368   bool found_unknown_shapes = false;
2369   const auto& op_info = op_context.op_info;
2370   // x's shape: op_info.inputs(0)
2371   // y_grad: op_info.inputs(1)
2372 
2373   // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
2374   bool shape_found = false;
2375   TensorShapeProto x_shape;
2376   if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
2377     const TensorProto& value = op_info.inputs(0).value();
2378     shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
2379   }
2380   if (!shape_found && op_info.outputs_size() > 0) {
2381     x_shape = op_info.outputs(0).shape();
2382     shape_found = true;
2383   }
2384   if (!shape_found) {
2385     // Set the minimum shape that's feasible.
2386     x_shape.Clear();
2387     for (int i = 0; i < 4; ++i) {
2388       x_shape.add_dim()->set_size(1);
2389     }
2390     found_unknown_shapes = true;
2391   }
2392 
2393   TF_ASSIGN_OR_RETURN(
2394       ConvolutionDimensions dims,
2395       OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes));
2396 
2397   int64_t ops = 0;
2398   if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2399     // Non-overlapping window.
2400     ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
2401   } else {
2402     // Overlapping window.
2403     ops = dims.batch * dims.iz *
2404           (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
2405   }
2406   auto s = PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2407                                    node_costs);
2408   node_costs->max_memory = node_costs->num_total_output_bytes();
2409   return s;
2410 }
2411 
PredictFusedBatchNorm(const OpContext & op_context,NodeCosts * node_costs) const2412 Status OpLevelCostEstimator::PredictFusedBatchNorm(
2413     const OpContext& op_context, NodeCosts* node_costs) const {
2414   bool found_unknown_shapes = false;
2415   const auto& op_info = op_context.op_info;
2416   // x: op_info.inputs(0)
2417   // scale: op_info.inputs(1)
2418   // offset: op_info.inputs(2)
2419   // mean: op_info.inputs(3)  --> only for inference
2420   // variance: op_info.inputs(4) --> only for inference
2421   TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2422                       OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2423                                              &found_unknown_shapes));
2424   const bool is_training = IsTraining(op_info);
2425 
2426   int64_t ops = 0;
2427   const auto rsqrt_cost = Eigen::internal::functor_traits<
2428       Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2429   if (is_training) {
2430     ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
2431   } else {
2432     ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
2433   }
2434   node_costs->num_compute_ops = ops;
2435 
2436   const int64_t size_nhwc =
2437       CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2438   const int64_t size_c =
2439       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2440   if (is_training) {
2441     node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c};
2442     node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2443                                              size_c};
2444     // FusedBatchNorm in training mode internally re-reads the input tensor:
2445     // one for mean/variance, and the 2nd internal read forthe actual scaling.
2446     // Assume small intermediate data such as mean / variance (size_c) can be
2447     // cached on-chip.
2448     node_costs->internal_read_bytes = size_nhwc;
2449   } else {
2450     node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2451                                             size_c};
2452     node_costs->num_output_bytes_accessed = {size_nhwc};
2453   }
2454   node_costs->max_memory = node_costs->num_total_output_bytes();
2455 
2456   if (found_unknown_shapes) {
2457     node_costs->inaccurate = true;
2458     node_costs->num_nodes_with_unknown_shapes = 1;
2459   }
2460   return OkStatus();
2461 }
2462 
PredictFusedBatchNormGrad(const OpContext & op_context,NodeCosts * node_costs) const2463 Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
2464     const OpContext& op_context, NodeCosts* node_costs) const {
2465   bool found_unknown_shapes = false;
2466   const auto& op_info = op_context.op_info;
2467   // y_backprop: op_info.inputs(0)
2468   // x: op_info.inputs(1)
2469   // scale: op_info.inputs(2)
2470   // mean: op_info.inputs(3)
2471   // variance or inverse of variance: op_info.inputs(4)
2472   TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2473                       OpDimensionsFromInputs(op_info.inputs(1).shape(), op_info,
2474                                              &found_unknown_shapes));
2475 
2476   int64_t ops = 0;
2477   const auto rsqrt_cost = Eigen::internal::functor_traits<
2478       Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2479   ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
2480   node_costs->num_compute_ops = ops;
2481 
2482   const int64_t size_nhwc =
2483       CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2484   const int64_t size_c =
2485       CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2486   // TODO(dyoon): fix missing memory cost for variance input (size_c) and
2487   // yet another read of y_backprop (size_nhwc) internally.
2488   node_costs->num_input_bytes_accessed = {size_nhwc, size_nhwc, size_c, size_c};
2489   node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c};
2490   // FusedBatchNormGrad has to read y_backprop internally.
2491   node_costs->internal_read_bytes = size_nhwc;
2492   node_costs->max_memory = node_costs->num_total_output_bytes();
2493 
2494   if (found_unknown_shapes) {
2495     node_costs->inaccurate = true;
2496     node_costs->num_nodes_with_unknown_shapes = 1;
2497   }
2498   return OkStatus();
2499 }
2500 
PredictNaryOp(const OpContext & op_context,NodeCosts * node_costs) const2501 Status OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context,
2502                                            NodeCosts* node_costs) const {
2503   const auto& op_info = op_context.op_info;
2504   bool found_unknown_shapes = false;
2505   // Calculate the largest known tensor size across all inputs and output.
2506   int64_t op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
2507   // If output shape is available, try to use the element count calculated from
2508   // that.
2509   if (op_info.outputs_size() > 0) {
2510     op_count = std::max(
2511         op_count,
2512         CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
2513   }
2514   // Also calculate the output shape possibly resulting from broadcasting.
2515   // Note that the some Nary ops (such as AddN) do not support broadcasting,
2516   // but we're including this here for completeness.
2517   if (op_info.inputs_size() >= 2) {
2518     op_count = std::max(op_count, CwiseOutputElementCount(op_info));
2519   }
2520 
2521   // Nary ops perform one operation for every element in every input tensor.
2522   op_count *= op_info.inputs_size() - 1;
2523 
2524   const auto sum_cost = Eigen::internal::functor_traits<
2525       Eigen::internal::scalar_sum_op<float>>::Cost;
2526   return PredictDefaultNodeCosts(op_count * sum_cost, op_context,
2527                                  &found_unknown_shapes, node_costs);
2528 }
2529 
2530 // softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
PredictSoftmax(const OpContext & op_context,NodeCosts * node_costs) const2531 Status OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context,
2532                                             NodeCosts* node_costs) const {
2533   bool found_unknown_shapes = false;
2534   const int64_t logits_size = CalculateTensorElementCount(
2535       op_context.op_info.inputs(0), &found_unknown_shapes);
2536   // Softmax input rank should be >=1.
2537   TensorShapeProto logits_shape = op_context.op_info.inputs(0).shape();
2538   if (logits_shape.unknown_rank() || logits_shape.dim_size() == 0) {
2539     return errors::InvalidArgument("Softmax op has invalid input: ",
2540                                    op_context.op_info.ShortDebugString());
2541   }
2542 
2543 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2544 
2545   // Every element of <logits> will be exponentiated, have that result included
2546   // in a sum across j, and also have that result multiplied by the reciprocal
2547   // of the sum_j. In addition, we'll compute 1/sum_j for every i.
2548   auto ops =
2549       (EIGEN_COST(scalar_exp_op<float>) + EIGEN_COST(scalar_sum_op<float>) +
2550        EIGEN_COST(scalar_product_op<float>)) *
2551           logits_size +
2552       EIGEN_COST(scalar_inverse_op<float>) * logits_shape.dim(0).size();
2553 
2554 #undef EIGEN_COST
2555   return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2556                                  node_costs);
2557 }
2558 
PredictResizeBilinear(const OpContext & op_context,NodeCosts * node_costs) const2559 Status OpLevelCostEstimator::PredictResizeBilinear(
2560     const OpContext& op_context, NodeCosts* node_costs) const {
2561   bool found_unknown_shapes = false;
2562 
2563   if (op_context.op_info.outputs().empty() ||
2564       op_context.op_info.inputs().empty()) {
2565     return errors::InvalidArgument(
2566         "ResizeBilinear op has invalid input / output ",
2567         op_context.op_info.ShortDebugString());
2568   }
2569 
2570   const int64_t output_elements = CalculateTensorElementCount(
2571       op_context.op_info.outputs(0), &found_unknown_shapes);
2572 
2573   const auto half_pixel_centers =
2574       op_context.op_info.attr().find("half_pixel_centers");
2575   bool use_half_pixel_centers = false;
2576   if (half_pixel_centers == op_context.op_info.attr().end()) {
2577     LOG(WARNING) << "half_pixel_centers attr not set for ResizeBilinear.";
2578     return PredictCostOfAnUnknownOp(op_context, node_costs);
2579   } else {
2580     use_half_pixel_centers = half_pixel_centers->second.b();
2581   }
2582 
2583   // Compose cost of bilinear interpolation.
2584   int64_t ops = 0;
2585 
2586 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2587   const auto sub_cost_float = EIGEN_COST(scalar_difference_op<float>);
2588   const auto sub_cost_int = EIGEN_COST(scalar_difference_op<int64_t>);
2589   const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2590   const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2591   const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2592   const auto max_cost = EIGEN_COST(scalar_max_op<int64_t>);
2593   const auto min_cost = EIGEN_COST(scalar_min_op<int64_t>);
2594   const auto cast_to_int_cost = Eigen::internal::functor_traits<
2595       Eigen::internal::scalar_cast_op<float, int64_t>>::Cost;
2596   const auto cast_to_float_cost = Eigen::internal::functor_traits<
2597       Eigen::internal::scalar_cast_op<int64_t, float>>::Cost;
2598   const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2599 #undef EIGEN_COST
2600 
2601   // Ops calculated from tensorflow/core/kernels/image/resize_bilinear_op.cc.
2602 
2603   // Op counts taken from resize_bilinear implementation on 07/21/2020.
2604   // Computed op counts may become inaccurate if resize_bilinear implementation
2605   // changes.
2606 
2607   // resize_bilinear has an optimization where the interpolation weights are
2608   // precomputed and cached. Given input tensors of size [B,H1,W1,C] and output
2609   // tensors of size [B,H2,W2,C], the last dimension C that needs to be accessed
2610   // in the input for interpolation are identical at every point in the output.
2611   // These values are cached in the compute_interpolation_weights function. For
2612   // a particular y in [0...H2-1], the rows to be accessed in the input are the
2613   // same. Likewise, for a particular x in [0...H2-1], the columns to be accsed
2614   // are the same. So the precomputation only needs to be done for H2 + W2
2615   // values.
2616   const auto output_shape = MaybeGetMinimumShape(
2617       op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2618   // Assume H is dim 1 and W is dim 2 to match logic in resize_bilinear, which
2619   // also makes this assumption.
2620   const int64_t output_height = output_shape.dim(1).size();
2621   const int64_t output_width = output_shape.dim(2).size();
2622   // Add the ops done outside of the scaler function in
2623   // compute_interpolation_weights.
2624   int64_t interp_weight_cost = floor_cost + max_cost + min_cost +
2625                                sub_cost_float + sub_cost_int + ceil_cost +
2626                                cast_to_int_cost * 2;
2627   // There are two options for computing the weight of each pixel in the
2628   // interpolation. Algorithm can use pixel centers, or corners, for the
2629   // weight. Ops depend on the scaler function passed into
2630   // compute_interpolation_weights.
2631   if (use_half_pixel_centers) {
2632     // Ops for HalfPixelScalaer.
2633     interp_weight_cost +=
2634         add_cost + mul_cost + sub_cost_float + cast_to_float_cost;
2635   } else {
2636     // Ops for LegacyScaler.
2637     interp_weight_cost += cast_to_float_cost + mul_cost;
2638   }
2639   // Cost for the interpolation is multiplied by (H2 + w2), as mentioned above.
2640   ops += interp_weight_cost * (output_height + output_width);
2641 
2642   // Ops for computing the new values, done for every element. Logic is from
2643   // compute_lerp in the inner loop of resize_image which consists of:
2644   //   const float top = top_left + (top_right - top_left) * x_lerp;
2645   //   const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
2646   //   return top + (bottom - top) * y_lerp;
2647   ops += (add_cost * 3 + sub_cost_float * 3 + mul_cost * 3) * output_elements;
2648 
2649   return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2650                                  node_costs);
2651 }
2652 
PredictCropAndResize(const OpContext & op_context,NodeCosts * node_costs) const2653 Status OpLevelCostEstimator::PredictCropAndResize(const OpContext& op_context,
2654                                                   NodeCosts* node_costs) const {
2655   bool found_unknown_shapes = false;
2656 
2657   const auto method = op_context.op_info.attr().find("method");
2658   bool use_bilinear_interp;
2659   if (method == op_context.op_info.attr().end() ||
2660       method->second.s() == "bilinear") {
2661     use_bilinear_interp = true;
2662   } else if (method->second.s() == "nearest") {
2663     use_bilinear_interp = false;
2664   } else {
2665     LOG(WARNING) << "method attr in CropAndResize invalid; expected bilinear "
2666                     "or nearest.";
2667     return PredictCostOfAnUnknownOp(op_context, node_costs);
2668   }
2669 
2670   const int64_t num_boxes = op_context.op_info.inputs(1).shape().dim(0).size();
2671   const auto crop_shape = MaybeGetMinimumShape(
2672       op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2673   const int64_t crop_height = crop_shape.dim(1).size();
2674   const int64_t crop_width = crop_shape.dim(2).size();
2675   const int64_t output_elements = CalculateTensorElementCount(
2676       op_context.op_info.outputs(0), &found_unknown_shapes);
2677 
2678 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2679   const auto sub_cost = EIGEN_COST(scalar_difference_op<float>);
2680   const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2681   const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2682   auto div_cost = EIGEN_COST(scalar_div_cost<float>);
2683   const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2684   const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2685   auto round_cost = EIGEN_COST(scalar_round_op<float>);
2686   const auto cast_to_float_cost = Eigen::internal::functor_traits<
2687       Eigen::internal::scalar_cast_op<int64_t, float>>::Cost;
2688 #undef EIGEN_COST
2689 
2690   // Computing ops following
2691   // tensorflow/core/kernels/image/crop_and_resize_op.cc at 08/25/2020. Op
2692   // calculation differs from rough estimate in implementation, as it separates
2693   // out cost per box from cost per pixel and cost per element.
2694 
2695   // Since crop arguments are user controlled, check for overflow.
2696   int64_t crop_area = MultiplyWithoutOverflow(crop_height, crop_width);
2697   if (crop_area < 0)
2698     return errors::InvalidArgument("Cannot estimate cost, multiplying ",
2699                                    crop_height, " with ", crop_width,
2700                                    " would overflow");
2701   int64_t crop_volume = MultiplyWithoutOverflow(crop_area, num_boxes);
2702   if (crop_volume < 0)
2703     return errors::InvalidArgument("Cannot estimate cost, multiplying ",
2704                                    crop_area, " with ", num_boxes,
2705                                    " would overflow");
2706   int64_t crop_depth = MultiplyWithoutOverflow(crop_height, num_boxes);
2707   if (crop_depth < 0)
2708     return errors::InvalidArgument("Cannot estimate cost, multiplying ",
2709                                    crop_height, " with ", num_boxes,
2710                                    " would overflow");
2711 
2712   // Ops for variables height_scale and width_scale.
2713   int64_t ops = (sub_cost * 6 + mul_cost * 2 + div_cost * 2) * num_boxes;
2714   // Ops for variable in_y.
2715   ops += (mul_cost * 2 + sub_cost + add_cost) * crop_depth;
2716   // Ops for variable in_x (same computation across both branches).
2717   ops += (mul_cost * 2 + sub_cost + add_cost) * crop_volume;
2718   // Specify op_cost based on the method.
2719   if (use_bilinear_interp) {
2720     // Ops for variables top_y_index, bottom_y_index, y_lerp.
2721     ops += (floor_cost + ceil_cost + sub_cost) * crop_depth;
2722     // Ops for variables left_x, right_x, x_lerp;
2723     ops += (floor_cost + ceil_cost + sub_cost) * crop_volume;
2724     // Ops for innermost loop across depth.
2725     ops +=
2726         (cast_to_float_cost * 4 + add_cost * 3 + sub_cost * 3 + mul_cost * 3) *
2727         output_elements;
2728   } else /* method == "nearest" */ {
2729     // Ops for variables closest_x_index and closest_y_index.
2730     ops += round_cost * 2 * crop_volume;
2731     // Ops for innermost loop across depth.
2732     ops += cast_to_float_cost * output_elements;
2733   }
2734   return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2735                                  node_costs);
2736 }
2737 
2738 }  // end namespace grappler
2739 }  // end namespace tensorflow
2740