xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
17 
18 #include <unordered_set>
19 
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/attr_value_util.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor.pb.h"
24 #include "tensorflow/core/framework/tensor_shape.h"
25 #include "tensorflow/core/framework/tensor_shape.pb.h"
26 #include "tensorflow/core/framework/types.h"
27 #include "tensorflow/core/platform/status_matchers.h"
28 #include "tensorflow/core/platform/test.h"
29 #include "tensorflow/core/protobuf/device_properties.pb.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 
34 namespace {
35 
36 // TODO(dyoon): Consider to use this Test class for all the test cases, and then
37 // remove friend in the OpLevelCostEstimator class header.
38 class TestOpLevelCostEstimator : public OpLevelCostEstimator {
39  public:
TestOpLevelCostEstimator()40   TestOpLevelCostEstimator() {
41     compute_memory_overlap_ = true;
42     device_info_ = DeviceInfo();
43   }
~TestOpLevelCostEstimator()44   ~TestOpLevelCostEstimator() override {}
45 
SetDeviceInfo(const DeviceInfo & device_info)46   void SetDeviceInfo(const DeviceInfo& device_info) {
47     device_info_ = device_info;
48   }
49 
SetComputeMemoryOverlap(bool value)50   void SetComputeMemoryOverlap(bool value) { compute_memory_overlap_ = value; }
51 
52  protected:
GetDeviceInfo(const DeviceProperties & device) const53   DeviceInfo GetDeviceInfo(const DeviceProperties& device) const override {
54     return device_info_;
55   }
56 
57   DeviceInfo device_info_;
58 };
59 
ExpectZeroCost(const Costs & cost)60 void ExpectZeroCost(const Costs& cost) {
61   EXPECT_TRUE(cost.inaccurate);
62   EXPECT_EQ(cost.compute_time, Costs::Duration::zero());
63   EXPECT_EQ(cost.execution_time, Costs::Duration::zero());
64   EXPECT_EQ(cost.memory_time, Costs::Duration::zero());
65 }
66 
67 // Wrangles the minimum number of proto fields to set up a matrix.
DescribeMatrix(int rows,int columns,OpInfo * op_info)68 void DescribeMatrix(int rows, int columns, OpInfo* op_info) {
69   auto input = op_info->add_inputs();
70   auto shape = input->mutable_shape();
71   auto shape_rows = shape->add_dim();
72   shape_rows->set_size(rows);
73   auto shape_columns = shape->add_dim();
74   shape_columns->set_size(columns);
75   input->set_dtype(DT_FLOAT);
76 }
77 
SetCpuDevice(OpInfo * op_info)78 void SetCpuDevice(OpInfo* op_info) {
79   auto device = op_info->mutable_device();
80   device->set_type("CPU");
81   device->set_num_cores(10);
82   device->set_bandwidth(10000000);  // 10000000 KB/s = 10 GB/s
83   device->set_frequency(1000);      // 1000 Mhz = 1 GHz
84 }
85 
86 // Returns an OpInfo for MatMul with the minimum set of fields set up.
DescribeMatMul(int m,int n,int l,int k)87 OpContext DescribeMatMul(int m, int n, int l, int k) {
88   OpContext op_context;
89   SetCpuDevice(&op_context.op_info);
90   op_context.op_info.set_op("MatMul");
91 
92   DescribeMatrix(m, l, &op_context.op_info);
93   DescribeMatrix(k, n, &op_context.op_info);
94   return op_context;
95 }
96 
97 // Wrangles the minimum number of proto fields to set up an input of
98 // arbitrary rank and type.
DescribeArbitraryRankInput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)99 void DescribeArbitraryRankInput(const std::vector<int>& dims, DataType dtype,
100                                 OpInfo* op_info) {
101   auto input = op_info->add_inputs();
102   input->set_dtype(dtype);
103   auto shape = input->mutable_shape();
104   for (auto d : dims) {
105     shape->add_dim()->set_size(d);
106   }
107 }
108 
109 // Wrangles the minimum number of proto fields to set up an output of
110 // arbitrary rank and type.
DescribeArbitraryRankOutput(const std::vector<int> & dims,DataType dtype,OpInfo * op_info)111 void DescribeArbitraryRankOutput(const std::vector<int>& dims, DataType dtype,
112                                  OpInfo* op_info) {
113   auto output = op_info->add_outputs();
114   output->set_dtype(dtype);
115   auto shape = output->mutable_shape();
116   for (auto d : dims) {
117     shape->add_dim()->set_size(d);
118   }
119 }
120 
121 // Returns an OpInfo for a SparseTensorDenseMatMul
DescribeSparseTensorDenseMatMul(const int nnz_a,const std::vector<int> & dims_b,const std::vector<int> & dims_out)122 OpContext DescribeSparseTensorDenseMatMul(const int nnz_a,
123                                           const std::vector<int>& dims_b,
124                                           const std::vector<int>& dims_out) {
125   OpContext op_context;
126   SetCpuDevice(&op_context.op_info);
127   op_context.op_info.set_op("SparseTensorDenseMatMul");
128 
129   DescribeArbitraryRankInput({nnz_a, 2}, DT_INT64, &op_context.op_info);
130   DescribeArbitraryRankInput({nnz_a}, DT_FLOAT, &op_context.op_info);
131   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
132   DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
133   DescribeArbitraryRankOutput(dims_out, DT_FLOAT, &op_context.op_info);
134   return op_context;
135 }
136 
137 // Returns an OpInfo for an XlaEinsum
DescribeXlaEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)138 OpContext DescribeXlaEinsum(const std::vector<int>& dims_a,
139                             const std::vector<int>& dims_b,
140                             const string& equation) {
141   OpContext op_context;
142   SetCpuDevice(&op_context.op_info);
143   op_context.op_info.set_op("XlaEinsum");
144   AttrValue equation_attribute;
145   equation_attribute.set_s(equation);
146   (*op_context.op_info.mutable_attr())["equation"] = equation_attribute;
147   if (!dims_a.empty())
148     DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
149   if (!dims_b.empty())
150     DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
151   return op_context;
152 }
153 
154 // Returns an OpInfo for an Einsum
DescribeEinsum(const std::vector<int> & dims_a,const std::vector<int> & dims_b,const string & equation)155 OpContext DescribeEinsum(const std::vector<int>& dims_a,
156                          const std::vector<int>& dims_b,
157                          const string& equation) {
158   OpContext op_context = DescribeXlaEinsum(dims_a, dims_b, equation);
159   op_context.op_info.set_op("Einsum");
160   return op_context;
161 }
162 
DescribeDummyTensor(OpInfo::TensorProperties * tensor)163 void DescribeDummyTensor(OpInfo::TensorProperties* tensor) {
164   // Intentionally leave the tensor shape and type information missing.
165 }
166 
167 // Wrangles the minimum number of proto fields to set up a 1D Tensor for cost
168 // estimation purposes.
DescribeTensor1D(int dim0,OpInfo::TensorProperties * tensor)169 void DescribeTensor1D(int dim0, OpInfo::TensorProperties* tensor) {
170   auto shape = tensor->mutable_shape();
171   shape->add_dim()->set_size(dim0);
172   tensor->set_dtype(DT_FLOAT);
173 }
174 
175 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
176 // estimation purposes.
DescribeTensor4D(int dim0,int dim1,int dim2,int dim3,OpInfo::TensorProperties * tensor)177 void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3,
178                       OpInfo::TensorProperties* tensor) {
179   auto shape = tensor->mutable_shape();
180   shape->add_dim()->set_size(dim0);
181   shape->add_dim()->set_size(dim1);
182   shape->add_dim()->set_size(dim2);
183   shape->add_dim()->set_size(dim3);
184   tensor->set_dtype(DT_FLOAT);
185 }
186 
187 // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost
188 // estimation purposes.
DescribeTensor5D(int dim0,int dim1,int dim2,int dim3,int dim4,OpInfo::TensorProperties * tensor)189 void DescribeTensor5D(int dim0, int dim1, int dim2, int dim3, int dim4,
190                       OpInfo::TensorProperties* tensor) {
191   auto shape = tensor->mutable_shape();
192   shape->add_dim()->set_size(dim0);
193   shape->add_dim()->set_size(dim1);
194   shape->add_dim()->set_size(dim2);
195   shape->add_dim()->set_size(dim3);
196   shape->add_dim()->set_size(dim4);
197   tensor->set_dtype(DT_FLOAT);
198 }
199 
200 // DescribeConvolution constructs an OpContext for a Conv2D applied to an input
201 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
202 // (kx, ky, iz2, oz).
DescribeConvolution(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int oz)203 OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2,
204                               int kx, int ky, int oz) {
205   OpContext op_context;
206   SetCpuDevice(&op_context.op_info);
207   op_context.op_info.set_op("Conv2D");
208 
209   DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
210   DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
211 
212   return op_context;
213 }
214 
215 // Describe DepthwiseConvolution constructs an OpContext for a
216 // DepthwiseConv2dNative applied to an input
217 // tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape
218 // (kx, ky, iz2, cm). cm is channel multiplier
219 
DescribeDepthwiseConv2dNative(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int cm)220 OpContext DescribeDepthwiseConv2dNative(int batch, int ix, int iy, int iz1,
221                                         int iz2, int kx, int ky, int cm) {
222   OpContext op_context;
223   SetCpuDevice(&op_context.op_info);
224   op_context.op_info.set_op("DepthwiseConv2dNative");
225 
226   DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
227   DescribeTensor4D(kx, ky, iz2, cm, op_context.op_info.add_inputs());
228 
229   return op_context;
230 }
231 
232 // DescribeFusedConv2DBiasActivation constructs an OpContext for a
233 // FusedConv2DBiasActivation applied to a convolution input tensor with shape
234 // (batch, ix, iy, iz1), a kernel tensor with shape (kx, ky, iz2, oz), a
235 // bias tensor with shape (oz), a side input tensor with shape
236 // (batch, ox, oy, oz) if has_side_input is set, and two scaling tensors with
237 // shape (1). If a vectorized channel format is chosen (NCHW_VECT_C, e.g.) we'll
238 // default to 4 (the vector size most often used with this format on NVIDIA
239 // platforms) for the major channel size, and divide the input channel size by
240 // that amount.
241 //
242 // Note that this assumes the NHWC data format.
DescribeFusedConv2DBiasActivation(int batch,int ix,int iy,int iz1,int iz2,int kx,int ky,int ox,int oy,int oz,bool has_side_input,const string & data_format,const string & filter_format)243 OpContext DescribeFusedConv2DBiasActivation(int batch, int ix, int iy, int iz1,
244                                             int iz2, int kx, int ky, int ox,
245                                             int oy, int oz, bool has_side_input,
246                                             const string& data_format,
247                                             const string& filter_format) {
248   const int kVecWidth = 4;
249   OpContext op_context;
250   SetCpuDevice(&op_context.op_info);
251   op_context.op_info.set_op("FusedConv2DBiasActivation");
252   auto* attr_data_format = op_context.op_info.mutable_attr();
253   SetAttrValue(data_format, &(*attr_data_format)["data_format"]);
254   auto* attr_filter_format = op_context.op_info.mutable_attr();
255   SetAttrValue(filter_format, &(*attr_filter_format)["filter_format"]);
256   if (data_format == "NHWC") {
257     DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs());
258   } else if (data_format == "NCHW") {
259     DescribeTensor4D(batch, iz1, ix, iy, op_context.op_info.add_inputs());
260   } else {
261     // Use the NCHW_VECT_C format.
262     EXPECT_EQ(data_format, "NCHW_VECT_C");
263     EXPECT_EQ(iz1 % kVecWidth, 0);
264     DescribeTensor5D(batch, iz1 / kVecWidth, ix, iy, kVecWidth,
265                      op_context.op_info.add_inputs());
266   }
267   if (filter_format == "HWIO") {
268     DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs());
269   } else if (filter_format == "OIHW") {
270     DescribeTensor4D(oz, iz2, kx, ky, op_context.op_info.add_inputs());
271   } else {
272     EXPECT_EQ(filter_format, "OIHW_VECT_I");
273     EXPECT_EQ(iz2 % kVecWidth, 0);
274     // Use the OIHW_VECT_I format.
275     DescribeTensor5D(oz, iz2 / kVecWidth, kx, ky, kVecWidth,
276                      op_context.op_info.add_inputs());
277   }
278   DescribeTensor1D(oz, op_context.op_info.add_inputs());
279 
280   // Add the side_input, if any.
281   auto side_input = op_context.op_info.add_inputs();
282   if (has_side_input) {
283     if (data_format == "NHWC") {
284       DescribeTensor4D(batch, ox, oy, oz, side_input);
285     } else if (data_format == "NCHW") {
286       DescribeTensor4D(batch, oz, ox, oy, side_input);
287     } else {
288       // Use the NCHW_VECT_C format.
289       EXPECT_EQ(data_format, "NCHW_VECT_C");
290       EXPECT_EQ(oz % kVecWidth, 0);
291       DescribeTensor5D(batch, oz / kVecWidth, ox, oy, kVecWidth, side_input);
292     }
293   }
294 
295   // Add the scaling tensors.
296   DescribeTensor1D(1, op_context.op_info.add_inputs());
297   DescribeTensor1D(1, op_context.op_info.add_inputs());
298 
299   return op_context;
300 }
301 
302 // DescribeUnaryOp constructs an OpContext for the given operation applied to
303 // a 4-tensor with shape (size1, 1, 1, 1).
DescribeUnaryOp(const string & op,int size1)304 OpContext DescribeUnaryOp(const string& op, int size1) {
305   OpContext op_context;
306   SetCpuDevice(&op_context.op_info);
307   op_context.op_info.set_op(op);
308 
309   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
310   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_outputs());
311 
312   return op_context;
313 }
314 
315 // DescribeBinaryOp constructs an OpContext for the given operation applied to
316 // a 4-tensor with dimensions (size1, 1, 1, 1) and a 4-tensor with dimensions
317 // (2 * size1, size2, 1, 1).
318 //
319 // The choice of dimension here is arbitrary, and is used strictly to test the
320 // cost model for applying elementwise operations to tensors with unequal
321 // dimension values.
DescribeBinaryOp(const string & op,int size1,int size2)322 OpContext DescribeBinaryOp(const string& op, int size1, int size2) {
323   OpContext op_context;
324   SetCpuDevice(&op_context.op_info);
325   op_context.op_info.set_op(op);
326 
327   DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs());
328   DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_inputs());
329   DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_outputs());
330 
331   return op_context;
332 }
333 
334 // DescribeBiasAdd constructs an OpContext for a BiasAdd applied to a 4-tensor
335 // with dimensions (1, 1, size2, size1) and a bias with dimension (size1),
336 // according to the constraint that the bias must be 1D with size equal to that
337 // of the last dimension of the input value.
DescribeBiasAdd(int size1,int size2)338 OpContext DescribeBiasAdd(int size1, int size2) {
339   OpContext op_context;
340   SetCpuDevice(&op_context.op_info);
341   op_context.op_info.set_op("BiasAdd");
342 
343   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs());
344   DescribeTensor1D(size1, op_context.op_info.add_inputs());
345   DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs());
346 
347   return op_context;
348 }
349 
GetOutputSize(const int x,const int k,const int s,const string & padding)350 int GetOutputSize(const int x, const int k, const int s,
351                   const string& padding) {
352   if (padding == "SAME") {
353     return (x + s - 1) / s;
354   } else {
355     return (x - k + s) / s;
356   }
357 }
358 
GetPoolingOutputSize(const std::vector<int> & input,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)359 std::vector<int> GetPoolingOutputSize(const std::vector<int>& input,
360                                       const std::vector<int>& ksize,
361                                       const std::vector<int>& strides,
362                                       const string& data_format,
363                                       const string& padding) {
364   // h, w, and c indices: default with NHWC.
365   int h_index = 1;
366   int w_index = 2;
367   int c_index = 3;
368   if (data_format == "NCHW") {
369     h_index = 2;
370     w_index = 3;
371     c_index = 1;
372   }
373   // Extract parameters.
374   int n = input[0];
375   int h = input[h_index];
376   int w = input[w_index];
377   int c = input[c_index];
378   int sx = strides[h_index];
379   int sy = strides[w_index];
380   int kx = ksize[h_index];
381   int ky = ksize[w_index];
382 
383   // Output activation size: default with VALID padding.
384   int ho = GetOutputSize(h, kx, sx, padding);
385   int wo = GetOutputSize(w, ky, sy, padding);
386 
387   std::vector<int> output;
388   if (data_format == "NHWC") {
389     output = {n, ho, wo, c};
390   } else {
391     output = {n, c, ho, wo};
392   }
393   return output;
394 }
395 
396 // Helper functions for testing GetTensorShapeProtoFromTensorProto().
GetTensorProto(const DataType dtype,const std::vector<int64_t> & shape,const std::vector<int64_t> values,const bool tensor_content,TensorProto * tensor_proto)397 void GetTensorProto(const DataType dtype, const std::vector<int64_t>& shape,
398                     const std::vector<int64_t> values,
399                     const bool tensor_content, TensorProto* tensor_proto) {
400   tensor_proto->Clear();
401   TensorProto temp_tensor_proto;
402   temp_tensor_proto.set_dtype(dtype);
403   for (const auto& x : shape) {
404     temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x);
405   }
406   for (const auto& x : values) {
407     if (dtype == DT_INT64) {
408       temp_tensor_proto.add_int64_val(x);
409     } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 ||
410                dtype == DT_UINT8) {
411       temp_tensor_proto.add_int_val(x);
412     } else if (dtype == DT_UINT32) {
413       temp_tensor_proto.add_uint32_val(x);
414     } else if (dtype == DT_UINT64) {
415       temp_tensor_proto.add_uint64_val(x);
416     } else {
417       CHECK(false) << "Unsupported dtype: " << dtype;
418     }
419   }
420   Tensor tensor(dtype);
421   CHECK(tensor.FromProto(temp_tensor_proto));
422   if (tensor_content) {
423     tensor.AsProtoTensorContent(tensor_proto);
424   } else {
425     tensor.AsProtoField(tensor_proto);
426   }
427 }
428 
DescribePoolingOp(const string & op_name,const std::vector<int> & x,const std::vector<int> & ksize,const std::vector<int> & strides,const string & data_format,const string & padding)429 OpContext DescribePoolingOp(const string& op_name, const std::vector<int>& x,
430                             const std::vector<int>& ksize,
431                             const std::vector<int>& strides,
432                             const string& data_format, const string& padding) {
433   OpContext op_context;
434   auto& op_info = op_context.op_info;
435   SetCpuDevice(&op_info);
436   op_info.set_op(op_name);
437 
438   const std::vector<int> y =
439       GetPoolingOutputSize(x, ksize, strides, data_format, padding);
440   if (op_name == "AvgPool" || op_name == "MaxPool") {
441     // input: x, output: y.
442     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
443     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_outputs());
444   } else if (op_name == "AvgPoolGrad") {
445     // input: x's shape, y_grad, output: x_grad.
446     DescribeArbitraryRankInput({4}, DT_INT32, &op_info);
447     auto* tensor_proto = op_info.mutable_inputs(0)->mutable_value();
448     GetTensorProto(DT_INT32, {4}, {x[0], x[1], x[2], x[3]},
449                    /*tensor_content=*/false, tensor_proto);
450     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
451     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
452   } else if (op_name == "MaxPoolGrad") {
453     // input: x, y, y_grad, output: x_grad.
454     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
455     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
456     DescribeTensor4D(y[0], y[1], y[2], y[3], op_info.add_inputs());
457     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_outputs());
458   }
459   auto* attr = op_info.mutable_attr();
460   SetAttrValue(data_format, &(*attr)["data_format"]);
461   SetAttrValue(padding, &(*attr)["padding"]);
462   SetAttrValue(strides, &(*attr)["strides"]);
463   SetAttrValue(ksize, &(*attr)["ksize"]);
464   return op_context;
465 }
466 
DescribeFusedBatchNorm(const bool is_training,const bool is_grad,const std::vector<int> & x,const string & data_format)467 OpContext DescribeFusedBatchNorm(const bool is_training, const bool is_grad,
468                                  const std::vector<int>& x,
469                                  const string& data_format) {
470   // First, get MaxPool op info with unit stride and unit window.
471   OpContext op_context = DescribePoolingOp("MaxPool", x, {1, 1, 1, 1},
472                                            {1, 1, 1, 1}, data_format, "SAME");
473   auto& op_info = op_context.op_info;
474   // Override op name.
475   if (is_grad) {
476     op_info.set_op("FusedBatchNormGrad");
477   } else {
478     op_info.set_op("FusedBatchNorm");
479   }
480 
481   // Add additional input output tensors.
482   if (is_grad) {
483     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
484   }
485   int num_1d_inputs = is_grad ? 3 : 4;
486   for (int i = 0; i < num_1d_inputs; i++) {
487     auto* tensor = op_info.add_inputs();
488     auto* shape = tensor->mutable_shape();
489     shape->add_dim()->set_size(x[3]);
490     tensor->set_dtype(DT_FLOAT);
491   }
492   for (int i = 0; i < 4; i++) {
493     auto* tensor = op_info.add_outputs();
494     auto* shape = tensor->mutable_shape();
495     shape->add_dim()->set_size(x[3]);
496     tensor->set_dtype(DT_FLOAT);
497   }
498 
499   // Delete unnecessary attr.
500   auto* attr = op_context.op_info.mutable_attr();
501   attr->erase("ksize");
502   attr->erase("strides");
503   attr->erase("padding");
504 
505   // Additional attrs for FusedBatchNorm.
506   SetAttrValue(is_training, &(*attr)["is_training"]);
507 
508   return op_context;
509 }
510 }  // namespace
511 
512 class OpLevelCostEstimatorTest : public ::testing::Test {
513  protected:
514   using BatchMatMulDimensions = OpLevelCostEstimator::BatchMatMulDimensions;
515 
PredictCosts(const OpContext & op_context) const516   Costs PredictCosts(const OpContext& op_context) const {
517     return estimator_.PredictCosts(op_context);
518   }
519 
CountMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const520   int64_t CountMatMulOperations(const OpInfo& op_info,
521                                 bool* found_unknown_shapes) const {
522     return estimator_.CountMatMulOperations(op_info, found_unknown_shapes);
523   }
524 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const525   int64_t CountBatchMatMulOperations(const OpInfo& op_info,
526                                      bool* found_unknown_shapes) const {
527     return estimator_.CountBatchMatMulOperations(op_info, found_unknown_shapes);
528   }
529 
CountBatchMatMulOperations(const OpInfo & op_info,BatchMatMulDimensions * batch_mat_mul,bool * found_unknown_shapes) const530   int64_t CountBatchMatMulOperations(const OpInfo& op_info,
531                                      BatchMatMulDimensions* batch_mat_mul,
532                                      bool* found_unknown_shapes) const {
533     return estimator_.CountBatchMatMulOperations(op_info, batch_mat_mul,
534                                                  found_unknown_shapes);
535   }
536 
SetComputeMemoryOverlap(bool value)537   void SetComputeMemoryOverlap(bool value) {
538     estimator_.compute_memory_overlap_ = value;
539   }
540 
ValidateOpDimensionsFromInputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)541   void ValidateOpDimensionsFromInputs(const int n, const int h, const int w,
542                                       const int c, const int kx, const int ky,
543                                       const int sx, const int sy,
544                                       const string& data_format,
545                                       const string& padding) {
546     OpContext op_context;
547     int ho;
548     int wo;
549     if (data_format == "NHWC") {
550       op_context = DescribePoolingOp("MaxPool", {n, h, w, c}, {1, kx, ky, 1},
551                                      {1, sx, sy, 1}, "NHWC", padding);
552       ho = op_context.op_info.outputs(0).shape().dim(1).size();
553       wo = op_context.op_info.outputs(0).shape().dim(2).size();
554     } else {
555       op_context = DescribePoolingOp("MaxPool", {n, c, h, w}, {1, 1, kx, ky},
556                                      {1, 1, sx, sy}, "NCHW", padding);
557       ho = op_context.op_info.outputs(0).shape().dim(2).size();
558       wo = op_context.op_info.outputs(0).shape().dim(3).size();
559     }
560 
561     bool found_unknown_shapes;
562     TF_ASSERT_OK_AND_ASSIGN(
563         auto dims, OpLevelCostEstimator::OpDimensionsFromInputs(
564                        op_context.op_info.inputs(0).shape(), op_context.op_info,
565                        &found_unknown_shapes));
566     Padding padding_enum;
567     if (padding == "VALID") {
568       padding_enum = Padding::VALID;
569     } else {
570       padding_enum = Padding::SAME;
571     }
572     EXPECT_EQ(n, dims.batch);
573     EXPECT_EQ(h, dims.ix);
574     EXPECT_EQ(w, dims.iy);
575     EXPECT_EQ(c, dims.iz);
576     EXPECT_EQ(kx, dims.kx);
577     EXPECT_EQ(ky, dims.ky);
578     EXPECT_EQ(sx, dims.sx);
579     EXPECT_EQ(sy, dims.sy);
580     EXPECT_EQ(ho, dims.ox);
581     EXPECT_EQ(wo, dims.oy);
582     EXPECT_EQ(c, dims.oz);
583     EXPECT_EQ(padding_enum, dims.padding);
584   }
585 
586   StatusOr<OpLevelCostEstimator::ConvolutionDimensions>
CallOpDimensionsFromInputs(const int n,const int h,const int w,const int c,const int kx,const int ky,const int sx,const int sy,const string & data_format,const string & padding)587   CallOpDimensionsFromInputs(const int n, const int h, const int w, const int c,
588                              const int kx, const int ky, const int sx,
589                              const int sy, const string& data_format,
590                              const string& padding) {
591     OpContext op_context;
592 
593     const std::vector<int> x = {n, h, w, c};
594     const std::vector<int> ksize = {1, kx, ky, 1};
595     std::vector<int> strides;
596     if (data_format == "NHWC") {
597       strides = {1, sy, sx, 1};
598     } else {
599       strides = {1, 1, sy, sx};
600     }
601 
602     auto& op_info = op_context.op_info;
603     SetCpuDevice(&op_info);
604     op_info.set_op("MaxPool");
605 
606     DescribeTensor4D(x[0], x[1], x[2], x[3], op_info.add_inputs());
607     auto* attr = op_info.mutable_attr();
608     SetAttrValue(data_format, &(*attr)["data_format"]);
609     SetAttrValue(padding, &(*attr)["padding"]);
610     SetAttrValue(strides, &(*attr)["strides"]);
611     SetAttrValue(ksize, &(*attr)["ksize"]);
612     bool found_unknown_shapes;
613     return OpLevelCostEstimator::OpDimensionsFromInputs(
614         op_context.op_info.inputs(0).shape(), op_context.op_info,
615         &found_unknown_shapes);
616   }
617 
618   OpLevelCostEstimator estimator_;
619 };
620 
621 class OpLevelBatchMatMulCostEstimatorTest
622     : public OpLevelCostEstimatorTest,
623       public ::testing::WithParamInterface<const char*> {
624  protected:
625   // Returns an OpInfo for a BatchMatMul
DescribeBatchMatMul(const std::vector<int> & dims_a,const std::vector<int> & dims_b)626   OpContext DescribeBatchMatMul(const std::vector<int>& dims_a,
627                                 const std::vector<int>& dims_b) {
628     OpContext op_context;
629     SetCpuDevice(&op_context.op_info);
630     op_context.op_info.set_op(GetParam());
631 
632     DescribeArbitraryRankInput(dims_a, DT_FLOAT, &op_context.op_info);
633     DescribeArbitraryRankInput(dims_b, DT_FLOAT, &op_context.op_info);
634     return op_context;
635   }
636 
CountBatchMatMulOperations(const OpInfo & op_info,bool * found_unknown_shapes) const637   int64_t CountBatchMatMulOperations(const OpInfo& op_info,
638                                      bool* found_unknown_shapes) const {
639     return OpLevelCostEstimatorTest::CountBatchMatMulOperations(
640         op_info, found_unknown_shapes);
641   }
642 
CountBatchMatMulDimProduct(const OpInfo & op_info,bool * found_unknown_shapes) const643   int64_t CountBatchMatMulDimProduct(const OpInfo& op_info,
644                                      bool* found_unknown_shapes) const {
645     BatchMatMulDimensions batch_mat_mul;
646 
647     batch_mat_mul.matmul_dims.n = 0;
648     batch_mat_mul.matmul_dims.m = 0;
649     batch_mat_mul.matmul_dims.k = 0;
650 
651     OpLevelCostEstimatorTest::CountBatchMatMulOperations(
652         op_info, &batch_mat_mul, found_unknown_shapes);
653     int dimension_product = 1;
654     for (auto dim : batch_mat_mul.batch_dims) dimension_product *= dim;
655 
656     dimension_product *= batch_mat_mul.matmul_dims.n;
657     dimension_product *= batch_mat_mul.matmul_dims.m;
658     dimension_product *= batch_mat_mul.matmul_dims.k;
659 
660     return dimension_product;
661   }
662 };
663 
TEST_F(OpLevelCostEstimatorTest,TestPersistentOpCosts)664 TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
665   OpContext op_context;
666   SetCpuDevice(&op_context.op_info);
667   std::unordered_set<string> persistent_ops = {
668       "Const",       "Variable",       "VariableV2", "AutoReloadVariable",
669       "VarHandleOp", "ReadVariableOp",
670   };
671   // Minimum cost for all persistent ops.
672   for (const auto& op : persistent_ops) {
673     op_context.op_info.set_op(op);
674     auto cost = estimator_.PredictCosts(op_context);
675     EXPECT_EQ(Costs::Duration(0), cost.memory_time);
676     EXPECT_EQ(Costs::Duration(1), cost.compute_time);
677     EXPECT_EQ(Costs::Duration(1), cost.execution_time);
678     EXPECT_EQ(cost.num_ops_total, 1);
679     EXPECT_FALSE(cost.inaccurate);
680     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
681     EXPECT_EQ(cost.temporary_memory, 0);
682     EXPECT_EQ(cost.persistent_memory, 0);
683   }
684 }
685 
TEST_F(OpLevelCostEstimatorTest,TestGatherCosts)686 TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
687   std::vector<std::string> gather_ops = {"Gather", "GatherNd", "GatherV2"};
688 
689   for (const auto& op : gather_ops) {
690     OpContext op_context;
691     SetCpuDevice(&op_context.op_info);
692     op_context.op_info.set_op(op);
693 
694     // Huge first input shouldn't affect Gather execution and memory costs.
695     DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
696     DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
697     DescribeArbitraryRankOutput({16, 10}, DT_FLOAT, &op_context.op_info);
698 
699     auto cost = estimator_.PredictCosts(op_context);
700     EXPECT_EQ(Costs::Duration(130), cost.memory_time);
701     EXPECT_EQ(Costs::Duration(16), cost.compute_time);
702     EXPECT_EQ(Costs::Duration(146), cost.execution_time);
703     EXPECT_EQ(cost.num_ops_total, 1);
704     EXPECT_FALSE(cost.inaccurate);
705     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
706     EXPECT_EQ(cost.temporary_memory, 0);
707     EXPECT_EQ(cost.persistent_memory, 0);
708   }
709 }
710 
TEST_F(OpLevelCostEstimatorTest,TestGatherCostsWithoutOutput)711 TEST_F(OpLevelCostEstimatorTest, TestGatherCostsWithoutOutput) {
712   OpContext op_context;
713   SetCpuDevice(&op_context.op_info);
714   op_context.op_info.set_op("Gather");
715 
716   // Huge first input shouldn't affect Gather execution and memory costs.
717   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
718   DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
719 
720   auto cost = estimator_.PredictCosts(op_context);
721   EXPECT_EQ(Costs::Duration(0), cost.memory_time);
722   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
723   EXPECT_EQ(Costs::Duration(0), cost.execution_time);
724   EXPECT_EQ(1, cost.num_ops_total);
725   EXPECT_TRUE(cost.inaccurate);
726   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
727   EXPECT_EQ(cost.temporary_memory, 0);
728   EXPECT_EQ(cost.persistent_memory, 0);
729 }
730 
TEST_F(OpLevelCostEstimatorTest,TestSliceCosts)731 TEST_F(OpLevelCostEstimatorTest, TestSliceCosts) {
732   OpContext op_context;
733   SetCpuDevice(&op_context.op_info);
734   op_context.op_info.set_op("Slice");
735 
736   // Huge first input shouldn't affect Slice execution and memory costs.
737   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
738   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
739   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
740   DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
741 
742   auto cost = estimator_.PredictCosts(op_context);
743   EXPECT_EQ(Costs::Duration(81), cost.memory_time);
744   EXPECT_EQ(Costs::Duration(10), cost.compute_time);
745   EXPECT_EQ(Costs::Duration(91), cost.execution_time);
746   EXPECT_EQ(cost.num_ops_total, 1);
747   EXPECT_FALSE(cost.inaccurate);
748   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
749   EXPECT_EQ(cost.temporary_memory, 0);
750   EXPECT_EQ(cost.persistent_memory, 0);
751 }
752 
TEST_F(OpLevelCostEstimatorTest,TestStridedSliceCosts)753 TEST_F(OpLevelCostEstimatorTest, TestStridedSliceCosts) {
754   OpContext op_context;
755   SetCpuDevice(&op_context.op_info);
756   op_context.op_info.set_op("StridedSlice");
757 
758   // Huge first input shouldn't affect StridedSlice execution and memory costs.
759   DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
760   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
761   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
762   DescribeArbitraryRankInput({2}, DT_INT64, &op_context.op_info);
763   DescribeArbitraryRankOutput({10, 10}, DT_FLOAT, &op_context.op_info);
764 
765   auto cost = estimator_.PredictCosts(op_context);
766   EXPECT_EQ(Costs::Duration(81), cost.memory_time);
767   EXPECT_EQ(Costs::Duration(10), cost.compute_time);
768   EXPECT_EQ(Costs::Duration(91), cost.execution_time);
769   EXPECT_EQ(cost.num_ops_total, 1);
770   EXPECT_FALSE(cost.inaccurate);
771   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
772   EXPECT_EQ(cost.temporary_memory, 0);
773   EXPECT_EQ(cost.persistent_memory, 0);
774 }
775 
TEST_F(OpLevelCostEstimatorTest,TestScatterOps)776 TEST_F(OpLevelCostEstimatorTest, TestScatterOps) {
777   std::vector<string> scatter_ops = {"ScatterAdd",   "ScatterDiv", "ScatterMax",
778                                      "ScatterMin",   "ScatterMul", "ScatterSub",
779                                      "ScatterUpdate"};
780   for (const auto& op : scatter_ops) {
781     // Test updates.shape = indices.shape + ref.shape[1:]
782     {
783       OpContext op_context;
784       SetCpuDevice(&op_context.op_info);
785       op_context.op_info.set_op(op);
786       // Huge first dimension in input shouldn't affect Scatter execution and
787       // memory costs.
788       DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
789       DescribeArbitraryRankInput({16}, DT_INT64, &op_context.op_info);
790       DescribeArbitraryRankInput({16, 10}, DT_FLOAT, &op_context.op_info);
791       DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
792                                   &op_context.op_info);
793 
794       auto cost = estimator_.PredictCosts(op_context);
795       EXPECT_EQ(Costs::Duration(205), cost.memory_time);
796       EXPECT_EQ(Costs::Duration(16), cost.compute_time);
797       EXPECT_EQ(Costs::Duration(221), cost.execution_time);
798       EXPECT_EQ(cost.num_ops_total, 1);
799       EXPECT_FALSE(cost.inaccurate);
800       EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
801       EXPECT_EQ(cost.temporary_memory, 0);
802       EXPECT_EQ(cost.persistent_memory, 0);
803     }
804 
805     // Test updates.shape = [] and INT32 indices
806     {
807       OpContext op_context;
808       SetCpuDevice(&op_context.op_info);
809       op_context.op_info.set_op(op);
810       // Huge first dimension in input shouldn't affect Scatter execution and
811       // memory costs.
812       DescribeArbitraryRankInput({10000000, 10}, DT_FLOAT, &op_context.op_info);
813       DescribeArbitraryRankInput({16}, DT_INT32, &op_context.op_info);
814       DescribeArbitraryRankInput({}, DT_FLOAT, &op_context.op_info);
815       DescribeArbitraryRankOutput({10000000, 10}, DT_FLOAT,
816                                   &op_context.op_info);
817 
818       auto cost = estimator_.PredictCosts(op_context);
819       EXPECT_EQ(Costs::Duration(135), cost.memory_time);
820       EXPECT_EQ(Costs::Duration(16), cost.compute_time);
821       EXPECT_EQ(Costs::Duration(151), cost.execution_time);
822       EXPECT_EQ(1, cost.num_ops_total);
823       EXPECT_FALSE(cost.inaccurate);
824       EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
825     }
826   }
827 }
828 
TEST_F(OpLevelCostEstimatorTest,BiasAddExecutionTime)829 TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) {
830   auto cost = PredictCosts(DescribeBiasAdd(1000, 10));
831   EXPECT_EQ(Costs::Duration(8400), cost.memory_time);
832   EXPECT_EQ(Costs::Duration(1000), cost.compute_time);
833   EXPECT_EQ(Costs::Duration(9400), cost.execution_time);
834   EXPECT_EQ(cost.num_ops_total, 1);
835   EXPECT_FALSE(cost.inaccurate);
836   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
837   EXPECT_EQ(cost.temporary_memory, 0);
838   EXPECT_EQ(cost.persistent_memory, 0);
839 }
840 
TEST_F(OpLevelCostEstimatorTest,Conv2DExecutionTime)841 TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) {
842   auto cost = PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
843   EXPECT_EQ(Costs::Duration(233780), cost.memory_time);
844   EXPECT_EQ(Costs::Duration(354877440), cost.compute_time);
845   EXPECT_EQ(Costs::Duration(355111220), cost.execution_time);
846   EXPECT_EQ(cost.num_ops_total, 1);
847   EXPECT_FALSE(cost.inaccurate);
848   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
849   EXPECT_EQ(cost.temporary_memory, 0);
850   EXPECT_EQ(cost.persistent_memory, 0);
851 }
852 
TEST_F(OpLevelCostEstimatorTest,InvalidConv2DConfig)853 TEST_F(OpLevelCostEstimatorTest, InvalidConv2DConfig) {
854   // Convolution ops.
855   const std::vector<std::string> conv_ops = {
856       "Conv2D",
857       "Conv2DBackpropFilter",
858       "Conv2DBackpropInput",
859       "DepthwiseConv2dNative",
860       "DepthwiseConv2dNativeBackpropFilter",
861       "DepthwiseConv2dNativeBackpropInput",
862   };
863   // A valid Conv2D config.
864   const std::vector<int> valid_conv_config = {16, 19, 19, 48, 48, 5, 5, 256};
865   for (const auto& op : conv_ops) {
866     // Test with setting one value in conv config to zero.
867     // PredictCosts() should return zero costs.
868     for (int i = 0; i < valid_conv_config.size(); ++i) {
869       std::vector<int> conv_config(valid_conv_config);
870       conv_config[i] = 0;
871       auto op_context = DescribeConvolution(
872           conv_config[0], conv_config[1], conv_config[2], conv_config[3],
873           conv_config[4], conv_config[5], conv_config[6], conv_config[7]);
874       op_context.op_info.set_op(op);
875       auto cost = PredictCosts(op_context);
876       EXPECT_EQ(Costs::Duration(0), cost.memory_time);
877       EXPECT_EQ(Costs::Duration(0), cost.compute_time);
878       EXPECT_EQ(Costs::Duration(0), cost.execution_time);
879       EXPECT_EQ(1, cost.num_ops_total);
880       EXPECT_TRUE(cost.inaccurate);
881       EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
882     }
883   }
884 }
885 
TEST_F(OpLevelCostEstimatorTest,DepthwiseConv2dNativeExecutionTime)886 TEST_F(OpLevelCostEstimatorTest, DepthwiseConv2dNativeExecutionTime) {
887   auto cost =
888       PredictCosts(DescribeDepthwiseConv2dNative(16, 19, 19, 48, 48, 5, 5, 3));
889   EXPECT_EQ(Costs::Duration(112340), cost.memory_time);
890   EXPECT_EQ(Costs::Duration(4158720), cost.compute_time);
891   EXPECT_EQ(Costs::Duration(4271060), cost.execution_time);
892   EXPECT_EQ(cost.num_ops_total, 1);
893   EXPECT_FALSE(cost.inaccurate);
894   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
895   EXPECT_EQ(cost.temporary_memory, 0);
896   EXPECT_EQ(cost.persistent_memory, 0);
897 }
898 
TEST_F(OpLevelCostEstimatorTest,DummyExecutionTime)899 TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) {
900   auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
901   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
902   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
903   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
904   EXPECT_EQ(cost.num_ops_total, 1);
905   EXPECT_TRUE(cost.inaccurate);
906   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
907   EXPECT_EQ(cost.temporary_memory, 0);
908   EXPECT_EQ(cost.persistent_memory, 0);
909 }
910 
TEST_F(OpLevelCostEstimatorTest,ExecutionTimeSumOrMax)911 TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) {
912   SetComputeMemoryOverlap(true);
913   auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1));
914   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
915   EXPECT_EQ(Costs::Duration(0), cost.compute_time);
916   EXPECT_EQ(Costs::Duration(2000), cost.execution_time);  // max(2000, 200)
917   EXPECT_EQ(cost.num_ops_total, 1);
918   EXPECT_TRUE(cost.inaccurate);
919   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
920   EXPECT_EQ(cost.temporary_memory, 0);
921   EXPECT_EQ(cost.persistent_memory, 0);
922   SetComputeMemoryOverlap(false);  // Set it back to default.
923 }
924 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO_NoSideInput)925 TEST_F(OpLevelCostEstimatorTest,
926        FusedConv2DBiasActivationNCHW_HWIO_NoSideInput) {
927   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
928       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ false,
929       "NCHW", "HWIO"));
930   EXPECT_EQ(Costs::Duration(825345), cost.memory_time);
931   EXPECT_EQ(Costs::Duration(355321037), cost.compute_time);
932   EXPECT_EQ(Costs::Duration(356146382), cost.execution_time);
933   EXPECT_EQ(cost.num_ops_total, 1);
934   EXPECT_FALSE(cost.inaccurate);
935   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
936   EXPECT_EQ(cost.temporary_memory, 0);
937   EXPECT_EQ(cost.persistent_memory, 0);
938 }
939 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_HWIO)940 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_HWIO) {
941   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
942       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
943       "NCHW", "HWIO"));
944   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
945   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
946   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
947   EXPECT_EQ(cost.num_ops_total, 1);
948   EXPECT_FALSE(cost.inaccurate);
949   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
950   EXPECT_EQ(cost.temporary_memory, 0);
951   EXPECT_EQ(cost.persistent_memory, 0);
952 }
953 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW)954 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW) {
955   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
956       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
957       "NCHW", "OIHW"));
958   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
959   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
960   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
961   EXPECT_EQ(cost.num_ops_total, 1);
962   EXPECT_FALSE(cost.inaccurate);
963   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
964   EXPECT_EQ(cost.temporary_memory, 0);
965   EXPECT_EQ(cost.persistent_memory, 0);
966 }
967 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_HWIO)968 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_HWIO) {
969   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
970       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
971       "NHWC", "HWIO"));
972   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
973   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
974   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
975   EXPECT_EQ(cost.num_ops_total, 1);
976   EXPECT_FALSE(cost.inaccurate);
977   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
978   EXPECT_EQ(cost.temporary_memory, 0);
979   EXPECT_EQ(cost.persistent_memory, 0);
980 }
981 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNHWC_OIHW)982 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNHWC_OIHW) {
983   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
984       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
985       "NHWC", "OIHW"));
986   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
987   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
988   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
989   EXPECT_EQ(cost.num_ops_total, 1);
990   EXPECT_FALSE(cost.inaccurate);
991   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
992   EXPECT_EQ(cost.temporary_memory, 0);
993   EXPECT_EQ(cost.persistent_memory, 0);
994 }
995 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW)996 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_VECT_C_OIHW) {
997   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
998       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
999       "NCHW_VECT_C", "OIHW"));
1000   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
1001   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
1002   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
1003   EXPECT_EQ(cost.num_ops_total, 1);
1004   EXPECT_FALSE(cost.inaccurate);
1005   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1006   EXPECT_EQ(cost.temporary_memory, 0);
1007   EXPECT_EQ(cost.persistent_memory, 0);
1008 }
1009 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_OIHW_VECT_I)1010 TEST_F(OpLevelCostEstimatorTest, FusedConv2DBiasActivationNCHW_OIHW_VECT_I) {
1011   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
1012       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
1013       "NCHW", "OIHW_VECT_I"));
1014   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
1015   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
1016   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
1017   EXPECT_EQ(cost.num_ops_total, 1);
1018   EXPECT_FALSE(cost.inaccurate);
1019   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1020   EXPECT_EQ(cost.temporary_memory, 0);
1021   EXPECT_EQ(cost.persistent_memory, 0);
1022 }
1023 
TEST_F(OpLevelCostEstimatorTest,FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I)1024 TEST_F(OpLevelCostEstimatorTest,
1025        FusedConv2DBiasActivationNCHW_VECT_C_OIHW_VECT_I) {
1026   auto cost = PredictCosts(DescribeFusedConv2DBiasActivation(
1027       16, 19, 19, 48, 48, 5, 5, 19, 19, 256, /* has_side_input = */ true,
1028       "NCHW_VECT_C", "OIHW_VECT_I"));
1029   EXPECT_EQ(Costs::Duration(1416808), cost.memory_time);
1030   EXPECT_EQ(Costs::Duration(355616768), cost.compute_time);
1031   EXPECT_EQ(Costs::Duration(357033576), cost.execution_time);
1032   EXPECT_EQ(cost.num_ops_total, 1);
1033   EXPECT_FALSE(cost.inaccurate);
1034   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1035   EXPECT_EQ(cost.temporary_memory, 0);
1036   EXPECT_EQ(cost.persistent_memory, 0);
1037 }
1038 
TEST_F(OpLevelCostEstimatorTest,MulExecutionTime)1039 TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) {
1040   auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1));
1041   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1042   EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1043   EXPECT_EQ(Costs::Duration(2200), cost.execution_time);
1044   EXPECT_EQ(cost.num_ops_total, 1);
1045   EXPECT_FALSE(cost.inaccurate);
1046   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1047   EXPECT_EQ(cost.temporary_memory, 0);
1048   EXPECT_EQ(cost.persistent_memory, 0);
1049 }
1050 
TEST_F(OpLevelCostEstimatorTest,MulBroadcastExecutionTime)1051 TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) {
1052   auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 2));
1053   EXPECT_EQ(Costs::Duration(3600), cost.memory_time);
1054   EXPECT_EQ(Costs::Duration(400), cost.compute_time);
1055   EXPECT_EQ(Costs::Duration(4000), cost.execution_time);
1056   EXPECT_EQ(cost.num_ops_total, 1);
1057   EXPECT_FALSE(cost.inaccurate);
1058   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1059   EXPECT_EQ(cost.temporary_memory, 0);
1060   EXPECT_EQ(cost.persistent_memory, 0);
1061 }
1062 
TEST_F(OpLevelCostEstimatorTest,ModExecutionTime)1063 TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) {
1064   auto cost = PredictCosts(DescribeBinaryOp("Mod", 1000, 1));
1065   EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1066   EXPECT_EQ(Costs::Duration(1600), cost.compute_time);
1067   EXPECT_EQ(Costs::Duration(3600), cost.execution_time);
1068   EXPECT_EQ(cost.num_ops_total, 1);
1069   EXPECT_FALSE(cost.inaccurate);
1070   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1071   EXPECT_EQ(cost.temporary_memory, 0);
1072   EXPECT_EQ(cost.persistent_memory, 0);
1073 }
1074 
TEST_F(OpLevelCostEstimatorTest,SquaredDifferenceExecutionTime)1075 TEST_F(OpLevelCostEstimatorTest, SquaredDifferenceExecutionTime) {
1076   auto cost = PredictCosts(DescribeBinaryOp("SquaredDifference", 1000, 2));
1077   EXPECT_EQ(cost.memory_time, Costs::Duration(3600));
1078   EXPECT_EQ(cost.compute_time, Costs::Duration(800));
1079   EXPECT_EQ(cost.execution_time, Costs::Duration(4400));
1080   EXPECT_EQ(cost.num_ops_total, 1);
1081   EXPECT_FALSE(cost.inaccurate);
1082   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1083   EXPECT_EQ(cost.temporary_memory, 0);
1084   EXPECT_EQ(cost.persistent_memory, 0);
1085 }
1086 
TEST_F(OpLevelCostEstimatorTest,UnaryOpExecutionTime)1087 TEST_F(OpLevelCostEstimatorTest, UnaryOpExecutionTime) {
1088   std::vector<std::pair<std::string, int>> unary_ops = {
1089       {"All", 1},      {"ArgMax", 1}, {"Cast", 1},  {"Max", 1},
1090       {"Min", 1},      {"Prod", 1},   {"Relu", 1},  {"Relu6", 1},
1091       {"Softmax", 40}, {"Sum", 1},    {"TopKV2", 1}};
1092 
1093   const int kTensorSize = 1000;
1094   for (auto unary_op : unary_ops) {
1095     OpContext op_context = DescribeUnaryOp(unary_op.first, kTensorSize);
1096 
1097     const int kExpectedMemoryTime = 800;
1098     int expected_compute_time = std::ceil(
1099         unary_op.second * kTensorSize /
1100         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1101 
1102     auto cost = PredictCosts(op_context);
1103     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
1104     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time))
1105         << unary_op.first;
1106     EXPECT_EQ(cost.execution_time,
1107               Costs::Duration(expected_compute_time + kExpectedMemoryTime));
1108     EXPECT_EQ(cost.num_ops_total, 1);
1109     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1110     EXPECT_FALSE(cost.inaccurate);
1111     EXPECT_EQ(cost.temporary_memory, 0);
1112     EXPECT_EQ(cost.persistent_memory, 0);
1113   }
1114 }
1115 
TEST_F(OpLevelCostEstimatorTest,BinaryOpExecutionTime)1116 TEST_F(OpLevelCostEstimatorTest, BinaryOpExecutionTime) {
1117   std::vector<std::pair<std::string, int>> binary_ops = {
1118       {"Select", 1},
1119       {"SelectV2", 1},
1120       {"SquaredDifference", 2},
1121       {"Where", 1},
1122   };
1123 
1124   const int kTensorSize1 = 1000;
1125   const int kTensorSize2 = 2;
1126   for (auto binary_op : binary_ops) {
1127     OpContext op_context =
1128         DescribeBinaryOp(binary_op.first, kTensorSize1, kTensorSize2);
1129 
1130     const int kExpectedMemoryTime = 3600;
1131     int expected_compute_time = std::ceil(
1132         binary_op.second * kTensorSize1 * kTensorSize2 * 2 /
1133         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
1134 
1135     auto cost = PredictCosts(op_context);
1136     EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time)
1137         << binary_op.first;
1138     EXPECT_EQ(Costs::Duration(expected_compute_time), cost.compute_time)
1139         << binary_op.first;
1140     EXPECT_EQ(Costs::Duration(expected_compute_time + kExpectedMemoryTime),
1141               cost.execution_time)
1142         << binary_op.first;
1143     EXPECT_EQ(cost.num_ops_total, 1);
1144     EXPECT_FALSE(cost.inaccurate);
1145     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1146     EXPECT_EQ(cost.temporary_memory, 0);
1147     EXPECT_EQ(cost.persistent_memory, 0);
1148   }
1149 }
1150 
TEST_F(OpLevelCostEstimatorTest,BroadcastAddExecutionTime)1151 TEST_F(OpLevelCostEstimatorTest, BroadcastAddExecutionTime) {
1152   OpContext op_context;
1153   SetCpuDevice(&op_context.op_info);
1154   op_context.op_info.set_op("Add");
1155 
1156   DescribeTensor1D(100, op_context.op_info.add_inputs());
1157   DescribeTensor4D(1, 10, 1, 1, op_context.op_info.add_inputs());
1158 
1159   auto cost = PredictCosts(op_context);
1160   EXPECT_EQ(Costs::Duration(44), cost.memory_time);
1161   EXPECT_EQ(Costs::Duration(100), cost.compute_time);
1162   EXPECT_EQ(Costs::Duration(144), cost.execution_time);
1163   EXPECT_EQ(cost.num_ops_total, 1);
1164   EXPECT_FALSE(cost.inaccurate);
1165   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1166   EXPECT_EQ(cost.temporary_memory, 0);
1167   EXPECT_EQ(cost.persistent_memory, 0);
1168 }
1169 
TEST_F(OpLevelCostEstimatorTest,UnknownOrPartialShape)1170 TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) {
1171   {
1172     auto cost = PredictCosts(DescribeMatMul(2, 4, 7, 7));
1173     EXPECT_EQ(1, cost.num_ops_total);
1174     EXPECT_FALSE(cost.inaccurate);
1175     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1176   }
1177   {
1178     auto cost = PredictCosts(DescribeMatMul(-1, 4, 7, 7));
1179     EXPECT_EQ(1, cost.num_ops_total);
1180     EXPECT_TRUE(cost.inaccurate);
1181     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1182   }
1183   {
1184     auto cost = PredictCosts(DescribeMatMul(2, 4, -1, 7));
1185     EXPECT_EQ(1, cost.num_ops_total);
1186     EXPECT_TRUE(cost.inaccurate);
1187     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1188   }
1189   {
1190     auto cost =
1191         PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1192     EXPECT_EQ(1, cost.num_ops_total);
1193     EXPECT_FALSE(cost.inaccurate);
1194     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1195   }
1196   {
1197     auto cost =
1198         PredictCosts(DescribeConvolution(16, -1, 19, 48, 48, 5, 5, 256));
1199     EXPECT_EQ(1, cost.num_ops_total);
1200     EXPECT_TRUE(cost.inaccurate);
1201     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1202   }
1203 }
1204 
TEST_P(OpLevelBatchMatMulCostEstimatorTest,TestBatchMatMul)1205 TEST_P(OpLevelBatchMatMulCostEstimatorTest, TestBatchMatMul) {
1206   {
1207     auto cost = PredictCosts(DescribeBatchMatMul({}, {}));
1208     EXPECT_EQ(1, cost.num_ops_total);
1209     EXPECT_TRUE(cost.inaccurate);
1210     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1211   }
1212   {
1213     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {}));
1214     EXPECT_EQ(1, cost.num_ops_total);
1215     EXPECT_TRUE(cost.inaccurate);
1216     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1217   }
1218   {
1219     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {4, 2}));
1220     EXPECT_EQ(1, cost.num_ops_total);
1221     EXPECT_FALSE(cost.inaccurate);
1222     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1223   }
1224   {
1225     auto cost = PredictCosts(DescribeBatchMatMul({1, 2, 4}, {1, 4, 2}));
1226     EXPECT_EQ(1, cost.num_ops_total);
1227     EXPECT_FALSE(cost.inaccurate);
1228     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1229   }
1230   {
1231     auto cost = PredictCosts(DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}));
1232     EXPECT_EQ(1, cost.num_ops_total);
1233     EXPECT_FALSE(cost.inaccurate);
1234     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1235   }
1236   bool matmul_inaccurate = false;
1237   bool batch_matmul_inaccurate = false;
1238   EXPECT_EQ(
1239       CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1240                             &matmul_inaccurate),
1241       CountBatchMatMulOperations(DescribeBatchMatMul({2, 4}, {4, 2}).op_info,
1242                                  &batch_matmul_inaccurate));
1243   EXPECT_EQ(matmul_inaccurate, batch_matmul_inaccurate);
1244   EXPECT_EQ(10 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1245                                        &matmul_inaccurate),
1246             CountBatchMatMulOperations(
1247                 DescribeBatchMatMul({10, 2, 4}, {-1, 10, 4, 2}).op_info,
1248                 &batch_matmul_inaccurate));
1249   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1250   EXPECT_EQ(20 * CountMatMulOperations(DescribeMatMul(2, 2, 4, 4).op_info,
1251                                        &matmul_inaccurate),
1252             CountBatchMatMulOperations(
1253                 DescribeBatchMatMul({2, 10, 2, 4}, {-1, 10, 4, 2}).op_info,
1254                 &batch_matmul_inaccurate));
1255   EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate);
1256 
1257   // Test the count to make sure that they extracted the dimensions correctly
1258   int prod = CountBatchMatMulDimProduct(
1259       DescribeBatchMatMul({2, 4}, {1, 3, 4, 2}).op_info,
1260       &batch_matmul_inaccurate);
1261   EXPECT_EQ(prod, 16);
1262   EXPECT_FALSE(batch_matmul_inaccurate);
1263 
1264   // Exercise the bad cases of a batchMatMul.
1265   OpContext bad_batch = DescribeBatchMatMul({2, 4}, {4, 2});
1266   bad_batch.op_info.set_op("notBatchMatMul");
1267   prod =
1268       CountBatchMatMulDimProduct(bad_batch.op_info, &batch_matmul_inaccurate);
1269 
1270   EXPECT_EQ(prod, 0);
1271   EXPECT_TRUE(batch_matmul_inaccurate);
1272 
1273   // Exercise a transpose case of a batchMatMul
1274   OpContext transpose_batch = DescribeBatchMatMul({2, 4, 3, 1}, {4, 2});
1275   auto attr = transpose_batch.op_info.mutable_attr();
1276   (*attr)["adj_x"].set_b(true);
1277   (*attr)["adj_y"].set_b(true);
1278 
1279   prod = CountBatchMatMulDimProduct(transpose_batch.op_info,
1280                                     &batch_matmul_inaccurate);
1281   EXPECT_EQ(prod, 12);
1282 }
1283 INSTANTIATE_TEST_SUITE_P(TestBatchMatMul, OpLevelBatchMatMulCostEstimatorTest,
1284                          ::testing::Values("BatchMatMul", "BatchMatMulV2"));
1285 
TEST_F(OpLevelCostEstimatorTest,SparseTensorDenseMatMul)1286 TEST_F(OpLevelCostEstimatorTest, SparseTensorDenseMatMul) {
1287   // Unknown shape cases
1288   {
1289     auto cost =
1290         PredictCosts(DescribeSparseTensorDenseMatMul(-1, {1, 1}, {1, 1}));
1291     EXPECT_EQ(1, cost.num_ops_total);
1292     EXPECT_TRUE(cost.inaccurate);
1293     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1294   }
1295   {
1296     auto cost =
1297         PredictCosts(DescribeSparseTensorDenseMatMul(1, {-1, 1}, {1, 1}));
1298     EXPECT_EQ(1, cost.num_ops_total);
1299     EXPECT_TRUE(cost.inaccurate);
1300     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1301   }
1302   {
1303     auto cost =
1304         PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, -1}, {1, -1}));
1305     EXPECT_EQ(1, cost.num_ops_total);
1306     EXPECT_TRUE(cost.inaccurate);
1307     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1308   }
1309   {
1310     auto cost =
1311         PredictCosts(DescribeSparseTensorDenseMatMul(1, {1, 1}, {-1, 1}));
1312     EXPECT_EQ(1, cost.num_ops_total);
1313     EXPECT_TRUE(cost.inaccurate);
1314     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
1315   }
1316   // Known shape cases
1317   {
1318     auto cost = PredictCosts(
1319         DescribeSparseTensorDenseMatMul(10, {1000, 100}, {50, 100}));
1320     EXPECT_EQ(1, cost.num_ops_total);
1321     EXPECT_FALSE(cost.inaccurate);
1322     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1323     EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1324     EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1325   }
1326   {
1327     // Same cost as above case because cost does not depend on k_dim
1328     auto cost = PredictCosts(
1329         DescribeSparseTensorDenseMatMul(10, {100000, 100}, {50, 100}));
1330     EXPECT_EQ(1, cost.num_ops_total);
1331     EXPECT_FALSE(cost.inaccurate);
1332     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1333     EXPECT_EQ(Costs::Duration(200), cost.compute_time);
1334     EXPECT_EQ(Costs::Duration(2422), cost.memory_time);
1335   }
1336 }
1337 
ExpectTensorShape(const std::vector<int64_t> & expected,const TensorShapeProto & tensor_shape_proto)1338 void ExpectTensorShape(const std::vector<int64_t>& expected,
1339                        const TensorShapeProto& tensor_shape_proto) {
1340   TensorShape tensor_shape_expected(expected);
1341   TensorShape tensor_shape(tensor_shape_proto);
1342 
1343   EXPECT_EQ(tensor_shape_expected, tensor_shape);
1344 }
1345 
TEST_F(OpLevelCostEstimatorTest,GetTensorShapeProtoFromTensorProto)1346 TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) {
1347   TensorProto tensor_proto;
1348   TensorShapeProto tensor_shape_proto;
1349 
1350   // Dimension larger than max value; should fail while converting to
1351   // Tensor class.
1352   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255);
1353   EXPECT_FALSE(
1354       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1355 
1356   tensor_proto.Clear();
1357   // Expect only 1D shape.
1358   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1);
1359   tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2);
1360   EXPECT_FALSE(
1361       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1362 
1363   // Expect only handle integer data types.
1364   GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto);
1365   EXPECT_FALSE(
1366       GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1367 
1368   // Check GetTensorShapeProtoFromTensorProto() returns correct values.
1369   {
1370     std::vector<int64_t> shape_expected = {10, 20, 30, 40};
1371     GetTensorProto(DT_INT32, {4}, shape_expected,
1372                    /*tensor_content=*/false, &tensor_proto);
1373     EXPECT_TRUE(
1374         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1375     ExpectTensorShape(shape_expected, tensor_shape_proto);
1376   }
1377 
1378   {
1379     std::vector<int64_t> shape_expected = {40, 20, 90, 40};
1380     GetTensorProto(DT_INT64, {4}, shape_expected,
1381                    /*tensor_content=*/false, &tensor_proto);
1382     EXPECT_TRUE(
1383         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1384     ExpectTensorShape(shape_expected, tensor_shape_proto);
1385   }
1386 
1387   {
1388     std::vector<int64_t> shape_expected = {10, 20, 30, 40};
1389     GetTensorProto(DT_INT32, {4}, shape_expected,
1390                    /*tensor_content=*/true, &tensor_proto);
1391     EXPECT_TRUE(
1392         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1393     ExpectTensorShape(shape_expected, tensor_shape_proto);
1394   }
1395 
1396   {
1397     std::vector<int64_t> shape_expected = {40, 20, 90, 40};
1398     GetTensorProto(DT_INT64, {4}, shape_expected,
1399                    /*tensor_content=*/true, &tensor_proto);
1400     EXPECT_TRUE(
1401         GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto));
1402     ExpectTensorShape(shape_expected, tensor_shape_proto);
1403   }
1404 }
1405 
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputs)1406 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputs) {
1407   std::vector<string> paddings = {"VALID", "SAME"};
1408   std::vector<string> formats = {"NHWC", "NCHW"};
1409   for (const auto& p : paddings) {
1410     for (const auto& f : formats) {
1411       // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1412       ValidateOpDimensionsFromInputs(10, 20, 20, 100, 3, 3, 2, 2, f, p);
1413       ValidateOpDimensionsFromInputs(10, 20, 20, 100, 1, 1, 3, 3, f, p);
1414       ValidateOpDimensionsFromInputs(10, 200, 200, 100, 5, 5, 3, 3, f, p);
1415       ValidateOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 2, f, p);
1416     }
1417   }
1418 }
1419 
TEST_F(OpLevelCostEstimatorTest,OpDimensionsFromInputsError)1420 TEST_F(OpLevelCostEstimatorTest, OpDimensionsFromInputsError) {
1421   std::vector<string> paddings = {"VALID", "SAME"};
1422   std::vector<string> formats = {"NHWC", "NCHW"};
1423   for (const auto& p : paddings) {
1424     for (const auto& f : formats) {
1425       // n, h, w, c, kx, ky, sx, sy, data_format, padding.
1426       ASSERT_THAT(
1427           CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 0, 2, f, p),
1428           testing::StatusIs(
1429               error::INVALID_ARGUMENT,
1430               "Stride must be > 0 for Height and Width, but got (2, 0)"));
1431       ASSERT_THAT(
1432           CallOpDimensionsFromInputs(10, 14, 14, 3840, 3, 3, 2, 0, f, p),
1433           testing::StatusIs(
1434               error::INVALID_ARGUMENT,
1435               "Stride must be > 0 for Height and Width, but got (0, 2)"));
1436     }
1437   }
1438 }
1439 
TEST_F(OpLevelCostEstimatorTest,PredictMaxPool)1440 TEST_F(OpLevelCostEstimatorTest, PredictMaxPool) {
1441   auto predict_max_pool = [this](const int n, const int in, const int c,
1442                                  const int k, const int s,
1443                                  const string& padding) -> Costs {
1444     OpContext op_context = DescribePoolingOp(
1445         "MaxPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1446     return estimator_.PredictCosts(op_context);
1447   };
1448 
1449   {
1450     // Typical 3xz3 window with 2x2 stride.
1451     auto costs = predict_max_pool(10, 20, 384, 3, 2, "SAME");
1452     EXPECT_EQ(Costs::Duration(1075200), costs.execution_time);
1453     EXPECT_EQ(Costs::Duration(307200), costs.compute_time);
1454     EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1455     EXPECT_EQ(costs.num_ops_total, 1);
1456     EXPECT_FALSE(costs.inaccurate);
1457     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1458     EXPECT_EQ(costs.temporary_memory, 0);
1459     EXPECT_EQ(costs.persistent_memory, 0);
1460   }
1461   {
1462     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1463     auto costs = predict_max_pool(10, 20, 384, 1, 2, "SAME");
1464     EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1465     EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1466     EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1467     EXPECT_EQ(1, costs.num_ops_total);
1468     EXPECT_FALSE(costs.inaccurate);
1469     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1470   }
1471   {
1472     // 2x2 window with 3x3 stride.
1473     auto costs = predict_max_pool(10, 20, 384, 2, 3, "VALID");
1474     EXPECT_EQ(Costs::Duration(561792), costs.execution_time);
1475     EXPECT_EQ(Costs::Duration(56448), costs.compute_time);
1476     EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1477     EXPECT_EQ(1, costs.num_ops_total);
1478     EXPECT_FALSE(costs.inaccurate);
1479     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1480   }
1481 }
1482 
TEST_F(OpLevelCostEstimatorTest,PredictMaxPoolGrad)1483 TEST_F(OpLevelCostEstimatorTest, PredictMaxPoolGrad) {
1484   auto predict_max_pool_grad = [this](const int n, const int in, const int c,
1485                                       const int k, const int s,
1486                                       const string& padding) -> Costs {
1487     OpContext op_context =
1488         DescribePoolingOp("MaxPoolGrad", {n, in, in, c}, {1, k, k, 1},
1489                           {1, s, s, 1}, "NHWC", padding);
1490     return estimator_.PredictCosts(op_context);
1491   };
1492 
1493   {
1494     // Typical 3x3 window with 2x2 stride.
1495     auto costs = predict_max_pool_grad(10, 20, 384, 3, 2, "SAME");
1496     EXPECT_EQ(Costs::Duration(1996800), costs.execution_time);
1497     EXPECT_EQ(Costs::Duration(614400), costs.compute_time);
1498     EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1499     EXPECT_EQ(costs.num_ops_total, 1);
1500     EXPECT_FALSE(costs.inaccurate);
1501     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1502     EXPECT_EQ(costs.temporary_memory, 0);
1503     EXPECT_EQ(costs.persistent_memory, 0);
1504   }
1505   {
1506     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1507     auto costs = predict_max_pool_grad(10, 20, 384, 1, 2, "SAME");
1508     EXPECT_EQ(Costs::Duration(1536000), costs.execution_time);
1509     EXPECT_EQ(Costs::Duration(153600), costs.compute_time);
1510     EXPECT_EQ(Costs::Duration(1382400), costs.memory_time);
1511     EXPECT_EQ(1, costs.num_ops_total);
1512     EXPECT_FALSE(costs.inaccurate);
1513     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1514   }
1515   {
1516     // 2x2 window with 3x3 stride.
1517     auto costs = predict_max_pool_grad(10, 20, 384, 2, 3, "VALID");
1518     EXPECT_EQ(Costs::Duration(1514112), costs.execution_time);
1519     EXPECT_EQ(Costs::Duration(210048), costs.compute_time);
1520     EXPECT_EQ(Costs::Duration(1304064), costs.memory_time);
1521     EXPECT_EQ(1, costs.num_ops_total);
1522     EXPECT_FALSE(costs.inaccurate);
1523     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1524   }
1525 }
1526 
TEST_F(OpLevelCostEstimatorTest,PredictAvgPool)1527 TEST_F(OpLevelCostEstimatorTest, PredictAvgPool) {
1528   auto predict_avg_pool = [this](const int n, const int in, const int c,
1529                                  const int k, const int s,
1530                                  const string& padding) -> Costs {
1531     OpContext op_context = DescribePoolingOp(
1532         "AvgPool", {n, in, in, c}, {1, k, k, 1}, {1, s, s, 1}, "NHWC", padding);
1533     return estimator_.PredictCosts(op_context);
1534   };
1535 
1536   {
1537     // Typical 3x3 window with 2x2 stride.
1538     auto costs = predict_avg_pool(10, 20, 384, 3, 2, "SAME");
1539     EXPECT_EQ(Costs::Duration(1113600), costs.execution_time);
1540     EXPECT_EQ(Costs::Duration(345600), costs.compute_time);
1541     EXPECT_EQ(Costs::Duration(768000), costs.memory_time);
1542     EXPECT_EQ(costs.num_ops_total, 1);
1543     EXPECT_FALSE(costs.inaccurate);
1544     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1545     EXPECT_EQ(costs.temporary_memory, 0);
1546     EXPECT_EQ(costs.persistent_memory, 0);
1547   }
1548   {
1549     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1550     auto costs = predict_avg_pool(10, 20, 384, 1, 2, "SAME");
1551     EXPECT_EQ(Costs::Duration(499200), costs.execution_time);
1552     EXPECT_EQ(Costs::Duration(38400), costs.compute_time);
1553     EXPECT_EQ(Costs::Duration(460800), costs.memory_time);
1554     EXPECT_EQ(1, costs.num_ops_total);
1555     EXPECT_FALSE(costs.inaccurate);
1556     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1557   }
1558   {
1559     // 2x2 window with 3x3 stride.
1560     auto costs = predict_avg_pool(10, 20, 384, 2, 3, "VALID");
1561     EXPECT_EQ(Costs::Duration(580608), costs.execution_time);
1562     EXPECT_EQ(Costs::Duration(75264), costs.compute_time);
1563     EXPECT_EQ(Costs::Duration(505344), costs.memory_time);
1564     EXPECT_EQ(1, costs.num_ops_total);
1565     EXPECT_FALSE(costs.inaccurate);
1566     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1567   }
1568 }
1569 
TEST_F(OpLevelCostEstimatorTest,PredictAvgPoolGrad)1570 TEST_F(OpLevelCostEstimatorTest, PredictAvgPoolGrad) {
1571   auto predict_avg_pool_grad = [this](const int n, const int in, const int c,
1572                                       const int k, const int s,
1573                                       const string& padding) -> Costs {
1574     OpContext op_context =
1575         DescribePoolingOp("AvgPoolGrad", {n, in, in, c}, {1, k, k, 1},
1576                           {1, s, s, 1}, "NHWC", padding);
1577     return estimator_.PredictCosts(op_context);
1578   };
1579 
1580   {
1581     // Typical 3xz3 window with 2x2 stride.
1582     auto costs = predict_avg_pool_grad(10, 20, 384, 3, 2, "SAME");
1583     EXPECT_EQ(Costs::Duration(1305602), costs.execution_time);
1584     EXPECT_EQ(Costs::Duration(537600), costs.compute_time);
1585     EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1586     EXPECT_EQ(costs.num_ops_total, 1);
1587     EXPECT_FALSE(costs.inaccurate);
1588     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1589     EXPECT_EQ(costs.temporary_memory, 0);
1590     EXPECT_EQ(costs.persistent_memory, 0);
1591   }
1592   {
1593     // 1x1 window with 2x2 stride: used for shortcut in resnet-50.
1594     auto costs = predict_avg_pool_grad(10, 20, 384, 1, 2, "SAME");
1595     EXPECT_EQ(Costs::Duration(960002), costs.execution_time);
1596     EXPECT_EQ(Costs::Duration(192000), costs.compute_time);
1597     EXPECT_EQ(Costs::Duration(768002), costs.memory_time);
1598     EXPECT_EQ(1, costs.num_ops_total);
1599     EXPECT_FALSE(costs.inaccurate);
1600     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1601   }
1602   {
1603     // 2x2 window with 3x3 stride.
1604     auto costs = predict_avg_pool_grad(10, 20, 384, 2, 3, "VALID");
1605     EXPECT_EQ(Costs::Duration(862082), costs.execution_time);
1606     EXPECT_EQ(Costs::Duration(172416), costs.compute_time);
1607     EXPECT_EQ(Costs::Duration(689666), costs.memory_time);
1608     EXPECT_EQ(1, costs.num_ops_total);
1609     EXPECT_FALSE(costs.inaccurate);
1610     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1611   }
1612 }
1613 
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNorm)1614 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNorm) {
1615   auto predict_fused_bn = [this](const int n, const int in, const int c,
1616                                  const bool is_training) -> Costs {
1617     OpContext op_context = DescribeFusedBatchNorm(
1618         is_training, /*is_grad=*/false, {n, in, in, c}, "NHWC");
1619     return estimator_.PredictCosts(op_context);
1620   };
1621 
1622   {
1623     auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/true);
1624     EXPECT_EQ(Costs::Duration(614737), costs.execution_time);
1625     EXPECT_EQ(Costs::Duration(153706), costs.compute_time);
1626     EXPECT_EQ(Costs::Duration(461031), costs.memory_time);
1627     EXPECT_EQ(costs.num_ops_total, 1);
1628     EXPECT_FALSE(costs.inaccurate);
1629     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1630     EXPECT_EQ(costs.temporary_memory, 0);
1631     EXPECT_EQ(costs.persistent_memory, 0);
1632   }
1633 
1634   {
1635     auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/true);
1636     EXPECT_EQ(Costs::Duration(204913), costs.execution_time);
1637     EXPECT_EQ(Costs::Duration(51236), costs.compute_time);
1638     EXPECT_EQ(Costs::Duration(153677), costs.memory_time);
1639     EXPECT_EQ(1, costs.num_ops_total);
1640     EXPECT_FALSE(costs.inaccurate);
1641     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1642   }
1643 
1644   {
1645     auto costs = predict_fused_bn(10, 20, 96, /*is_training=*/false);
1646     EXPECT_EQ(Costs::Duration(384154), costs.execution_time);
1647     EXPECT_EQ(Costs::Duration(76800), costs.compute_time);
1648     EXPECT_EQ(Costs::Duration(307354), costs.memory_time);
1649     EXPECT_EQ(1, costs.num_ops_total);
1650     EXPECT_FALSE(costs.inaccurate);
1651     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1652   }
1653 
1654   {
1655     auto costs = predict_fused_bn(10, 20, 32, /*is_training=*/false);
1656     EXPECT_EQ(Costs::Duration(128052), costs.execution_time);
1657     EXPECT_EQ(Costs::Duration(25600), costs.compute_time);
1658     EXPECT_EQ(Costs::Duration(102452), costs.memory_time);
1659     EXPECT_FALSE(costs.inaccurate);
1660     EXPECT_EQ(1, costs.num_ops_total);
1661     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1662   }
1663 }
1664 
TEST_F(OpLevelCostEstimatorTest,PredictFusedBatchNormGrad)1665 TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
1666   auto predict_fused_bn_grad = [this](const int n, const int in,
1667                                       const int c) -> Costs {
1668     OpContext op_context = DescribeFusedBatchNorm(
1669         /*is_training=*/false, /*is_grad=*/true, {n, in, in, c}, "NHWC");
1670     return estimator_.PredictCosts(op_context);
1671   };
1672 
1673   {
1674     auto costs = predict_fused_bn_grad(10, 20, 96);
1675     EXPECT_EQ(Costs::Duration(1037050), costs.execution_time);
1676     EXPECT_EQ(Costs::Duration(422496), costs.compute_time);
1677     EXPECT_EQ(Costs::Duration(614554), costs.memory_time);
1678     EXPECT_EQ(costs.num_ops_total, 1);
1679     EXPECT_FALSE(costs.inaccurate);
1680     EXPECT_EQ(costs.num_ops_with_unknown_shapes, 0);
1681     EXPECT_EQ(costs.temporary_memory, 0);
1682     EXPECT_EQ(costs.persistent_memory, 0);
1683   }
1684 
1685   {
1686     auto costs = predict_fused_bn_grad(128, 7, 384);
1687     EXPECT_EQ(Costs::Duration(6503809), costs.execution_time);
1688     EXPECT_EQ(Costs::Duration(2649677), costs.compute_time);
1689     EXPECT_EQ(Costs::Duration(3854132), costs.memory_time);
1690     EXPECT_EQ(1, costs.num_ops_total);
1691     EXPECT_FALSE(costs.inaccurate);
1692     EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
1693   }
1694 }
1695 
TEST_F(OpLevelCostEstimatorTest,MaybeGetMinimumShape)1696 TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
1697   {
1698     TensorShapeProto x;
1699     x.set_unknown_rank(true);
1700     bool unknown_shapes = false;
1701     TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1702     EXPECT_TRUE(unknown_shapes);
1703     ExpectTensorShape({1, 1, 1, 1}, y);
1704   }
1705 
1706   {
1707     TensorShapeProto x;
1708     x.set_unknown_rank(false);
1709     bool unknown_shapes = false;
1710     TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
1711     EXPECT_FALSE(unknown_shapes);
1712     ExpectTensorShape({1}, y);
1713   }
1714 
1715   {
1716     TensorShapeProto x;
1717     x.set_unknown_rank(false);
1718     bool unknown_shapes = false;
1719     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1720     EXPECT_FALSE(unknown_shapes);
1721     ExpectTensorShape({1, 1}, y);
1722   }
1723 
1724   {
1725     TensorShapeProto x;
1726     x.set_unknown_rank(false);
1727     x.add_dim()->set_size(10);
1728     x.add_dim()->set_size(20);
1729     bool unknown_shapes = false;
1730     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1731     EXPECT_FALSE(unknown_shapes);
1732     ExpectTensorShape({10, 20}, y);
1733 
1734     unknown_shapes = false;
1735     TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1736     EXPECT_TRUE(unknown_shapes);
1737     EXPECT_EQ(4, z.dim_size());
1738     ExpectTensorShape({10, 20, 1, 1}, z);
1739   }
1740 
1741   {
1742     TensorShapeProto x;
1743     x.set_unknown_rank(false);
1744     x.add_dim()->set_size(10);
1745     x.add_dim()->set_size(20);
1746     x.add_dim()->set_size(-1);
1747     x.add_dim()->set_size(20);
1748     bool unknown_shapes = false;
1749     TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
1750     EXPECT_TRUE(unknown_shapes);
1751     ExpectTensorShape({10, 20, 1, 20}, y);
1752   }
1753 
1754   {
1755     TensorShapeProto x;
1756     x.set_unknown_rank(false);
1757     x.add_dim()->set_size(10);
1758     x.add_dim()->set_size(20);
1759     x.add_dim()->set_size(30);
1760     x.add_dim()->set_size(20);
1761     bool unknown_shapes = false;
1762     TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
1763     EXPECT_TRUE(unknown_shapes);
1764     ExpectTensorShape({10, 20}, y);
1765   }
1766 }
1767 
TEST_F(OpLevelCostEstimatorTest,IntermediateRdWrBandwidth)1768 TEST_F(OpLevelCostEstimatorTest, IntermediateRdWrBandwidth) {
1769   TestOpLevelCostEstimator estimator;
1770 
1771   // Compute limited.
1772   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1,
1773                                      /*gb_per_sec=*/1));
1774   estimator.SetComputeMemoryOverlap(true);
1775   auto cost = estimator.PredictCosts(
1776       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1777   EXPECT_EQ(Costs::Duration(3548774400), cost.execution_time);
1778   EXPECT_EQ(cost.execution_time, cost.compute_time);
1779 
1780   estimator.SetComputeMemoryOverlap(false);
1781   cost = estimator.PredictCosts(
1782       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1783   EXPECT_EQ(Costs::Duration(3551112192), cost.execution_time);
1784   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1785                                      cost.intermediate_memory_time);
1786 
1787   // Memory limited.
1788   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1789                                      /*gb_per_sec=*/1));
1790   estimator.SetComputeMemoryOverlap(true);
1791   cost = estimator.PredictCosts(
1792       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1793   EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1794   EXPECT_EQ(cost.execution_time, cost.memory_time);
1795 
1796   estimator.SetComputeMemoryOverlap(false);
1797   cost = estimator.PredictCosts(
1798       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1799   EXPECT_EQ(Costs::Duration(2373281), cost.execution_time);
1800   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1801                                      cost.intermediate_memory_time);
1802 
1803   // Intermediate memory bandwidth limited.
1804   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/99999,
1805                                      /*gb_per_sec=*/9999,
1806                                      /*intermediate_read_gb_per_sec=*/1,
1807                                      /*intermediate_write_gb_per_sec=*/1));
1808   estimator.SetComputeMemoryOverlap(true);
1809   cost = estimator.PredictCosts(
1810       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1811   EXPECT_EQ(Costs::Duration(2337792), cost.execution_time);
1812   EXPECT_EQ(cost.execution_time, cost.intermediate_memory_time);
1813 
1814   estimator.SetComputeMemoryOverlap(false);
1815   cost = estimator.PredictCosts(
1816       DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256));
1817   EXPECT_EQ(Costs::Duration(2373515), cost.execution_time);
1818   EXPECT_EQ(cost.execution_time, cost.compute_time + cost.memory_time +
1819                                      cost.intermediate_memory_time);
1820 }
1821 
TEST_F(OpLevelCostEstimatorTest,Einsum)1822 TEST_F(OpLevelCostEstimatorTest, Einsum) {
1823   {  // Test a simple matrix multiplication.
1824     auto cost = PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"));
1825     EXPECT_EQ(Costs::Duration(104000), cost.execution_time);
1826     EXPECT_EQ(Costs::Duration(100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1827               cost.compute_time);
1828     EXPECT_EQ(Costs::Duration(4000), cost.memory_time);
1829     EXPECT_EQ(cost.num_ops_total, 1);
1830     EXPECT_FALSE(cost.inaccurate);
1831     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
1832     EXPECT_EQ(cost.temporary_memory, 0);
1833     EXPECT_EQ(cost.persistent_memory, 0);
1834 
1835     // Einsums and XlaEinsums should be estimated similarly.
1836     EXPECT_EQ(PredictCosts(DescribeEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1837                   .execution_time,
1838               PredictCosts(DescribeXlaEinsum({100, 50}, {100, 50}, "ik,jk->ij"))
1839                   .execution_time);
1840   }
1841   {  // Test a simple batch matrix multiplication.
1842     auto cost = PredictCosts(
1843         DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"));
1844     EXPECT_EQ(Costs::Duration(25 * 104000), cost.execution_time);
1845     EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1846               cost.compute_time);
1847     EXPECT_EQ(Costs::Duration(25 * 4000), cost.memory_time);
1848     EXPECT_EQ(1, cost.num_ops_total);
1849     EXPECT_FALSE(cost.inaccurate);
1850     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1851 
1852     // Einsums and XlaEinsums should be estimated similarly.
1853     EXPECT_EQ(PredictCosts(
1854                   DescribeEinsum({25, 100, 50}, {100, 50, 25}, "Bik,jkB->Bij"))
1855                   .execution_time,
1856               PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50, 25},
1857                                              "Bik,jkB->Bij"))
1858                   .execution_time);
1859   }
1860   {  // Test multiple batch dimensions.
1861     auto cost = PredictCosts(DescribeEinsum(
1862         {25, 16, 100, 50}, {16, 100, 50, 25}, "BNik,NjkB->BNij"));
1863     EXPECT_EQ(Costs::Duration(16 * 25 * 104000), cost.execution_time);
1864     EXPECT_EQ(
1865         Costs::Duration(16 * 25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1866         cost.compute_time);
1867     EXPECT_EQ(Costs::Duration(16 * 25 * 4000), cost.memory_time);
1868     EXPECT_EQ(1, cost.num_ops_total);
1869     EXPECT_FALSE(cost.inaccurate);
1870     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1871 
1872     // Einsums and XlaEinsums should be estimated similarly.
1873     EXPECT_EQ(
1874         PredictCosts(DescribeEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1875                                     "BNik,NjkB->BNij"))
1876             .execution_time,
1877         PredictCosts(DescribeXlaEinsum({25, 16, 100, 50}, {16, 100, 50, 25},
1878                                        "BNik,NjkB->BNij"))
1879             .execution_time);
1880   }
1881   {  // Test multiple M dimensions.
1882     auto cost =
1883         PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"));
1884     EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1885     EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1886               cost.compute_time);
1887     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1888     EXPECT_EQ(1, cost.num_ops_total);
1889     EXPECT_FALSE(cost.inaccurate);
1890     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1891 
1892     // Einsums and XlaEinsums should be estimated similarly.
1893     EXPECT_EQ(
1894         PredictCosts(DescribeEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1895             .execution_time,
1896         PredictCosts(DescribeXlaEinsum({25, 100, 50}, {100, 50}, "Aik,jk->Aij"))
1897             .execution_time);
1898   }
1899   {  // Test multiple N dimensions.
1900     auto cost =
1901         PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"));
1902     EXPECT_EQ(Costs::Duration(2552000), cost.execution_time);
1903     EXPECT_EQ(Costs::Duration(25 * 100 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
1904               cost.compute_time);
1905     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1906     EXPECT_EQ(1, cost.num_ops_total);
1907     EXPECT_FALSE(cost.inaccurate);
1908     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1909 
1910     // Einsums and XlaEinsums should be estimated similarly.
1911     EXPECT_EQ(
1912         PredictCosts(DescribeEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1913             .execution_time,
1914         PredictCosts(DescribeXlaEinsum({100, 50}, {25, 100, 50}, "ik,Bjk->ijB"))
1915             .execution_time);
1916   }
1917   {  // Test multiple contracting dimensions.
1918     auto cost = PredictCosts(
1919         DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"));
1920     EXPECT_EQ(Costs::Duration(2600000), cost.execution_time);
1921     EXPECT_EQ(Costs::Duration(100 * 50 * 25 * 100 * 2 / (1000 * 10 * 1e-3)),
1922               cost.compute_time);
1923     EXPECT_EQ(Costs::Duration(100000), cost.memory_time);
1924     EXPECT_EQ(1, cost.num_ops_total);
1925     EXPECT_FALSE(cost.inaccurate);
1926     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1927 
1928     // Einsums and XlaEinsums should be estimated similarly.
1929     EXPECT_EQ(PredictCosts(
1930                   DescribeEinsum({100, 50, 25}, {100, 50, 25}, "ikl,jkl->ij"))
1931                   .execution_time,
1932               PredictCosts(DescribeXlaEinsum({100, 50, 25}, {100, 50, 25},
1933                                              "ikl,jkl->ij"))
1934                   .execution_time);
1935   }
1936   {  // Test a simple matrix transpose.
1937     auto cost = PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji"));
1938     EXPECT_EQ(Costs::Duration(2000), cost.execution_time);
1939     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1940     EXPECT_EQ(Costs::Duration(2000), cost.memory_time);
1941     EXPECT_EQ(1, cost.num_ops_total);
1942     EXPECT_TRUE(cost.inaccurate);
1943     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1944 
1945     // Einsums and XlaEinsums should be estimated similarly.
1946     EXPECT_EQ(
1947         PredictCosts(DescribeEinsum({100, 50}, {}, "ij->ji")).execution_time,
1948         PredictCosts(DescribeXlaEinsum({100, 50}, {}, "ij->ji"))
1949             .execution_time);
1950   }
1951   {  // Test a malformed Einsum equation: Mismatch between shapes and equation.
1952     auto cost =
1953         PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"));
1954     EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1955     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1956     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1957     EXPECT_EQ(1, cost.num_ops_total);
1958     EXPECT_TRUE(cost.inaccurate);
1959     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1960 
1961     // Einsums and XlaEinsums should be estimated similarly.
1962     EXPECT_EQ(
1963         PredictCosts(DescribeEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1964             .execution_time,
1965         PredictCosts(DescribeXlaEinsum({100, 50, 25}, {50, 100}, "ik,kl->il"))
1966             .execution_time);
1967 
1968     cost = PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"));
1969     EXPECT_EQ(Costs::Duration(52000), cost.execution_time);
1970     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1971     EXPECT_EQ(Costs::Duration(52000), cost.memory_time);
1972     EXPECT_EQ(1, cost.num_ops_total);
1973     EXPECT_TRUE(cost.inaccurate);
1974     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1975 
1976     // Einsums and XlaEinsums should be estimated similarly.
1977     EXPECT_EQ(
1978         PredictCosts(DescribeEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1979             .execution_time,
1980         PredictCosts(DescribeXlaEinsum({100, 50}, {50, 100, 25}, "ik,kl->il"))
1981             .execution_time);
1982   }
1983   {  // Test an unsupported Einsum: ellipsis
1984     auto cost = PredictCosts(DescribeEinsum(
1985         {100, 50, 25, 16}, {50, 100, 32, 12}, "ik...,kl...->il..."));
1986     EXPECT_EQ(Costs::Duration(1568000), cost.execution_time);
1987     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
1988     EXPECT_EQ(Costs::Duration(1568000), cost.memory_time);
1989     EXPECT_EQ(1, cost.num_ops_total);
1990     EXPECT_TRUE(cost.inaccurate);
1991     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
1992 
1993     // Einsums and XlaEinsums should be estimated similarly.
1994     EXPECT_EQ(
1995         PredictCosts(DescribeEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1996                                     "ik...,kl...->il..."))
1997             .execution_time,
1998         PredictCosts(DescribeXlaEinsum({100, 50, 25, 16}, {50, 100, 32, 12},
1999                                        "ik...,kl...->il..."))
2000             .execution_time);
2001   }
2002   {  // Test a malformed/unsupported Einsum: repeated indices
2003     auto cost =
2004         PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"));
2005     EXPECT_EQ(Costs::Duration(202000), cost.execution_time);
2006     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
2007     EXPECT_EQ(Costs::Duration(202000), cost.memory_time);
2008     EXPECT_EQ(1, cost.num_ops_total);
2009     EXPECT_TRUE(cost.inaccurate);
2010     EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
2011 
2012     // Einsums and XlaEinsums should be estimated similarly.
2013     EXPECT_EQ(
2014         PredictCosts(DescribeEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
2015             .execution_time,
2016         PredictCosts(DescribeXlaEinsum({100, 100, 50}, {50, 100}, "iik,kl->il"))
2017             .execution_time);
2018   }
2019   {  // Test missing shapes.
2020     auto cost = PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"));
2021     EXPECT_EQ(Costs::Duration(3020), cost.execution_time);
2022     EXPECT_EQ(Costs::Duration(1 * 50 * 100 * 2 / (1000 * 10 * 1e-3)),
2023               cost.compute_time);
2024     EXPECT_EQ(Costs::Duration(2020), cost.memory_time);
2025     EXPECT_EQ(1, cost.num_ops_total);
2026     EXPECT_TRUE(cost.inaccurate);
2027     EXPECT_EQ(1, cost.num_ops_with_unknown_shapes);
2028 
2029     // Einsums and XlaEinsums should be estimated similarly.
2030     EXPECT_EQ(PredictCosts(DescribeEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
2031                   .execution_time,
2032               PredictCosts(DescribeXlaEinsum({-1, 50}, {100, 50}, "ik,jk->ij"))
2033                   .execution_time);
2034   }
2035 }
2036 
TEST_F(OpLevelCostEstimatorTest,PredictResourceVariableOps)2037 TEST_F(OpLevelCostEstimatorTest, PredictResourceVariableOps) {
2038   TestOpLevelCostEstimator estimator;
2039   estimator.SetDeviceInfo(DeviceInfo(/*gigaops=*/1, /*gb_per_sec=*/1));
2040 
2041   {
2042     OpContext op_context;
2043     op_context.op_info.set_op("AssignVariableOp");
2044     DescribeDummyTensor(op_context.op_info.add_inputs());
2045     DescribeTensor1D(100, op_context.op_info.add_inputs());
2046     auto cost = estimator.PredictCosts(op_context);
2047     EXPECT_EQ(Costs::Duration(400), cost.memory_time);
2048     EXPECT_EQ(Costs::Duration(0), cost.compute_time);
2049     EXPECT_EQ(Costs::Duration(400), cost.execution_time);
2050     EXPECT_FALSE(cost.inaccurate);
2051     EXPECT_EQ(cost.temporary_memory, 0);
2052     EXPECT_EQ(cost.persistent_memory, 0);
2053   }
2054 
2055   {
2056     OpContext op_context;
2057     op_context.op_info.set_op("AssignSubVariableOp");
2058     DescribeDummyTensor(op_context.op_info.add_inputs());
2059     DescribeTensor1D(100, op_context.op_info.add_inputs());
2060     auto cost = estimator.PredictCosts(op_context);
2061     EXPECT_EQ(Costs::Duration(400), cost.memory_time);
2062     EXPECT_EQ(Costs::Duration(100), cost.compute_time);
2063     EXPECT_EQ(Costs::Duration(400), cost.execution_time);
2064     EXPECT_FALSE(cost.inaccurate);
2065   }
2066 }
2067 
TEST_F(OpLevelCostEstimatorTest,AddNExecutionTime)2068 TEST_F(OpLevelCostEstimatorTest, AddNExecutionTime) {
2069   OpContext op_context;
2070   SetCpuDevice(&op_context.op_info);
2071   op_context.op_info.set_op("AddN");
2072 
2073   DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2074   DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2075   DescribeTensor4D(1, 10, 10, 10, op_context.op_info.add_inputs());
2076 
2077   auto cost = PredictCosts(op_context);
2078   EXPECT_EQ(Costs::Duration(1200), cost.memory_time);
2079   EXPECT_EQ(Costs::Duration(200), cost.compute_time);
2080   EXPECT_EQ(Costs::Duration(1400), cost.execution_time);
2081   EXPECT_EQ(cost.num_ops_total, 1);
2082   EXPECT_FALSE(cost.inaccurate);
2083   EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2084   EXPECT_EQ(cost.temporary_memory, 0);
2085   EXPECT_EQ(cost.persistent_memory, 0);
2086 }
2087 
TEST_F(OpLevelCostEstimatorTest,IdentityOpExecutionTime)2088 TEST_F(OpLevelCostEstimatorTest, IdentityOpExecutionTime) {
2089   std::vector<std::string> identity_ops = {
2090       "_Recv",         "_Send",        "BitCast",         "Identity",
2091       "Enter",         "Exit",         "IdentityN",       "Merge",
2092       "NextIteration", "Placeholder",  "PreventGradient", "RefIdentity",
2093       "Reshape",       "StopGradient", "Switch"};
2094 
2095   const int kTensorSize = 1000;
2096   for (auto identity_op : identity_ops) {
2097     OpContext op_context = DescribeUnaryOp(identity_op, kTensorSize);
2098 
2099     const int kExpectedMemoryTime = 0;
2100     const int kExpectedComputeTime = 1;
2101 
2102     auto cost = PredictCosts(op_context);
2103     EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2104     EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2105     EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2106               cost.execution_time);
2107     EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2108     EXPECT_EQ(cost.num_ops_total, 1);
2109     EXPECT_FALSE(cost.inaccurate);
2110     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2111     EXPECT_EQ(cost.temporary_memory, 0);
2112     EXPECT_EQ(cost.persistent_memory, 0);
2113   }
2114 }
2115 
TEST_F(OpLevelCostEstimatorTest,PureMemoryOpExecutionTime)2116 TEST_F(OpLevelCostEstimatorTest, PureMemoryOpExecutionTime) {
2117   std::vector<std::string> reshape_ops = {
2118       "ConcatV2",     "DataFormatVecPermute",
2119       "DepthToSpace", "ExpandDims",
2120       "Fill",         "OneHot",
2121       "Pack",         "Range",
2122       "SpaceToDepth", "Split",
2123       "Squeeze",      "Transpose",
2124       "Tile",         "Unpack"};
2125 
2126   const int kTensorSize = 1000;
2127   for (auto reshape_op : reshape_ops) {
2128     OpContext op_context = DescribeUnaryOp(reshape_op, kTensorSize);
2129 
2130     const int kExpectedMemoryTime = 800;
2131     const int kExpectedComputeTime = 0;
2132 
2133     auto cost = PredictCosts(op_context);
2134     EXPECT_EQ(Costs::Duration(kExpectedMemoryTime), cost.memory_time);
2135     EXPECT_EQ(Costs::Duration(kExpectedComputeTime), cost.compute_time);
2136     EXPECT_EQ(Costs::Duration(kExpectedComputeTime + kExpectedMemoryTime),
2137               cost.execution_time);
2138     EXPECT_EQ(cost.max_memory, kTensorSize * 4);
2139     EXPECT_EQ(cost.num_ops_total, 1);
2140     EXPECT_FALSE(cost.inaccurate);
2141     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2142     EXPECT_EQ(cost.temporary_memory, 0);
2143     EXPECT_EQ(cost.persistent_memory, 0);
2144   }
2145 }
2146 
TEST_F(OpLevelCostEstimatorTest,ResizeBilinearExecutionTime)2147 TEST_F(OpLevelCostEstimatorTest, ResizeBilinearExecutionTime) {
2148   const int kImageDim = 255;
2149   const int kChannelSize = 10;
2150   const int kComputeLerpCost = 9;
2151   {
2152     OpContext op_context;
2153     SetCpuDevice(&op_context.op_info);
2154     op_context.op_info.set_op("ResizeBilinear");
2155     DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2156                      op_context.op_info.add_inputs());
2157     // Test with no output.
2158     auto cost = PredictCosts(op_context);
2159     ExpectZeroCost(cost);
2160     op_context.op_info.clear_inputs();
2161 
2162     DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2163     // Test with no input.
2164     cost = PredictCosts(op_context);
2165     ExpectZeroCost(cost);
2166   }
2167   {
2168     // Test with size 0 output.
2169     OpContext op_context;
2170     SetCpuDevice(&op_context.op_info);
2171     op_context.op_info.set_op("ResizeBilinear");
2172 
2173     DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2174                      op_context.op_info.add_inputs());
2175     const int kExpectedMemoryTime = kImageDim * kImageDim * 4;
2176     DescribeTensor4D(0, 0, 0, 0, op_context.op_info.add_outputs());
2177 
2178     // As the half_pixel_centers attr was not set, cost should be inaccurate
2179     // with 0 compute time.
2180     auto cost = PredictCosts(op_context);
2181     EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2182     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2183     EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2184     EXPECT_TRUE(cost.inaccurate);
2185     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2186     EXPECT_EQ(cost.temporary_memory, 0);
2187     EXPECT_EQ(cost.persistent_memory, 0);
2188 
2189     AttrValue half_pixel_centers;
2190     half_pixel_centers.set_b(false);
2191     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2192         half_pixel_centers;
2193     cost = PredictCosts(op_context);
2194     // Compute time depends only on output size, so compute time is 0.
2195     EXPECT_EQ(cost.compute_time, Costs::Duration(0));
2196     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2197     EXPECT_EQ(cost.execution_time, Costs::Duration(kExpectedMemoryTime));
2198     EXPECT_FALSE(cost.inaccurate);
2199     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2200   }
2201 
2202   // Test with non-zero output size.
2203   const int kOutputImageDim = 100;
2204   OpContext op_context;
2205   SetCpuDevice(&op_context.op_info);
2206   op_context.op_info.set_op("ResizeBilinear");
2207   DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2208                    op_context.op_info.add_inputs());
2209   DescribeTensor4D(1, kOutputImageDim, kOutputImageDim, kChannelSize,
2210                    op_context.op_info.add_outputs());
2211   const int kExpectedMemoryTime =
2212       (kImageDim * kImageDim + kOutputImageDim * kOutputImageDim) * 4;
2213 
2214   {
2215     // Cost of calculating weights without using half_pixel_centers.
2216     AttrValue half_pixel_centers;
2217     half_pixel_centers.set_b(false);
2218     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2219         half_pixel_centers;
2220     const int kInterpWeightCost = 10;
2221     const int num_ops =
2222         kInterpWeightCost * (kOutputImageDim * 2) +
2223         kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2224     const int expected_compute_time = std::ceil(
2225         num_ops /
2226         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2227 
2228     const auto cost = PredictCosts(op_context);
2229     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2230     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2231     EXPECT_EQ(cost.execution_time,
2232               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2233     EXPECT_FALSE(cost.inaccurate);
2234     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2235   }
2236 
2237   {
2238     // Cost of calculating weights using half_pixel_centers.
2239     AttrValue half_pixel_centers;
2240     half_pixel_centers.set_b(true);
2241     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2242         half_pixel_centers;
2243     const int kInterpWeightCost = 12;
2244     const int num_ops =
2245         kInterpWeightCost * (kOutputImageDim * 2) +
2246         kComputeLerpCost * (kOutputImageDim * kOutputImageDim * kChannelSize);
2247     const int expected_compute_time = std::ceil(
2248         num_ops /
2249         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2250 
2251     const auto cost = PredictCosts(op_context);
2252     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2253     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2254     EXPECT_EQ(cost.execution_time,
2255               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2256     EXPECT_FALSE(cost.inaccurate);
2257     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2258   }
2259 
2260   {
2261     // Cost with very large tensor.
2262     op_context.op_info.clear_outputs();
2263     // Number of elements in tensor exceeds 2^32.
2264     constexpr int64_t kLargeOutputImageDim = 40000;
2265     DescribeTensor4D(1, kLargeOutputImageDim, kLargeOutputImageDim,
2266                      kChannelSize, op_context.op_info.add_outputs());
2267     const int64_t kInterpWeightCost = 12;
2268     // Using half_pixel_centers.
2269     AttrValue half_pixel_centers;
2270     half_pixel_centers.set_b(true);
2271     (*op_context.op_info.mutable_attr())["half_pixel_centers"] =
2272         half_pixel_centers;
2273 
2274     const int64_t num_ops =
2275         kInterpWeightCost * (kLargeOutputImageDim * 2) +
2276         kComputeLerpCost *
2277             (kLargeOutputImageDim * kLargeOutputImageDim * kChannelSize);
2278     const int64_t expected_compute_time = std::ceil(
2279         num_ops /
2280         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2281 
2282     const int64_t expected_memory_time =
2283         (kImageDim * kImageDim + kLargeOutputImageDim * kLargeOutputImageDim) *
2284         4;
2285 
2286     const auto cost = PredictCosts(op_context);
2287     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2288     EXPECT_EQ(cost.memory_time, Costs::Duration(expected_memory_time));
2289     EXPECT_EQ(cost.execution_time,
2290               Costs::Duration(expected_memory_time + expected_compute_time));
2291     EXPECT_FALSE(cost.inaccurate);
2292     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2293   }
2294 }
2295 
TEST_F(OpLevelCostEstimatorTest,CropAndResizeExecutionTime)2296 TEST_F(OpLevelCostEstimatorTest, CropAndResizeExecutionTime) {
2297   const int kImageDim = 255;
2298   const int kChannelSize = 10;
2299   const int kOutputImageDim = 100;
2300   const int kNumBoxes = 10;
2301   const int kOutputElements =
2302       kNumBoxes * kOutputImageDim * kOutputImageDim * kChannelSize;
2303   OpContext op_context;
2304   SetCpuDevice(&op_context.op_info);
2305   op_context.op_info.set_op("CropAndResize");
2306   DescribeTensor4D(1, kImageDim, kImageDim, kChannelSize,
2307                    op_context.op_info.add_inputs());
2308   DescribeArbitraryRankInput({kNumBoxes, 4}, DT_INT64, &op_context.op_info);
2309   DescribeTensor4D(kNumBoxes, kOutputImageDim, kOutputImageDim, kChannelSize,
2310                    op_context.op_info.add_outputs());
2311 
2312   // Note this is time [ns, default in Duration in Costs], not bytes;
2313   // whereas memory bandwidth from SetCpuDevice() is 10GB/s.
2314   const int kExpectedMemoryTime =
2315       (kImageDim * kImageDim * 4 +  // input image in float.
2316        kNumBoxes * 4 * 8 / 10 +     // boxes (kNumBoxes x 4) in int64.
2317        kNumBoxes * kOutputImageDim * kOutputImageDim * 4);  // output in float.
2318   // Note that input image and output image has kChannelSize dim, which is 10,
2319   // hence, no need to divide it by 10 (bandwidth).
2320 
2321   {
2322     // Cost of CropAndResize with bilinear interpolation.
2323     AttrValue method;
2324     method.set_s("bilinear");
2325     (*op_context.op_info.mutable_attr())["method"] = method;
2326     int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2327                   4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2328                   3 * kNumBoxes * kOutputImageDim +
2329                   3 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2330                   13 * kOutputElements;
2331     const int expected_compute_time = std::ceil(
2332         num_ops /
2333         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2334 
2335     const auto cost = PredictCosts(op_context);
2336     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2337     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2338     EXPECT_EQ(cost.execution_time,
2339               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2340     EXPECT_FALSE(cost.inaccurate);
2341     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2342   }
2343 
2344   {
2345     // Cost of CropAndResize when nearest pixel is taken.
2346     AttrValue method;
2347     method.set_s("nearest");
2348     (*op_context.op_info.mutable_attr())["method"] = method;
2349     int num_ops = 28 * kNumBoxes + 4 * kNumBoxes * kOutputImageDim +
2350                   4 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2351                   2 * kNumBoxes * kOutputImageDim * kOutputImageDim +
2352                   kOutputElements;
2353     const int expected_compute_time = std::ceil(
2354         num_ops /
2355         estimator_.GetDeviceInfo(op_context.op_info.device()).gigaops);
2356 
2357     const auto cost = PredictCosts(op_context);
2358     EXPECT_EQ(cost.compute_time, Costs::Duration(expected_compute_time));
2359     EXPECT_EQ(cost.memory_time, Costs::Duration(kExpectedMemoryTime));
2360     EXPECT_EQ(cost.execution_time,
2361               Costs::Duration(kExpectedMemoryTime + expected_compute_time));
2362     EXPECT_FALSE(cost.inaccurate);
2363     EXPECT_EQ(cost.num_ops_with_unknown_shapes, 0);
2364   }
2365 }
2366 
2367 }  // end namespace grappler
2368 }  // end namespace tensorflow
2369