1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17
18 #include <unordered_set>
19
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/status_matchers.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/protobuf/device_properties.pb.h"
30
31 namespace tensorflow {
32 namespace grappler {
33
34 namespace {
35
36 // TODO(dyoon): Consider to use this Test class for all the test cases, and then
37 // remove friend in the OpLevelCostEstimator class header.
38 class TestOpLevelCostEstimator : public OpLevelCostEstimator {
39 public:
TestOpLevelCostEstimator()40 TestOpLevelCostEstimator() {
41 compute_memory_overlap_ = true;
42 device_info_ = DeviceInfo();
43 }
~TestOpLevelCostEstimator()44 ~TestOpLevelCostEstimator() override {}
45
SetDeviceInfo(const DeviceInfo & device_info)46 void SetDeviceInfo(const DeviceInfo& device_info) {
47 device_info_ = device_info;
48 }
49
SetComputeMemoryOverlap(bool value)50 void SetComputeMemoryOverlap(bool value) { compute_memory_overlap_ = value; }
51
52 protected:
GetDeviceInfo(const DeviceProperties & device) const53 DeviceInfo GetDeviceInfo(const DeviceProperties& device) const override {
54 return device_info_;
55 }
56
57 DeviceInfo device_info_;
58 };
59
ExpectZeroCost(const Costs & cost)60 void ExpectZeroCost(const Costs& cost) {
61 EXPECT_TRUE(cost.inaccurate);
62 EXPECT_EQ(cost.compute_time, Costs::Duration::zero());
63 EXPECT_EQ(cost.execution_time, Costs::Duration::zero());
64 EXPECT_EQ(cost.memory_time, Costs::Duration::zero());
65 }
66
67 // Wrangles the minimum number of proto fields to set up a matrix.
DescribeMatrix(int rows,int columns,OpInfo * op_info)68 void DescribeMatrix(int rows, int columns, OpInfo* op_info) {
69 auto input = op_info->add_inputs();
70 auto shape = input->mutable_shape();
71 auto shape_rows = shape->add_dim();
72 shape_rows->set_size(rows);
73 auto shape_columns = shape->add_dim();
74 shape_columns->set_size(columns);
75 input->set_dtype(DT_FLOAT);
76 }
77
SetCpuDevice(OpInfo * op_info)78 void SetCpuDevice(OpInfo* op_info) {
79 auto device = op_info->mutable_device();
80 device->set_type("CPU");
81 device->set_num_cores(10);
82 device->set_bandwidth(10000000); // 10000000 KB/s = 10 GB/s
83 device->set_frequency(1000); // 1000 Mhz = 1 GHz
84 }
85
86 // Returns an OpInfo for MatMul with the minimum set of fields set up.
DescribeMatMul(int m,int n,int l,int k)87 OpContext DescribeMatMul(int m, int n, int l, int k) {
88 OpContext op_context;
89 SetCpuDevice(&op_context.op_info);
90 op_context.op_info.set_op("MatMul");
91
92 DescribeMatrix(m, l, &op_context.op_info);
93 DescribeMatrix(k, n, &op_context.op_info);
94 return op_context;
95 }
96
97 // Wrangles the minimum number of proto fields to set up an input of
98 // arbitrary rank and type.
DescribeArbitraryRankInput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)99 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
100 OpInfo* op_info) {
101 auto input = op_info->add_inputs();
102 input->set_dtype(dtype);
103 auto shape = input->mutable_shape();
104 for (auto d : dims) {
105 shape->add_dim()->set_size(d);
106 }
107 }
108
109 // Wrangles the minimum number of proto fields to set up an output of
110 // arbitrary rank and type.
DescribeArbitraryRankOutput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)111 void DescribeArbitraryRankOutput(const std::vector<int>& dims, DataType dtype,
112 OpInfo* op_info) {
113 auto output = op_info->add_outputs();
114 output->set_dtype(dtype);
115 auto shape = output->mutable_shape();
116 for (auto d : dims) {
117 shape->add_dim()->set_size(d);
118 }
119 }
120
121 // Returns an OpInfo for a SparseTensorDenseMatMul
DescribeSparseTensorDenseMatMul(const int nnz_a,const std::vector<int> & dims_b,const std::vector<int> & dims_out)122 OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
123 const std::vector<int>& dims_b,
124 const std::vector<int>& dims_out) {
125 OpContext op_context;
126 SetCpuDevice(&op_context.op_info);
127 op_context.op_info.set_op("SparseTensorDenseMatMul");
128
129 DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
130 DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
131 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
132 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
133 DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
134 return op_context;
135 }
136
137 // Returns an OpInfo for an XlaEinsum
DescribeXlaEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)138 OpContext DescribeXlaEinsum(const std::vector<int>& dims_a,
139 const std::vector<int>& dims_b,
140 const string& equation) {
141 OpContext op_context;
142 SetCpuDevice(&op_context.op_info);
143 op_context.op_info.set_op("XlaEinsum");
144 AttrValue equation_attribute;
145 equation_attribute.set_s(equation);
146 (*op_context.op_info.mutable_attr())["equation"] = equation_attribute;
147 if (!dims_a.empty())
148 DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
149 if (!dims_b.empty())
150 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
151 return op_context;
152 }
153
154 // Returns an OpInfo for an Einsum
DescribeEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)155 OpContext DescribeEinsum(const std::vector<int>& dims_a,
156 const std::vector<int>& dims_b,
157 const string& equation) {
158 OpContext op_context = DescribeXlaEinsum(dims_a, dims_b, equation);
159 op_context.op_info.set_op("Einsum");
160 return op_context;
161 }
162
DescribeDummyTensor(OpInfo::TensorProperties * tensor)163 void DescribeDummyTensor(OpInfo::TensorProperties* tensor) {
164 // Intentionally leave the tensor shape and type information missing.
165 }
166
167 // Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
168 // estimation purposes.
DescribeTensor1D(int dim0,OpInfo::TensorProperties * tensor)169 void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
170 auto shape = tensor->mutable_shape();
171 shape->add_dim()->set_size(dim0);
172 tensor->set_dtype(DT_FLOAT);
173 }
174
175 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
176 // estimation purposes.
DescribeTensor4D(int dim0,int dim1,int dim2,int dim3,OpInfo::TensorProperties * tensor)177 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
178 OpInfo::TensorProperties* tensor) {
179 auto shape = tensor->mutable_shape();
180 shape->add_dim()->set_size(dim0);
181 shape->add_dim()->set_size(dim1);
182 shape->add_dim()->set_size(dim2);
183 shape->add_dim()->set_size(dim3);
184 tensor->set_dtype(DT_FLOAT);
185 }
186
187 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
188 // estimation purposes.
DescribeTensor5D(int dim0,int dim1,int dim2,int dim3,int dim4,OpInfo::TensorProperties * tensor)189 void DescribeTensor5D(int dim0, int dim1, int dim2, int dim3, int dim4,
190 OpInfo::TensorProperties* tensor) {
191 auto shape = tensor->mutable_shape();
192 shape->add_dim()->set_size(dim0);
193 shape->add_dim()->set_size(dim1);
194 shape->add_dim()->set_size(dim2);
195 shape->add_dim()->set_size(dim3);
196 shape->add_dim()->set_size(dim4);
197 tensor->set_dtype(DT_FLOAT);
198 }
199
200 // DescribeConvolution constructs an OpContext for a Conv2D applied to an input
201 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
202 // (kx, ky, iz2, oz).
DescribeConvolution(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int oz)203 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
204 int kx, int ky, int oz) {
205 OpContext op_context;
206 SetCpuDevice(&op_context.op_info);
207 op_context.op_info.set_op("Conv2D");
208
209 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
210 DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
211
212 return op_context;
213 }
214
215 // Describe DepthwiseConvolution constructs an OpContext for a
216 // DepthwiseConv2dNative applied to an input
217 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
218 // (kx, ky, iz2, cm). cm is channel multiplier
219
DescribeDepthwiseConv2dNative(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int cm)220 OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
221 int iz2, int kx, int ky, int cm) {
222 OpContext op_context;
223 SetCpuDevice(&op_context.op_info);
224 op_context.op_info.set_op("DepthwiseConv2dNative");
225
226 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
227 DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs());
228
229 return op_context;
230 }
231
232 // DescribeFusedConv2DBiasActivation constructs an OpContext for a
233 // FusedConv2DBiasActivation applied to a convolution input tensor with shape
234 // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
235 // bias tensor with shape (oz), a side input tensor with shape
236 // (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
237 // shape (1). If a vectorized channel format is chosen (NCHW_VECT_C, e.g.) we'll
238 // default to 4 (the vector size most often used with this format on NVIDIA
239 // platforms) for the major channel size, and divide the input channel size by
240 // that amount.
241 //
242 // Note that this assumes the NHWC data format.
DescribeFusedConv2DBiasActivation(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int ox,int oy,int oz,bool has_side_input,const string & data_format,const string & filter_format)243 OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
244 int iz2, int kx, int ky, int ox,
245 int oy, int oz, bool has_side_input,
246 const string& data_format,
247 const string& filter_format) {
248 const int kVecWidth = 4;
249 OpContext op_context;
250 SetCpuDevice(&op_context.op_info);
251 op_context.op_info.set_op("FusedConv2DBiasActivation");
252 auto* attr_data_format = op_context.op_info.mutable_attr();
253 SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
254 auto* attr_filter_format = op_context.op_info.mutable_attr();
255 SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
256 if (data_format == "NHWC") {
257 DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
258 } else if (data_format == "NCHW") {
259 DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
260 } else {
261 // Use the NCHW_VECT_C format.
262 EXPECT_EQ(data_format, "NCHW_VECT_C");
263 EXPECT_EQ(iz1 % kVecWidth, 0);
264 DescribeTensor5D(batch, iz1 / kVecWidth, ix, iy, kVecWidth,
265 op_context.op_info.add_inputs());
266 }
267 if (filter_format == "HWIO") {
268 DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
269 } else if (filter_format == "OIHW") {
270 DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
271 } else {
272 EXPECT_EQ(filter_format, "OIHW_VECT_I");
273 EXPECT_EQ(iz2 % kVecWidth, 0);
274 // Use the OIHW_VECT_I format.
275 DescribeTensor5D(oz, iz2 / kVecWidth, kx, ky, kVecWidth,
276 op_context.op_info.add_inputs());
277 }
278 DescribeTensor1D(oz, op_context.op_info.add_inputs());
279
280 // Add the side_input, if any.
281 auto side_input = op_context.op_info.add_inputs();
282 if (has_side_input) {
283 if (data_format == "NHWC") {
284 DescribeTensor4D(batch, ox, oy, oz, side_input);
285 } else if (data_format == "NCHW") {
286 DescribeTensor4D(batch, oz, ox, oy, side_input);
287 } else {
288 // Use the NCHW_VECT_C format.
289 EXPECT_EQ(data_format, "NCHW_VECT_C");
290 EXPECT_EQ(oz % kVecWidth, 0);
291 DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input);
292 }
293 }
294
295 // Add the scaling tensors.
296 DescribeTensor1D(1, op_context.op_info.add_inputs());
297 DescribeTensor1D(1, op_context.op_info.add_inputs());
298
299 return op_context;
300 }
301
302 // DescribeUnaryOp constructs an OpContext for the given operation applied to
303 // a 4-tensor with shape (size1, 1, 1, 1).
DescribeUnaryOp(const string & op,int size1)304 OpContext DescribeUnaryOp(const string& op, int size1) {
305 OpContext op_context;
306 SetCpuDevice(&op_context.op_info);
307 op_context.op_info.set_op(op);
308
309 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
310 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_outputs());
311
312 return op_context;
313 }
314
315 // DescribeBinaryOp constructs an OpContext for the given operation applied to
316 // a 4-tensor with dimensions (size1, 1, 1, 1) and a 4-tensor with dimensions
317 // (2 * size1, size2, 1, 1).
318 //
319 // The choice of dimension here is arbitrary, and is used strictly to test the
320 // cost model for applying elementwise operations to tensors with unequal
321 // dimension values.
DescribeBinaryOp(const string & op,int size1,int size2)322 OpContext DescribeBinaryOp(const string& op, int size1, int size2) {
323 OpContext op_context;
324 SetCpuDevice(&op_context.op_info);
325 op_context.op_info.set_op(op);
326
327 DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
328 DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_inputs());
329 DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_outputs());
330
331 return op_context;
332 }
333
334 // DescribeBiasAdd constructs an OpContext for a BiasAdd applied to a 4-tensor
335 // with dimensions (1, 1, size2, size1) and a bias with dimension (size1),
336 // according to the constraint that the bias must be 1D with size equal to that
337 // of the last dimension of the input value.
DescribeBiasAdd(int size1,int size2)338 OpContext DescribeBiasAdd(int size1, int size2) {
339 OpContext op_context;
340 SetCpuDevice(&op_context.op_info);
341 op_context.op_info.set_op("BiasAdd");
342
343 DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
344 DescribeTensor1D(size1, op_context.op_info.add_inputs());
345 DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
346
347 return op_context;
348 }
349
GetOutputSize(const int x,const int k,const int s,const string & padding)350 int GetOutputSize(const int x, const int k, const int s,
351 const string& padding) {
352 if (padding == "SAME") {
353 return (x + s - 1) / s;
354 } else {
355 return (x - k + s) / s;
356 }
357 }
358
GetPoolingOutputSize(const std::vector<int> & input,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)359 std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
360 const std::vector<int>& ksize,
361 const std::vector<int>& strides,
362 const string& data_format,
363 const string& padding) {
364 // h, w, and c indices: default with NHWC.
365 int h_index = 1;
366 int w_index = 2;
367 int c_index = 3;
368 if (data_format == "NCHW") {
369 h_index = 2;
370 w_index = 3;
371 c_index = 1;
372 }
373 // Extract parameters.
374 int n = input[0];
375 int h = input[h_index];
376 int w = input[w_index];
377 int c = input[c_index];
378 int sx = strides[h_index];
379 int sy = strides[w_index];
380 int kx = ksize[h_index];
381 int ky = ksize[w_index];
382
383 // Output activation size: default with VALID padding.
384 int ho = GetOutputSize(h, kx, sx, padding);
385 int wo = GetOutputSize(w, ky, sy, padding);
386
387 std::vector<int> output;
388 if (data_format == "NHWC") {
389 output = {n, ho, wo, c};
390 } else {
391 output = {n, c, ho, wo};
392 }
393 return output;
394 }
395
396 // Helper functions for testing GetTensorShapeProtoFromTensorProto().
GetTensorProto(const DataType dtype,const std::vector<int64_t> & shape,const std::vector<int64_t> values,const bool tensor_content,TensorProto * tensor_proto)397 void GetTensorProto(const DataType dtype, const std::vector<int64_t>& shape,
398 const std::vector<int64_t> values,
399 const bool tensor_content, TensorProto* tensor_proto) {
400 tensor_proto->Clear();
401 TensorProto temp_tensor_proto;
402 temp_tensor_proto.set_dtype(dtype);
403 for (const auto& x : shape) {
404 temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
405 }
406 for (const auto& x : values) {
407 if (dtype == DT_INT64) {
408 temp_tensor_proto.add_int64_val(x);
409 } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
410 dtype == DT_UINT8) {
411 temp_tensor_proto.add_int_val(x);
412 } else if (dtype == DT_UINT32) {
413 temp_tensor_proto.add_uint32_val(x);
414 } else if (dtype == DT_UINT64) {
415 temp_tensor_proto.add_uint64_val(x);
416 } else {
417 CHECK(false) << "Unsupported dtype: " << dtype;
418 }
419 }
420 Tensor tensor(dtype);
421 CHECK(tensor.FromProto(temp_tensor_proto));
422 if (tensor_content) {
423 tensor.AsProtoTensorContent(tensor_proto);
424 } else {
425 tensor.AsProtoField(tensor_proto);
426 }
427 }
428
DescribePoolingOp(const string & op_name,const std::vector<int> & x,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)429 OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
430 const std::vector<int>& ksize,
431 const std::vector<int>& strides,
432 const string& data_format, const string& padding) {
433 OpContext op_context;
434 auto& op_info = op_context.op_info;
435 SetCpuDevice(&op_info);
436 op_info.set_op(op_name);
437
438 const std::vector<int> y =
439 GetPoolingOutputSize(x, ksize, strides, data_format, padding);
440 if (op_name == "AvgPool" || op_name == "MaxPool") {
441 // input: x, output: y.
442 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
443 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
444 } else if (op_name == "AvgPoolGrad") {
445 // input: x's shape, y_grad, output: x_grad.
446 DescribeArbitraryRankInput({4}, DT_INT32, &op_info);
447 auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value();
448 GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]},
449 /*tensor_content=*/false, tensor_proto);
450 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
451 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
452 } else if (op_name == "MaxPoolGrad") {
453 // input: x, y, y_grad, output: x_grad.
454 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
455 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
456 DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
457 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
458 }
459 auto* attr = op_info.mutable_attr();
460 SetAttrValue(data_format, &(*attr)["data_format"]);
461 SetAttrValue(padding, &(*attr)["padding"]);
462 SetAttrValue(strides, &(*attr)["strides"]);
463 SetAttrValue(ksize, &(*attr)["ksize"]);
464 return op_context;
465 }
466
DescribeFusedBatchNorm(const bool is_training,const bool is_grad,const std::vector<int> & x,const string & data_format)467 OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
468 const std::vector<int>& x,
469 const string& data_format) {
470 // First, get MaxPool op info with unit stride and unit window.
471 OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
472 {1, 1, 1, 1}, data_format, "SAME");
473 auto& op_info = op_context.op_info;
474 // Override op name.
475 if (is_grad) {
476 op_info.set_op("FusedBatchNormGrad");
477 } else {
478 op_info.set_op("FusedBatchNorm");
479 }
480
481 // Add additional input output tensors.
482 if (is_grad) {
483 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
484 }
485 int num_1d_inputs = is_grad ? 3 : 4;
486 for (int i = 0; i < num_1d_inputs; i++) {
487 auto* tensor = op_info.add_inputs();
488 auto* shape = tensor->mutable_shape();
489 shape->add_dim()->set_size(x[3]);
490 tensor->set_dtype(DT_FLOAT);
491 }
492 for (int i = 0; i < 4; i++) {
493 auto* tensor = op_info.add_outputs();
494 auto* shape = tensor->mutable_shape();
495 shape->add_dim()->set_size(x[3]);
496 tensor->set_dtype(DT_FLOAT);
497 }
498
499 // Delete unnecessary attr.
500 auto* attr = op_context.op_info.mutable_attr();
501 attr->erase("ksize");
502 attr->erase("strides");
503 attr->erase("padding");
504
505 // Additional attrs for FusedBatchNorm.
506 SetAttrValue(is_training, &(*attr)["is_training"]);
507
508 return op_context;
509 }
510 } // namespace
511
512 class OpLevelCostEstimatorTest : public ::testing::Test {
513 protected:
514 using BatchMatMulDimensions = OpLevelCostEstimator::BatchMatMulDimensions;
515
PredictCosts(const OpContext & op_context) const516 Costs PredictCosts(const OpContext& op_context) const {
517 return estimator_.PredictCosts(op_context);
518 }
519
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const520 int64_t CountMatMulOperations(const OpInfo& op_info,
521 bool* found_unknown_shapes) const {
522 return estimator_.CountMatMulOperations(op_info, found_unknown_shapes);
523 }
524
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const525 int64_t CountBatchMatMulOperations(const OpInfo& op_info,
526 bool* found_unknown_shapes) const {
527 return estimator_.CountBatchMatMulOperations(op_info, found_unknown_shapes);
528 }
529
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes) const530 int64_t CountBatchMatMulOperations(const OpInfo& op_info,
531 BatchMatMulDimensions* batch_mat_mul,
532 bool* found_unknown_shapes) const {
533 return estimator_.CountBatchMatMulOperations(op_info, batch_mat_mul,
534 found_unknown_shapes);
535 }
536
SetComputeMemoryOverlap(bool value)537 void SetComputeMemoryOverlap(bool value) {
538 estimator_.compute_memory_overlap_ = value;
539 }
540
ValidateOpDimensionsFromInputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)541 void ValidateOpDimensionsFromInputs(const int n, const int h, const int w,
542 const int c, const int kx, const int ky,
543 const int sx, const int sy,
544 const string& data_format,
545 const string& padding) {
546 OpContext op_context;
547 int ho;
548 int wo;
549 if (data_format == "NHWC") {
550 op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
551 {1, sx, sy, 1}, "NHWC", padding);
552 ho = op_context.op_info.outputs(0).shape().dim(1).size();
553 wo = op_context.op_info.outputs(0).shape().dim(2).size();
554 } else {
555 op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
556 {1, 1, sx, sy}, "NCHW", padding);
557 ho = op_context.op_info.outputs(0).shape().dim(2).size();
558 wo = op_context.op_info.outputs(0).shape().dim(3).size();
559 }
560
561 bool found_unknown_shapes;
562 TF_ASSERT_OK_AND_ASSIGN(
563 auto dims, OpLevelCostEstimator::OpDimensionsFromInputs(
564 op_context.op_info.inputs(0).shape(), op_context.op_info,
565 &found_unknown_shapes));
566 Padding padding_enum;
567 if (padding == "VALID") {
568 padding_enum = Padding::VALID;
569 } else {
570 padding_enum = Padding::SAME;
571 }
572 EXPECT_EQ(n, dims.batch);
573 EXPECT_EQ(h, dims.ix);
574 EXPECT_EQ(w, dims.iy);
575 EXPECT_EQ(c, dims.iz);
576 EXPECT_EQ(kx, dims.kx);
577 EXPECT_EQ(ky, dims.ky);
578 EXPECT_EQ(sx, dims.sx);
579 EXPECT_EQ(sy, dims.sy);
580 EXPECT_EQ(ho, dims.ox);
581 EXPECT_EQ(wo, dims.oy);
582 EXPECT_EQ(c, dims.oz);
583 EXPECT_EQ(padding_enum, dims.padding);
584 }
585
586 StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
CallOpDimensionsFromInputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)587 CallOpDimensionsFromInputs(const int n, const int h, const int w, const int c,
588 const int kx, const int ky, const int sx,
589 const int sy, const string& data_format,
590 const string& padding) {
591 OpContext op_context;
592
593 const std::vector<int> x = {n, h, w, c};
594 const std::vector<int> ksize = {1, kx, ky, 1};
595 std::vector<int> strides;
596 if (data_format == "NHWC") {
597 strides = {1, sy, sx, 1};
598 } else {
599 strides = {1, 1, sy, sx};
600 }
601
602 auto& op_info = op_context.op_info;
603 SetCpuDevice(&op_info);
604 op_info.set_op("MaxPool");
605
606 DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
607 auto* attr = op_info.mutable_attr();
608 SetAttrValue(data_format, &(*attr)["data_format"]);
609 SetAttrValue(padding, &(*attr)["padding"]);
610 SetAttrValue(strides, &(*attr)["strides"]);
611 SetAttrValue(ksize, &(*attr)["ksize"]);
612 bool found_unknown_shapes;
613 return OpLevelCostEstimator::OpDimensionsFromInputs(
614 op_context.op_info.inputs(0).shape(), op_context.op_info,
615 &found_unknown_shapes);
616 }
617
618 OpLevelCostEstimator estimator_;
619 };
620
621 class OpLevelBatchMatMulCostEstimatorTest
622 : public OpLevelCostEstimatorTest,
623 public ::testing::WithParamInterface<const char*> {
624 protected:
625 // Returns an OpInfo for a BatchMatMul
DescribeBatchMatMul(const std::vector<int> & dims_a,const std::vector<int> & dims_b)626 OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
627 const std::vector<int>& dims_b) {
628 OpContext op_context;
629 SetCpuDevice(&op_context.op_info);
630 op_context.op_info.set_op(GetParam());
631
632 DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
633 DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
634 return op_context;
635 }
636
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const637 int64_t CountBatchMatMulOperations(const OpInfo& op_info,
638 bool* found_unknown_shapes) const {
639 return OpLevelCostEstimatorTest::CountBatchMatMulOperations(
640 op_info, found_unknown_shapes);
641 }
642
CountBatchMatMulDimProduct(const OpInfo & op_info,bool * found_unknown_shapes) const643 int64_t CountBatchMatMulDimProduct(const OpInfo& op_info,
644 bool* found_unknown_shapes) const {
645 BatchMatMulDimensions batch_mat_mul;
646
647 batch_mat_mul.matmul_dims.n = 0;
648 batch_mat_mul.matmul_dims.m = 0;
649 batch_mat_mul.matmul_dims.k = 0;
650
651 OpLevelCostEstimatorTest::CountBatchMatMulOperations(
652 op_info, &batch_mat_mul, found_unknown_shapes);
653 int dimension_product = 1;
654 for (auto dim : batch_mat_mul.batch_dims) dimension_product *= dim;
655
656 dimension_product *= batch_mat_mul.matmul_dims.n;
657 dimension_product *= batch_mat_mul.matmul_dims.m;
658 dimension_product *= batch_mat_mul.matmul_dims.k;
659
660 return dimension_product;
661 }
662 };
663
TEST_F(OpLevelCostEstimatorTest,TestPersistentOpCosts)664 TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
665 OpContext op_context;
666 SetCpuDevice(&op_context.op_info);
667 std::unordered_set<string> persistent_ops = {
668 "Const", "Variable", "VariableV2", "AutoReloadVariable",
669 "VarHandleOp", "ReadVariableOp",
670 };
671 // Minimum cost for all persistent ops.
672 for (const auto& op : persistent_ops) {
673 op_context.op_info.set_op(op);
674 auto cost = estimator_.PredictCosts(op_context);
675 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
676 EXPECT_EQ(Costs::Duration(1), cost.compute_time);
677 EXPECT_EQ(Costs::Duration(1), cost.execution_time);
678 EXPECT_EQ(cost.num_ops_total, 1);
679 EXPECT_FALSE(cost.inaccurate);
680 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
681 EXPECT_EQ(cost.temporary_memory, 0);
682 EXPECT_EQ(cost.persistent_memory, 0);
683 }
684 }
685
TEST_F(OpLevelCostEstimatorTest,TestGatherCosts)686 TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
687 std::vector<std::string> gather_ops = {"Gather", "GatherNd", "GatherV2"};
688
689 for (const auto& op : gather_ops) {
690 OpContext op_context;
691 SetCpuDevice(&op_context.op_info);
692 op_context.op_info.set_op(op);
693
694 // Huge first input shouldn't affect Gather execution and memory costs.
695 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
696 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
697 DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
698
699 auto cost = estimator_.PredictCosts(op_context);
700 EXPECT_EQ(Costs::Duration(130), cost.memory_time);
701 EXPECT_EQ(Costs::Duration(16), cost.compute_time);
702 EXPECT_EQ(Costs::Duration(146), cost.execution_time);
703 EXPECT_EQ(cost.num_ops_total, 1);
704 EXPECT_FALSE(cost.inaccurate);
705 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
706 EXPECT_EQ(cost.temporary_memory, 0);
707 EXPECT_EQ(cost.persistent_memory, 0);
708 }
709 }
710
TEST_F(OpLevelCostEstimatorTest,TestGatherCostsWithoutOutput)711 TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
712 OpContext op_context;
713 SetCpuDevice(&op_context.op_info);
714 op_context.op_info.set_op("Gather");
715
716 // Huge first input shouldn't affect Gather execution and memory costs.
717 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
718 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
719
720 auto cost = estimator_.PredictCosts(op_context);
721 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
722 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
723 EXPECT_EQ(Costs::Duration(0), cost.execution_time);
724 EXPECT_EQ(1, cost.num_ops_total);
725 EXPECT_TRUE(cost.inaccurate);
726 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
727 EXPECT_EQ(cost.temporary_memory, 0);
728 EXPECT_EQ(cost.persistent_memory, 0);
729 }
730
TEST_F(OpLevelCostEstimatorTest,TestSliceCosts)731 TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
732 OpContext op_context;
733 SetCpuDevice(&op_context.op_info);
734 op_context.op_info.set_op("Slice");
735
736 // Huge first input shouldn't affect Slice execution and memory costs.
737 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
738 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
739 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
740 DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
741
742 auto cost = estimator_.PredictCosts(op_context);
743 EXPECT_EQ(Costs::Duration(81), cost.memory_time);
744 EXPECT_EQ(Costs::Duration(10), cost.compute_time);
745 EXPECT_EQ(Costs::Duration(91), cost.execution_time);
746 EXPECT_EQ(cost.num_ops_total, 1);
747 EXPECT_FALSE(cost.inaccurate);
748 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
749 EXPECT_EQ(cost.temporary_memory, 0);
750 EXPECT_EQ(cost.persistent_memory, 0);
751 }
752
TEST_F(OpLevelCostEstimatorTest,TestStridedSliceCosts)753 TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) {
754 OpContext op_context;
755 SetCpuDevice(&op_context.op_info);
756 op_context.op_info.set_op("StridedSlice");
757
758 // Huge first input shouldn't affect StridedSlice execution and memory costs.
759 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
760 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
761 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
762 DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
763 DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
764
765 auto cost = estimator_.PredictCosts(op_context);
766 EXPECT_EQ(Costs::Duration(81), cost.memory_time);
767 EXPECT_EQ(Costs::Duration(10), cost.compute_time);
768 EXPECT_EQ(Costs::Duration(91), cost.execution_time);
769 EXPECT_EQ(cost.num_ops_total, 1);
770 EXPECT_FALSE(cost.inaccurate);
771 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
772 EXPECT_EQ(cost.temporary_memory, 0);
773 EXPECT_EQ(cost.persistent_memory, 0);
774 }
775
TEST_F(OpLevelCostEstimatorTest,TestScatterOps)776 TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
777 std::vector<string> scatter_ops = {"ScatterAdd", "ScatterDiv", "ScatterMax",
778 "ScatterMin", "ScatterMul", "ScatterSub",
779 "ScatterUpdate"};
780 for (const auto& op : scatter_ops) {
781 // Test updates.shape = indices.shape + ref.shape[1:]
782 {
783 OpContext op_context;
784 SetCpuDevice(&op_context.op_info);
785 op_context.op_info.set_op(op);
786 // Huge first dimension in input shouldn't affect Scatter execution and
787 // memory costs.
788 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
789 DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
790 DescribeArbitraryRankInput({16, 10}, DT_FLOAT, &op_context.op_info);
791 DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
792 &op_context.op_info);
793
794 auto cost = estimator_.PredictCosts(op_context);
795 EXPECT_EQ(Costs::Duration(205), cost.memory_time);
796 EXPECT_EQ(Costs::Duration(16), cost.compute_time);
797 EXPECT_EQ(Costs::Duration(221), cost.execution_time);
798 EXPECT_EQ(cost.num_ops_total, 1);
799 EXPECT_FALSE(cost.inaccurate);
800 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
801 EXPECT_EQ(cost.temporary_memory, 0);
802 EXPECT_EQ(cost.persistent_memory, 0);
803 }
804
805 // Test updates.shape = [] and INT32 indices
806 {
807 OpContext op_context;
808 SetCpuDevice(&op_context.op_info);
809 op_context.op_info.set_op(op);
810 // Huge first dimension in input shouldn't affect Scatter execution and
811 // memory costs.
812 DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
813 DescribeArbitraryRankInput({16}, DT_INT32, &op_context.op_info);
814 DescribeArbitraryRankInput({}, DT_FLOAT, &op_context.op_info);
815 DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
816 &op_context.op_info);
817
818 auto cost = estimator_.PredictCosts(op_context);
819 EXPECT_EQ(Costs::Duration(135), cost.memory_time);
820 EXPECT_EQ(Costs::Duration(16), cost.compute_time);
821 EXPECT_EQ(Costs::Duration(151), cost.execution_time);
822 EXPECT_EQ(1, cost.num_ops_total);
823 EXPECT_FALSE(cost.inaccurate);
824 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
825 }
826 }
827 }
828
TEST_F(OpLevelCostEstimatorTest,BiasAddExecutionTime)829 TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
830 auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
831 EXPECT_EQ(Costs::Duration(8400), cost.memory_time);
832 EXPECT_EQ(Costs::Duration(1000), cost.compute_time);
833 EXPECT_EQ(Costs::Duration(9400), cost.execution_time);
834 EXPECT_EQ(cost.num_ops_total, 1);
835 EXPECT_FALSE(cost.inaccurate);
836 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
837 EXPECT_EQ(cost.temporary_memory, 0);
838 EXPECT_EQ(cost.persistent_memory, 0);
839 }
840
TEST_F(OpLevelCostEstimatorTest,Conv2DExecutionTime)841 TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) {
842 auto cost = PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
843 EXPECT_EQ(Costs::Duration(233780), cost.memory_time);
844 EXPECT_EQ(Costs::Duration(354877440), cost.compute_time);
845 EXPECT_EQ(Costs::Duration(355111220), cost.execution_time);
846 EXPECT_EQ(cost.num_ops_total, 1);
847 EXPECT_FALSE(cost.inaccurate);
848 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
849 EXPECT_EQ(cost.temporary_memory, 0);
850 EXPECT_EQ(cost.persistent_memory, 0);
851 }
852
TEST_F(OpLevelCostEstimatorTest,InvalidConv2DConfig)853 TEST_F(OpLevelCostEstimatorTest, InvalidConv2DConfig) {
854 // Convolution ops.
855 const std::vector<std::string> conv_ops = {
856 "Conv2D",
857 "Conv2DBackpropFilter",
858 "Conv2DBackpropInput",
859 "DepthwiseConv2dNative",
860 "DepthwiseConv2dNativeBackpropFilter",
861 "DepthwiseConv2dNativeBackpropInput",
862 };
863 // A valid Conv2D config.
864 const std::vector<int> valid_conv_config = {16, 19, 19, 48, 48, 5, 5, 256};
865 for (const auto& op : conv_ops) {
866 // Test with setting one value in conv config to zero.
867 // PredictCosts() should return zero costs.
868 for (int i = 0; i < valid_conv_config.size(); ++i) {
869 std::vector<int> conv_config(valid_conv_config);
870 conv_config[i] = 0;
871 auto op_context = DescribeConvolution(
872 conv_config[0], conv_config[1], conv_config[2], conv_config[3],
873 conv_config[4], conv_config[5], conv_config[6], conv_config[7]);
874 op_context.op_info.set_op(op);
875 auto cost = PredictCosts(op_context);
876 EXPECT_EQ(Costs::Duration(0), cost.memory_time);
877 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
878 EXPECT_EQ(Costs::Duration(0), cost.execution_time);
879 EXPECT_EQ(1, cost.num_ops_total);
880 EXPECT_TRUE(cost.inaccurate);
881 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
882 }
883 }
884 }
885
TEST_F(OpLevelCostEstimatorTest,DepthwiseConv2dNativeExecutionTime)886 TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
887 auto cost =
888 PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
889 EXPECT_EQ(Costs::Duration(112340), cost.memory_time);
890 EXPECT_EQ(Costs::Duration(4158720), cost.compute_time);
891 EXPECT_EQ(Costs::Duration(4271060), cost.execution_time);
892 EXPECT_EQ(cost.num_ops_total, 1);
893 EXPECT_FALSE(cost.inaccurate);
894 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
895 EXPECT_EQ(cost.temporary_memory, 0);
896 EXPECT_EQ(cost.persistent_memory, 0);
897 }
898
TEST_F(OpLevelCostEstimatorTest,DummyExecutionTime)899 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
900 auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
901 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
902 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
903 EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
904 EXPECT_EQ(cost.num_ops_total, 1);
905 EXPECT_TRUE(cost.inaccurate);
906 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
907 EXPECT_EQ(cost.temporary_memory, 0);
908 EXPECT_EQ(cost.persistent_memory, 0);
909 }
910
TEST_F(OpLevelCostEstimatorTest,ExecutionTimeSumOrMax)911 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
912 SetComputeMemoryOverlap(true);
913 auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
914 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
915 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
916 EXPECT_EQ(Costs::Duration(2000), cost.execution_time); // max(2000, 200)
917 EXPECT_EQ(cost.num_ops_total, 1);
918 EXPECT_TRUE(cost.inaccurate);
919 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
920 EXPECT_EQ(cost.temporary_memory, 0);
921 EXPECT_EQ(cost.persistent_memory, 0);
922 SetComputeMemoryOverlap(false); // Set it back to default.
923 }
924
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO_NoSideInput)925 TEST_F(OpLevelCostEstimatorTest,
926 FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
927 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
928 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
929 "NCHW", "HWIO"));
930 EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
931 EXPECT_EQ(Costs::Duration(355321037), cost.compute_time);
932 EXPECT_EQ(Costs::Duration(356146382), cost.execution_time);
933 EXPECT_EQ(cost.num_ops_total, 1);
934 EXPECT_FALSE(cost.inaccurate);
935 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
936 EXPECT_EQ(cost.temporary_memory, 0);
937 EXPECT_EQ(cost.persistent_memory, 0);
938 }
939
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO)940 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
941 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
942 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
943 "NCHW", "HWIO"));
944 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
945 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
946 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
947 EXPECT_EQ(cost.num_ops_total, 1);
948 EXPECT_FALSE(cost.inaccurate);
949 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
950 EXPECT_EQ(cost.temporary_memory, 0);
951 EXPECT_EQ(cost.persistent_memory, 0);
952 }
953
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW)954 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
955 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
956 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
957 "NCHW", "OIHW"));
958 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
959 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
960 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
961 EXPECT_EQ(cost.num_ops_total, 1);
962 EXPECT_FALSE(cost.inaccurate);
963 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
964 EXPECT_EQ(cost.temporary_memory, 0);
965 EXPECT_EQ(cost.persistent_memory, 0);
966 }
967
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_HWIO)968 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
969 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
970 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
971 "NHWC", "HWIO"));
972 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
973 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
974 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
975 EXPECT_EQ(cost.num_ops_total, 1);
976 EXPECT_FALSE(cost.inaccurate);
977 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
978 EXPECT_EQ(cost.temporary_memory, 0);
979 EXPECT_EQ(cost.persistent_memory, 0);
980 }
981
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_OIHW)982 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
983 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
984 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
985 "NHWC", "OIHW"));
986 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
987 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
988 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
989 EXPECT_EQ(cost.num_ops_total, 1);
990 EXPECT_FALSE(cost.inaccurate);
991 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
992 EXPECT_EQ(cost.temporary_memory, 0);
993 EXPECT_EQ(cost.persistent_memory, 0);
994 }
995
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW)996 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
997 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
998 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
999 "NCHW_VECT_C", "OIHW"));
1000 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
1001 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
1002 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
1003 EXPECT_EQ(cost.num_ops_total, 1);
1004 EXPECT_FALSE(cost.inaccurate);
1005 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1006 EXPECT_EQ(cost.temporary_memory, 0);
1007 EXPECT_EQ(cost.persistent_memory, 0);
1008 }
1009
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW_VECT_I)1010 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
1011 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
1012 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
1013 "NCHW", "OIHW_VECT_I"));
1014 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
1015 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
1016 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
1017 EXPECT_EQ(cost.num_ops_total, 1);
1018 EXPECT_FALSE(cost.inaccurate);
1019 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1020 EXPECT_EQ(cost.temporary_memory, 0);
1021 EXPECT_EQ(cost.persistent_memory, 0);
1022 }
1023
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I)1024 TEST_F(OpLevelCostEstimatorTest,
1025 FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I) {
1026 auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
1027 16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
1028 "NCHW_VECT_C", "OIHW_VECT_I"));
1029 EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
1030 EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
1031 EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
1032 EXPECT_EQ(cost.num_ops_total, 1);
1033 EXPECT_FALSE(cost.inaccurate);
1034 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1035 EXPECT_EQ(cost.temporary_memory, 0);
1036 EXPECT_EQ(cost.persistent_memory, 0);
1037 }
1038
TEST_F(OpLevelCostEstimatorTest,MulExecutionTime)1039 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
1040 auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
1041 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1042 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1043 EXPECT_EQ(Costs::Duration(2200), cost.execution_time);
1044 EXPECT_EQ(cost.num_ops_total, 1);
1045 EXPECT_FALSE(cost.inaccurate);
1046 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1047 EXPECT_EQ(cost.temporary_memory, 0);
1048 EXPECT_EQ(cost.persistent_memory, 0);
1049 }
1050
TEST_F(OpLevelCostEstimatorTest,MulBroadcastExecutionTime)1051 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) {
1052 auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 2));
1053 EXPECT_EQ(Costs::Duration(3600), cost.memory_time);
1054 EXPECT_EQ(Costs::Duration(400), cost.compute_time);
1055 EXPECT_EQ(Costs::Duration(4000), cost.execution_time);
1056 EXPECT_EQ(cost.num_ops_total, 1);
1057 EXPECT_FALSE(cost.inaccurate);
1058 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1059 EXPECT_EQ(cost.temporary_memory, 0);
1060 EXPECT_EQ(cost.persistent_memory, 0);
1061 }
1062
TEST_F(OpLevelCostEstimatorTest,ModExecutionTime)1063 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) {
1064 auto cost = PredictCosts(DescribeBinaryOp("Mod", 1000, 1));
1065 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1066 EXPECT_EQ(Costs::Duration(1600), cost.compute_time);
1067 EXPECT_EQ(Costs::Duration(3600), cost.execution_time);
1068 EXPECT_EQ(cost.num_ops_total, 1);
1069 EXPECT_FALSE(cost.inaccurate);
1070 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1071 EXPECT_EQ(cost.temporary_memory, 0);
1072 EXPECT_EQ(cost.persistent_memory, 0);
1073 }
1074
TEST_F(OpLevelCostEstimatorTest,SquaredDifferenceExecutionTime)1075 TEST_F(OpLevelCostEstimatorTest, SquaredDifferenceExecutionTime) {
1076 auto cost = PredictCosts(DescribeBinaryOp("SquaredDifference", 1000, 2));
1077 EXPECT_EQ(cost.memory_time, Costs::Duration(3600));
1078 EXPECT_EQ(cost.compute_time, Costs::Duration(800));
1079 EXPECT_EQ(cost.execution_time, Costs::Duration(4400));
1080 EXPECT_EQ(cost.num_ops_total, 1);
1081 EXPECT_FALSE(cost.inaccurate);
1082 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1083 EXPECT_EQ(cost.temporary_memory, 0);
1084 EXPECT_EQ(cost.persistent_memory, 0);
1085 }
1086
TEST_F(OpLevelCostEstimatorTest,UnaryOpExecutionTime)1087 TEST_F(OpLevelCostEstimatorTest, UnaryOpExecutionTime) {
1088 std::vector<std::pair<std::string, int>> unary_ops = {
1089 {"All", 1}, {"ArgMax", 1}, {"Cast", 1}, {"Max", 1},
1090 {"Min", 1}, {"Prod", 1}, {"Relu", 1}, {"Relu6", 1},
1091 {"Softmax", 40}, {"Sum", 1}, {"TopKV2", 1}};
1092
1093 const int kTensorSize = 1000;
1094 for (auto unary_op : unary_ops) {
1095 OpContext op_context = DescribeUnaryOp(unary_op.first, kTensorSize);
1096
1097 const int kExpectedMemoryTime = 800;
1098 int expected_compute_time = std::ceil(
1099 unary_op.second * kTensorSize /
1100 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1101
1102 auto cost = PredictCosts(op_context);
1103 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
1104 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time))
1105 << unary_op.first;
1106 EXPECT_EQ(cost.execution_time,
1107 Costs::Duration(expected_compute_time + kExpectedMemoryTime));
1108 EXPECT_EQ(cost.num_ops_total, 1);
1109 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1110 EXPECT_FALSE(cost.inaccurate);
1111 EXPECT_EQ(cost.temporary_memory, 0);
1112 EXPECT_EQ(cost.persistent_memory, 0);
1113 }
1114 }
1115
TEST_F(OpLevelCostEstimatorTest,BinaryOpExecutionTime)1116 TEST_F(OpLevelCostEstimatorTest, BinaryOpExecutionTime) {
1117 std::vector<std::pair<std::string, int>> binary_ops = {
1118 {"Select", 1},
1119 {"SelectV2", 1},
1120 {"SquaredDifference", 2},
1121 {"Where", 1},
1122 };
1123
1124 const int kTensorSize1 = 1000;
1125 const int kTensorSize2 = 2;
1126 for (auto binary_op : binary_ops) {
1127 OpContext op_context =
1128 DescribeBinaryOp(binary_op.first, kTensorSize1, kTensorSize2);
1129
1130 const int kExpectedMemoryTime = 3600;
1131 int expected_compute_time = std::ceil(
1132 binary_op.second * kTensorSize1 * kTensorSize2 * 2 /
1133 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1134
1135 auto cost = PredictCosts(op_context);
1136 EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time)
1137 << binary_op.first;
1138 EXPECT_EQ(Costs::Duration(expected_compute_time), cost.compute_time)
1139 << binary_op.first;
1140 EXPECT_EQ(Costs::Duration(expected_compute_time + kExpectedMemoryTime),
1141 cost.execution_time)
1142 << binary_op.first;
1143 EXPECT_EQ(cost.num_ops_total, 1);
1144 EXPECT_FALSE(cost.inaccurate);
1145 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1146 EXPECT_EQ(cost.temporary_memory, 0);
1147 EXPECT_EQ(cost.persistent_memory, 0);
1148 }
1149 }
1150
TEST_F(OpLevelCostEstimatorTest,BroadcastAddExecutionTime)1151 TEST_F(OpLevelCostEstimatorTest, BroadcastAddExecutionTime) {
1152 OpContext op_context;
1153 SetCpuDevice(&op_context.op_info);
1154 op_context.op_info.set_op("Add");
1155
1156 DescribeTensor1D(100, op_context.op_info.add_inputs());
1157 DescribeTensor4D(1, 10, 1, 1, op_context.op_info.add_inputs());
1158
1159 auto cost = PredictCosts(op_context);
1160 EXPECT_EQ(Costs::Duration(44), cost.memory_time);
1161 EXPECT_EQ(Costs::Duration(100), cost.compute_time);
1162 EXPECT_EQ(Costs::Duration(144), cost.execution_time);
1163 EXPECT_EQ(cost.num_ops_total, 1);
1164 EXPECT_FALSE(cost.inaccurate);
1165 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1166 EXPECT_EQ(cost.temporary_memory, 0);
1167 EXPECT_EQ(cost.persistent_memory, 0);
1168 }
1169
TEST_F(OpLevelCostEstimatorTest,UnknownOrPartialShape)1170 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
1171 {
1172 auto cost = PredictCosts(DescribeMatMul(2, 4, 7, 7));
1173 EXPECT_EQ(1, cost.num_ops_total);
1174 EXPECT_FALSE(cost.inaccurate);
1175 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1176 }
1177 {
1178 auto cost = PredictCosts(DescribeMatMul(-1, 4, 7, 7));
1179 EXPECT_EQ(1, cost.num_ops_total);
1180 EXPECT_TRUE(cost.inaccurate);
1181 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1182 }
1183 {
1184 auto cost = PredictCosts(DescribeMatMul(2, 4, -1, 7));
1185 EXPECT_EQ(1, cost.num_ops_total);
1186 EXPECT_TRUE(cost.inaccurate);
1187 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1188 }
1189 {
1190 auto cost =
1191 PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1192 EXPECT_EQ(1, cost.num_ops_total);
1193 EXPECT_FALSE(cost.inaccurate);
1194 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1195 }
1196 {
1197 auto cost =
1198 PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256));
1199 EXPECT_EQ(1, cost.num_ops_total);
1200 EXPECT_TRUE(cost.inaccurate);
1201 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1202 }
1203 }
1204
TEST_P(OpLevelBatchMatMulCostEstimatorTest,TestBatchMatMul)1205 TEST_P(OpLevelBatchMatMulCostEstimatorTest, TestBatchMatMul) {
1206 {
1207 auto cost = PredictCosts(DescribeBatchMatMul({}, {}));
1208 EXPECT_EQ(1, cost.num_ops_total);
1209 EXPECT_TRUE(cost.inaccurate);
1210 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1211 }
1212 {
1213 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {}));
1214 EXPECT_EQ(1, cost.num_ops_total);
1215 EXPECT_TRUE(cost.inaccurate);
1216 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1217 }
1218 {
1219 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2}));
1220 EXPECT_EQ(1, cost.num_ops_total);
1221 EXPECT_FALSE(cost.inaccurate);
1222 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1223 }
1224 {
1225 auto cost = PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2}));
1226 EXPECT_EQ(1, cost.num_ops_total);
1227 EXPECT_FALSE(cost.inaccurate);
1228 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1229 }
1230 {
1231 auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}));
1232 EXPECT_EQ(1, cost.num_ops_total);
1233 EXPECT_FALSE(cost.inaccurate);
1234 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1235 }
1236 bool matmul_inaccurate = false;
1237 bool batch_matmul_inaccurate = false;
1238 EXPECT_EQ(
1239 CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1240 &matmul_inaccurate),
1241 CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
1242 &batch_matmul_inaccurate));
1243 EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
1244 EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1245 &matmul_inaccurate),
1246 CountBatchMatMulOperations(
1247 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
1248 &batch_matmul_inaccurate));
1249 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1250 EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1251 &matmul_inaccurate),
1252 CountBatchMatMulOperations(
1253 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
1254 &batch_matmul_inaccurate));
1255 EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1256
1257 // Test the count to make sure that they extracted the dimensions correctly
1258 int prod = CountBatchMatMulDimProduct(
1259 DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}).op_info,
1260 &batch_matmul_inaccurate);
1261 EXPECT_EQ(prod, 16);
1262 EXPECT_FALSE(batch_matmul_inaccurate);
1263
1264 // Exercise the bad cases of a batchMatMul.
1265 OpContext bad_batch = DescribeBatchMatMul({2, 4}, {4, 2});
1266 bad_batch.op_info.set_op("notBatchMatMul");
1267 prod =
1268 CountBatchMatMulDimProduct(bad_batch.op_info, &batch_matmul_inaccurate);
1269
1270 EXPECT_EQ(prod, 0);
1271 EXPECT_TRUE(batch_matmul_inaccurate);
1272
1273 // Exercise a transpose case of a batchMatMul
1274 OpContext transpose_batch = DescribeBatchMatMul({2, 4, 3, 1}, {4, 2});
1275 auto attr = transpose_batch.op_info.mutable_attr();
1276 (*attr)["adj_x"].set_b(true);
1277 (*attr)["adj_y"].set_b(true);
1278
1279 prod = CountBatchMatMulDimProduct(transpose_batch.op_info,
1280 &batch_matmul_inaccurate);
1281 EXPECT_EQ(prod, 12);
1282 }
1283 INSTANTIATE_TEST_SUITE_P(TestBatchMatMul, OpLevelBatchMatMulCostEstimatorTest,
1284 ::testing::Values("BatchMatMul", "BatchMatMulV2"));
1285
TEST_F(OpLevelCostEstimatorTest,SparseTensorDenseMatMul)1286 TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
1287 // Unknown shape cases
1288 {
1289 auto cost =
1290 PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
1291 EXPECT_EQ(1, cost.num_ops_total);
1292 EXPECT_TRUE(cost.inaccurate);
1293 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1294 }
1295 {
1296 auto cost =
1297 PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
1298 EXPECT_EQ(1, cost.num_ops_total);
1299 EXPECT_TRUE(cost.inaccurate);
1300 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1301 }
1302 {
1303 auto cost =
1304 PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
1305 EXPECT_EQ(1, cost.num_ops_total);
1306 EXPECT_TRUE(cost.inaccurate);
1307 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1308 }
1309 {
1310 auto cost =
1311 PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
1312 EXPECT_EQ(1, cost.num_ops_total);
1313 EXPECT_TRUE(cost.inaccurate);
1314 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1315 }
1316 // Known shape cases
1317 {
1318 auto cost = PredictCosts(
1319 DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
1320 EXPECT_EQ(1, cost.num_ops_total);
1321 EXPECT_FALSE(cost.inaccurate);
1322 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1323 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1324 EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1325 }
1326 {
1327 // Same cost as above case because cost does not depend on k_dim
1328 auto cost = PredictCosts(
1329 DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
1330 EXPECT_EQ(1, cost.num_ops_total);
1331 EXPECT_FALSE(cost.inaccurate);
1332 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1333 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1334 EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1335 }
1336 }
1337
ExpectTensorShape(const std::vector<int64_t> & expected,const TensorShapeProto & tensor_shape_proto)1338 void ExpectTensorShape(const std::vector<int64_t>& expected,
1339 const TensorShapeProto& tensor_shape_proto) {
1340 TensorShape tensor_shape_expected(expected);
1341 TensorShape tensor_shape(tensor_shape_proto);
1342
1343 EXPECT_EQ(tensor_shape_expected, tensor_shape);
1344 }
1345
TEST_F(OpLevelCostEstimatorTest,GetTensorShapeProtoFromTensorProto)1346 TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
1347 TensorProto tensor_proto;
1348 TensorShapeProto tensor_shape_proto;
1349
1350 // Dimension larger than max value; should fail while converting to
1351 // Tensor class.
1352 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
1353 EXPECT_FALSE(
1354 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1355
1356 tensor_proto.Clear();
1357 // Expect only 1D shape.
1358 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
1359 tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
1360 EXPECT_FALSE(
1361 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1362
1363 // Expect only handle integer data types.
1364 GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
1365 EXPECT_FALSE(
1366 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1367
1368 // Check GetTensorShapeProtoFromTensorProto() returns correct values.
1369 {
1370 std::vector<int64_t> shape_expected = {10, 20, 30, 40};
1371 GetTensorProto(DT_INT32, {4}, shape_expected,
1372 /*tensor_content=*/false, &tensor_proto);
1373 EXPECT_TRUE(
1374 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1375 ExpectTensorShape(shape_expected, tensor_shape_proto);
1376 }
1377
1378 {
1379 std::vector<int64_t> shape_expected = {40, 20, 90, 40};
1380 GetTensorProto(DT_INT64, {4}, shape_expected,
1381 /*tensor_content=*/false, &tensor_proto);
1382 EXPECT_TRUE(
1383 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1384 ExpectTensorShape(shape_expected, tensor_shape_proto);
1385 }
1386
1387 {
1388 std::vector<int64_t> shape_expected = {10, 20, 30, 40};
1389 GetTensorProto(DT_INT32, {4}, shape_expected,
1390 /*tensor_content=*/true, &tensor_proto);
1391 EXPECT_TRUE(
1392 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1393 ExpectTensorShape(shape_expected, tensor_shape_proto);
1394 }
1395
1396 {
1397 std::vector<int64_t> shape_expected = {40, 20, 90, 40};
1398 GetTensorProto(DT_INT64, {4}, shape_expected,
1399 /*tensor_content=*/true, &tensor_proto);
1400 EXPECT_TRUE(
1401 GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1402 ExpectTensorShape(shape_expected, tensor_shape_proto);
1403 }
1404 }
1405
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputs)1406 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
1407 std::vector<string> paddings = {"VALID", "SAME"};
1408 std::vector<string> formats = {"NHWC", "NCHW"};
1409 for (const auto& p : paddings) {
1410 for (const auto& f : formats) {
1411 // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1412 ValidateOpDimensionsFromInputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
1413 ValidateOpDimensionsFromInputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
1414 ValidateOpDimensionsFromInputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
1415 ValidateOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
1416 }
1417 }
1418 }
1419
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputsError)1420 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputsError) {
1421 std::vector<string> paddings = {"VALID", "SAME"};
1422 std::vector<string> formats = {"NHWC", "NCHW"};
1423 for (const auto& p : paddings) {
1424 for (const auto& f : formats) {
1425 // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1426 ASSERT_THAT(
1427 CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 0, 2, f, p),
1428 testing::StatusIs(
1429 error::INVALID_ARGUMENT,
1430 "Stride must be > 0 for Height and Width, but got (2, 0)"));
1431 ASSERT_THAT(
1432 CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 0, f, p),
1433 testing::StatusIs(
1434 error::INVALID_ARGUMENT,
1435 "Stride must be > 0 for Height and Width, but got (0, 2)"));
1436 }
1437 }
1438 }
1439
TEST_F(OpLevelCostEstimatorTest,PredictMaxPool)1440 TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
1441 auto predict_max_pool = [this](const int n, const int in, const int c,
1442 const int k, const int s,
1443 const string& padding) -> Costs {
1444 OpContext op_context = DescribePoolingOp(
1445 "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1446 return estimator_.PredictCosts(op_context);
1447 };
1448
1449 {
1450 // Typical 3xz3 window with 2x2 stride.
1451 auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
1452 EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
1453 EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
1454 EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1455 EXPECT_EQ(costs.num_ops_total, 1);
1456 EXPECT_FALSE(costs.inaccurate);
1457 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1458 EXPECT_EQ(costs.temporary_memory, 0);
1459 EXPECT_EQ(costs.persistent_memory, 0);
1460 }
1461 {
1462 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1463 auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
1464 EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1465 EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1466 EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1467 EXPECT_EQ(1, costs.num_ops_total);
1468 EXPECT_FALSE(costs.inaccurate);
1469 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1470 }
1471 {
1472 // 2x2 window with 3x3 stride.
1473 auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
1474 EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
1475 EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
1476 EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1477 EXPECT_EQ(1, costs.num_ops_total);
1478 EXPECT_FALSE(costs.inaccurate);
1479 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1480 }
1481 }
1482
TEST_F(OpLevelCostEstimatorTest,PredictMaxPoolGrad)1483 TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
1484 auto predict_max_pool_grad = [this](const int n, const int in, const int c,
1485 const int k, const int s,
1486 const string& padding) -> Costs {
1487 OpContext op_context =
1488 DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
1489 {1, s, s, 1}, "NHWC", padding);
1490 return estimator_.PredictCosts(op_context);
1491 };
1492
1493 {
1494 // Typical 3x3 window with 2x2 stride.
1495 auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
1496 EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
1497 EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
1498 EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1499 EXPECT_EQ(costs.num_ops_total, 1);
1500 EXPECT_FALSE(costs.inaccurate);
1501 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1502 EXPECT_EQ(costs.temporary_memory, 0);
1503 EXPECT_EQ(costs.persistent_memory, 0);
1504 }
1505 {
1506 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1507 auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
1508 EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
1509 EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
1510 EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1511 EXPECT_EQ(1, costs.num_ops_total);
1512 EXPECT_FALSE(costs.inaccurate);
1513 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1514 }
1515 {
1516 // 2x2 window with 3x3 stride.
1517 auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
1518 EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
1519 EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
1520 EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
1521 EXPECT_EQ(1, costs.num_ops_total);
1522 EXPECT_FALSE(costs.inaccurate);
1523 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1524 }
1525 }
1526
TEST_F(OpLevelCostEstimatorTest,PredictAvgPool)1527 TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
1528 auto predict_avg_pool = [this](const int n, const int in, const int c,
1529 const int k, const int s,
1530 const string& padding) -> Costs {
1531 OpContext op_context = DescribePoolingOp(
1532 "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1533 return estimator_.PredictCosts(op_context);
1534 };
1535
1536 {
1537 // Typical 3x3 window with 2x2 stride.
1538 auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
1539 EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
1540 EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
1541 EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1542 EXPECT_EQ(costs.num_ops_total, 1);
1543 EXPECT_FALSE(costs.inaccurate);
1544 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1545 EXPECT_EQ(costs.temporary_memory, 0);
1546 EXPECT_EQ(costs.persistent_memory, 0);
1547 }
1548 {
1549 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1550 auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
1551 EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1552 EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1553 EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1554 EXPECT_EQ(1, costs.num_ops_total);
1555 EXPECT_FALSE(costs.inaccurate);
1556 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1557 }
1558 {
1559 // 2x2 window with 3x3 stride.
1560 auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
1561 EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
1562 EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
1563 EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1564 EXPECT_EQ(1, costs.num_ops_total);
1565 EXPECT_FALSE(costs.inaccurate);
1566 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1567 }
1568 }
1569
TEST_F(OpLevelCostEstimatorTest,PredictAvgPoolGrad)1570 TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
1571 auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
1572 const int k, const int s,
1573 const string& padding) -> Costs {
1574 OpContext op_context =
1575 DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
1576 {1, s, s, 1}, "NHWC", padding);
1577 return estimator_.PredictCosts(op_context);
1578 };
1579
1580 {
1581 // Typical 3xz3 window with 2x2 stride.
1582 auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
1583 EXPECT_EQ(Costs::Duration(1305602), costs.execution_time);
1584 EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
1585 EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1586 EXPECT_EQ(costs.num_ops_total, 1);
1587 EXPECT_FALSE(costs.inaccurate);
1588 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1589 EXPECT_EQ(costs.temporary_memory, 0);
1590 EXPECT_EQ(costs.persistent_memory, 0);
1591 }
1592 {
1593 // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1594 auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
1595 EXPECT_EQ(Costs::Duration(960002), costs.execution_time);
1596 EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
1597 EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1598 EXPECT_EQ(1, costs.num_ops_total);
1599 EXPECT_FALSE(costs.inaccurate);
1600 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1601 }
1602 {
1603 // 2x2 window with 3x3 stride.
1604 auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
1605 EXPECT_EQ(Costs::Duration(862082), costs.execution_time);
1606 EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
1607 EXPECT_EQ(Costs::Duration(689666), costs.memory_time);
1608 EXPECT_EQ(1, costs.num_ops_total);
1609 EXPECT_FALSE(costs.inaccurate);
1610 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1611 }
1612 }
1613
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNorm)1614 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
1615 auto predict_fused_bn = [this](const int n, const int in, const int c,
1616 const bool is_training) -> Costs {
1617 OpContext op_context = DescribeFusedBatchNorm(
1618 is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
1619 return estimator_.PredictCosts(op_context);
1620 };
1621
1622 {
1623 auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
1624 EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
1625 EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
1626 EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
1627 EXPECT_EQ(costs.num_ops_total, 1);
1628 EXPECT_FALSE(costs.inaccurate);
1629 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1630 EXPECT_EQ(costs.temporary_memory, 0);
1631 EXPECT_EQ(costs.persistent_memory, 0);
1632 }
1633
1634 {
1635 auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
1636 EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
1637 EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
1638 EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
1639 EXPECT_EQ(1, costs.num_ops_total);
1640 EXPECT_FALSE(costs.inaccurate);
1641 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1642 }
1643
1644 {
1645 auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
1646 EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
1647 EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
1648 EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
1649 EXPECT_EQ(1, costs.num_ops_total);
1650 EXPECT_FALSE(costs.inaccurate);
1651 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1652 }
1653
1654 {
1655 auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
1656 EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
1657 EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
1658 EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
1659 EXPECT_FALSE(costs.inaccurate);
1660 EXPECT_EQ(1, costs.num_ops_total);
1661 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1662 }
1663 }
1664
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNormGrad)1665 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
1666 auto predict_fused_bn_grad = [this](const int n, const int in,
1667 const int c) -> Costs {
1668 OpContext op_context = DescribeFusedBatchNorm(
1669 /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
1670 return estimator_.PredictCosts(op_context);
1671 };
1672
1673 {
1674 auto costs = predict_fused_bn_grad(10, 20, 96);
1675 EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
1676 EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
1677 EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
1678 EXPECT_EQ(costs.num_ops_total, 1);
1679 EXPECT_FALSE(costs.inaccurate);
1680 EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1681 EXPECT_EQ(costs.temporary_memory, 0);
1682 EXPECT_EQ(costs.persistent_memory, 0);
1683 }
1684
1685 {
1686 auto costs = predict_fused_bn_grad(128, 7, 384);
1687 EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
1688 EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
1689 EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
1690 EXPECT_EQ(1, costs.num_ops_total);
1691 EXPECT_FALSE(costs.inaccurate);
1692 EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1693 }
1694 }
1695
TEST_F(OpLevelCostEstimatorTest,MaybeGetMinimumShape)1696 TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
1697 {
1698 TensorShapeProto x;
1699 x.set_unknown_rank(true);
1700 bool unknown_shapes = false;
1701 TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1702 EXPECT_TRUE(unknown_shapes);
1703 ExpectTensorShape({1, 1, 1, 1}, y);
1704 }
1705
1706 {
1707 TensorShapeProto x;
1708 x.set_unknown_rank(false);
1709 bool unknown_shapes = false;
1710 TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
1711 EXPECT_FALSE(unknown_shapes);
1712 ExpectTensorShape({1}, y);
1713 }
1714
1715 {
1716 TensorShapeProto x;
1717 x.set_unknown_rank(false);
1718 bool unknown_shapes = false;
1719 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1720 EXPECT_FALSE(unknown_shapes);
1721 ExpectTensorShape({1, 1}, y);
1722 }
1723
1724 {
1725 TensorShapeProto x;
1726 x.set_unknown_rank(false);
1727 x.add_dim()->set_size(10);
1728 x.add_dim()->set_size(20);
1729 bool unknown_shapes = false;
1730 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1731 EXPECT_FALSE(unknown_shapes);
1732 ExpectTensorShape({10, 20}, y);
1733
1734 unknown_shapes = false;
1735 TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1736 EXPECT_TRUE(unknown_shapes);
1737 EXPECT_EQ(4, z.dim_size());
1738 ExpectTensorShape({10, 20, 1, 1}, z);
1739 }
1740
1741 {
1742 TensorShapeProto x;
1743 x.set_unknown_rank(false);
1744 x.add_dim()->set_size(10);
1745 x.add_dim()->set_size(20);
1746 x.add_dim()->set_size(-1);
1747 x.add_dim()->set_size(20);
1748 bool unknown_shapes = false;
1749 TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1750 EXPECT_TRUE(unknown_shapes);
1751 ExpectTensorShape({10, 20, 1, 20}, y);
1752 }
1753
1754 {
1755 TensorShapeProto x;
1756 x.set_unknown_rank(false);
1757 x.add_dim()->set_size(10);
1758 x.add_dim()->set_size(20);
1759 x.add_dim()->set_size(30);
1760 x.add_dim()->set_size(20);
1761 bool unknown_shapes = false;
1762 TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1763 EXPECT_TRUE(unknown_shapes);
1764 ExpectTensorShape({10, 20}, y);
1765 }
1766 }
1767
TEST_F(OpLevelCostEstimatorTest,IntermediateRdWrBandwidth)1768 TEST_F(OpLevelCostEstimatorTest, IntermediateRdWrBandwidth) {
1769 TestOpLevelCostEstimator estimator;
1770
1771 // Compute limited.
1772 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1,
1773 /*gb_per_sec=*/1));
1774 estimator.SetComputeMemoryOverlap(true);
1775 auto cost = estimator.PredictCosts(
1776 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1777 EXPECT_EQ(Costs::Duration(3548774400), cost.execution_time);
1778 EXPECT_EQ(cost.execution_time, cost.compute_time);
1779
1780 estimator.SetComputeMemoryOverlap(false);
1781 cost = estimator.PredictCosts(
1782 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1783 EXPECT_EQ(Costs::Duration(3551112192), cost.execution_time);
1784 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1785 cost.intermediate_memory_time);
1786
1787 // Memory limited.
1788 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1789 /*gb_per_sec=*/1));
1790 estimator.SetComputeMemoryOverlap(true);
1791 cost = estimator.PredictCosts(
1792 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1793 EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1794 EXPECT_EQ(cost.execution_time, cost.memory_time);
1795
1796 estimator.SetComputeMemoryOverlap(false);
1797 cost = estimator.PredictCosts(
1798 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1799 EXPECT_EQ(Costs::Duration(2373281), cost.execution_time);
1800 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1801 cost.intermediate_memory_time);
1802
1803 // Intermediate memory bandwidth limited.
1804 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1805 /*gb_per_sec=*/9999,
1806 /*intermediate_read_gb_per_sec=*/1,
1807 /*intermediate_write_gb_per_sec=*/1));
1808 estimator.SetComputeMemoryOverlap(true);
1809 cost = estimator.PredictCosts(
1810 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1811 EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1812 EXPECT_EQ(cost.execution_time, cost.intermediate_memory_time);
1813
1814 estimator.SetComputeMemoryOverlap(false);
1815 cost = estimator.PredictCosts(
1816 DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1817 EXPECT_EQ(Costs::Duration(2373515), cost.execution_time);
1818 EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1819 cost.intermediate_memory_time);
1820 }
1821
TEST_F(OpLevelCostEstimatorTest,Einsum)1822 TEST_F(OpLevelCostEstimatorTest, Einsum) {
1823 { // Test a simple matrix multiplication.
1824 auto cost = PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"));
1825 EXPECT_EQ(Costs::Duration(104000), cost.execution_time);
1826 EXPECT_EQ(Costs::Duration(100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1827 cost.compute_time);
1828 EXPECT_EQ(Costs::Duration(4000), cost.memory_time);
1829 EXPECT_EQ(cost.num_ops_total, 1);
1830 EXPECT_FALSE(cost.inaccurate);
1831 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1832 EXPECT_EQ(cost.temporary_memory, 0);
1833 EXPECT_EQ(cost.persistent_memory, 0);
1834
1835 // Einsums and XlaEinsums should be estimated similarly.
1836 EXPECT_EQ(PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1837 .execution_time,
1838 PredictCosts(DescribeXlaEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1839 .execution_time);
1840 }
1841 { // Test a simple batch matrix multiplication.
1842 auto cost = PredictCosts(
1843 DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"));
1844 EXPECT_EQ(Costs::Duration(25 * 104000), cost.execution_time);
1845 EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1846 cost.compute_time);
1847 EXPECT_EQ(Costs::Duration(25 * 4000), cost.memory_time);
1848 EXPECT_EQ(1, cost.num_ops_total);
1849 EXPECT_FALSE(cost.inaccurate);
1850 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1851
1852 // Einsums and XlaEinsums should be estimated similarly.
1853 EXPECT_EQ(PredictCosts(
1854 DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"))
1855 .execution_time,
1856 PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50, 25},
1857 "Bik,jkB->Bij"))
1858 .execution_time);
1859 }
1860 { // Test multiple batch dimensions.
1861 auto cost = PredictCosts(DescribeEinsum(
1862 {25, 16, 100, 50}, {16, 100, 50, 25}, "BNik,NjkB->BNij"));
1863 EXPECT_EQ(Costs::Duration(16 * 25 * 104000), cost.execution_time);
1864 EXPECT_EQ(
1865 Costs::Duration(16 * 25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1866 cost.compute_time);
1867 EXPECT_EQ(Costs::Duration(16 * 25 * 4000), cost.memory_time);
1868 EXPECT_EQ(1, cost.num_ops_total);
1869 EXPECT_FALSE(cost.inaccurate);
1870 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1871
1872 // Einsums and XlaEinsums should be estimated similarly.
1873 EXPECT_EQ(
1874 PredictCosts(DescribeEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1875 "BNik,NjkB->BNij"))
1876 .execution_time,
1877 PredictCosts(DescribeXlaEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1878 "BNik,NjkB->BNij"))
1879 .execution_time);
1880 }
1881 { // Test multiple M dimensions.
1882 auto cost =
1883 PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"));
1884 EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1885 EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1886 cost.compute_time);
1887 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1888 EXPECT_EQ(1, cost.num_ops_total);
1889 EXPECT_FALSE(cost.inaccurate);
1890 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1891
1892 // Einsums and XlaEinsums should be estimated similarly.
1893 EXPECT_EQ(
1894 PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1895 .execution_time,
1896 PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1897 .execution_time);
1898 }
1899 { // Test multiple N dimensions.
1900 auto cost =
1901 PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"));
1902 EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1903 EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1904 cost.compute_time);
1905 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1906 EXPECT_EQ(1, cost.num_ops_total);
1907 EXPECT_FALSE(cost.inaccurate);
1908 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1909
1910 // Einsums and XlaEinsums should be estimated similarly.
1911 EXPECT_EQ(
1912 PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1913 .execution_time,
1914 PredictCosts(DescribeXlaEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1915 .execution_time);
1916 }
1917 { // Test multiple contracting dimensions.
1918 auto cost = PredictCosts(
1919 DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"));
1920 EXPECT_EQ(Costs::Duration(2600000), cost.execution_time);
1921 EXPECT_EQ(Costs::Duration(100 * 50 * 25 * 100 * 2 / (1000 * 10 * 1e-3)),
1922 cost.compute_time);
1923 EXPECT_EQ(Costs::Duration(100000), cost.memory_time);
1924 EXPECT_EQ(1, cost.num_ops_total);
1925 EXPECT_FALSE(cost.inaccurate);
1926 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1927
1928 // Einsums and XlaEinsums should be estimated similarly.
1929 EXPECT_EQ(PredictCosts(
1930 DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"))
1931 .execution_time,
1932 PredictCosts(DescribeXlaEinsum({100, 50, 25}, {100, 50, 25},
1933 "ikl,jkl->ij"))
1934 .execution_time);
1935 }
1936 { // Test a simple matrix transpose.
1937 auto cost = PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji"));
1938 EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
1939 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1940 EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1941 EXPECT_EQ(1, cost.num_ops_total);
1942 EXPECT_TRUE(cost.inaccurate);
1943 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1944
1945 // Einsums and XlaEinsums should be estimated similarly.
1946 EXPECT_EQ(
1947 PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji")).execution_time,
1948 PredictCosts(DescribeXlaEinsum({100, 50}, {}, "ij->ji"))
1949 .execution_time);
1950 }
1951 { // Test a malformed Einsum equation: Mismatch between shapes and equation.
1952 auto cost =
1953 PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"));
1954 EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1955 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1956 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1957 EXPECT_EQ(1, cost.num_ops_total);
1958 EXPECT_TRUE(cost.inaccurate);
1959 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1960
1961 // Einsums and XlaEinsums should be estimated similarly.
1962 EXPECT_EQ(
1963 PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1964 .execution_time,
1965 PredictCosts(DescribeXlaEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1966 .execution_time);
1967
1968 cost = PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"));
1969 EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1970 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1971 EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1972 EXPECT_EQ(1, cost.num_ops_total);
1973 EXPECT_TRUE(cost.inaccurate);
1974 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1975
1976 // Einsums and XlaEinsums should be estimated similarly.
1977 EXPECT_EQ(
1978 PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1979 .execution_time,
1980 PredictCosts(DescribeXlaEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1981 .execution_time);
1982 }
1983 { // Test an unsupported Einsum: ellipsis
1984 auto cost = PredictCosts(DescribeEinsum(
1985 {100, 50, 25, 16}, {50, 100, 32, 12}, "ik...,kl...->il..."));
1986 EXPECT_EQ(Costs::Duration(1568000), cost.execution_time);
1987 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1988 EXPECT_EQ(Costs::Duration(1568000), cost.memory_time);
1989 EXPECT_EQ(1, cost.num_ops_total);
1990 EXPECT_TRUE(cost.inaccurate);
1991 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1992
1993 // Einsums and XlaEinsums should be estimated similarly.
1994 EXPECT_EQ(
1995 PredictCosts(DescribeEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1996 "ik...,kl...->il..."))
1997 .execution_time,
1998 PredictCosts(DescribeXlaEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1999 "ik...,kl...->il..."))
2000 .execution_time);
2001 }
2002 { // Test a malformed/unsupported Einsum: repeated indices
2003 auto cost =
2004 PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"));
2005 EXPECT_EQ(Costs::Duration(202000), cost.execution_time);
2006 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
2007 EXPECT_EQ(Costs::Duration(202000), cost.memory_time);
2008 EXPECT_EQ(1, cost.num_ops_total);
2009 EXPECT_TRUE(cost.inaccurate);
2010 EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
2011
2012 // Einsums and XlaEinsums should be estimated similarly.
2013 EXPECT_EQ(
2014 PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
2015 .execution_time,
2016 PredictCosts(DescribeXlaEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
2017 .execution_time);
2018 }
2019 { // Test missing shapes.
2020 auto cost = PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"));
2021 EXPECT_EQ(Costs::Duration(3020), cost.execution_time);
2022 EXPECT_EQ(Costs::Duration(1 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
2023 cost.compute_time);
2024 EXPECT_EQ(Costs::Duration(2020), cost.memory_time);
2025 EXPECT_EQ(1, cost.num_ops_total);
2026 EXPECT_TRUE(cost.inaccurate);
2027 EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
2028
2029 // Einsums and XlaEinsums should be estimated similarly.
2030 EXPECT_EQ(PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
2031 .execution_time,
2032 PredictCosts(DescribeXlaEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
2033 .execution_time);
2034 }
2035 }
2036
TEST_F(OpLevelCostEstimatorTest,PredictResourceVariableOps)2037 TEST_F(OpLevelCostEstimatorTest, PredictResourceVariableOps) {
2038 TestOpLevelCostEstimator estimator;
2039 estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1, /*gb_per_sec=*/1));
2040
2041 {
2042 OpContext op_context;
2043 op_context.op_info.set_op("AssignVariableOp");
2044 DescribeDummyTensor(op_context.op_info.add_inputs());
2045 DescribeTensor1D(100, op_context.op_info.add_inputs());
2046 auto cost = estimator.PredictCosts(op_context);
2047 EXPECT_EQ(Costs::Duration(400), cost.memory_time);
2048 EXPECT_EQ(Costs::Duration(0), cost.compute_time);
2049 EXPECT_EQ(Costs::Duration(400), cost.execution_time);
2050 EXPECT_FALSE(cost.inaccurate);
2051 EXPECT_EQ(cost.temporary_memory, 0);
2052 EXPECT_EQ(cost.persistent_memory, 0);
2053 }
2054
2055 {
2056 OpContext op_context;
2057 op_context.op_info.set_op("AssignSubVariableOp");
2058 DescribeDummyTensor(op_context.op_info.add_inputs());
2059 DescribeTensor1D(100, op_context.op_info.add_inputs());
2060 auto cost = estimator.PredictCosts(op_context);
2061 EXPECT_EQ(Costs::Duration(400), cost.memory_time);
2062 EXPECT_EQ(Costs::Duration(100), cost.compute_time);
2063 EXPECT_EQ(Costs::Duration(400), cost.execution_time);
2064 EXPECT_FALSE(cost.inaccurate);
2065 }
2066 }
2067
TEST_F(OpLevelCostEstimatorTest,AddNExecutionTime)2068 TEST_F(OpLevelCostEstimatorTest, AddNExecutionTime) {
2069 OpContext op_context;
2070 SetCpuDevice(&op_context.op_info);
2071 op_context.op_info.set_op("AddN");
2072
2073 DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2074 DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2075 DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2076
2077 auto cost = PredictCosts(op_context);
2078 EXPECT_EQ(Costs::Duration(1200), cost.memory_time);
2079 EXPECT_EQ(Costs::Duration(200), cost.compute_time);
2080 EXPECT_EQ(Costs::Duration(1400), cost.execution_time);
2081 EXPECT_EQ(cost.num_ops_total, 1);
2082 EXPECT_FALSE(cost.inaccurate);
2083 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2084 EXPECT_EQ(cost.temporary_memory, 0);
2085 EXPECT_EQ(cost.persistent_memory, 0);
2086 }
2087
TEST_F(OpLevelCostEstimatorTest,IdentityOpExecutionTime)2088 TEST_F(OpLevelCostEstimatorTest, IdentityOpExecutionTime) {
2089 std::vector<std::string> identity_ops = {
2090 "_Recv", "_Send", "BitCast", "Identity",
2091 "Enter", "Exit", "IdentityN", "Merge",
2092 "NextIteration", "Placeholder", "PreventGradient", "RefIdentity",
2093 "Reshape", "StopGradient", "Switch"};
2094
2095 const int kTensorSize = 1000;
2096 for (auto identity_op : identity_ops) {
2097 OpContext op_context = DescribeUnaryOp(identity_op, kTensorSize);
2098
2099 const int kExpectedMemoryTime = 0;
2100 const int kExpectedComputeTime = 1;
2101
2102 auto cost = PredictCosts(op_context);
2103 EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2104 EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2105 EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2106 cost.execution_time);
2107 EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2108 EXPECT_EQ(cost.num_ops_total, 1);
2109 EXPECT_FALSE(cost.inaccurate);
2110 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2111 EXPECT_EQ(cost.temporary_memory, 0);
2112 EXPECT_EQ(cost.persistent_memory, 0);
2113 }
2114 }
2115
TEST_F(OpLevelCostEstimatorTest,PureMemoryOpExecutionTime)2116 TEST_F(OpLevelCostEstimatorTest, PureMemoryOpExecutionTime) {
2117 std::vector<std::string> reshape_ops = {
2118 "ConcatV2", "DataFormatVecPermute",
2119 "DepthToSpace", "ExpandDims",
2120 "Fill", "OneHot",
2121 "Pack", "Range",
2122 "SpaceToDepth", "Split",
2123 "Squeeze", "Transpose",
2124 "Tile", "Unpack"};
2125
2126 const int kTensorSize = 1000;
2127 for (auto reshape_op : reshape_ops) {
2128 OpContext op_context = DescribeUnaryOp(reshape_op, kTensorSize);
2129
2130 const int kExpectedMemoryTime = 800;
2131 const int kExpectedComputeTime = 0;
2132
2133 auto cost = PredictCosts(op_context);
2134 EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2135 EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2136 EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2137 cost.execution_time);
2138 EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2139 EXPECT_EQ(cost.num_ops_total, 1);
2140 EXPECT_FALSE(cost.inaccurate);
2141 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2142 EXPECT_EQ(cost.temporary_memory, 0);
2143 EXPECT_EQ(cost.persistent_memory, 0);
2144 }
2145 }
2146
TEST_F(OpLevelCostEstimatorTest,ResizeBilinearExecutionTime)2147 TEST_F(OpLevelCostEstimatorTest, ResizeBilinearExecutionTime) {
2148 const int kImageDim = 255;
2149 const int kChannelSize = 10;
2150 const int kComputeLerpCost = 9;
2151 {
2152 OpContext op_context;
2153 SetCpuDevice(&op_context.op_info);
2154 op_context.op_info.set_op("ResizeBilinear");
2155 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2156 op_context.op_info.add_inputs());
2157 // Test with no output.
2158 auto cost = PredictCosts(op_context);
2159 ExpectZeroCost(cost);
2160 op_context.op_info.clear_inputs();
2161
2162 DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2163 // Test with no input.
2164 cost = PredictCosts(op_context);
2165 ExpectZeroCost(cost);
2166 }
2167 {
2168 // Test with size 0 output.
2169 OpContext op_context;
2170 SetCpuDevice(&op_context.op_info);
2171 op_context.op_info.set_op("ResizeBilinear");
2172
2173 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2174 op_context.op_info.add_inputs());
2175 const int kExpectedMemoryTime = kImageDim * kImageDim * 4;
2176 DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2177
2178 // As the half_pixel_centers attr was not set, cost should be inaccurate
2179 // with 0 compute time.
2180 auto cost = PredictCosts(op_context);
2181 EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2182 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2183 EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2184 EXPECT_TRUE(cost.inaccurate);
2185 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2186 EXPECT_EQ(cost.temporary_memory, 0);
2187 EXPECT_EQ(cost.persistent_memory, 0);
2188
2189 AttrValue half_pixel_centers;
2190 half_pixel_centers.set_b(false);
2191 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2192 half_pixel_centers;
2193 cost = PredictCosts(op_context);
2194 // Compute time depends only on output size, so compute time is 0.
2195 EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2196 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2197 EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2198 EXPECT_FALSE(cost.inaccurate);
2199 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2200 }
2201
2202 // Test with non-zero output size.
2203 const int kOutputImageDim = 100;
2204 OpContext op_context;
2205 SetCpuDevice(&op_context.op_info);
2206 op_context.op_info.set_op("ResizeBilinear");
2207 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2208 op_context.op_info.add_inputs());
2209 DescribeTensor4D(1, kOutputImageDim, kOutputImageDim, kChannelSize,
2210 op_context.op_info.add_outputs());
2211 const int kExpectedMemoryTime =
2212 (kImageDim * kImageDim + kOutputImageDim * kOutputImageDim) * 4;
2213
2214 {
2215 // Cost of calculating weights without using half_pixel_centers.
2216 AttrValue half_pixel_centers;
2217 half_pixel_centers.set_b(false);
2218 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2219 half_pixel_centers;
2220 const int kInterpWeightCost = 10;
2221 const int num_ops =
2222 kInterpWeightCost * (kOutputImageDim * 2) +
2223 kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2224 const int expected_compute_time = std::ceil(
2225 num_ops /
2226 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2227
2228 const auto cost = PredictCosts(op_context);
2229 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2230 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2231 EXPECT_EQ(cost.execution_time,
2232 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2233 EXPECT_FALSE(cost.inaccurate);
2234 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2235 }
2236
2237 {
2238 // Cost of calculating weights using half_pixel_centers.
2239 AttrValue half_pixel_centers;
2240 half_pixel_centers.set_b(true);
2241 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2242 half_pixel_centers;
2243 const int kInterpWeightCost = 12;
2244 const int num_ops =
2245 kInterpWeightCost * (kOutputImageDim * 2) +
2246 kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2247 const int expected_compute_time = std::ceil(
2248 num_ops /
2249 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2250
2251 const auto cost = PredictCosts(op_context);
2252 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2253 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2254 EXPECT_EQ(cost.execution_time,
2255 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2256 EXPECT_FALSE(cost.inaccurate);
2257 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2258 }
2259
2260 {
2261 // Cost with very large tensor.
2262 op_context.op_info.clear_outputs();
2263 // Number of elements in tensor exceeds 2^32.
2264 constexpr int64_t kLargeOutputImageDim = 40000;
2265 DescribeTensor4D(1, kLargeOutputImageDim, kLargeOutputImageDim,
2266 kChannelSize, op_context.op_info.add_outputs());
2267 const int64_t kInterpWeightCost = 12;
2268 // Using half_pixel_centers.
2269 AttrValue half_pixel_centers;
2270 half_pixel_centers.set_b(true);
2271 (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2272 half_pixel_centers;
2273
2274 const int64_t num_ops =
2275 kInterpWeightCost * (kLargeOutputImageDim * 2) +
2276 kComputeLerpCost *
2277 (kLargeOutputImageDim * kLargeOutputImageDim * kChannelSize);
2278 const int64_t expected_compute_time = std::ceil(
2279 num_ops /
2280 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2281
2282 const int64_t expected_memory_time =
2283 (kImageDim * kImageDim + kLargeOutputImageDim * kLargeOutputImageDim) *
2284 4;
2285
2286 const auto cost = PredictCosts(op_context);
2287 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2288 EXPECT_EQ(cost.memory_time, Costs::Duration(expected_memory_time));
2289 EXPECT_EQ(cost.execution_time,
2290 Costs::Duration(expected_memory_time + expected_compute_time));
2291 EXPECT_FALSE(cost.inaccurate);
2292 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2293 }
2294 }
2295
TEST_F(OpLevelCostEstimatorTest,CropAndResizeExecutionTime)2296 TEST_F(OpLevelCostEstimatorTest, CropAndResizeExecutionTime) {
2297 const int kImageDim = 255;
2298 const int kChannelSize = 10;
2299 const int kOutputImageDim = 100;
2300 const int kNumBoxes = 10;
2301 const int kOutputElements =
2302 kNumBoxes * kOutputImageDim * kOutputImageDim * kChannelSize;
2303 OpContext op_context;
2304 SetCpuDevice(&op_context.op_info);
2305 op_context.op_info.set_op("CropAndResize");
2306 DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2307 op_context.op_info.add_inputs());
2308 DescribeArbitraryRankInput({kNumBoxes, 4}, DT_INT64, &op_context.op_info);
2309 DescribeTensor4D(kNumBoxes, kOutputImageDim, kOutputImageDim, kChannelSize,
2310 op_context.op_info.add_outputs());
2311
2312 // Note this is time [ns, default in Duration in Costs], not bytes;
2313 // whereas memory bandwidth from SetCpuDevice() is 10GB/s.
2314 const int kExpectedMemoryTime =
2315 (kImageDim * kImageDim * 4 + // input image in float.
2316 kNumBoxes * 4 * 8 / 10 + // boxes (kNumBoxes x 4) in int64.
2317 kNumBoxes * kOutputImageDim * kOutputImageDim * 4); // output in float.
2318 // Note that input image and output image has kChannelSize dim, which is 10,
2319 // hence, no need to divide it by 10 (bandwidth).
2320
2321 {
2322 // Cost of CropAndResize with bilinear interpolation.
2323 AttrValue method;
2324 method.set_s("bilinear");
2325 (*op_context.op_info.mutable_attr())["method"] = method;
2326 int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2327 4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2328 3 * kNumBoxes * kOutputImageDim +
2329 3 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2330 13 * kOutputElements;
2331 const int expected_compute_time = std::ceil(
2332 num_ops /
2333 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2334
2335 const auto cost = PredictCosts(op_context);
2336 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2337 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2338 EXPECT_EQ(cost.execution_time,
2339 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2340 EXPECT_FALSE(cost.inaccurate);
2341 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2342 }
2343
2344 {
2345 // Cost of CropAndResize when nearest pixel is taken.
2346 AttrValue method;
2347 method.set_s("nearest");
2348 (*op_context.op_info.mutable_attr())["method"] = method;
2349 int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2350 4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2351 2 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2352 kOutputElements;
2353 const int expected_compute_time = std::ceil(
2354 num_ops /
2355 estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2356
2357 const auto cost = PredictCosts(op_context);
2358 EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2359 EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2360 EXPECT_EQ(cost.execution_time,
2361 Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2362 EXPECT_FALSE(cost.inaccurate);
2363 EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2364 }
2365 }
2366
2367 } // end namespace grappler
2368 } // end namespace tensorflow
2369