1
2 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
3
4 Licensed under the Apache License, Version 2.0 (the "License");
5 you may not use this file except in compliance with the License.
6 You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10 Unless required by applicable law or agreed to in writing, software
11 distributed under the License is distributed on an "AS IS" BASIS,
12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 See the License for the specific language governing permissions and
14 limitations under the License.
15 ==============================================================================*/
16
17 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
18
19 #include "absl/strings/match.h"
20 #include "third_party/eigen3/Eigen/Core"
21 #include "tensorflow/core/framework/attr_value.pb.h"
22 #include "tensorflow/core/framework/attr_value_util.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/clusters/utils.h"
27 #include "tensorflow/core/grappler/costs/op_context.h"
28 #include "tensorflow/core/grappler/costs/utils.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/util/overflow.h"
31
32 namespace tensorflow {
33 namespace grappler {
34
35 // TODO(dyoon): update op to Predict method map for TF ops with V2 or V3 suffix.
36 constexpr int kOpsPerMac = 2;
37 constexpr char kGuaranteeConst[] = "GuaranteeConst";
38 constexpr char kAddN[] = "AddN";
39 constexpr char kBitCast[] = "BitCast";
40 constexpr char kConcatV2[] = "ConcatV2";
41 constexpr char kConv2d[] = "Conv2D";
42 constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
43 constexpr char kConv2dBackpropInput[] = "Conv2DBackpropInput";
44 constexpr char kFusedConv2dBiasActivation[] = "FusedConv2DBiasActivation";
45 constexpr char kDataFormatVecPermute[] = "DataFormatVecPermute";
46 constexpr char kDepthToSpace[] = "DepthToSpace";
47 constexpr char kDepthwiseConv2dNative[] = "DepthwiseConv2dNative";
48 constexpr char kDepthwiseConv2dNativeBackpropFilter[] =
49 "DepthwiseConv2dNativeBackpropFilter";
50 constexpr char kDepthwiseConv2dNativeBackpropInput[] =
51 "DepthwiseConv2dNativeBackpropInput";
52 constexpr char kMatMul[] = "MatMul";
53 constexpr char kXlaEinsum[] = "XlaEinsum";
54 constexpr char kEinsum[] = "Einsum";
55 constexpr char kExpandDims[] = "ExpandDims";
56 constexpr char kFill[] = "Fill";
57 constexpr char kSparseMatMul[] = "SparseMatMul";
58 constexpr char kSparseTensorDenseMatMul[] = "SparseTensorDenseMatMul";
59 constexpr char kPlaceholder[] = "Placeholder";
60 constexpr char kIdentity[] = "Identity";
61 constexpr char kIdentityN[] = "IdentityN";
62 constexpr char kRefIdentity[] = "RefIdentity";
63 constexpr char kNoOp[] = "NoOp";
64 constexpr char kReshape[] = "Reshape";
65 constexpr char kSplit[] = "Split";
66 constexpr char kSqueeze[] = "Squeeze";
67 constexpr char kRecv[] = "_Recv";
68 constexpr char kSend[] = "_Send";
69 constexpr char kBatchMatMul[] = "BatchMatMul";
70 constexpr char kBatchMatMulV2[] = "BatchMatMulV2";
71 constexpr char kOneHot[] = "OneHot";
72 constexpr char kPack[] = "Pack";
73 constexpr char kRank[] = "Rank";
74 constexpr char kRange[] = "Range";
75 constexpr char kShape[] = "Shape";
76 constexpr char kShapeN[] = "ShapeN";
77 constexpr char kSize[] = "Size";
78 constexpr char kStopGradient[] = "StopGradient";
79 constexpr char kPreventGradient[] = "PreventGradient";
80 constexpr char kGather[] = "Gather";
81 constexpr char kGatherNd[] = "GatherNd";
82 constexpr char kGatherV2[] = "GatherV2";
83 constexpr char kScatterAdd[] = "ScatterAdd";
84 constexpr char kScatterDiv[] = "ScatterDiv";
85 constexpr char kScatterMax[] = "ScatterMax";
86 constexpr char kScatterMin[] = "ScatterMin";
87 constexpr char kScatterMul[] = "ScatterMul";
88 constexpr char kScatterSub[] = "ScatterSub";
89 constexpr char kScatterUpdate[] = "ScatterUpdate";
90 constexpr char kSlice[] = "Slice";
91 constexpr char kStridedSlice[] = "StridedSlice";
92 constexpr char kSpaceToDepth[] = "SpaceToDepth";
93 constexpr char kTranspose[] = "Transpose";
94 constexpr char kTile[] = "Tile";
95 constexpr char kMaxPool[] = "MaxPool";
96 constexpr char kMaxPoolGrad[] = "MaxPoolGrad";
97 constexpr char kAvgPool[] = "AvgPool";
98 constexpr char kAvgPoolGrad[] = "AvgPoolGrad";
99 constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
100 constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
101 constexpr char kQuantizedMatMul[] = "QuantizedMatMul";
102 constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
103 constexpr char kUnpack[] = "Unpack";
104 constexpr char kSoftmax[] = "Softmax";
105 constexpr char kResizeBilinear[] = "ResizeBilinear";
106 constexpr char kCropAndResize[] = "CropAndResize";
107 // Dynamic control flow ops.
108 constexpr char kSwitch[] = "Switch";
109 constexpr char kMerge[] = "Merge";
110 constexpr char kEnter[] = "Enter";
111 constexpr char kExit[] = "Exit";
112 constexpr char kNextIteration[] = "NextIteration";
113 // Persistent ops.
114 constexpr char kConst[] = "Const";
115 constexpr char kVariable[] = "Variable";
116 constexpr char kVariableV2[] = "VariableV2";
117 constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
118 constexpr char kVarHandleOp[] = "VarHandleOp";
119 constexpr char kVarHandlesOp[] = "_VarHandlesOp";
120 constexpr char kReadVariableOp[] = "ReadVariableOp";
121 constexpr char kReadVariablesOp[] = "_ReadVariablesOp";
122 constexpr char kAssignVariableOp[] = "AssignVariableOp";
123 constexpr char kAssignAddVariableOp[] = "AssignAddVariableOp";
124 constexpr char kAssignSubVariableOp[] = "AssignSubVariableOp";
125
126 static const Costs::Duration kMinComputeTime(1);
127 static const int64_t kMinComputeOp = 1;
128
129 namespace {
130
GetDataFormat(const OpInfo & op_info)131 std::string GetDataFormat(const OpInfo& op_info) {
132 std::string data_format = "NHWC"; // Default format.
133 if (op_info.attr().find("data_format") != op_info.attr().end()) {
134 data_format = op_info.attr().at("data_format").s();
135 }
136 return data_format;
137 }
138
GetFilterFormat(const OpInfo & op_info)139 std::string GetFilterFormat(const OpInfo& op_info) {
140 std::string filter_format = "HWIO"; // Default format.
141 if (op_info.attr().find("filter_format") != op_info.attr().end()) {
142 filter_format = op_info.attr().at("filter_format").s();
143 }
144 return filter_format;
145 }
146
GetPadding(const OpInfo & op_info)147 Padding GetPadding(const OpInfo& op_info) {
148 if (op_info.attr().find("padding") != op_info.attr().end() &&
149 op_info.attr().at("padding").s() == "VALID") {
150 return Padding::VALID;
151 }
152 return Padding::SAME; // Default padding.
153 }
154
IsTraining(const OpInfo & op_info)155 bool IsTraining(const OpInfo& op_info) {
156 if (op_info.attr().find("is_training") != op_info.attr().end() &&
157 op_info.attr().at("is_training").b()) {
158 return true;
159 }
160 return false;
161 }
162
163 // TODO(dyoon): support non-4D tensors in the cost functions of convolution
164 // related ops (Conv, Pool, BatchNorm, and their backprops) and the related
165 // helper functions.
GetStrides(const OpInfo & op_info)166 std::vector<int64_t> GetStrides(const OpInfo& op_info) {
167 if (op_info.attr().find("strides") != op_info.attr().end()) {
168 const auto strides = op_info.attr().at("strides").list().i();
169 DCHECK(strides.size() == 4)
170 << "Attr strides is not a length-4 vector: " << op_info.DebugString();
171 if (strides.size() != 4) return {1, 1, 1, 1};
172 return {strides[0], strides[1], strides[2], strides[3]};
173 }
174 return {1, 1, 1, 1};
175 }
176
GetKernelSize(const OpInfo & op_info)177 std::vector<int64_t> GetKernelSize(const OpInfo& op_info) {
178 if (op_info.attr().find("ksize") != op_info.attr().end()) {
179 const auto ksize = op_info.attr().at("ksize").list().i();
180 DCHECK(ksize.size() == 4)
181 << "Attr ksize is not a length-4 vector: " << op_info.DebugString();
182 if (ksize.size() != 4) return {1, 1, 1, 1};
183 return {ksize[0], ksize[1], ksize[2], ksize[3]};
184 }
185 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
186 // {1, 1, 1, 1} in that case.
187 return {1, 1, 1, 1};
188 }
189
GetOutputSize(const int64_t input,const int64_t filter,const int64_t stride,const Padding & padding)190 int64_t GetOutputSize(const int64_t input, const int64_t filter,
191 const int64_t stride, const Padding& padding) {
192 // Logic for calculating output shape is from GetWindowedOutputSizeVerbose()
193 // function in third_party/tensorflow/core/framework/common_shape_fns.cc.
194 if (padding == Padding::VALID) {
195 return (input - filter + stride) / stride;
196 } else { // SAME.
197 return (input + stride - 1) / stride;
198 }
199 }
200
201 // Return the output element count of a multi-input element-wise op considering
202 // broadcasting.
CwiseOutputElementCount(const OpInfo & op_info)203 int64_t CwiseOutputElementCount(const OpInfo& op_info) {
204 int max_rank = 1;
205 for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
206 max_rank = std::max(max_rank, input_properties.shape().dim_size());
207 }
208
209 TensorShapeProto output_shape;
210 output_shape.mutable_dim()->Reserve(max_rank);
211 for (int i = 0; i < max_rank; ++i) {
212 output_shape.add_dim();
213 }
214
215 // Expand the shape of the output to follow the numpy-style broadcast rule
216 // which matches each input starting with the trailing dimensions and working
217 // its way forward. To do this, iterate through each input shape's dimensions
218 // in reverse order, and potentially increase the corresponding output
219 // dimension.
220 for (const OpInfo::TensorProperties& input_properties : op_info.inputs()) {
221 const TensorShapeProto& input_shape = input_properties.shape();
222 for (int i = input_shape.dim_size() - 1; i >= 0; --i) {
223 int output_shape_dim_index =
224 i + output_shape.dim_size() - input_shape.dim_size();
225 output_shape.mutable_dim(output_shape_dim_index)
226 ->set_size(std::max(output_shape.dim(output_shape_dim_index).size(),
227 input_shape.dim(i).size()));
228 }
229 }
230
231 int64_t count = 1;
232 for (int i = 0; i < output_shape.dim_size(); i++) {
233 count *= output_shape.dim(i).size();
234 }
235 return count;
236 }
237
238 // Helper function for determining whether there are repeated indices in the
239 // input Einsum equation.
CheckRepeatedDimensions(const absl::string_view dim_str)240 bool CheckRepeatedDimensions(const absl::string_view dim_str) {
241 int str_size = dim_str.size();
242 for (int idx = 0; idx < str_size - 1; idx++) {
243 if (dim_str.find(dim_str[idx], idx + 1) != std::string::npos) {
244 return true;
245 }
246 }
247 return false;
248 }
249
250 // Auxiliary function for determining whether OpLevelCostEstimator is compatible
251 // with a given Einsum.
IsEinsumCorrectlyFormed(const OpContext & einsum_context)252 bool IsEinsumCorrectlyFormed(const OpContext& einsum_context) {
253 const auto& op_info = einsum_context.op_info;
254
255 auto it = op_info.attr().find("equation");
256 if (it == op_info.attr().end()) return false;
257 const absl::string_view equation = it->second.s();
258 std::vector<std::string> equation_split = absl::StrSplit(equation, "->");
259
260 if (equation_split.empty()) {
261 LOG(WARNING) << "Einsum with malformed equation";
262 return false;
263 }
264 std::vector<absl::string_view> input_split =
265 absl::StrSplit(equation_split[0], ',');
266
267 // The current model covers Einsum operations with two operands and a RHS
268 if (op_info.inputs_size() != 2 || equation_split.size() != 2) {
269 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
270 return false;
271 }
272 const auto& a_input = op_info.inputs(0);
273 const auto& b_input = op_info.inputs(1);
274 absl::string_view rhs_str = equation_split[1];
275 absl::string_view a_input_str = input_split[0];
276 absl::string_view b_input_str = input_split[1];
277
278 // Ellipsis are not currently supported
279 if (absl::StrContains(a_input_str, "...") ||
280 absl::StrContains(b_input_str, "...")) {
281 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
282 << ", ellipsis not supported";
283 return false;
284 }
285
286 constexpr int kMatrixRank = 2;
287
288 bool a_input_shape_unknown = false;
289 bool b_input_shape_unknown = false;
290
291 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
292 a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
293 &a_input_shape_unknown);
294 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
295 b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
296 &b_input_shape_unknown);
297
298 if (a_input_str.size() != static_cast<size_t>(a_input_shape.dim_size()) ||
299 b_input_str.size() != static_cast<size_t>(b_input_shape.dim_size())) {
300 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
301 << ", equation subscripts don't match tensor rank.";
302 return false;
303 }
304
305 // Subscripts where axis appears more than once for a single input are not yet
306 // supported
307 if (CheckRepeatedDimensions(a_input_str) ||
308 CheckRepeatedDimensions(b_input_str) ||
309 CheckRepeatedDimensions(rhs_str)) {
310 VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
311 << ", Subscripts where axis appears more than once for a single "
312 "input are not yet supported";
313 return false;
314 }
315
316 return true;
317 }
318
319 } // namespace
320
321 // Return a minimum shape if the shape is unknown. If known, return the original
322 // shape.
MaybeGetMinimumShape(const TensorShapeProto & original_shape,int rank,bool * found_unknown_shapes)323 TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
324 int rank, bool* found_unknown_shapes) {
325 auto shape = original_shape;
326 bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
327
328 if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
329 *found_unknown_shapes = true;
330 VLOG(2) << "Use minimum shape because the rank is unknown.";
331 // The size of each dimension is at least 1, if unknown.
332 for (int i = shape.dim_size(); i < rank; i++) {
333 shape.add_dim()->set_size(1);
334 }
335 } else if (is_scalar) {
336 for (int i = 0; i < rank; i++) {
337 shape.add_dim()->set_size(1);
338 }
339 } else if (shape.dim_size() > rank) {
340 *found_unknown_shapes = true;
341 shape.clear_dim();
342 for (int i = 0; i < rank; i++) {
343 shape.add_dim()->set_size(original_shape.dim(i).size());
344 }
345 } else {
346 for (int i = 0; i < shape.dim_size(); i++) {
347 if (shape.dim(i).size() < 0) {
348 *found_unknown_shapes = true;
349 VLOG(2) << "Use minimum dim size 1 because the shape is unknown.";
350 // The size of each dimension is at least 1, if unknown.
351 shape.mutable_dim(i)->set_size(1);
352 }
353 }
354 }
355 return shape;
356 }
357
OpLevelCostEstimator()358 OpLevelCostEstimator::OpLevelCostEstimator() {
359 // Syntactic sugar to build and return a lambda that takes an OpInfo and
360 // returns a cost.
361 typedef Status (OpLevelCostEstimator::*CostImpl)(const OpContext& op_context,
362 NodeCosts*) const;
363 auto wrap = [this](CostImpl impl)
364 -> std::function<Status(const OpContext&, NodeCosts*)> {
365 return [this, impl](const OpContext& op_context, NodeCosts* node_costs) {
366 return (this->*impl)(op_context, node_costs);
367 };
368 };
369
370 device_cost_impl_.emplace(kConv2d,
371 wrap(&OpLevelCostEstimator::PredictConv2D));
372 device_cost_impl_.emplace(
373 kConv2dBackpropFilter,
374 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
375 device_cost_impl_.emplace(
376 kConv2dBackpropInput,
377 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
378 device_cost_impl_.emplace(
379 kFusedConv2dBiasActivation,
380 wrap(&OpLevelCostEstimator::PredictFusedConv2DBiasActivation));
381 // reuse Conv2D for DepthwiseConv2dNative because the calculation is the
382 // same although the actual meaning of the parameters are different. See
383 // comments in PredictConv2D and related functions
384 device_cost_impl_.emplace(kDepthwiseConv2dNative,
385 wrap(&OpLevelCostEstimator::PredictConv2D));
386 device_cost_impl_.emplace(
387 kDepthwiseConv2dNativeBackpropFilter,
388 wrap(&OpLevelCostEstimator::PredictConv2DBackpropFilter));
389 device_cost_impl_.emplace(
390 kDepthwiseConv2dNativeBackpropInput,
391 wrap(&OpLevelCostEstimator::PredictConv2DBackpropInput));
392 device_cost_impl_.emplace(kMatMul,
393 wrap(&OpLevelCostEstimator::PredictMatMul));
394 device_cost_impl_.emplace(kSparseMatMul,
395 wrap(&OpLevelCostEstimator::PredictMatMul));
396 device_cost_impl_.emplace(
397 kSparseTensorDenseMatMul,
398 wrap(&OpLevelCostEstimator::PredictSparseTensorDenseMatMul));
399 device_cost_impl_.emplace(kBatchMatMul,
400 wrap(&OpLevelCostEstimator::PredictBatchMatMul));
401 device_cost_impl_.emplace(kBatchMatMulV2,
402 wrap(&OpLevelCostEstimator::PredictBatchMatMul));
403 device_cost_impl_.emplace(kQuantizedMatMul,
404 wrap(&OpLevelCostEstimator::PredictMatMul));
405 device_cost_impl_.emplace(kQuantizedMatMulV2,
406 wrap(&OpLevelCostEstimator::PredictMatMul));
407 device_cost_impl_.emplace(kXlaEinsum,
408 wrap(&OpLevelCostEstimator::PredictEinsum));
409 device_cost_impl_.emplace(kEinsum,
410 wrap(&OpLevelCostEstimator::PredictEinsum));
411
412 device_cost_impl_.emplace(kNoOp, wrap(&OpLevelCostEstimator::PredictNoOp));
413 device_cost_impl_.emplace(kGuaranteeConst,
414 wrap(&OpLevelCostEstimator::PredictNoOp));
415
416 device_cost_impl_.emplace(kGather,
417 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
418 device_cost_impl_.emplace(kGatherNd,
419 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
420 device_cost_impl_.emplace(kGatherV2,
421 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
422 device_cost_impl_.emplace(kScatterAdd,
423 wrap(&OpLevelCostEstimator::PredictScatter));
424 device_cost_impl_.emplace(kScatterDiv,
425 wrap(&OpLevelCostEstimator::PredictScatter));
426 device_cost_impl_.emplace(kScatterMax,
427 wrap(&OpLevelCostEstimator::PredictScatter));
428 device_cost_impl_.emplace(kScatterMin,
429 wrap(&OpLevelCostEstimator::PredictScatter));
430 device_cost_impl_.emplace(kScatterMul,
431 wrap(&OpLevelCostEstimator::PredictScatter));
432 device_cost_impl_.emplace(kScatterSub,
433 wrap(&OpLevelCostEstimator::PredictScatter));
434 device_cost_impl_.emplace(kScatterUpdate,
435 wrap(&OpLevelCostEstimator::PredictScatter));
436
437 device_cost_impl_.emplace(kSlice,
438 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
439 device_cost_impl_.emplace(kStridedSlice,
440 wrap(&OpLevelCostEstimator::PredictGatherOrSlice));
441
442 device_cost_impl_.emplace(kPlaceholder,
443 wrap(&OpLevelCostEstimator::PredictIdentity));
444 device_cost_impl_.emplace(kIdentity,
445 wrap(&OpLevelCostEstimator::PredictIdentity));
446 device_cost_impl_.emplace(kIdentityN,
447 wrap(&OpLevelCostEstimator::PredictIdentity));
448 device_cost_impl_.emplace(kRefIdentity,
449 wrap(&OpLevelCostEstimator::PredictIdentity));
450 device_cost_impl_.emplace(kStopGradient,
451 wrap(&OpLevelCostEstimator::PredictIdentity));
452 device_cost_impl_.emplace(kPreventGradient,
453 wrap(&OpLevelCostEstimator::PredictIdentity));
454 device_cost_impl_.emplace(kReshape,
455 wrap(&OpLevelCostEstimator::PredictIdentity));
456 device_cost_impl_.emplace(kRecv,
457 wrap(&OpLevelCostEstimator::PredictIdentity));
458 device_cost_impl_.emplace(kSend,
459 wrap(&OpLevelCostEstimator::PredictIdentity));
460 device_cost_impl_.emplace(kSwitch,
461 wrap(&OpLevelCostEstimator::PredictIdentity));
462 device_cost_impl_.emplace(kMerge,
463 wrap(&OpLevelCostEstimator::PredictIdentity));
464 device_cost_impl_.emplace(kEnter,
465 wrap(&OpLevelCostEstimator::PredictIdentity));
466 device_cost_impl_.emplace(kExit,
467 wrap(&OpLevelCostEstimator::PredictIdentity));
468 device_cost_impl_.emplace(kNextIteration,
469 wrap(&OpLevelCostEstimator::PredictIdentity));
470 device_cost_impl_.emplace(kBitCast,
471 wrap(&OpLevelCostEstimator::PredictIdentity));
472
473 device_cost_impl_.emplace(kConcatV2,
474 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
475 device_cost_impl_.emplace(kDataFormatVecPermute,
476 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
477 device_cost_impl_.emplace(kDepthToSpace,
478 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
479 device_cost_impl_.emplace(kExpandDims,
480 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
481 device_cost_impl_.emplace(kFill,
482 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
483 device_cost_impl_.emplace(kOneHot,
484 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
485 device_cost_impl_.emplace(kPack,
486 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
487 device_cost_impl_.emplace(kRange,
488 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
489 device_cost_impl_.emplace(kSpaceToDepth,
490 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
491 device_cost_impl_.emplace(kSplit,
492 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
493 device_cost_impl_.emplace(kSqueeze,
494 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
495 device_cost_impl_.emplace(kTranspose,
496 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
497 device_cost_impl_.emplace(kTile,
498 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
499 device_cost_impl_.emplace(kUnpack,
500 wrap(&OpLevelCostEstimator::PredictPureMemoryOp));
501
502 device_cost_impl_.emplace(kRank,
503 wrap(&OpLevelCostEstimator::PredictMetadata));
504 device_cost_impl_.emplace(kShape,
505 wrap(&OpLevelCostEstimator::PredictMetadata));
506 device_cost_impl_.emplace(kShapeN,
507 wrap(&OpLevelCostEstimator::PredictMetadata));
508 device_cost_impl_.emplace(kSize,
509 wrap(&OpLevelCostEstimator::PredictMetadata));
510 device_cost_impl_.emplace(kMaxPool,
511 wrap(&OpLevelCostEstimator::PredictMaxPool));
512 device_cost_impl_.emplace(kMaxPoolGrad,
513 wrap(&OpLevelCostEstimator::PredictMaxPoolGrad));
514 device_cost_impl_.emplace(kAvgPool,
515 wrap(&OpLevelCostEstimator::PredictAvgPool));
516 device_cost_impl_.emplace(kAvgPoolGrad,
517 wrap(&OpLevelCostEstimator::PredictAvgPoolGrad));
518 device_cost_impl_.emplace(kFusedBatchNorm,
519 wrap(&OpLevelCostEstimator::PredictFusedBatchNorm));
520 device_cost_impl_.emplace(
521 kFusedBatchNormGrad,
522 wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad));
523 device_cost_impl_.emplace(kSoftmax,
524 wrap(&OpLevelCostEstimator::PredictSoftmax));
525 device_cost_impl_.emplace(kResizeBilinear,
526 wrap(&OpLevelCostEstimator::PredictResizeBilinear));
527 device_cost_impl_.emplace(kCropAndResize,
528 wrap(&OpLevelCostEstimator::PredictCropAndResize));
529 device_cost_impl_.emplace(
530 kAssignVariableOp, wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
531 device_cost_impl_.emplace(
532 kAssignAddVariableOp,
533 wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
534 device_cost_impl_.emplace(
535 kAssignSubVariableOp,
536 wrap(&OpLevelCostEstimator::PredictAssignVariableOps));
537 device_cost_impl_.emplace(kAddN, wrap(&OpLevelCostEstimator::PredictNaryOp));
538
539 persistent_ops_ = {
540 kConst, kVariable, kVariableV2, kAutoReloadVariable,
541 kVarHandleOp, kReadVariableOp, kVarHandlesOp, kReadVariablesOp};
542
543 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
544
545 // Quantize = apply min and max bounds, multiply by scale factor and round.
546 const int quantize_v2_cost =
547 EIGEN_COST(scalar_product_op<float>) + EIGEN_COST(scalar_max_op<float>) +
548 EIGEN_COST(scalar_min_op<float>) + EIGEN_COST(scalar_round_op<float>);
549 const int quantize_and_dequantize_v2_cost =
550 quantize_v2_cost + EIGEN_COST(scalar_product_op<float>);
551
552 // Unary ops alphabetically sorted
553 elementwise_ops_.emplace("Acos", EIGEN_COST(scalar_acos_op<float>));
554 elementwise_ops_.emplace("All", EIGEN_COST(scalar_boolean_and_op));
555 elementwise_ops_.emplace("ArgMax", EIGEN_COST(scalar_max_op<float>));
556 elementwise_ops_.emplace("Asin", EIGEN_COST(scalar_asin_op<float>));
557 elementwise_ops_.emplace("Atan", EIGEN_COST(scalar_atan_op<float>));
558 elementwise_ops_.emplace("Atan2", EIGEN_COST(scalar_quotient_op<float>) +
559 EIGEN_COST(scalar_atan_op<float>));
560 // For now, we use Eigen cost model for float to int16 cast as an example
561 // case; Eigen cost model is zero when src and dst types are identical,
562 // and it uses AddCost (1) when different. We may implement a separate
563 // cost functions for cast ops, using the actual input and output types.
564 elementwise_ops_.emplace(
565 "Cast", Eigen::internal::functor_traits<
566 Eigen::internal::scalar_cast_op<float, int16>>::Cost);
567 elementwise_ops_.emplace("Ceil", EIGEN_COST(scalar_ceil_op<float>));
568 elementwise_ops_.emplace("Cos", EIGEN_COST(scalar_cos_op<float>));
569 elementwise_ops_.emplace("Dequantize", EIGEN_COST(scalar_product_op<float>));
570 elementwise_ops_.emplace("Erf", 1);
571 elementwise_ops_.emplace("Erfc", 1);
572 elementwise_ops_.emplace("Exp", EIGEN_COST(scalar_exp_op<float>));
573 elementwise_ops_.emplace("Expm1", EIGEN_COST(scalar_expm1_op<float>));
574 elementwise_ops_.emplace("Floor", EIGEN_COST(scalar_floor_op<float>));
575 elementwise_ops_.emplace("Inv", EIGEN_COST(scalar_inverse_op<float>));
576 elementwise_ops_.emplace("InvGrad", 1);
577 elementwise_ops_.emplace("Lgamma", 1);
578 elementwise_ops_.emplace("Log", EIGEN_COST(scalar_log_op<float>));
579 elementwise_ops_.emplace("Log1p", EIGEN_COST(scalar_log1p_op<float>));
580 elementwise_ops_.emplace("Max", EIGEN_COST(scalar_max_op<float>));
581 elementwise_ops_.emplace("Min", EIGEN_COST(scalar_min_op<float>));
582 elementwise_ops_.emplace("Neg", EIGEN_COST(scalar_opposite_op<float>));
583 elementwise_ops_.emplace("Prod", EIGEN_COST(scalar_product_op<float>));
584 elementwise_ops_.emplace("QuantizeAndDequantizeV2",
585 quantize_and_dequantize_v2_cost);
586 elementwise_ops_.emplace("QuantizeAndDequantizeV4",
587 quantize_and_dequantize_v2_cost);
588 elementwise_ops_.emplace("QuantizedSigmoid",
589 EIGEN_COST(scalar_logistic_op<float>));
590 elementwise_ops_.emplace("QuantizeV2", quantize_v2_cost);
591 elementwise_ops_.emplace("Reciprocal", EIGEN_COST(scalar_inverse_op<float>));
592 elementwise_ops_.emplace("Relu", EIGEN_COST(scalar_max_op<float>));
593 elementwise_ops_.emplace("Relu6", EIGEN_COST(scalar_max_op<float>));
594 elementwise_ops_.emplace("Rint", 1);
595 elementwise_ops_.emplace("Round", EIGEN_COST(scalar_round_op<float>));
596 elementwise_ops_.emplace("Rsqrt", EIGEN_COST(scalar_rsqrt_op<float>));
597 elementwise_ops_.emplace("Sigmoid", EIGEN_COST(scalar_logistic_op<float>));
598 elementwise_ops_.emplace("Sign", EIGEN_COST(scalar_sign_op<float>));
599 elementwise_ops_.emplace("Sin", EIGEN_COST(scalar_sin_op<float>));
600 elementwise_ops_.emplace("Sqrt", EIGEN_COST(scalar_sqrt_op<float>));
601 elementwise_ops_.emplace("Square", EIGEN_COST(scalar_square_op<float>));
602 elementwise_ops_.emplace("Sum", EIGEN_COST(scalar_sum_op<float>));
603 elementwise_ops_.emplace("Tan", EIGEN_COST(scalar_tan_op<float>));
604 elementwise_ops_.emplace("Tanh", EIGEN_COST(scalar_tanh_op<float>));
605 elementwise_ops_.emplace("TopKV2", EIGEN_COST(scalar_max_op<float>));
606 // Binary ops alphabetically sorted
607 elementwise_ops_.emplace("Add", EIGEN_COST(scalar_sum_op<float>));
608 elementwise_ops_.emplace("AddV2", EIGEN_COST(scalar_sum_op<float>));
609 elementwise_ops_.emplace("ApproximateEqual", 1);
610 elementwise_ops_.emplace("BiasAdd", EIGEN_COST(scalar_sum_op<float>));
611 elementwise_ops_.emplace("QuantizedBiasAdd",
612 EIGEN_COST(scalar_sum_op<float>));
613 elementwise_ops_.emplace("Div", EIGEN_COST(scalar_quotient_op<float>));
614 elementwise_ops_.emplace("Equal", 1);
615 elementwise_ops_.emplace("FloorDiv", EIGEN_COST(scalar_quotient_op<float>));
616 elementwise_ops_.emplace("FloorMod", EIGEN_COST(scalar_mod_op<float>));
617 elementwise_ops_.emplace("Greater", 1);
618 elementwise_ops_.emplace("GreaterEqual", 1);
619 elementwise_ops_.emplace("Less", 1);
620 elementwise_ops_.emplace("LessEqual", 1);
621 elementwise_ops_.emplace("LogicalAnd", EIGEN_COST(scalar_boolean_and_op));
622 elementwise_ops_.emplace("LogicalNot", 1);
623 elementwise_ops_.emplace("LogicalOr", EIGEN_COST(scalar_boolean_or_op));
624 elementwise_ops_.emplace("Maximum", EIGEN_COST(scalar_max_op<float>));
625 elementwise_ops_.emplace("Minimum", EIGEN_COST(scalar_min_op<float>));
626 elementwise_ops_.emplace("Mod", EIGEN_COST(scalar_mod_op<float>));
627 elementwise_ops_.emplace("Mul", EIGEN_COST(scalar_product_op<float>));
628 elementwise_ops_.emplace("NotEqual", 1);
629 elementwise_ops_.emplace("QuantizedAdd", EIGEN_COST(scalar_sum_op<float>));
630 elementwise_ops_.emplace("QuantizedMul",
631 EIGEN_COST(scalar_product_op<float>));
632 elementwise_ops_.emplace("RealDiv", EIGEN_COST(scalar_quotient_op<float>));
633 elementwise_ops_.emplace("ReluGrad", EIGEN_COST(scalar_max_op<float>));
634 elementwise_ops_.emplace("Select", EIGEN_COST(scalar_boolean_or_op));
635 elementwise_ops_.emplace("SelectV2", EIGEN_COST(scalar_boolean_or_op));
636 elementwise_ops_.emplace("SquaredDifference",
637 EIGEN_COST(scalar_square_op<float>) +
638 EIGEN_COST(scalar_difference_op<float>));
639 elementwise_ops_.emplace("Sub", EIGEN_COST(scalar_difference_op<float>));
640 elementwise_ops_.emplace("TruncateDiv",
641 EIGEN_COST(scalar_quotient_op<float>));
642 elementwise_ops_.emplace("TruncateMod", EIGEN_COST(scalar_mod_op<float>));
643 elementwise_ops_.emplace("Where", 1);
644
645 #undef EIGEN_COST
646
647 // By default, use sum of memory_time and compute_time for execution_time.
648 compute_memory_overlap_ = false;
649 }
650
PredictCosts(const OpContext & op_context) const651 Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
652 Costs costs;
653 NodeCosts node_costs;
654 if (PredictNodeCosts(op_context, &node_costs).ok()) {
655 if (node_costs.has_costs) {
656 return node_costs.costs;
657 }
658 // Convert NodeCosts to Costs.
659 if (node_costs.minimum_cost_op) {
660 // Override to minimum cost; Note that some ops with minimum cost may have
661 // non-typical device (e.g., channel for _Send), which may fail with
662 // GetDeviceInfo(), called from PredictOpCountBasedCost(). Make sure we
663 // directly set minimum values to Costs here, not calling
664 // PredictOpCountBasedCost().
665 costs.compute_time = kMinComputeTime;
666 costs.execution_time = kMinComputeTime;
667 costs.memory_time = 0;
668 costs.intermediate_memory_time = 0;
669 costs.intermediate_memory_read_time = 0;
670 costs.intermediate_memory_write_time = 0;
671 } else {
672 // Convert NodeCosts to Costs.
673 costs = PredictOpCountBasedCost(
674 node_costs.num_compute_ops, node_costs.num_total_read_bytes(),
675 node_costs.num_total_write_bytes(), op_context.op_info);
676 }
677 VLOG(1) << "Operation " << op_context.op_info.op() << " takes "
678 << costs.execution_time.count() << " ns.";
679 // Copy additional stats from NodeCosts to Costs.
680 costs.max_memory = node_costs.max_memory;
681 costs.persistent_memory = node_costs.persistent_memory;
682 costs.temporary_memory = node_costs.temporary_memory;
683 costs.inaccurate = node_costs.inaccurate;
684 costs.num_ops_with_unknown_shapes =
685 node_costs.num_nodes_with_unknown_shapes;
686 costs.num_ops_total = node_costs.num_nodes;
687 return costs;
688 }
689 // Errors during node cost estimate.
690 LOG(WARNING) << "Error in PredictCost() for the op: "
691 << op_context.op_info.ShortDebugString();
692 costs = Costs::ZeroCosts(/*inaccurate=*/true);
693 costs.num_ops_with_unknown_shapes = node_costs.num_nodes_with_unknown_shapes;
694 return costs;
695 }
696
PredictNodeCosts(const OpContext & op_context,NodeCosts * node_costs) const697 Status OpLevelCostEstimator::PredictNodeCosts(const OpContext& op_context,
698 NodeCosts* node_costs) const {
699 const auto& op_info = op_context.op_info;
700 auto it = device_cost_impl_.find(op_info.op());
701 if (it != device_cost_impl_.end()) {
702 std::function<Status(const OpContext&, NodeCosts*)> estimator = it->second;
703 return estimator(op_context, node_costs);
704 }
705
706 if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
707 return PredictVariable(op_context, node_costs);
708 }
709
710 if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
711 return PredictCwiseOp(op_context, node_costs);
712 }
713
714 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
715
716 node_costs->num_nodes_with_unknown_op_type = 1;
717 return PredictCostOfAnUnknownOp(op_context, node_costs);
718 }
719
720 // This method assumes a typical system composed of CPUs and GPUs, connected
721 // through PCIe. To define device info more precisely, override this method.
GetDeviceInfo(const DeviceProperties & device) const722 DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
723 const DeviceProperties& device) const {
724 double gflops = -1;
725 double gb_per_sec = -1;
726
727 if (device.type() == "CPU") {
728 // Check if vector instructions are available, and refine performance
729 // prediction based on this.
730 // Frequencies are stored in MHz in the DeviceProperties.
731 gflops = device.num_cores() * device.frequency() * 1e-3;
732 if (gb_per_sec < 0) {
733 if (device.bandwidth() > 0) {
734 gb_per_sec = device.bandwidth() / 1e6;
735 } else {
736 gb_per_sec = 32;
737 }
738 }
739 } else if (device.type() == "GPU") {
740 const auto& device_env = device.environment();
741 auto it = device_env.find("architecture");
742 if (it != device_env.end()) {
743 const std::string architecture = device_env.at("architecture");
744 int cores_per_multiprocessor;
745 if (architecture < "3") {
746 // Fermi
747 cores_per_multiprocessor = 32;
748 } else if (architecture < "4") {
749 // Kepler
750 cores_per_multiprocessor = 192;
751 } else if (architecture < "6") {
752 // Maxwell
753 cores_per_multiprocessor = 128;
754 } else {
755 // Pascal (compute capability version 6) and Volta (compute capability
756 // version 7)
757 cores_per_multiprocessor = 64;
758 }
759 gflops = device.num_cores() * device.frequency() * 1e-3 *
760 cores_per_multiprocessor * kOpsPerMac;
761 if (device.bandwidth() > 0) {
762 gb_per_sec = device.bandwidth() / 1e6;
763 } else {
764 gb_per_sec = 100;
765 }
766 } else {
767 // Architecture is not available (ex: pluggable device), return default
768 // value.
769 gflops = 100; // Dummy value;
770 gb_per_sec = 12; // default PCIe x16 gen3.
771 }
772 } else {
773 LOG_EVERY_N(WARNING, 1000) << "Unknown device type: " << device.type()
774 << ", assuming PCIe between CPU and GPU.";
775 gflops = 1; // Dummy value; data transfer ops would not have compute ops.
776 gb_per_sec = 12; // default PCIe x16 gen3.
777 }
778 VLOG(1) << "Device: " << device.type() << " gflops: " << gflops
779 << " gb_per_sec: " << gb_per_sec;
780
781 return DeviceInfo(gflops, gb_per_sec);
782 }
783
PredictCwiseOp(const OpContext & op_context,NodeCosts * node_costs) const784 Status OpLevelCostEstimator::PredictCwiseOp(const OpContext& op_context,
785 NodeCosts* node_costs) const {
786 const auto& op_info = op_context.op_info;
787 bool found_unknown_shapes = false;
788 // For element-wise operations, op count is the element count of any input. We
789 // use the count for the largest input here to be more robust in case that the
790 // shape is unknown or partially known for other input.
791 int64_t op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
792 // If output shape is available, try to use the element count calculated from
793 // that.
794 if (op_info.outputs_size() > 0) {
795 op_count = std::max(
796 op_count,
797 CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
798 }
799 // Calculate the output shape possibly resulting from broadcasting.
800 if (op_info.inputs_size() >= 2) {
801 op_count = std::max(op_count, CwiseOutputElementCount(op_info));
802 }
803
804 int op_cost = 1;
805 auto it = elementwise_ops_.find(op_info.op());
806 if (it != elementwise_ops_.end()) {
807 op_cost = it->second;
808 } else {
809 return errors::InvalidArgument("Not a cwise op: ", op_info.op());
810 }
811
812 return PredictDefaultNodeCosts(op_count * op_cost, op_context,
813 &found_unknown_shapes, node_costs);
814 }
815
PredictCostOfAnUnknownOp(const OpContext & op_context,NodeCosts * node_costs) const816 Status OpLevelCostEstimator::PredictCostOfAnUnknownOp(
817 const OpContext& op_context, NodeCosts* node_costs) const {
818 // Don't assume the operation is cwise, return cost based on input/output size
819 // and admit that it is inaccurate...
820 bool found_unknown_shapes = false;
821 node_costs->inaccurate = true;
822 return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
823 node_costs);
824 }
825
PredictOpCountBasedCost(double operations,const OpInfo & op_info) const826 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
827 double operations, const OpInfo& op_info) const {
828 bool unknown_shapes = false;
829 const double input_size = CalculateInputSize(op_info, &unknown_shapes);
830 const double output_size = CalculateOutputSize(op_info, &unknown_shapes);
831 Costs costs =
832 PredictOpCountBasedCost(operations, input_size, output_size, op_info);
833 costs.inaccurate = unknown_shapes;
834 costs.num_ops_with_unknown_shapes = unknown_shapes;
835 costs.max_memory = output_size;
836 return costs;
837 }
838
PredictOpCountBasedCost(double operations,double input_io_bytes,double output_io_bytes,const OpInfo & op_info) const839 Costs OpLevelCostEstimator::PredictOpCountBasedCost(
840 double operations, double input_io_bytes, double output_io_bytes,
841 const OpInfo& op_info) const {
842 double total_io_bytes = input_io_bytes + output_io_bytes;
843 const DeviceInfo device_info = GetDeviceInfo(op_info.device());
844 if (device_info.gigaops <= 0 || device_info.gb_per_sec <= 0 ||
845 device_info.intermediate_read_gb_per_sec <= 0 ||
846 device_info.intermediate_write_gb_per_sec <= 0) {
847 VLOG(1) << "BAD DEVICE. Op:" << op_info.op()
848 << " device type:" << op_info.device().type()
849 << " device model:" << op_info.device().model();
850 }
851
852 Costs::NanoSeconds compute_cost(std::ceil(operations / device_info.gigaops));
853 VLOG(1) << "Op:" << op_info.op() << " GOps:" << operations / 1e9
854 << " Compute Time (ns):" << compute_cost.count();
855
856 Costs::NanoSeconds memory_cost(
857 std::ceil(total_io_bytes / device_info.gb_per_sec));
858 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
859 << " Memory Time (ns):" << memory_cost.count();
860
861 // Check if bytes > 0. If it's not and the bandwidth is set to infinity
862 // then the result would be undefined.
863 double intermediate_read_time =
864 (input_io_bytes > 0)
865 ? std::ceil(input_io_bytes / device_info.intermediate_read_gb_per_sec)
866 : 0;
867
868 double intermediate_write_time =
869 (output_io_bytes > 0)
870 ? std::ceil(output_io_bytes /
871 device_info.intermediate_write_gb_per_sec)
872 : 0;
873
874 Costs::NanoSeconds intermediate_memory_cost =
875 compute_memory_overlap_
876 ? std::max(intermediate_read_time, intermediate_write_time)
877 : (intermediate_read_time + intermediate_write_time);
878 VLOG(1) << "Op:" << op_info.op() << " Size (KB):" << (total_io_bytes) / 1e3
879 << " Intermediate Memory Time (ns):"
880 << intermediate_memory_cost.count();
881
882 Costs costs = Costs::ZeroCosts();
883 costs.compute_time = compute_cost;
884 costs.memory_time = memory_cost;
885 costs.intermediate_memory_time = intermediate_memory_cost;
886 costs.intermediate_memory_read_time =
887 Costs::NanoSeconds(intermediate_read_time);
888 costs.intermediate_memory_write_time =
889 Costs::NanoSeconds(intermediate_write_time);
890 CombineCostsAndUpdateExecutionTime(compute_memory_overlap_, &costs);
891 return costs;
892 }
893
CountConv2DOperations(const OpInfo & op_info,bool * found_unknown_shapes)894 int64_t OpLevelCostEstimator::CountConv2DOperations(
895 const OpInfo& op_info, bool* found_unknown_shapes) {
896 return CountConv2DOperations(op_info, nullptr, found_unknown_shapes);
897 }
898
899 // Helper to translate the positional arguments into named fields.
900 /* static */
901 OpLevelCostEstimator::ConvolutionDimensions
ConvolutionDimensionsFromInputs(const TensorShapeProto & original_image_shape,const TensorShapeProto & original_filter_shape,const OpInfo & op_info,bool * found_unknown_shapes)902 OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
903 const TensorShapeProto& original_image_shape,
904 const TensorShapeProto& original_filter_shape, const OpInfo& op_info,
905 bool* found_unknown_shapes) {
906 VLOG(2) << "op features: " << op_info.DebugString();
907 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
908 VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
909
910 int x_index, y_index, major_channel_index, minor_channel_index = -1;
911 const std::string& data_format = GetDataFormat(op_info);
912 if (data_format == "NCHW") {
913 major_channel_index = 1;
914 y_index = 2;
915 x_index = 3;
916 } else if (data_format == "NCHW_VECT_C") {
917 // Use NCHW_VECT_C
918 minor_channel_index = 1;
919 y_index = 2;
920 x_index = 3;
921 major_channel_index = 4;
922 } else {
923 // Use NHWC.
924 y_index = 1;
925 x_index = 2;
926 major_channel_index = 3;
927 }
928 const std::string& filter_format = GetFilterFormat(op_info);
929 int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
930 in_minor_channel_index = -1;
931 if (filter_format == "HWIO") {
932 filter_y_index = 0;
933 filter_x_index = 1;
934 in_major_channel_index = 2;
935 out_channel_index = 3;
936 } else if (filter_format == "OIHW_VECT_I") {
937 out_channel_index = 0;
938 in_minor_channel_index = 1;
939 filter_y_index = 2;
940 filter_x_index = 3;
941 in_major_channel_index = 4;
942 } else {
943 // Use OIHW
944 out_channel_index = 0;
945 in_major_channel_index = 1;
946 filter_y_index = 2;
947 filter_x_index = 3;
948 }
949
950 auto image_shape = MaybeGetMinimumShape(original_image_shape,
951 minor_channel_index >= 0 ? 5 : 4,
952 found_unknown_shapes);
953 auto filter_shape = MaybeGetMinimumShape(original_filter_shape,
954 in_minor_channel_index >= 0 ? 5 : 4,
955 found_unknown_shapes);
956 VLOG(2) << "Image shape: " << image_shape.DebugString();
957 VLOG(2) << "Filter shape: " << filter_shape.DebugString();
958
959 int64_t batch = image_shape.dim(0).size();
960 int64_t ix = image_shape.dim(x_index).size();
961 int64_t iy = image_shape.dim(y_index).size();
962 int64_t iz = minor_channel_index >= 0
963 ? image_shape.dim(minor_channel_index).size() *
964 image_shape.dim(major_channel_index).size()
965 : image_shape.dim(major_channel_index).size();
966 int64_t kx = filter_shape.dim(filter_x_index).size();
967 int64_t ky = filter_shape.dim(filter_y_index).size();
968 int64_t kz = in_minor_channel_index >= 0
969 ? filter_shape.dim(in_major_channel_index).size() *
970 filter_shape.dim(in_minor_channel_index).size()
971 : filter_shape.dim(in_major_channel_index).size();
972 std::vector<int64_t> strides = GetStrides(op_info);
973 const auto padding = GetPadding(op_info);
974 int64_t sx = strides[x_index];
975 int64_t sy = strides[y_index];
976 int64_t ox = GetOutputSize(ix, kx, sx, padding);
977 int64_t oy = GetOutputSize(iy, ky, sy, padding);
978 int64_t oz = filter_shape.dim(out_channel_index).size();
979 // Only check equality when both sizes are known (in other words, when
980 // neither is set to a minimum dimension size of 1).
981 if (iz != 1 && kz != 1) {
982 DCHECK_EQ(iz % kz, 0) << "Input channel " << iz
983 << " is not a multiple of filter channel " << kz
984 << ".";
985 if (iz % kz) {
986 *found_unknown_shapes = true;
987 }
988 } else {
989 iz = kz = std::max<int64_t>(iz, kz);
990 }
991 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
992 batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
993
994 VLOG(1) << "Batch Size:" << batch;
995 VLOG(1) << "Image Dims:" << ix << "," << iy;
996 VLOG(1) << "Input Depth:" << iz;
997 VLOG(1) << "Kernel Dims:" << kx << "," << ky;
998 VLOG(1) << "Kernel Depth:" << kz;
999 VLOG(1) << "Output Dims:" << ox << "," << oy;
1000 VLOG(1) << "Output Depth:" << oz;
1001 VLOG(1) << "Strides:" << sx << "," << sy;
1002 VLOG(1) << "Padding:" << (padding == Padding::VALID ? "VALID" : "SAME");
1003 return conv_dims;
1004 }
1005
CountConv2DOperations(const OpInfo & op_info,ConvolutionDimensions * conv_info,bool * found_unknown_shapes)1006 int64_t OpLevelCostEstimator::CountConv2DOperations(
1007 const OpInfo& op_info, ConvolutionDimensions* conv_info,
1008 bool* found_unknown_shapes) {
1009 DCHECK(op_info.op() == kConv2d || op_info.op() == kDepthwiseConv2dNative)
1010 << "Invalid Operation: not Conv2D nor DepthwiseConv2dNative";
1011
1012 if (op_info.inputs_size() < 2) { // Unexpected inputs.
1013 *found_unknown_shapes = true;
1014 return 0;
1015 }
1016
1017 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1018 op_info.inputs(0).shape(), op_info.inputs(1).shape(), op_info,
1019 found_unknown_shapes);
1020
1021 // in DepthwiseConv2dNative conv_dims.oz is actually the channel depth
1022 // multiplier; The effective output channel depth oz_effective is
1023 // conv_dims.iz * conv_dims.oz. thus # ops = N x H x W x oz_effective x 2RS.
1024 // Compare to Conv2D where # ops = N x H x W x kz x oz x 2RS,
1025 // oz = oz_effective, then Conv2D_ops / Depthwise_conv2d_native_ops = kz.
1026 int64_t ops = conv_dims.batch;
1027 ops *= conv_dims.ox * conv_dims.oy;
1028 ops *= conv_dims.kx * conv_dims.ky;
1029 if (op_info.op() == kConv2d) {
1030 ops *= conv_dims.kz * conv_dims.oz;
1031 } else {
1032 // To ensure output tensor dims to be correct for DepthwiseConv2DNative,
1033 // although ops are the same as Conv2D.
1034 conv_dims.oz *= conv_dims.iz;
1035 ops *= conv_dims.oz;
1036 }
1037 ops *= kOpsPerMac;
1038
1039 if (conv_info != nullptr) {
1040 *conv_info = conv_dims;
1041 }
1042 return ops;
1043 }
1044
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1045 int64_t OpLevelCostEstimator::CountMatMulOperations(
1046 const OpInfo& op_info, bool* found_unknown_shapes) {
1047 return CountMatMulOperations(op_info, nullptr, found_unknown_shapes);
1048 }
1049
1050 // TODO(nishantpatil): Create separate estimator for Sparse Matmul
CountMatMulOperations(const OpInfo & op_info,MatMulDimensions * mat_mul,bool * found_unknown_shapes)1051 int64_t OpLevelCostEstimator::CountMatMulOperations(
1052 const OpInfo& op_info, MatMulDimensions* mat_mul,
1053 bool* found_unknown_shapes) {
1054 double ops = 0;
1055
1056 if (op_info.inputs_size() < 2) {
1057 LOG(ERROR) << "Need 2 inputs but got " << op_info.inputs_size();
1058 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1059 *found_unknown_shapes = true;
1060 return 0;
1061 }
1062
1063 auto& a_matrix = op_info.inputs(0);
1064 auto& b_matrix = op_info.inputs(1);
1065
1066 bool transpose_a = false;
1067 bool transpose_b = false;
1068
1069 double m_dim, n_dim, k_dim, k_dim_b = 0;
1070
1071 for (const auto& item : op_info.attr()) {
1072 VLOG(1) << "Key:" << item.first
1073 << " Value:" << SummarizeAttrValue(item.second);
1074 if (item.first == "transpose_a" && item.second.b() == true)
1075 transpose_a = true;
1076 if (item.first == "transpose_b" && item.second.b() == true)
1077 transpose_b = true;
1078 }
1079 VLOG(1) << "transpose_a:" << transpose_a;
1080 VLOG(1) << "transpose_b:" << transpose_b;
1081 auto a_matrix_shape =
1082 MaybeGetMinimumShape(a_matrix.shape(), 2, found_unknown_shapes);
1083 auto b_matrix_shape =
1084 MaybeGetMinimumShape(b_matrix.shape(), 2, found_unknown_shapes);
1085 if (transpose_a) {
1086 m_dim = a_matrix_shape.dim(1).size();
1087 k_dim = a_matrix_shape.dim(0).size();
1088 } else {
1089 m_dim = a_matrix_shape.dim(0).size();
1090 k_dim = a_matrix_shape.dim(1).size();
1091 }
1092 if (transpose_b) {
1093 k_dim_b = b_matrix_shape.dim(1).size();
1094 n_dim = b_matrix_shape.dim(0).size();
1095 } else {
1096 k_dim_b = b_matrix_shape.dim(0).size();
1097 n_dim = b_matrix_shape.dim(1).size();
1098 }
1099
1100 VLOG(1) << "M, N, K: " << m_dim << "," << n_dim << "," << k_dim;
1101 // Only check equality when both sizes are known (in other words, when
1102 // neither is set to a minimum dimension size of 1).
1103 if (k_dim_b != 1 && k_dim != 1 && k_dim_b != k_dim) {
1104 LOG(ERROR) << "Incompatible Matrix dimensions";
1105 return ops;
1106 } else {
1107 // One of k_dim and k_dim_b might be 1 (minimum dimension size).
1108 k_dim = std::max(k_dim, k_dim_b);
1109 }
1110
1111 ops = m_dim * n_dim * k_dim * 2;
1112 VLOG(1) << "Operations for Matmul: " << ops;
1113
1114 if (mat_mul != nullptr) {
1115 mat_mul->m = m_dim;
1116 mat_mul->n = n_dim;
1117 mat_mul->k = k_dim;
1118 }
1119 return ops;
1120 }
1121
GenerateBatchMatmulContextFromEinsum(const OpContext & einsum_context,OpContext * batch_matmul_context,bool * found_unknown_shapes) const1122 bool OpLevelCostEstimator::GenerateBatchMatmulContextFromEinsum(
1123 const OpContext& einsum_context, OpContext* batch_matmul_context,
1124 bool* found_unknown_shapes) const {
1125 // This auxiliary function transforms an einsum OpContext into its equivalent
1126 // Batch Matmul OpContext. The function returns a boolean, which determines
1127 // whether it was successful in generating the output OpContext or not.
1128
1129 // Einsum computes a generalized contraction between tensors of arbitrary
1130 // dimension as defined by the equation written in the Einstein summation
1131 // convention. The number of tensors in the computation and the number of
1132 // contractions can be arbitrarily long. The current model only contemplates
1133 // Einsum equations, which can be translated into a single BatchMatMul
1134 // operation. Einsum operations with more than two operands are not currently
1135 // supported. Subscripts where an axis appears more than once for a single
1136 // input and ellipsis are currently also excluded. See:
1137 // https://www.tensorflow.org/api_docs/python/tf/einsum
1138 // We distinguish four kinds of dimensions, depending on their placement in
1139 // the equation:
1140 // + B: Batch dimensions: Dimensions which appear in both operands and RHS.
1141 // + K: Contracting dimensions: These appear in both inputs but not RHS.
1142 // + M: Operand A dimensions: These appear in the first operand and the RHS.
1143 // + N: Operand B dimensions: These appear in the second operand and the RHS.
1144 // Then, the operation to estimate is BatchMatMul([B,M,K],[B,K,N])
1145
1146 if (batch_matmul_context == nullptr) {
1147 VLOG(1) << "Output context should not be a nullptr.";
1148 return false;
1149 }
1150 if (!IsEinsumCorrectlyFormed(einsum_context)) return false;
1151 const auto& op_info = einsum_context.op_info;
1152 std::vector<std::string> equation_split =
1153 absl::StrSplit(op_info.attr().find("equation")->second.s(), "->");
1154 std::vector<absl::string_view> input_split =
1155 absl::StrSplit(equation_split[0], ',');
1156 const auto& a_input = op_info.inputs(0);
1157 const auto& b_input = op_info.inputs(1);
1158 absl::string_view rhs_str = equation_split[1];
1159 absl::string_view a_input_str = input_split[0];
1160 absl::string_view b_input_str = input_split[1];
1161
1162 constexpr int kMatrixRank = 2;
1163
1164 bool a_input_shape_unknown = false;
1165 bool b_input_shape_unknown = false;
1166
1167 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1168 a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
1169 &a_input_shape_unknown);
1170 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1171 b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
1172 &b_input_shape_unknown);
1173
1174 *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1175 (a_input.shape().dim_size() < kMatrixRank) ||
1176 (b_input.shape().dim_size() < kMatrixRank);
1177
1178 OpInfo batch_matmul_op_info = op_info;
1179 batch_matmul_op_info.mutable_inputs()->Clear();
1180 batch_matmul_op_info.set_op("BatchMatMul");
1181
1182 AttrValue transpose_attribute;
1183 transpose_attribute.set_b(false);
1184 (*batch_matmul_op_info.mutable_attr())["transpose_a"] = transpose_attribute;
1185 (*batch_matmul_op_info.mutable_attr())["transpose_b"] = transpose_attribute;
1186
1187 OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs();
1188 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1189 a_matrix->set_dtype(a_input.dtype());
1190
1191 OpInfo::TensorProperties* b_matrix = batch_matmul_op_info.add_inputs();
1192 b_matrix->set_dtype(b_input.dtype());
1193 TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1194
1195 TensorShapeProto_Dim m_dim;
1196 TensorShapeProto_Dim n_dim;
1197 TensorShapeProto_Dim k_dim;
1198
1199 m_dim.set_size(1);
1200 n_dim.set_size(1);
1201 k_dim.set_size(1);
1202
1203 for (int i_idx = 0, a_input_str_size = a_input_str.size();
1204 i_idx < a_input_str_size; ++i_idx) {
1205 if (b_input_str.find(a_input_str[i_idx]) == std::string::npos) {
1206 if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1207 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1208 return false;
1209 }
1210
1211 m_dim.set_size(m_dim.size() * a_input_shape.dim(i_idx).size());
1212 continue;
1213 } else if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
1214 // The dimension does not appear in the RHS, therefore it is a contracting
1215 // dimension.
1216 k_dim.set_size(k_dim.size() * a_input_shape.dim(i_idx).size());
1217 continue;
1218 }
1219 // It appears in both input operands, therefore we place it as an outer
1220 // dimension for the Batch Matmul.
1221 *(a_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1222 *(b_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
1223 }
1224 for (int i_idx = 0, b_input_str_size = b_input_str.size();
1225 i_idx < b_input_str_size; ++i_idx) {
1226 if (a_input_str.find(b_input_str[i_idx]) == std::string::npos) {
1227 if (rhs_str.find(b_input_str[i_idx]) == std::string::npos) {
1228 VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
1229 return false;
1230 }
1231 n_dim.set_size(n_dim.size() * b_input_shape.dim(i_idx).size());
1232 }
1233 }
1234
1235 // The two inner-most dimensions of the Batch Matmul are added.
1236 *(a_matrix_shape->add_dim()) = m_dim;
1237 *(a_matrix_shape->add_dim()) = k_dim;
1238 *(b_matrix_shape->add_dim()) = k_dim;
1239 *(b_matrix_shape->add_dim()) = n_dim;
1240
1241 *batch_matmul_context = einsum_context;
1242 batch_matmul_context->op_info = batch_matmul_op_info;
1243 return true;
1244 }
1245
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes)1246 int64_t OpLevelCostEstimator::CountBatchMatMulOperations(
1247 const OpInfo& op_info, bool* found_unknown_shapes) {
1248 return CountBatchMatMulOperations(op_info, nullptr, found_unknown_shapes);
1249 }
1250
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes)1251 int64_t OpLevelCostEstimator::CountBatchMatMulOperations(
1252 const OpInfo& op_info, BatchMatMulDimensions* batch_mat_mul,
1253 bool* found_unknown_shapes) {
1254 if (op_info.op() != kBatchMatMul && op_info.op() != kBatchMatMulV2) {
1255 LOG(ERROR) << "Invalid Operation: " << op_info.op();
1256 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1257 *found_unknown_shapes = true;
1258 return 0;
1259 }
1260 if (op_info.inputs_size() != 2) {
1261 LOG(ERROR) << "Expected 2 inputs but got " << op_info.inputs_size();
1262 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1263 *found_unknown_shapes = true;
1264 return 0;
1265 }
1266
1267 double ops = 0;
1268 const auto& a_input = op_info.inputs(0);
1269 const auto& b_input = op_info.inputs(1);
1270
1271 // BatchMatMul requires inputs of at least matrix shape (rank 2).
1272 // The two most minor dimensions of each input are matrices that
1273 // need to be multiplied together. The other dimensions determine
1274 // the number of such MatMuls. For example, if the BatchMatMul has
1275 // inputs of shape:
1276 // a_input_shape = [2, 3, 4, 5]
1277 // b_input_shape = [2, 3, 5, 6]
1278 // then there are 2*3 = 6 MatMuls of dimensions m = 4, k = 5, n = 6
1279 // in this BatchMatMul.
1280 const int matrix_rank = 2;
1281
1282 bool a_input_shape_unknown = false;
1283 bool b_input_shape_unknown = false;
1284
1285 TensorShapeProto a_input_shape = MaybeGetMinimumShape(
1286 a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
1287 &a_input_shape_unknown);
1288 TensorShapeProto b_input_shape = MaybeGetMinimumShape(
1289 b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
1290 &b_input_shape_unknown);
1291
1292 *found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
1293 (a_input.shape().dim_size() < matrix_rank) ||
1294 (b_input.shape().dim_size() < matrix_rank);
1295
1296 // Compute the number of matmuls as the max indicated at each dimension
1297 // by either input. Note that the shapes do not have to have
1298 // the same rank due to incompleteness.
1299 TensorShapeProto* bigger_rank_shape = &a_input_shape;
1300 TensorShapeProto* smaller_rank_shape = &b_input_shape;
1301 if (b_input_shape.dim_size() > a_input_shape.dim_size()) {
1302 bigger_rank_shape = &b_input_shape;
1303 smaller_rank_shape = &a_input_shape;
1304 }
1305 int num_matmuls = 1;
1306 for (int b_i = 0,
1307 s_i = smaller_rank_shape->dim_size() - bigger_rank_shape->dim_size();
1308 b_i < bigger_rank_shape->dim_size() - matrix_rank; ++b_i, ++s_i) {
1309 int b_dim = bigger_rank_shape->dim(b_i).size();
1310 int s_dim = 1;
1311 if (s_i >= 0) {
1312 s_dim = smaller_rank_shape->dim(s_i).size();
1313 }
1314 if (batch_mat_mul != nullptr) {
1315 batch_mat_mul->batch_dims.push_back(s_dim);
1316 }
1317 num_matmuls *= std::max(b_dim, s_dim);
1318 }
1319
1320 // Build the MatMul. Note that values are ignored here since we are just
1321 // counting ops (e.g. only shapes matter).
1322 OpInfo matmul_op_info;
1323 matmul_op_info.set_op("MatMul");
1324
1325 AttrValue transpose_a;
1326 transpose_a.set_b(false);
1327 if (op_info.attr().find("adj_x") != op_info.attr().end()) {
1328 transpose_a.set_b(op_info.attr().at("adj_x").b());
1329 }
1330 (*matmul_op_info.mutable_attr())["transpose_a"] = transpose_a;
1331
1332 AttrValue transpose_b;
1333 transpose_b.set_b(false);
1334 if (op_info.attr().find("adj_y") != op_info.attr().end()) {
1335 transpose_b.set_b(op_info.attr().at("adj_y").b());
1336 }
1337 (*matmul_op_info.mutable_attr())["transpose_b"] = transpose_b;
1338
1339 OpInfo::TensorProperties* a_matrix = matmul_op_info.add_inputs();
1340 a_matrix->set_dtype(a_input.dtype());
1341 TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
1342 for (int i = std::max(0, a_input_shape.dim_size() - matrix_rank);
1343 i < a_input_shape.dim_size(); ++i) {
1344 *(a_matrix_shape->add_dim()) = a_input_shape.dim(i);
1345 }
1346
1347 OpInfo::TensorProperties* b_matrix = matmul_op_info.add_inputs();
1348 b_matrix->set_dtype(b_input.dtype());
1349 TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
1350 for (int i = std::max(0, b_input_shape.dim_size() - matrix_rank);
1351 i < b_input_shape.dim_size(); ++i) {
1352 *(b_matrix_shape->add_dim()) = b_input_shape.dim(i);
1353 }
1354 if (batch_mat_mul != nullptr) {
1355 batch_mat_mul->matmul_dims.m = (transpose_a.b())
1356 ? a_matrix_shape->dim(1).size()
1357 : a_matrix_shape->dim(0).size();
1358 batch_mat_mul->matmul_dims.k = (transpose_a.b())
1359 ? a_matrix_shape->dim(0).size()
1360 : a_matrix_shape->dim(1).size();
1361 batch_mat_mul->matmul_dims.n = (transpose_b.b())
1362 ? b_matrix_shape->dim(0).size()
1363 : b_matrix_shape->dim(1).size();
1364 }
1365
1366 for (int i = 0; i < num_matmuls; ++i) {
1367 bool matmul_unknown_shapes = false;
1368 ops += CountMatMulOperations(matmul_op_info, &matmul_unknown_shapes);
1369 *found_unknown_shapes |= matmul_unknown_shapes;
1370 }
1371 return ops;
1372 }
1373
GetTensorShapeProtoFromTensorProto(const TensorProto & tensor_proto,TensorShapeProto * tensor_shape_proto)1374 bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto,
1375 TensorShapeProto* tensor_shape_proto) {
1376 tensor_shape_proto->Clear();
1377 // First convert TensorProto into Tensor class so that it correctly parses
1378 // data values within TensorProto (whether it's in int_val, int64_val,
1379 // tensor_content, or anything.
1380 Tensor tensor(tensor_proto.dtype());
1381 if (!tensor.FromProto(tensor_proto)) {
1382 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1383 << "failed to parse TensorProto: "
1384 << tensor_proto.DebugString();
1385 return false;
1386 }
1387 if (tensor.dims() != 1) {
1388 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1389 << "tensor is not 1D: " << tensor.dims();
1390 return false;
1391 }
1392 // Then, convert it back to TensorProto using AsProtoField, which makes sure
1393 // the data is in int_val, int64_val, or such repeated data fields, not in
1394 // tensor_content.
1395 TensorProto temp_tensor;
1396 tensor.AsProtoField(&temp_tensor);
1397
1398 #define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type) \
1399 do { \
1400 for (const auto& value : temp_tensor.type##_val()) { \
1401 tensor_shape_proto->add_dim()->set_size(value); \
1402 } \
1403 } while (0)
1404
1405 if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 ||
1406 tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) {
1407 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int);
1408 } else if (tensor.dtype() == DT_INT64) {
1409 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64);
1410 } else if (tensor.dtype() == DT_UINT32) {
1411 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32);
1412 } else if (tensor.dtype() == DT_UINT64) {
1413 TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64);
1414 } else {
1415 LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- "
1416 << "Unsupported dtype: " << tensor.dtype();
1417 return false;
1418 }
1419 #undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO
1420
1421 return true;
1422 }
1423
1424 // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations.
CountConv2DBackpropInputOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1425 int64_t OpLevelCostEstimator::CountConv2DBackpropInputOperations(
1426 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1427 bool* found_unknown_shapes) {
1428 int64_t ops = 0;
1429
1430 DCHECK(op_info.op() == kConv2dBackpropInput ||
1431 op_info.op() == kDepthwiseConv2dNativeBackpropInput)
1432 << "Invalid Operation: not kConv2dBackpropInput nor"
1433 "kDepthwiseConv2dNativeBackpropInput";
1434
1435 if (op_info.inputs_size() < 2) {
1436 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1437 *found_unknown_shapes = true;
1438 return ops;
1439 }
1440
1441 TensorShapeProto input_shape;
1442 bool shape_found = false;
1443 if (op_info.inputs(0).has_value()) {
1444 const TensorProto& value = op_info.inputs(0).value();
1445 shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape);
1446 }
1447 if (!shape_found && op_info.outputs_size() == 1) {
1448 input_shape = op_info.outputs(0).shape();
1449 shape_found = true;
1450 }
1451 if (!shape_found) {
1452 // Set the minimum filter size that's feasible.
1453 input_shape.Clear();
1454 for (int i = 0; i < 4; ++i) {
1455 input_shape.add_dim()->set_size(1);
1456 }
1457 *found_unknown_shapes = true;
1458 }
1459
1460 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1461 input_shape, op_info.inputs(1).shape(), op_info, found_unknown_shapes);
1462
1463 ops = conv_dims.batch;
1464 ops *= conv_dims.ox * conv_dims.oy;
1465 ops *= conv_dims.kx * conv_dims.ky;
1466 if (op_info.op() == kConv2dBackpropInput) {
1467 ops *= conv_dims.kz * conv_dims.oz;
1468 } else {
1469 // conv_dims always use forward path definition regardless
1470 conv_dims.oz *= conv_dims.iz;
1471 ops *= conv_dims.oz;
1472 }
1473 ops *= kOpsPerMac;
1474
1475 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
1476
1477 if (returned_conv_dims != nullptr) {
1478 *returned_conv_dims = conv_dims;
1479 }
1480 return ops;
1481 }
1482
CountConv2DBackpropFilterOperations(const OpInfo & op_info,ConvolutionDimensions * returned_conv_dims,bool * found_unknown_shapes)1483 int64_t OpLevelCostEstimator::CountConv2DBackpropFilterOperations(
1484 const OpInfo& op_info, ConvolutionDimensions* returned_conv_dims,
1485 bool* found_unknown_shapes) {
1486 int64_t ops = 0;
1487
1488 DCHECK(op_info.op() == kConv2dBackpropFilter ||
1489 op_info.op() == kDepthwiseConv2dNativeBackpropFilter)
1490 << "Invalid Operation: not kConv2dBackpropFilter nor"
1491 "kDepthwiseConv2dNativeBackpropFilter";
1492
1493 TensorShapeProto filter_shape;
1494 bool shape_found = false;
1495 if (op_info.inputs_size() >= 2 && op_info.inputs(1).has_value()) {
1496 const TensorProto& value = op_info.inputs(1).value();
1497 shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape);
1498 }
1499 if (!shape_found && op_info.outputs_size() == 1) {
1500 filter_shape = op_info.outputs(0).shape();
1501 shape_found = true;
1502 }
1503 if (!shape_found) {
1504 // Set the minimum filter size that's feasible.
1505 filter_shape.Clear();
1506 for (int i = 0; i < 4; ++i) {
1507 filter_shape.add_dim()->set_size(1);
1508 }
1509 *found_unknown_shapes = true;
1510 }
1511
1512 if (op_info.inputs_size() < 1) {
1513 // TODO(pcma): Try to separate invalid inputs from unknown shapes
1514 *found_unknown_shapes = true;
1515 return ops;
1516 }
1517 ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs(
1518 op_info.inputs(0).shape(), filter_shape, op_info, found_unknown_shapes);
1519
1520 ops = conv_dims.batch;
1521 ops *= conv_dims.ox * conv_dims.oy;
1522 ops *= conv_dims.kx * conv_dims.ky;
1523 if (op_info.op() == kConv2dBackpropFilter) {
1524 ops *= conv_dims.kz * conv_dims.oz;
1525 } else {
1526 // conv_dims always use forward path definition regardless
1527 conv_dims.oz *= conv_dims.iz;
1528 ops *= conv_dims.oz;
1529 }
1530 ops *= kOpsPerMac;
1531 VLOG(1) << "Operations for" << op_info.op() << " " << ops;
1532
1533 if (returned_conv_dims != nullptr) {
1534 *returned_conv_dims = conv_dims;
1535 }
1536 return ops;
1537 }
1538
CalculateTensorElementCount(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1539 int64_t OpLevelCostEstimator::CalculateTensorElementCount(
1540 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1541 VLOG(2) << " with " << DataTypeString(tensor.dtype()) << " tensor of shape "
1542 << tensor.shape().DebugString();
1543 int64_t tensor_size = 1;
1544 int num_dims = std::max(1, tensor.shape().dim_size());
1545 auto tensor_shape =
1546 MaybeGetMinimumShape(tensor.shape(), num_dims, found_unknown_shapes);
1547 for (const auto& dim : tensor_shape.dim()) {
1548 int64_t new_tensor_size = MultiplyWithoutOverflow(tensor_size, dim.size());
1549 if (new_tensor_size < 0) {
1550 VLOG(1) << "Overflow encountered when computing element count of a "
1551 "tensor, multiplying "
1552 << tensor_size << " with " << dim.size();
1553 return -1;
1554 }
1555 tensor_size = new_tensor_size;
1556 }
1557 return tensor_size;
1558 }
1559
CalculateTensorSize(const OpInfo::TensorProperties & tensor,bool * found_unknown_shapes)1560 int64_t OpLevelCostEstimator::CalculateTensorSize(
1561 const OpInfo::TensorProperties& tensor, bool* found_unknown_shapes) {
1562 int64_t count = CalculateTensorElementCount(tensor, found_unknown_shapes);
1563 int size = DataTypeSize(BaseType(tensor.dtype()));
1564 VLOG(2) << "Count: " << count << " DataTypeSize: " << size;
1565 int64_t tensor_size = MultiplyWithoutOverflow(count, size);
1566 if (tensor_size < 0) {
1567 VLOG(1) << "Overflow encountered when computing tensor size, multiplying "
1568 << count << " with " << size;
1569 return -1;
1570 }
1571 return tensor_size;
1572 }
1573
CalculateInputSize(const OpInfo & op_info,bool * found_unknown_shapes)1574 int64_t OpLevelCostEstimator::CalculateInputSize(const OpInfo& op_info,
1575 bool* found_unknown_shapes) {
1576 int64_t total_input_size = 0;
1577 for (auto& input : op_info.inputs()) {
1578 int64_t input_size = CalculateTensorSize(input, found_unknown_shapes);
1579 total_input_size += input_size;
1580 VLOG(1) << "Input Size: " << input_size
1581 << " Total Input Size:" << total_input_size;
1582 }
1583 return total_input_size;
1584 }
1585
CalculateInputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1586 std::vector<int64_t> OpLevelCostEstimator::CalculateInputTensorSize(
1587 const OpInfo& op_info, bool* found_unknown_shapes) {
1588 std::vector<int64_t> input_tensor_size;
1589 input_tensor_size.reserve(op_info.inputs().size());
1590 for (auto& input : op_info.inputs()) {
1591 input_tensor_size.push_back(
1592 CalculateTensorSize(input, found_unknown_shapes));
1593 }
1594 return input_tensor_size;
1595 }
1596
CalculateLargestInputCount(const OpInfo & op_info,bool * found_unknown_shapes)1597 int64_t OpLevelCostEstimator::CalculateLargestInputCount(
1598 const OpInfo& op_info, bool* found_unknown_shapes) {
1599 int64_t largest_input_count = 0;
1600 for (auto& input : op_info.inputs()) {
1601 int64_t input_count =
1602 CalculateTensorElementCount(input, found_unknown_shapes);
1603 if (input_count > largest_input_count) {
1604 largest_input_count = input_count;
1605 }
1606 VLOG(1) << "Input Count: " << input_count
1607 << " Largest Input Count:" << largest_input_count;
1608 }
1609 return largest_input_count;
1610 }
1611
CalculateOutputSize(const OpInfo & op_info,bool * found_unknown_shapes)1612 int64_t OpLevelCostEstimator::CalculateOutputSize(const OpInfo& op_info,
1613 bool* found_unknown_shapes) {
1614 int64_t total_output_size = 0;
1615 // Use float as default for calculations.
1616 for (const auto& output : op_info.outputs()) {
1617 DataType dt = output.dtype();
1618 const auto& original_output_shape = output.shape();
1619 int64_t output_size = DataTypeSize(BaseType(dt));
1620 int num_dims = std::max(1, original_output_shape.dim_size());
1621 auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1622 found_unknown_shapes);
1623 for (const auto& dim : output_shape.dim()) {
1624 int64_t new_output_size =
1625 MultiplyWithoutOverflow(output_size, dim.size());
1626 if (new_output_size < 0) {
1627 VLOG(1) << "Overflow encountered when estimating cost, multiplying "
1628 << output_size << " with " << dim.size();
1629 return -1;
1630 }
1631 output_size = new_output_size;
1632 }
1633 total_output_size += output_size;
1634 VLOG(1) << "Output Size: " << output_size
1635 << " Total Output Size:" << total_output_size;
1636 }
1637 return total_output_size;
1638 }
1639
CalculateOutputTensorSize(const OpInfo & op_info,bool * found_unknown_shapes)1640 std::vector<int64_t> OpLevelCostEstimator::CalculateOutputTensorSize(
1641 const OpInfo& op_info, bool* found_unknown_shapes) {
1642 std::vector<int64_t> output_tensor_size;
1643 output_tensor_size.reserve(op_info.outputs().size());
1644 // Use float as default for calculations.
1645 for (const auto& output : op_info.outputs()) {
1646 DataType dt = output.dtype();
1647 const auto& original_output_shape = output.shape();
1648 int64_t output_size = DataTypeSize(BaseType(dt));
1649 int num_dims = std::max(1, original_output_shape.dim_size());
1650 auto output_shape = MaybeGetMinimumShape(original_output_shape, num_dims,
1651 found_unknown_shapes);
1652 for (const auto& dim : output_shape.dim()) {
1653 output_size *= dim.size();
1654 }
1655 output_tensor_size.push_back(output_size);
1656 }
1657 return output_tensor_size;
1658 }
1659
PredictDefaultNodeCosts(const int64_t num_compute_ops,const OpContext & op_context,bool * found_unknown_shapes,NodeCosts * node_costs)1660 Status OpLevelCostEstimator::PredictDefaultNodeCosts(
1661 const int64_t num_compute_ops, const OpContext& op_context,
1662 bool* found_unknown_shapes, NodeCosts* node_costs) {
1663 const auto& op_info = op_context.op_info;
1664 node_costs->num_compute_ops = num_compute_ops;
1665 node_costs->num_input_bytes_accessed =
1666 CalculateInputTensorSize(op_info, found_unknown_shapes);
1667 node_costs->num_output_bytes_accessed =
1668 CalculateOutputTensorSize(op_info, found_unknown_shapes);
1669 node_costs->max_memory = node_costs->num_total_output_bytes();
1670 if (*found_unknown_shapes) {
1671 node_costs->inaccurate = true;
1672 node_costs->num_nodes_with_unknown_shapes = 1;
1673 }
1674 return OkStatus();
1675 }
1676
HasZeroDim(const OpInfo & op_info)1677 bool HasZeroDim(const OpInfo& op_info) {
1678 for (int i = 0; i < op_info.inputs_size(); ++i) {
1679 const auto& input = op_info.inputs(i);
1680 for (int j = 0; j < input.shape().dim_size(); ++j) {
1681 const auto& dim = input.shape().dim(j);
1682 if (dim.size() == 0) {
1683 VLOG(1) << "Convolution config has zero dim "
1684 << op_info.ShortDebugString();
1685 return true;
1686 }
1687 }
1688 }
1689 return false;
1690 }
1691
PredictConv2D(const OpContext & op_context,NodeCosts * node_costs) const1692 Status OpLevelCostEstimator::PredictConv2D(const OpContext& op_context,
1693 NodeCosts* node_costs) const {
1694 const auto& op_info = op_context.op_info;
1695 if (HasZeroDim(op_info)) {
1696 node_costs->num_nodes_with_unknown_shapes = 1;
1697 return errors::InvalidArgument("Conv2D op includes zero dimension: ",
1698 op_info.ShortDebugString());
1699 }
1700 bool found_unknown_shapes = false;
1701 int64_t num_compute_ops =
1702 CountConv2DOperations(op_info, &found_unknown_shapes);
1703 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1704 &found_unknown_shapes, node_costs);
1705 }
1706
PredictConv2DBackpropInput(const OpContext & op_context,NodeCosts * node_costs) const1707 Status OpLevelCostEstimator::PredictConv2DBackpropInput(
1708 const OpContext& op_context, NodeCosts* node_costs) const {
1709 const auto& op_info = op_context.op_info;
1710 if (HasZeroDim(op_info)) {
1711 node_costs->num_nodes_with_unknown_shapes = 1;
1712 return errors::InvalidArgument(
1713 "Conv2DBackpropInput op includes zero dimension",
1714 op_info.ShortDebugString());
1715 }
1716 bool found_unknown_shapes = false;
1717 int64_t num_compute_ops = CountConv2DBackpropInputOperations(
1718 op_info, nullptr, &found_unknown_shapes);
1719 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1720 &found_unknown_shapes, node_costs);
1721 }
1722
PredictConv2DBackpropFilter(const OpContext & op_context,NodeCosts * node_costs) const1723 Status OpLevelCostEstimator::PredictConv2DBackpropFilter(
1724 const OpContext& op_context, NodeCosts* node_costs) const {
1725 const auto& op_info = op_context.op_info;
1726 if (HasZeroDim(op_info)) {
1727 node_costs->num_nodes_with_unknown_shapes = 1;
1728 return errors::InvalidArgument(
1729 "Conv2DBackpropFilter op includes zero dimension",
1730 op_info.ShortDebugString());
1731 }
1732 bool found_unknown_shapes = false;
1733 int64_t num_compute_ops = CountConv2DBackpropFilterOperations(
1734 op_info, nullptr, &found_unknown_shapes);
1735 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1736 &found_unknown_shapes, node_costs);
1737 }
1738
PredictFusedConv2DBiasActivation(const OpContext & op_context,NodeCosts * node_costs) const1739 Status OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
1740 const OpContext& op_context, NodeCosts* node_costs) const {
1741 // FusedConv2DBiasActivation computes a fused kernel which implements:
1742 // 2D convolution, adds side input with separate scaling on convolution and
1743 // side inputs, then adds bias, and finally applies the ReLU activation
1744 // function to the result:
1745 //
1746 // Input -> Conv2D -> Add -> BiasAdd -> ReLU
1747 // ^ ^ ^
1748 // Filter Side Input Bias
1749 //
1750 // Note that when adding the side input, the operation multiplies the output
1751 // of Conv2D by conv_input_scale, confusingly, and the side_input by
1752 // side_input_scale.
1753 //
1754 // Note that in the special case that side_input_scale is 0, which we infer
1755 // from side_input having dimensions [], we skip that addition operation.
1756 //
1757 // For more information, see
1758 // contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
1759
1760 // TODO(yaozhang): Support NHWC_VECT_W.
1761 std::string data_format = GetDataFormat(op_context.op_info);
1762 if (data_format != "NCHW" && data_format != "NHWC" &&
1763 data_format != "NCHW_VECT_C") {
1764 return errors::InvalidArgument(
1765 "Unsupported data format (", data_format,
1766 ") for op: ", op_context.op_info.ShortDebugString());
1767 }
1768 std::string filter_format = GetFilterFormat(op_context.op_info);
1769 if (filter_format != "HWIO" && filter_format != "OIHW" &&
1770 filter_format != "OIHW_VECT_I") {
1771 return errors::InvalidArgument(
1772 "Unsupported filter format (", filter_format,
1773 ") for op: ", op_context.op_info.ShortDebugString());
1774 }
1775
1776 auto& conv_input = op_context.op_info.inputs(0);
1777 auto& filter = op_context.op_info.inputs(1);
1778 auto& side_input = op_context.op_info.inputs(3);
1779 auto& conv_input_scale = op_context.op_info.inputs(4);
1780 auto& side_input_scale = op_context.op_info.inputs(5);
1781
1782 // Manually compute our convolution dimensions.
1783 bool found_unknown_shapes = false;
1784 auto dims = ConvolutionDimensionsFromInputs(
1785 conv_input.shape(), filter.shape(), op_context.op_info,
1786 &found_unknown_shapes);
1787 OpInfo::TensorProperties output;
1788 if (data_format == "NCHW" || data_format == "NCHW_VECT_C") {
1789 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oz, dims.oy, dims.ox});
1790 } else if (data_format == "NHWC") {
1791 output = DescribeTensor(DT_FLOAT, {dims.batch, dims.oy, dims.ox, dims.oz});
1792 }
1793
1794 // Add the operations the fused op always computes.
1795 std::vector<OpContext> component_ops = {
1796 FusedChildContext(op_context, "Conv2D", output, {conv_input, filter}),
1797 FusedChildContext(op_context, "Mul", output, {output, conv_input_scale}),
1798 FusedChildContext(
1799 op_context, "BiasAdd", output,
1800 {output, output}), // Note we're no longer using bias at all
1801 FusedChildContext(op_context, "Relu", output, {output})};
1802
1803 // Add our side_input iff it's non-empty.
1804 if (side_input.shape().dim_size() > 0) {
1805 component_ops.push_back(FusedChildContext(op_context, "Mul", side_input,
1806 {side_input, side_input_scale}));
1807 component_ops.push_back(FusedChildContext(
1808 op_context, "Add", output,
1809 {output, output})); // Note that we're not using side_input here
1810 }
1811
1812 // Construct an op_context which definitely has our output shape.
1813 auto op_context_with_output = op_context;
1814 op_context_with_output.op_info.mutable_outputs()->Clear();
1815 *op_context_with_output.op_info.mutable_outputs()->Add() = output;
1816
1817 // Construct component operations and run the cost computation.
1818 if (found_unknown_shapes) {
1819 node_costs->inaccurate = true;
1820 node_costs->num_nodes_with_unknown_shapes = 1;
1821 }
1822 return PredictFusedOp(op_context_with_output, component_ops, node_costs);
1823 }
1824
PredictMatMul(const OpContext & op_context,NodeCosts * node_costs) const1825 Status OpLevelCostEstimator::PredictMatMul(const OpContext& op_context,
1826 NodeCosts* node_costs) const {
1827 const auto& op_info = op_context.op_info;
1828 bool found_unknown_shapes = false;
1829 int64_t num_compute_ops =
1830 CountMatMulOperations(op_info, &found_unknown_shapes);
1831 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1832 &found_unknown_shapes, node_costs);
1833 }
1834
PredictEinsum(const OpContext & op_context,NodeCosts * node_costs) const1835 Status OpLevelCostEstimator::PredictEinsum(const OpContext& op_context,
1836 NodeCosts* node_costs) const {
1837 const auto& op_info = op_context.op_info;
1838
1839 auto it = op_info.attr().find("equation");
1840 if (it == op_info.attr().end()) {
1841 return errors::InvalidArgument("Einsum op doesn't have equation attr: ",
1842 op_info.ShortDebugString());
1843 }
1844
1845 OpContext batch_matmul_op_context;
1846 bool found_unknown_shapes = false;
1847 bool success = GenerateBatchMatmulContextFromEinsum(
1848 op_context, &batch_matmul_op_context, &found_unknown_shapes);
1849 if (found_unknown_shapes) {
1850 node_costs->inaccurate = true;
1851 node_costs->num_nodes_with_unknown_shapes = 1;
1852 }
1853 if (!success) {
1854 return PredictCostOfAnUnknownOp(op_context, node_costs);
1855 }
1856 return PredictNodeCosts(batch_matmul_op_context, node_costs);
1857 }
1858
PredictSparseTensorDenseMatMul(const OpContext & op_context,NodeCosts * node_costs) const1859 Status OpLevelCostEstimator::PredictSparseTensorDenseMatMul(
1860 const OpContext& op_context, NodeCosts* node_costs) const {
1861 const auto& op_info = op_context.op_info;
1862 bool found_unknown_shapes = false;
1863 // input[0]: indices in sparse matrix a
1864 // input[1]: values in sparse matrix a
1865 // input[2]: shape of matrix a
1866 // input[3]: matrix b
1867 // See
1868 // https://github.com/tensorflow/tensorflow/blob/9a43dfeac5/tensorflow/core/ops/sparse_ops.cc#L85
1869 int64_t num_elems_in_a =
1870 CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
1871 auto b_matrix = op_info.inputs(3);
1872 auto b_matrix_shape =
1873 MaybeGetMinimumShape(b_matrix.shape(), 2, &found_unknown_shapes);
1874 int64_t n_dim = b_matrix_shape.dim(1).size();
1875
1876 // Each element in A is multiplied and added with an element from each column
1877 // in b.
1878 const int64_t op_count = kOpsPerMac * num_elems_in_a * n_dim;
1879
1880 int64_t a_indices_input_size =
1881 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
1882 int64_t a_values_input_size =
1883 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
1884 int64_t a_shape_input_size =
1885 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
1886 int64_t b_input_size =
1887 num_elems_in_a * n_dim * DataTypeSize(BaseType(b_matrix.dtype()));
1888 int64_t output_size = CalculateOutputSize(op_info, &found_unknown_shapes);
1889
1890 node_costs->num_compute_ops = op_count;
1891 node_costs->num_input_bytes_accessed = {a_indices_input_size,
1892 a_values_input_size,
1893 a_shape_input_size, b_input_size};
1894 node_costs->num_output_bytes_accessed = {output_size};
1895 if (found_unknown_shapes) {
1896 node_costs->inaccurate = true;
1897 node_costs->num_nodes_with_unknown_shapes = 1;
1898 }
1899 return OkStatus();
1900 }
1901
PredictNoOp(const OpContext & op_context,NodeCosts * node_costs) const1902 Status OpLevelCostEstimator::PredictNoOp(const OpContext& op_context,
1903 NodeCosts* node_costs) const {
1904 const auto& op_info = op_context.op_info;
1905 VLOG(1) << "Op:" << op_info.op() << " Execution Time 0 (ns)";
1906 // By default, NodeCosts is initialized to zero ops and bytes.
1907 return OkStatus();
1908 }
1909
PredictPureMemoryOp(const OpContext & op_context,NodeCosts * node_costs) const1910 Status OpLevelCostEstimator::PredictPureMemoryOp(const OpContext& op_context,
1911 NodeCosts* node_costs) const {
1912 // Each output element is a copy of some element from input, with no required
1913 // computation, so just compute memory costs.
1914 bool found_unknown_shapes = false;
1915 node_costs->num_nodes_with_pure_memory_op = 1;
1916 return PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes,
1917 node_costs);
1918 }
1919
PredictIdentity(const OpContext & op_context,NodeCosts * node_costs) const1920 Status OpLevelCostEstimator::PredictIdentity(const OpContext& op_context,
1921 NodeCosts* node_costs) const {
1922 const auto& op_info = op_context.op_info;
1923 VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Identity";
1924 node_costs->minimum_cost_op = true;
1925 node_costs->num_compute_ops = kMinComputeOp;
1926 // Identity op internally pass input tensor buffer's pointer to the output
1927 // tensor buffer; no actual memory operation.
1928 node_costs->num_input_bytes_accessed = {0};
1929 node_costs->num_output_bytes_accessed = {0};
1930 bool inaccurate = false;
1931 node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1932 if (inaccurate) {
1933 node_costs->inaccurate = true;
1934 node_costs->num_nodes_with_unknown_shapes = 1;
1935 }
1936 return OkStatus();
1937 }
1938
PredictVariable(const OpContext & op_context,NodeCosts * node_costs) const1939 Status OpLevelCostEstimator::PredictVariable(const OpContext& op_context,
1940 NodeCosts* node_costs) const {
1941 const auto& op_info = op_context.op_info;
1942 VLOG(1) << "Op:" << op_info.op() << " Minimum cost for Variable";
1943 node_costs->minimum_cost_op = true;
1944 node_costs->num_compute_ops = kMinComputeOp;
1945 // Variables are persistent ops; initialized before step; hence, no memory
1946 // cost.
1947 node_costs->num_input_bytes_accessed = {0};
1948 node_costs->num_output_bytes_accessed = {0};
1949 bool inaccurate = false;
1950 node_costs->persistent_memory = CalculateOutputSize(op_info, &inaccurate);
1951 if (inaccurate) {
1952 node_costs->inaccurate = true;
1953 node_costs->num_nodes_with_unknown_shapes = 1;
1954 }
1955 return OkStatus();
1956 }
1957
PredictBatchMatMul(const OpContext & op_context,NodeCosts * node_costs) const1958 Status OpLevelCostEstimator::PredictBatchMatMul(const OpContext& op_context,
1959 NodeCosts* node_costs) const {
1960 const auto& op_info = op_context.op_info;
1961 bool found_unknown_shapes = false;
1962 int64_t num_compute_ops =
1963 CountBatchMatMulOperations(op_info, &found_unknown_shapes);
1964 return PredictDefaultNodeCosts(num_compute_ops, op_context,
1965 &found_unknown_shapes, node_costs);
1966 }
1967
PredictMetadata(const OpContext & op_context,NodeCosts * node_costs) const1968 Status OpLevelCostEstimator::PredictMetadata(const OpContext& op_context,
1969 NodeCosts* node_costs) const {
1970 const auto& op_info = op_context.op_info;
1971 node_costs->minimum_cost_op = true;
1972 node_costs->num_compute_ops = kMinComputeOp;
1973 node_costs->num_input_bytes_accessed = {0};
1974 node_costs->num_output_bytes_accessed = {0};
1975 bool inaccurate = false;
1976 node_costs->max_memory = CalculateOutputSize(op_info, &inaccurate);
1977 if (inaccurate) {
1978 node_costs->inaccurate = true;
1979 node_costs->num_nodes_with_unknown_shapes = 1;
1980 }
1981 return OkStatus();
1982 }
1983
PredictGatherOrSlice(const OpContext & op_context,NodeCosts * node_costs) const1984 Status OpLevelCostEstimator::PredictGatherOrSlice(const OpContext& op_context,
1985 NodeCosts* node_costs) const {
1986 // Gather & Slice ops can have a very large input, but only access a small
1987 // part of it. For these op the size of the output determines the memory cost.
1988 const auto& op_info = op_context.op_info;
1989
1990 const int inputs_needed = op_info.op() == "Slice" ? 3 : 2;
1991 if (op_info.outputs_size() == 0 || op_info.inputs_size() < inputs_needed) {
1992 return errors::InvalidArgument(
1993 op_info.op(),
1994 " Op doesn't have valid input / output: ", op_info.ShortDebugString());
1995 }
1996
1997 bool unknown_shapes = false;
1998
1999 // Each output element is a copy of some element from input.
2000 // For roofline estimate we assume each copy has a unit cost.
2001 const int64_t op_count =
2002 CalculateTensorElementCount(op_info.outputs(0), &unknown_shapes);
2003 node_costs->num_compute_ops = op_count;
2004
2005 const int64_t output_size = CalculateOutputSize(op_info, &unknown_shapes);
2006 node_costs->num_output_bytes_accessed = {output_size};
2007
2008 node_costs->num_input_bytes_accessed.reserve(op_info.inputs().size());
2009 int64_t input_size = output_size;
2010 // Note that input(0) byte accessed is not equal to input(0) tensor size.
2011 // It's equal to the output size; though, input access is indexed gather or
2012 // slice (ignore duplicate indices).
2013 node_costs->num_input_bytes_accessed.push_back(input_size);
2014 int begin_input_index = 1;
2015 int end_input_index;
2016 if (op_info.op() == "Slice") {
2017 // Slice: 'input' (omitted), 'begin', 'size'
2018 end_input_index = 3;
2019 } else if (op_info.op() == "StridedSlice") {
2020 // StridedSlice: 'input' (omitted), 'begin', 'end', 'strides'
2021 end_input_index = 4;
2022 } else {
2023 // Gather, GatherV2, GatherNd: 'params' (omitted), 'indices'
2024 end_input_index = 2;
2025 }
2026 for (int i = begin_input_index; i < end_input_index; ++i) {
2027 node_costs->num_input_bytes_accessed.push_back(
2028 CalculateTensorElementCount(op_info.inputs(i), &unknown_shapes));
2029 }
2030 if (unknown_shapes) {
2031 node_costs->inaccurate = true;
2032 node_costs->num_nodes_with_unknown_shapes = 1;
2033 }
2034 return OkStatus();
2035 }
2036
PredictScatter(const OpContext & op_context,NodeCosts * node_costs) const2037 Status OpLevelCostEstimator::PredictScatter(const OpContext& op_context,
2038 NodeCosts* node_costs) const {
2039 // Scatter ops sparsely access a reference input and output tensor.
2040 const auto& op_info = op_context.op_info;
2041 bool found_unknown_shapes = false;
2042
2043 // input[0]: ref tensor that will be sparsely accessed
2044 // input[1]: indices - A tensor of indices into the first dimension of ref.
2045 // input[2]: updates where updates.shape = indices.shape + ref.shape[1:]
2046 // See
2047 // https://www.tensorflow.org/api_docs/python/tf/scatter_add and
2048 // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/ops/state_ops.cc#L146
2049
2050 const int64_t num_indices =
2051 CalculateTensorElementCount(op_info.inputs(1), &found_unknown_shapes);
2052
2053 int64_t num_elems_in_ref_per_index = 1;
2054 auto ref_tensor_shape = MaybeGetMinimumShape(
2055 op_info.inputs(0).shape(), op_info.inputs(0).shape().dim_size(),
2056 &found_unknown_shapes);
2057 for (int i = 1; i < ref_tensor_shape.dim().size(); ++i) {
2058 num_elems_in_ref_per_index *= ref_tensor_shape.dim(i).size();
2059 }
2060 const int64_t op_count = num_indices * num_elems_in_ref_per_index;
2061 node_costs->num_compute_ops = op_count;
2062
2063 // Sparsely access ref so input size depends on the number of operations
2064 int64_t ref_input_size =
2065 op_count * DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2066 int64_t indices_input_size =
2067 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2068 int64_t updates_input_size =
2069 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2070 node_costs->num_input_bytes_accessed = {ref_input_size, indices_input_size,
2071 updates_input_size};
2072
2073 // Sparsely access ref so output size depends on the number of operations
2074 int64_t output_size =
2075 op_count * DataTypeSize(BaseType(op_info.outputs(0).dtype()));
2076 node_costs->num_output_bytes_accessed = {output_size};
2077
2078 if (found_unknown_shapes) {
2079 node_costs->inaccurate = true;
2080 node_costs->num_nodes_with_unknown_shapes = 1;
2081 }
2082 return OkStatus();
2083 }
2084
PredictFusedOp(const OpContext & op_context,const std::vector<OpContext> & fused_op_contexts,NodeCosts * node_costs) const2085 Status OpLevelCostEstimator::PredictFusedOp(
2086 const OpContext& op_context,
2087 const std::vector<OpContext>& fused_op_contexts,
2088 NodeCosts* node_costs) const {
2089 // Note that PredictDefaultNodeCosts will get the correct memory costs from
2090 // the node's inputs and outputs; but we don't want to have to re-implement
2091 // the logic for computing the operation count of each of our component
2092 // operations here; so we simply add the compute times of each component
2093 // operation, then update the cost.
2094 bool found_unknown_shapes = false;
2095 Status s =
2096 PredictDefaultNodeCosts(0, op_context, &found_unknown_shapes, node_costs);
2097
2098 for (auto& fused_op : fused_op_contexts) {
2099 NodeCosts fused_node_costs;
2100 s.Update(PredictNodeCosts(fused_op, &fused_node_costs));
2101 node_costs->num_compute_ops += fused_node_costs.num_compute_ops;
2102 node_costs->inaccurate |= fused_node_costs.inaccurate;
2103 // Set, not increment. Note that we are predicting the cost of one fused
2104 // node, not a function node composed of many nodes.
2105 node_costs->num_nodes_with_unknown_shapes |=
2106 fused_node_costs.num_nodes_with_unknown_shapes;
2107 node_costs->num_nodes_with_unknown_op_type |=
2108 fused_node_costs.num_nodes_with_unknown_op_type;
2109 node_costs->num_nodes_with_pure_memory_op |=
2110 fused_node_costs.num_nodes_with_pure_memory_op;
2111 }
2112
2113 return OkStatus();
2114 }
2115
2116 /* static */
FusedChildContext(const OpContext & parent,const std::string & op_name,const OpInfo::TensorProperties & output,const std::vector<OpInfo::TensorProperties> & inputs)2117 OpContext OpLevelCostEstimator::FusedChildContext(
2118 const OpContext& parent, const std::string& op_name,
2119 const OpInfo::TensorProperties& output,
2120 const std::vector<OpInfo::TensorProperties>& inputs) {
2121 // Setup the base parameters of our new context.
2122 OpContext new_context;
2123 new_context.name = op_name;
2124 new_context.device_name = parent.device_name;
2125 new_context.op_info = parent.op_info;
2126 new_context.op_info.set_op(op_name);
2127
2128 // Setup the inputs of our new context.
2129 new_context.op_info.mutable_inputs()->Clear();
2130 for (const auto& input : inputs) {
2131 *new_context.op_info.mutable_inputs()->Add() = input;
2132 }
2133
2134 // Setup the output of our new context.
2135 new_context.op_info.mutable_outputs()->Clear();
2136 *new_context.op_info.mutable_outputs()->Add() = output;
2137
2138 return new_context;
2139 }
2140
2141 /* static */
DescribeTensor(DataType type,const std::vector<int64_t> & dims)2142 OpInfo::TensorProperties OpLevelCostEstimator::DescribeTensor(
2143 DataType type, const std::vector<int64_t>& dims) {
2144 OpInfo::TensorProperties ret;
2145 ret.set_dtype(type);
2146
2147 auto shape = ret.mutable_shape();
2148 for (const int dim : dims) {
2149 shape->add_dim()->set_size(dim);
2150 }
2151
2152 return ret;
2153 }
2154
2155 /* static */
2156 StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
OpDimensionsFromInputs(const TensorShapeProto & original_image_shape,const OpInfo & op_info,bool * found_unknown_shapes)2157 OpLevelCostEstimator::OpDimensionsFromInputs(
2158 const TensorShapeProto& original_image_shape, const OpInfo& op_info,
2159 bool* found_unknown_shapes) {
2160 VLOG(2) << "op features: " << op_info.DebugString();
2161 VLOG(2) << "Original image shape: " << original_image_shape.DebugString();
2162 auto image_shape =
2163 MaybeGetMinimumShape(original_image_shape, 4, found_unknown_shapes);
2164 VLOG(2) << "Image shape: " << image_shape.DebugString();
2165
2166 int x_index, y_index, channel_index;
2167 const std::string& data_format = GetDataFormat(op_info);
2168 if (data_format == "NCHW") {
2169 channel_index = 1;
2170 y_index = 2;
2171 x_index = 3;
2172 } else {
2173 y_index = 1;
2174 x_index = 2;
2175 channel_index = 3;
2176 }
2177 int64_t batch = image_shape.dim(0).size();
2178 int64_t ix = image_shape.dim(x_index).size();
2179 int64_t iy = image_shape.dim(y_index).size();
2180 int64_t iz = image_shape.dim(channel_index).size();
2181
2182 // Note that FusedBatchNorm doesn't have ksize attr, but GetKernelSize returns
2183 // {1, 1, 1, 1} in that case.
2184 std::vector<int64_t> ksize = GetKernelSize(op_info);
2185 int64_t kx = ksize[x_index];
2186 int64_t ky = ksize[y_index];
2187 // These ops don't support groupwise operation, therefore kz == iz.
2188 int64_t kz = iz;
2189
2190 std::vector<int64_t> strides = GetStrides(op_info);
2191 int64_t sx = strides[x_index];
2192 int64_t sy = strides[y_index];
2193 if (sx == 0 || sy == 0) {
2194 return errors::InvalidArgument(
2195 "Stride must be > 0 for Height and Width, but got (", sy, ", ", sx,
2196 ")");
2197 }
2198 const auto padding = GetPadding(op_info);
2199
2200 int64_t ox = GetOutputSize(ix, kx, sx, padding);
2201 int64_t oy = GetOutputSize(iy, ky, sy, padding);
2202 int64_t oz = iz;
2203
2204 OpLevelCostEstimator::ConvolutionDimensions conv_dims = {
2205 batch, ix, iy, iz, kx, ky, kz, oz, ox, oy, sx, sy, padding};
2206 return conv_dims;
2207 }
2208
PredictMaxPool(const OpContext & op_context,NodeCosts * node_costs) const2209 Status OpLevelCostEstimator::PredictMaxPool(const OpContext& op_context,
2210 NodeCosts* node_costs) const {
2211 bool found_unknown_shapes = false;
2212 const auto& op_info = op_context.op_info;
2213 // x: op_info.inputs(0)
2214 TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2215 OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2216 &found_unknown_shapes));
2217 // kx * ky - 1 comparisons per output (kx * xy > 1)
2218 // or 1 copy per output (kx * k1 = 1).
2219 int per_output_ops = dims.kx * dims.ky == 1 ? 1 : dims.kx * dims.ky - 1;
2220 int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * per_output_ops;
2221 node_costs->num_compute_ops = ops;
2222
2223 int64_t input_size = 0;
2224 if (dims.ky >= dims.sy) {
2225 input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2226 } else { // dims.ky < dims.sy
2227 // Vertical stride is larger than vertical kernel; assuming row-major
2228 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2229 // others are not used for output).
2230 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2231 input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2232 }
2233 node_costs->num_input_bytes_accessed = {input_size};
2234 const int64_t output_size =
2235 CalculateOutputSize(op_info, &found_unknown_shapes);
2236 node_costs->num_output_bytes_accessed = {output_size};
2237 node_costs->max_memory = output_size;
2238 if (found_unknown_shapes) {
2239 node_costs->inaccurate = true;
2240 node_costs->num_nodes_with_unknown_shapes = 1;
2241 }
2242 return OkStatus();
2243 }
2244
PredictMaxPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2245 Status OpLevelCostEstimator::PredictMaxPoolGrad(const OpContext& op_context,
2246 NodeCosts* node_costs) const {
2247 bool found_unknown_shapes = false;
2248 const auto& op_info = op_context.op_info;
2249 // x: op_info.inputs(0)
2250 // y: op_info.inputs(1)
2251 // y_grad: op_info.inputs(2)
2252 if (op_info.inputs_size() < 3) {
2253 return errors::InvalidArgument("MaxPoolGrad op has invalid inputs: ",
2254 op_info.ShortDebugString());
2255 }
2256
2257 TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2258 OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2259 &found_unknown_shapes));
2260
2261 int64_t ops = 0;
2262 if (dims.kx == 1 && dims.ky == 1) {
2263 // 1x1 window. No need to know which input was max.
2264 ops = dims.batch * dims.ix * dims.iy * dims.iz;
2265 } else if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2266 // Non-overlapping window: re-run maxpool, then assign zero or y_grad.
2267 ops = dims.batch * dims.iz *
2268 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy);
2269 } else {
2270 // Overlapping window: initialize with zeros, re-run maxpool, then
2271 // accumulate y_gad to proper x_grad locations.
2272 ops = dims.batch * dims.iz *
2273 (dims.ox * dims.oy * (dims.kx * dims.ky - 1) + dims.ix * dims.iy * 2);
2274 }
2275 node_costs->num_compute_ops = ops;
2276
2277 // Just read x and y_grad; no need to read y as we assume MaxPoolGrad re-run
2278 // MaxPool internally.
2279 const int64_t input0_size =
2280 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2281 const int64_t input2_size =
2282 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2283 node_costs->num_input_bytes_accessed = {input0_size, 0, input2_size};
2284 // Write x_grad; size equal to x.
2285 const int64_t output_size =
2286 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2287 node_costs->num_output_bytes_accessed = {output_size};
2288 node_costs->max_memory = output_size;
2289
2290 if (found_unknown_shapes) {
2291 node_costs->inaccurate = true;
2292 node_costs->num_nodes_with_unknown_shapes = 1;
2293 }
2294 return OkStatus();
2295 }
2296
2297 /* This predict function handles three types of tensorflow ops
2298 * AssignVariableOp/AssignAddVariableOp/AssignSubVariableOp, broadcasting
2299 * was not possible for these ops, therefore the input tensor's shapes is
2300 * enough to compute the cost */
PredictAssignVariableOps(const OpContext & op_context,NodeCosts * node_costs) const2301 Status OpLevelCostEstimator::PredictAssignVariableOps(
2302 const OpContext& op_context, NodeCosts* node_costs) const {
2303 bool found_unknown_shapes = false;
2304 const auto& op_info = op_context.op_info;
2305 /* First input of these ops are reference to the assignee. */
2306 if (op_info.inputs_size() != 2) {
2307 return errors::InvalidArgument("AssignVariable op has invalid input: ",
2308 op_info.ShortDebugString());
2309 }
2310
2311 const int64_t ops = op_info.op() == kAssignVariableOp
2312 ? 0
2313 : CalculateTensorElementCount(op_info.inputs(1),
2314 &found_unknown_shapes);
2315 node_costs->num_compute_ops = ops;
2316 const int64_t input_size = CalculateInputSize(op_info, &found_unknown_shapes);
2317 node_costs->num_input_bytes_accessed = {input_size};
2318 // TODO(dyoon): check these ops' behavior whether it writes data;
2319 // Op itself doesn't have output tensor, but it may modify the input (ref or
2320 // resource). Maybe use node_costs->internal_write_bytes.
2321 node_costs->num_output_bytes_accessed = {0};
2322 if (found_unknown_shapes) {
2323 node_costs->inaccurate = true;
2324 node_costs->num_nodes_with_unknown_shapes = 1;
2325 }
2326 return OkStatus();
2327 }
2328
PredictAvgPool(const OpContext & op_context,NodeCosts * node_costs) const2329 Status OpLevelCostEstimator::PredictAvgPool(const OpContext& op_context,
2330 NodeCosts* node_costs) const {
2331 bool found_unknown_shapes = false;
2332 const auto& op_info = op_context.op_info;
2333 // x: op_info.inputs(0)
2334 TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2335 OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2336 &found_unknown_shapes));
2337
2338 // kx * ky - 1 additions and 1 multiplication per output.
2339 int64_t ops = dims.batch * dims.ox * dims.oy * dims.oz * dims.kx * dims.ky;
2340 node_costs->num_compute_ops = ops;
2341
2342 int64_t input_size;
2343 if (dims.ky >= dims.sy) {
2344 input_size = CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2345 } else { // dims.ky < dims.sy
2346 // vertical stride is larger than vertical kernel; assuming row-major
2347 // format, skip unnecessary rows (or read every kx rows per sy rows, as the
2348 // others are not used for output).
2349 const auto data_size = DataTypeSize(BaseType(op_info.inputs(0).dtype()));
2350 input_size = data_size * dims.batch * dims.ix * dims.ky * dims.oy * dims.iz;
2351 }
2352 node_costs->num_input_bytes_accessed = {input_size};
2353
2354 const int64_t output_size =
2355 CalculateOutputSize(op_info, &found_unknown_shapes);
2356 node_costs->num_output_bytes_accessed = {output_size};
2357 node_costs->max_memory = output_size;
2358
2359 if (found_unknown_shapes) {
2360 node_costs->inaccurate = true;
2361 node_costs->num_nodes_with_unknown_shapes = 1;
2362 }
2363 return OkStatus();
2364 }
2365
PredictAvgPoolGrad(const OpContext & op_context,NodeCosts * node_costs) const2366 Status OpLevelCostEstimator::PredictAvgPoolGrad(const OpContext& op_context,
2367 NodeCosts* node_costs) const {
2368 bool found_unknown_shapes = false;
2369 const auto& op_info = op_context.op_info;
2370 // x's shape: op_info.inputs(0)
2371 // y_grad: op_info.inputs(1)
2372
2373 // Extract x_shape from op_info.inputs(0).value() or op_info.outputs(0).
2374 bool shape_found = false;
2375 TensorShapeProto x_shape;
2376 if (op_info.inputs_size() >= 1 && op_info.inputs(0).has_value()) {
2377 const TensorProto& value = op_info.inputs(0).value();
2378 shape_found = GetTensorShapeProtoFromTensorProto(value, &x_shape);
2379 }
2380 if (!shape_found && op_info.outputs_size() > 0) {
2381 x_shape = op_info.outputs(0).shape();
2382 shape_found = true;
2383 }
2384 if (!shape_found) {
2385 // Set the minimum shape that's feasible.
2386 x_shape.Clear();
2387 for (int i = 0; i < 4; ++i) {
2388 x_shape.add_dim()->set_size(1);
2389 }
2390 found_unknown_shapes = true;
2391 }
2392
2393 TF_ASSIGN_OR_RETURN(
2394 ConvolutionDimensions dims,
2395 OpDimensionsFromInputs(x_shape, op_info, &found_unknown_shapes));
2396
2397 int64_t ops = 0;
2398 if (dims.kx <= dims.sx && dims.ky <= dims.sy) {
2399 // Non-overlapping window.
2400 ops = dims.batch * dims.iz * (dims.ix * dims.iy + dims.ox * dims.oy);
2401 } else {
2402 // Overlapping window.
2403 ops = dims.batch * dims.iz *
2404 (dims.ix * dims.iy + dims.ox * dims.oy * (dims.kx * dims.ky + 1));
2405 }
2406 auto s = PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2407 node_costs);
2408 node_costs->max_memory = node_costs->num_total_output_bytes();
2409 return s;
2410 }
2411
PredictFusedBatchNorm(const OpContext & op_context,NodeCosts * node_costs) const2412 Status OpLevelCostEstimator::PredictFusedBatchNorm(
2413 const OpContext& op_context, NodeCosts* node_costs) const {
2414 bool found_unknown_shapes = false;
2415 const auto& op_info = op_context.op_info;
2416 // x: op_info.inputs(0)
2417 // scale: op_info.inputs(1)
2418 // offset: op_info.inputs(2)
2419 // mean: op_info.inputs(3) --> only for inference
2420 // variance: op_info.inputs(4) --> only for inference
2421 TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2422 OpDimensionsFromInputs(op_info.inputs(0).shape(), op_info,
2423 &found_unknown_shapes));
2424 const bool is_training = IsTraining(op_info);
2425
2426 int64_t ops = 0;
2427 const auto rsqrt_cost = Eigen::internal::functor_traits<
2428 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2429 if (is_training) {
2430 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 4 + 6 + rsqrt_cost);
2431 } else {
2432 ops = dims.batch * dims.ix * dims.iy * dims.iz * 2;
2433 }
2434 node_costs->num_compute_ops = ops;
2435
2436 const int64_t size_nhwc =
2437 CalculateTensorSize(op_info.inputs(0), &found_unknown_shapes);
2438 const int64_t size_c =
2439 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2440 if (is_training) {
2441 node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c};
2442 node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2443 size_c};
2444 // FusedBatchNorm in training mode internally re-reads the input tensor:
2445 // one for mean/variance, and the 2nd internal read forthe actual scaling.
2446 // Assume small intermediate data such as mean / variance (size_c) can be
2447 // cached on-chip.
2448 node_costs->internal_read_bytes = size_nhwc;
2449 } else {
2450 node_costs->num_input_bytes_accessed = {size_nhwc, size_c, size_c, size_c,
2451 size_c};
2452 node_costs->num_output_bytes_accessed = {size_nhwc};
2453 }
2454 node_costs->max_memory = node_costs->num_total_output_bytes();
2455
2456 if (found_unknown_shapes) {
2457 node_costs->inaccurate = true;
2458 node_costs->num_nodes_with_unknown_shapes = 1;
2459 }
2460 return OkStatus();
2461 }
2462
PredictFusedBatchNormGrad(const OpContext & op_context,NodeCosts * node_costs) const2463 Status OpLevelCostEstimator::PredictFusedBatchNormGrad(
2464 const OpContext& op_context, NodeCosts* node_costs) const {
2465 bool found_unknown_shapes = false;
2466 const auto& op_info = op_context.op_info;
2467 // y_backprop: op_info.inputs(0)
2468 // x: op_info.inputs(1)
2469 // scale: op_info.inputs(2)
2470 // mean: op_info.inputs(3)
2471 // variance or inverse of variance: op_info.inputs(4)
2472 TF_ASSIGN_OR_RETURN(ConvolutionDimensions dims,
2473 OpDimensionsFromInputs(op_info.inputs(1).shape(), op_info,
2474 &found_unknown_shapes));
2475
2476 int64_t ops = 0;
2477 const auto rsqrt_cost = Eigen::internal::functor_traits<
2478 Eigen::internal::scalar_rsqrt_op<float>>::Cost;
2479 ops = dims.iz * (dims.batch * dims.ix * dims.iy * 11 + 5 + rsqrt_cost);
2480 node_costs->num_compute_ops = ops;
2481
2482 const int64_t size_nhwc =
2483 CalculateTensorSize(op_info.inputs(1), &found_unknown_shapes);
2484 const int64_t size_c =
2485 CalculateTensorSize(op_info.inputs(2), &found_unknown_shapes);
2486 // TODO(dyoon): fix missing memory cost for variance input (size_c) and
2487 // yet another read of y_backprop (size_nhwc) internally.
2488 node_costs->num_input_bytes_accessed = {size_nhwc, size_nhwc, size_c, size_c};
2489 node_costs->num_output_bytes_accessed = {size_nhwc, size_c, size_c};
2490 // FusedBatchNormGrad has to read y_backprop internally.
2491 node_costs->internal_read_bytes = size_nhwc;
2492 node_costs->max_memory = node_costs->num_total_output_bytes();
2493
2494 if (found_unknown_shapes) {
2495 node_costs->inaccurate = true;
2496 node_costs->num_nodes_with_unknown_shapes = 1;
2497 }
2498 return OkStatus();
2499 }
2500
PredictNaryOp(const OpContext & op_context,NodeCosts * node_costs) const2501 Status OpLevelCostEstimator::PredictNaryOp(const OpContext& op_context,
2502 NodeCosts* node_costs) const {
2503 const auto& op_info = op_context.op_info;
2504 bool found_unknown_shapes = false;
2505 // Calculate the largest known tensor size across all inputs and output.
2506 int64_t op_count = CalculateLargestInputCount(op_info, &found_unknown_shapes);
2507 // If output shape is available, try to use the element count calculated from
2508 // that.
2509 if (op_info.outputs_size() > 0) {
2510 op_count = std::max(
2511 op_count,
2512 CalculateTensorElementCount(op_info.outputs(0), &found_unknown_shapes));
2513 }
2514 // Also calculate the output shape possibly resulting from broadcasting.
2515 // Note that the some Nary ops (such as AddN) do not support broadcasting,
2516 // but we're including this here for completeness.
2517 if (op_info.inputs_size() >= 2) {
2518 op_count = std::max(op_count, CwiseOutputElementCount(op_info));
2519 }
2520
2521 // Nary ops perform one operation for every element in every input tensor.
2522 op_count *= op_info.inputs_size() - 1;
2523
2524 const auto sum_cost = Eigen::internal::functor_traits<
2525 Eigen::internal::scalar_sum_op<float>>::Cost;
2526 return PredictDefaultNodeCosts(op_count * sum_cost, op_context,
2527 &found_unknown_shapes, node_costs);
2528 }
2529
2530 // softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j]))
PredictSoftmax(const OpContext & op_context,NodeCosts * node_costs) const2531 Status OpLevelCostEstimator::PredictSoftmax(const OpContext& op_context,
2532 NodeCosts* node_costs) const {
2533 bool found_unknown_shapes = false;
2534 const int64_t logits_size = CalculateTensorElementCount(
2535 op_context.op_info.inputs(0), &found_unknown_shapes);
2536 // Softmax input rank should be >=1.
2537 TensorShapeProto logits_shape = op_context.op_info.inputs(0).shape();
2538 if (logits_shape.unknown_rank() || logits_shape.dim_size() == 0) {
2539 return errors::InvalidArgument("Softmax op has invalid input: ",
2540 op_context.op_info.ShortDebugString());
2541 }
2542
2543 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2544
2545 // Every element of <logits> will be exponentiated, have that result included
2546 // in a sum across j, and also have that result multiplied by the reciprocal
2547 // of the sum_j. In addition, we'll compute 1/sum_j for every i.
2548 auto ops =
2549 (EIGEN_COST(scalar_exp_op<float>) + EIGEN_COST(scalar_sum_op<float>) +
2550 EIGEN_COST(scalar_product_op<float>)) *
2551 logits_size +
2552 EIGEN_COST(scalar_inverse_op<float>) * logits_shape.dim(0).size();
2553
2554 #undef EIGEN_COST
2555 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2556 node_costs);
2557 }
2558
PredictResizeBilinear(const OpContext & op_context,NodeCosts * node_costs) const2559 Status OpLevelCostEstimator::PredictResizeBilinear(
2560 const OpContext& op_context, NodeCosts* node_costs) const {
2561 bool found_unknown_shapes = false;
2562
2563 if (op_context.op_info.outputs().empty() ||
2564 op_context.op_info.inputs().empty()) {
2565 return errors::InvalidArgument(
2566 "ResizeBilinear op has invalid input / output ",
2567 op_context.op_info.ShortDebugString());
2568 }
2569
2570 const int64_t output_elements = CalculateTensorElementCount(
2571 op_context.op_info.outputs(0), &found_unknown_shapes);
2572
2573 const auto half_pixel_centers =
2574 op_context.op_info.attr().find("half_pixel_centers");
2575 bool use_half_pixel_centers = false;
2576 if (half_pixel_centers == op_context.op_info.attr().end()) {
2577 LOG(WARNING) << "half_pixel_centers attr not set for ResizeBilinear.";
2578 return PredictCostOfAnUnknownOp(op_context, node_costs);
2579 } else {
2580 use_half_pixel_centers = half_pixel_centers->second.b();
2581 }
2582
2583 // Compose cost of bilinear interpolation.
2584 int64_t ops = 0;
2585
2586 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2587 const auto sub_cost_float = EIGEN_COST(scalar_difference_op<float>);
2588 const auto sub_cost_int = EIGEN_COST(scalar_difference_op<int64_t>);
2589 const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2590 const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2591 const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2592 const auto max_cost = EIGEN_COST(scalar_max_op<int64_t>);
2593 const auto min_cost = EIGEN_COST(scalar_min_op<int64_t>);
2594 const auto cast_to_int_cost = Eigen::internal::functor_traits<
2595 Eigen::internal::scalar_cast_op<float, int64_t>>::Cost;
2596 const auto cast_to_float_cost = Eigen::internal::functor_traits<
2597 Eigen::internal::scalar_cast_op<int64_t, float>>::Cost;
2598 const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2599 #undef EIGEN_COST
2600
2601 // Ops calculated from tensorflow/core/kernels/image/resize_bilinear_op.cc.
2602
2603 // Op counts taken from resize_bilinear implementation on 07/21/2020.
2604 // Computed op counts may become inaccurate if resize_bilinear implementation
2605 // changes.
2606
2607 // resize_bilinear has an optimization where the interpolation weights are
2608 // precomputed and cached. Given input tensors of size [B,H1,W1,C] and output
2609 // tensors of size [B,H2,W2,C], the last dimension C that needs to be accessed
2610 // in the input for interpolation are identical at every point in the output.
2611 // These values are cached in the compute_interpolation_weights function. For
2612 // a particular y in [0...H2-1], the rows to be accessed in the input are the
2613 // same. Likewise, for a particular x in [0...H2-1], the columns to be accsed
2614 // are the same. So the precomputation only needs to be done for H2 + W2
2615 // values.
2616 const auto output_shape = MaybeGetMinimumShape(
2617 op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2618 // Assume H is dim 1 and W is dim 2 to match logic in resize_bilinear, which
2619 // also makes this assumption.
2620 const int64_t output_height = output_shape.dim(1).size();
2621 const int64_t output_width = output_shape.dim(2).size();
2622 // Add the ops done outside of the scaler function in
2623 // compute_interpolation_weights.
2624 int64_t interp_weight_cost = floor_cost + max_cost + min_cost +
2625 sub_cost_float + sub_cost_int + ceil_cost +
2626 cast_to_int_cost * 2;
2627 // There are two options for computing the weight of each pixel in the
2628 // interpolation. Algorithm can use pixel centers, or corners, for the
2629 // weight. Ops depend on the scaler function passed into
2630 // compute_interpolation_weights.
2631 if (use_half_pixel_centers) {
2632 // Ops for HalfPixelScalaer.
2633 interp_weight_cost +=
2634 add_cost + mul_cost + sub_cost_float + cast_to_float_cost;
2635 } else {
2636 // Ops for LegacyScaler.
2637 interp_weight_cost += cast_to_float_cost + mul_cost;
2638 }
2639 // Cost for the interpolation is multiplied by (H2 + w2), as mentioned above.
2640 ops += interp_weight_cost * (output_height + output_width);
2641
2642 // Ops for computing the new values, done for every element. Logic is from
2643 // compute_lerp in the inner loop of resize_image which consists of:
2644 // const float top = top_left + (top_right - top_left) * x_lerp;
2645 // const float bottom = bottom_left + (bottom_right - bottom_left) * x_lerp;
2646 // return top + (bottom - top) * y_lerp;
2647 ops += (add_cost * 3 + sub_cost_float * 3 + mul_cost * 3) * output_elements;
2648
2649 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2650 node_costs);
2651 }
2652
PredictCropAndResize(const OpContext & op_context,NodeCosts * node_costs) const2653 Status OpLevelCostEstimator::PredictCropAndResize(const OpContext& op_context,
2654 NodeCosts* node_costs) const {
2655 bool found_unknown_shapes = false;
2656
2657 const auto method = op_context.op_info.attr().find("method");
2658 bool use_bilinear_interp;
2659 if (method == op_context.op_info.attr().end() ||
2660 method->second.s() == "bilinear") {
2661 use_bilinear_interp = true;
2662 } else if (method->second.s() == "nearest") {
2663 use_bilinear_interp = false;
2664 } else {
2665 LOG(WARNING) << "method attr in CropAndResize invalid; expected bilinear "
2666 "or nearest.";
2667 return PredictCostOfAnUnknownOp(op_context, node_costs);
2668 }
2669
2670 const int64_t num_boxes = op_context.op_info.inputs(1).shape().dim(0).size();
2671 const auto crop_shape = MaybeGetMinimumShape(
2672 op_context.op_info.outputs(0).shape(), 4, &found_unknown_shapes);
2673 const int64_t crop_height = crop_shape.dim(1).size();
2674 const int64_t crop_width = crop_shape.dim(2).size();
2675 const int64_t output_elements = CalculateTensorElementCount(
2676 op_context.op_info.outputs(0), &found_unknown_shapes);
2677
2678 #define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
2679 const auto sub_cost = EIGEN_COST(scalar_difference_op<float>);
2680 const auto add_cost = EIGEN_COST(scalar_sum_op<float>);
2681 const auto mul_cost = EIGEN_COST(scalar_product_op<float>);
2682 auto div_cost = EIGEN_COST(scalar_div_cost<float>);
2683 const auto floor_cost = EIGEN_COST(scalar_floor_op<float>);
2684 const auto ceil_cost = EIGEN_COST(scalar_ceil_op<float>);
2685 auto round_cost = EIGEN_COST(scalar_round_op<float>);
2686 const auto cast_to_float_cost = Eigen::internal::functor_traits<
2687 Eigen::internal::scalar_cast_op<int64_t, float>>::Cost;
2688 #undef EIGEN_COST
2689
2690 // Computing ops following
2691 // tensorflow/core/kernels/image/crop_and_resize_op.cc at 08/25/2020. Op
2692 // calculation differs from rough estimate in implementation, as it separates
2693 // out cost per box from cost per pixel and cost per element.
2694
2695 // Since crop arguments are user controlled, check for overflow.
2696 int64_t crop_area = MultiplyWithoutOverflow(crop_height, crop_width);
2697 if (crop_area < 0)
2698 return errors::InvalidArgument("Cannot estimate cost, multiplying ",
2699 crop_height, " with ", crop_width,
2700 " would overflow");
2701 int64_t crop_volume = MultiplyWithoutOverflow(crop_area, num_boxes);
2702 if (crop_volume < 0)
2703 return errors::InvalidArgument("Cannot estimate cost, multiplying ",
2704 crop_area, " with ", num_boxes,
2705 " would overflow");
2706 int64_t crop_depth = MultiplyWithoutOverflow(crop_height, num_boxes);
2707 if (crop_depth < 0)
2708 return errors::InvalidArgument("Cannot estimate cost, multiplying ",
2709 crop_height, " with ", num_boxes,
2710 " would overflow");
2711
2712 // Ops for variables height_scale and width_scale.
2713 int64_t ops = (sub_cost * 6 + mul_cost * 2 + div_cost * 2) * num_boxes;
2714 // Ops for variable in_y.
2715 ops += (mul_cost * 2 + sub_cost + add_cost) * crop_depth;
2716 // Ops for variable in_x (same computation across both branches).
2717 ops += (mul_cost * 2 + sub_cost + add_cost) * crop_volume;
2718 // Specify op_cost based on the method.
2719 if (use_bilinear_interp) {
2720 // Ops for variables top_y_index, bottom_y_index, y_lerp.
2721 ops += (floor_cost + ceil_cost + sub_cost) * crop_depth;
2722 // Ops for variables left_x, right_x, x_lerp;
2723 ops += (floor_cost + ceil_cost + sub_cost) * crop_volume;
2724 // Ops for innermost loop across depth.
2725 ops +=
2726 (cast_to_float_cost * 4 + add_cost * 3 + sub_cost * 3 + mul_cost * 3) *
2727 output_elements;
2728 } else /* method == "nearest" */ {
2729 // Ops for variables closest_x_index and closest_y_index.
2730 ops += round_cost * 2 * crop_volume;
2731 // Ops for innermost loop across depth.
2732 ops += cast_to_float_cost * output_elements;
2733 }
2734 return PredictDefaultNodeCosts(ops, op_context, &found_unknown_shapes,
2735 node_costs);
2736 }
2737
2738 } // end namespace grappler
2739 } // end namespace tensorflow
2740