1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <cmath>
17 #include <iterator>
18 #include <memory>
19 #include <numeric>
20 #include <string>
21 #include <unordered_map>
22 #include <vector>
23
24 #include "absl/strings/str_join.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/lite/kernels/internal/strided_slice_logic.h"
27 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
28 #include "tensorflow/lite/toco/model.h"
29 #include "tensorflow/lite/toco/tooling_util.h"
30
31 namespace toco {
32
33 namespace {
34
ComputeConvSizes(const Shape & input_shape,int output_depth,int kwidth,int kheight,int stride_width,int stride_height,int dilation_width_factor,int dilation_height_factor,PaddingType padding_type,Shape * output_shape,FixedPadding * fixed_padding)35 void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
36 int kheight, int stride_width, int stride_height,
37 int dilation_width_factor, int dilation_height_factor,
38 PaddingType padding_type, Shape* output_shape,
39 FixedPadding* fixed_padding) {
40 const int input_width = input_shape.dims(2);
41 const int input_height = input_shape.dims(1);
42 const int batch = input_shape.dims(0);
43
44 CHECK_GE(input_width, 1);
45 CHECK_GE(input_height, 1);
46 CHECK_GE(batch, 1);
47 CHECK_GE(kwidth, 1);
48 CHECK_GE(kheight, 1);
49 CHECK_GE(stride_width, 1);
50 CHECK_GE(stride_height, 1);
51 CHECK_GE(dilation_width_factor, 1);
52 CHECK_GE(dilation_height_factor, 1);
53
54 int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1;
55 int dilated_kheight = dilation_height_factor * (kheight - 1) + 1;
56
57 int output_height = 0;
58 int output_width = 0;
59 if (padding_type == PaddingType::kValid) {
60 output_height =
61 (input_height + stride_height - dilated_kheight) / stride_height;
62 output_width = (input_width + stride_width - dilated_kwidth) / stride_width;
63 } else if (padding_type == PaddingType::kSame) {
64 output_height = (input_height + stride_height - 1) / stride_height;
65 output_width = (input_width + stride_width - 1) / stride_width;
66 } else {
67 LOG(FATAL) << "Only supporting SAME or VALID padding";
68 }
69
70 fixed_padding->height = std::max(0, ((output_height - 1) * stride_height +
71 dilated_kheight - input_height) /
72 2);
73 fixed_padding->width = std::max(
74 0,
75 ((output_width - 1) * stride_width + dilated_kwidth - input_width) / 2);
76
77 // Actually had to debug a situation where those were negative due to bad
78 // propagation of placeholder -1 sizes in TensorFlowReshape.
79 CHECK_GT(output_width, 0);
80 CHECK_GT(output_height, 0);
81 output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
82 }
83
ComputeBinaryOperatorOutputSize(const Shape & input_shape_x,const Shape & input_shape_y,Array * output_array)84 void ComputeBinaryOperatorOutputSize(const Shape& input_shape_x,
85 const Shape& input_shape_y,
86 Array* output_array) {
87 // This matches the code in BroadcastBinaryOpShapeFn from tensorflow.
88 // It zips together the two input shapes and pads with 1 to make them the
89 // same length. For each dimension we broadcast if either dimension is 1 and
90 // otherwise expect them to match.
91 int rank_x = input_shape_x.dimensions_count();
92 int rank_y = input_shape_y.dimensions_count();
93 int rank_out = std::max(rank_x, rank_y);
94 std::vector<int>* dims_out = output_array->mutable_shape()->mutable_dims();
95 dims_out->clear();
96 dims_out->reserve(rank_out);
97 for (int i = 0; i < rank_out; ++i) {
98 int dim_x = i < (rank_out - rank_x)
99 ? 1
100 : input_shape_x.dims(i - (rank_out - rank_x));
101 bool dim_y_is_one = i < (rank_out - rank_y);
102 int dim_y = dim_y_is_one ? 1 : input_shape_y.dims(i - (rank_out - rank_y));
103 if (dim_x == -1 || dim_y == -1) {
104 // One or both dimensions is unknown.
105 QCHECK(false) << "Shapes must be specified";
106 } else if (dim_x == 1 || dim_y == 1) {
107 // Broadcast one dimension to the other that is 1.
108 if (dim_x == 1 && !dim_y_is_one) {
109 // Broadcast dim_y to dim_x (1).
110 dims_out->push_back(dim_y);
111 } else {
112 // Broadcast dim_x to dim_y (1).
113 DCHECK_EQ(dim_y, 1);
114 dims_out->push_back(dim_x);
115 }
116 } else {
117 // Expect the dimensions to match.
118 CHECK_EQ(dim_x, dim_y) << "Dimensions must match";
119 dims_out->push_back(dim_x);
120 }
121 }
122 CHECK(output_array->has_shape());
123 }
124
ProcessConvOperator(Model * model,ConvOperator * op)125 void ProcessConvOperator(Model* model, ConvOperator* op) {
126 const auto& input_array = model->GetArray(op->inputs[0]);
127 // Yield until input dims have been resolved.
128 if (!input_array.has_shape()) {
129 return;
130 }
131 const auto& input_shape = input_array.shape();
132 CHECK(input_shape.dimensions_count() == 4)
133 << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
134 << "\" is " << input_shape.dimensions_count() << "D.";
135
136 const auto& weights_array = model->GetArray(op->inputs[1]);
137 // Yield until weights dims have been resolved.
138 if (!weights_array.has_shape()) {
139 return;
140 }
141 const auto& weights_shape = weights_array.shape();
142 CHECK_EQ(weights_shape.dimensions_count(), 4);
143
144 auto& output_array = model->GetArray(op->outputs[0]);
145 const int output_depth = weights_shape.dims(0);
146 const int kheight = weights_shape.dims(1);
147 const int kwidth = weights_shape.dims(2);
148 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
149 op->stride_height, op->dilation_width_factor,
150 op->dilation_height_factor, op->padding.type,
151 output_array.mutable_shape(),
152 &op->padding.GetOrCreateFixedPadding());
153 CHECK_EQ(output_array.shape().dimensions_count(), 4);
154
155 // Set im2col array dimensions if there is one.
156 if (op->outputs.size() == 2) {
157 const auto& output_shape = output_array.shape();
158 const int input_depth = weights_shape.dims(3);
159 auto& im2col_array = model->GetArray(op->outputs[1]);
160 im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
161 output_shape.dims(2),
162 input_depth * kheight * kwidth});
163 }
164 }
165
ProcessTransposeConvOperator(Model * model,TransposeConvOperator * op)166 void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
167 // TransposeConv is unique in that it is specifically given the output shape
168 // as a 1D array on it's 1st input. Resolving the output shape is as easy
169 // as waiting for this input to be resolved. However, we also have to
170 // calculate the padding which requires the weights shape.
171
172 // SPECIFIED OUTPUT SHAPE
173 // The below is the specified, or prescribed output shape, _given_ to the
174 // operator as an input.
175 auto& specified_output_shape_array =
176 model->GetArray(op->inputs[TransposeConvOperator::OUTPUT_SHAPE]);
177 if (!specified_output_shape_array.has_shape() ||
178 !specified_output_shape_array.buffer) {
179 // Yield until the specified output shape is resolved as a constant
180 return;
181 }
182
183 CHECK(specified_output_shape_array.data_type == ArrayDataType::kInt32)
184 << "TransposeConv output_shape must be int32";
185
186 CHECK(specified_output_shape_array.shape().dimensions_count() == 1 &&
187 specified_output_shape_array.shape().dims(0) == 4)
188 << "TransposeConv requires a 1D, 4 element array on it's 0th input "
189 "specifying the output shape. \""
190 << op->inputs[TransposeConvOperator::OUTPUT_SHAPE] << "\" had shape "
191 << toco::ShapeToString(specified_output_shape_array.shape());
192
193 // COMPUTE PADDING
194 // We require the weights shape to calculate padding.
195 const auto& weights_array =
196 model->GetArray(op->inputs[TransposeConvOperator::WEIGHTS]);
197 if (!weights_array.has_shape()) {
198 // Yield until weights dims have been resolved.
199 return;
200 }
201 const auto& weights_shape = weights_array.shape();
202 CHECK_EQ(weights_shape.dimensions_count(), 4)
203 << "TransposeConv weights must have 4 input dimensions. Input weights \""
204 << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
205 << toco::ShapeToString(weights_shape) << ".";
206
207 // Compute padding
208 const int kheight = weights_shape.dims(1);
209 const int kwidth = weights_shape.dims(2);
210 op->padding.GetOrCreateFixedPadding();
211 if (op->padding.type == PaddingType::kValid) {
212 op->padding.fixed->height = 0;
213 op->padding.fixed->width = 0;
214 } else if (op->padding.type == PaddingType::kSame) {
215 op->padding.fixed->height = (kheight - 1) / 2;
216 op->padding.fixed->width = (kwidth - 1) / 2;
217 } else {
218 LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
219 }
220
221 // VALIDATE some dimensions and set the output shape.
222 const auto& input_array =
223 model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
224 if (!input_array.has_shape()) {
225 // Yield until input dims have been resolved.
226 return;
227 }
228 const auto& input_shape = input_array.shape();
229 CHECK_EQ(input_shape.dimensions_count(), 4)
230 << "TransposeConv input shape must have 4 dimensions. Input \""
231 << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
232 << toco::ShapeToString(weights_shape) << ".";
233 CHECK_EQ(input_shape.dims(3), weights_shape.dims(3))
234 << "Input shape depth and weight depth do not agree";
235
236 // Set the output shape according to the specified output shape.
237 std::vector<int32> const& specified_output_shape =
238 specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
239 auto& output_array = model->GetArray(op->outputs[0]);
240 *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
241
242 // Set im2col array dimensions if there is one.
243 if (op->outputs.size() == 2) {
244 const int input_depth = weights_shape.dims(3);
245 auto& im2col_array = model->GetArray(op->outputs[1]);
246 im2col_array.copy_shape(
247 Shape{specified_output_shape[0], specified_output_shape[1],
248 specified_output_shape[2], input_depth * kheight * kwidth});
249 }
250 }
251
ProcessDepthwiseConvOperator(Model * model,DepthwiseConvOperator * op)252 void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
253 const auto& input_array = model->GetArray(op->inputs[0]);
254 // Yield until input dims have been resolved.
255 if (!input_array.has_shape()) {
256 return;
257 }
258 const auto& input_shape = input_array.shape();
259 CHECK_EQ(input_shape.dimensions_count(), 4);
260
261 const auto& weights_array = model->GetArray(op->inputs[1]);
262 // Yield until weights dims have been resolved.
263 if (!weights_array.has_shape()) {
264 return;
265 }
266 const auto& weights_shape = weights_array.shape();
267 CHECK_EQ(weights_shape.dimensions_count(), 4);
268
269 const std::string& output_name = op->outputs[0];
270 const int input_depth = input_shape.dims(3);
271 const int output_depth = weights_shape.dims(3);
272 // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
273 // instead it has to be inferred from the weights dims. However, once we are
274 // here, weights dims have already been converted to our own internal format,
275 // where the multiplier is no longer readily apparent. So instead we get it
276 // as the quotient of output and input depths. We only want to do that when
277 // depth_multiplier had the zero value: any other value should be checked
278 // as done by the next if() below.
279 if (!op->depth_multiplier) {
280 op->depth_multiplier = output_depth / input_depth;
281 }
282 CHECK_EQ(output_depth, input_depth * op->depth_multiplier)
283 << "input/output depths and depth_multiplier don't match";
284
285 const int kheight = weights_shape.dims(1);
286 const int kwidth = weights_shape.dims(2);
287 ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
288 op->stride_height, op->dilation_width_factor,
289 op->dilation_height_factor, op->padding.type,
290 model->GetArray(output_name).mutable_shape(),
291 &op->padding.GetOrCreateFixedPadding());
292 }
293
ProcessDepthToSpaceOperator(Model * model,DepthToSpaceOperator * op)294 void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
295 const auto& input_array = model->GetArray(op->inputs[0]);
296 // Yield until input dims have been resolved.
297 if (!input_array.has_shape()) {
298 return;
299 }
300 const auto& input_shape = input_array.shape();
301 CHECK_EQ(input_shape.dimensions_count(), 4);
302
303 const std::string& output_name = op->outputs[0];
304 const int block_size = op->block_size;
305 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
306 const int batch = input_shape.dims(0);
307 const int height = input_shape.dims(1);
308 const int width = input_shape.dims(2);
309 const int depth = input_shape.dims(3);
310 QCHECK_EQ(depth % (block_size * block_size), 0);
311
312 model->GetArray(output_name)
313 .copy_shape(Shape({batch, height * block_size, width * block_size,
314 depth / block_size / block_size}));
315 }
316
ProcessSpaceToDepthOperator(Model * model,SpaceToDepthOperator * op)317 void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
318 const auto& input_array = model->GetArray(op->inputs[0]);
319 // Yield until input dims have been resolved.
320 if (!input_array.has_shape()) {
321 return;
322 }
323 const auto& input_shape = input_array.shape();
324 CHECK_EQ(input_shape.dimensions_count(), 4);
325
326 const std::string& output_name = op->outputs[0];
327 const int block_size = op->block_size;
328 CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
329 const int batch = input_shape.dims(0);
330 const int height = input_shape.dims(1);
331 const int width = input_shape.dims(2);
332 const int depth = input_shape.dims(3);
333 QCHECK_EQ(width % block_size, 0);
334 QCHECK_EQ(height % block_size, 0);
335
336 model->GetArray(output_name)
337 .copy_shape(Shape({batch, height / block_size, width / block_size,
338 depth * block_size * block_size}));
339 }
340
ProcessOpWithShapeInput(Model * model,Operator * op)341 void ProcessOpWithShapeInput(Model* model, Operator* op) {
342 CHECK_EQ(op->outputs.size(), 1);
343 auto& output_array = model->GetArray(op->outputs[0]);
344 if (output_array.has_shape()) {
345 // We have already run
346 return;
347 }
348
349 auto& dims_array = model->GetArray(op->inputs[0]);
350 if (!dims_array.has_shape()) {
351 // Yield until dims shape been resolved.
352 return;
353 }
354 if (!dims_array.buffer) {
355 // Yield until the dims are constant
356 return;
357 }
358 CHECK(dims_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
359 CHECK_LE(RequiredBufferSizeForShape(dims_array.shape()), 6)
360 << "dims vector can be no larger than 6 values";
361
362 std::vector<int32> const& dims =
363 dims_array.GetBuffer<ArrayDataType::kInt32>().data;
364 *(output_array.mutable_shape()->mutable_dims()) = dims;
365 }
366
ProcessFullyConnectedOperator(Model * model,FullyConnectedOperator * op)367 void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
368 const auto& input_array = model->GetArray(op->inputs[0]);
369 // Yield until input dims have been resolved.
370 if (!input_array.has_shape()) {
371 return;
372 }
373 const auto& input_shape = input_array.shape();
374 if (input_shape.dimensions_count() < 1) {
375 return;
376 }
377
378 const auto& weights_array = model->GetArray(op->inputs[1]);
379 // Yield until weights dims have been resolved.
380 if (!weights_array.has_shape()) {
381 return;
382 }
383 const auto& weights_shape = weights_array.shape();
384
385 const int weights_output_depth = weights_shape.dims(0);
386 CHECK_EQ(weights_shape.dimensions_count(), 2);
387
388 const int input_overall_size = RequiredBufferSizeForShape(input_shape);
389 const int matmul_repeats = input_overall_size / weights_shape.dims(1);
390 CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);
391
392 auto& output_array = model->GetArray(op->outputs[0]);
393 output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
394 }
395
ProcessTensorFlowReshapeOperator(Model * model,TensorFlowReshapeOperator * op)396 void ProcessTensorFlowReshapeOperator(Model* model,
397 TensorFlowReshapeOperator* op) {
398 auto& output_array = model->GetArray(op->outputs[0]);
399 if (output_array.has_shape()) {
400 // We have already run
401 return;
402 }
403
404 const auto& input_array = model->GetArray(op->inputs[0]);
405 if (!input_array.has_shape()) {
406 // Yield until input dims have been resolved.
407 return;
408 }
409 const auto& input_shape = input_array.shape();
410
411 auto& shape_array = model->GetArray(op->inputs[1]);
412 if (!shape_array.has_shape()) {
413 // Yield until target_shape shape been resolved.
414 return;
415 }
416 if (!shape_array.buffer) {
417 // Yield until the target_shape is constant
418 return;
419 }
420 CHECK(shape_array.data_type == ArrayDataType::kInt32)
421 << "Reshape dims must be int32";
422
423 // shape_data is the raw array of ints describing the shape
424 // in the TensorFlow node. We intentionally make a copy here, rather than
425 // modify wildcards in-place below, because in some graphs, the same shape
426 // array with a wildcard may be referenced from multiple Reshape nodes, where
427 // the wildcard needs to resolved to distinct values.
428 std::vector<int32> shape_data =
429 shape_array.GetBuffer<ArrayDataType::kInt32>().data;
430 // The Reshape shape may have a wildcard dim, encoded as -1.
431 bool has_wildcard = false;
432 int wildcard_index = 0;
433 int product_non_wildcard_dims = 1;
434 for (size_t i = 0; i < shape_data.size(); i++) {
435 if (shape_data[i] == -1) {
436 CHECK(!has_wildcard);
437 has_wildcard = true;
438 wildcard_index = i;
439 } else {
440 product_non_wildcard_dims *= shape_data[i];
441 }
442 }
443
444 const int input_flat_size = RequiredBufferSizeForShape(input_shape);
445 if (has_wildcard) {
446 CHECK_GE(input_flat_size, product_non_wildcard_dims)
447 << "Array not large enough to fill the requested dimensions for "
448 "Reshape op with output \""
449 << op->outputs[0] << "\". Are your input shapes correct?";
450 shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
451 }
452
453 if (shape_data.size() == 1 && shape_data[0] == 0) {
454 // We have reshaped a scalar, so preserve as a scalar.
455 shape_data.clear();
456 }
457
458 auto& output_shape = *output_array.mutable_shape();
459 *output_shape.mutable_dims() = shape_data;
460 CHECK_EQ(input_flat_size, RequiredBufferSizeForShape(output_shape))
461 << "Input cannot be reshaped to requested dimensions for Reshape op with "
462 "output \""
463 << op->outputs[0] << "\". Are your input shapes correct?";
464 }
465
ProcessSimpleOperator(Model * model,Operator * op,int input_index)466 void ProcessSimpleOperator(Model* model, Operator* op, int input_index) {
467 const auto& input_array = model->GetArray(op->inputs[input_index]);
468 // Yield until input dims have been resolved.
469 if (!input_array.has_shape()) {
470 return;
471 }
472
473 const std::string& output_name = op->outputs[0];
474 auto& output_array = model->GetArray(output_name);
475 if (output_array.has_shape()) {
476 return;
477 }
478
479 output_array.copy_shape(input_array.shape());
480 }
481
ProcessSimpleBinaryOperator(Model * model,Operator * op)482 void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
483 CHECK_EQ(op->inputs.size(), 2);
484 const auto& input0_array = model->GetArray(op->inputs[0]);
485 const auto& input1_array = model->GetArray(op->inputs[1]);
486 // Yield until input dims have been resolved.
487 if (!input0_array.has_shape() || !input1_array.has_shape()) {
488 return;
489 }
490 const std::string& output_name = op->outputs[0];
491 auto& output_array = model->GetArray(output_name);
492 ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
493 &output_array);
494 }
495
ProcessSelectOperator(Model * model,SelectOperator * op)496 void ProcessSelectOperator(Model* model, SelectOperator* op) {
497 // Yield until all input dims have been resolved.
498 for (const auto& input : op->inputs) {
499 const auto& input_array = model->GetArray(input);
500 if (!input_array.has_shape()) {
501 return;
502 }
503 }
504
505 // Select's output matches the second and third output.
506 const auto& input1_array = model->GetArray(op->inputs[1]);
507 auto& output_array = model->GetArray(op->outputs[0]);
508 output_array.copy_shape(input1_array.shape());
509 }
510
ProcessAddNOperator(Model * model,Operator * op)511 void ProcessAddNOperator(Model* model, Operator* op) {
512 // Yield until all input dims have been resolved.
513 //
514 // TODO(myenik): Since AddN does not support broadcasting, maybe we could
515 // actually use this to improve shape propagation by propagating the shape of
516 // one input to all other inputs once it is resolved instead of just the
517 // output, since all inputs must be the same size and shape for a well-formed
518 // graph.
519 for (const auto& input : op->inputs) {
520 const auto& input_array = model->GetArray(input);
521 if (!input_array.has_shape()) {
522 return;
523 }
524 }
525
526 // AddN does not support broadcasting, all inputs must be the same shape, so
527 // we just take the first input shape and apply it to the output.
528 const auto& input0_array = model->GetArray(op->inputs[0]);
529 auto& output_array = model->GetArray(op->outputs[0]);
530 output_array.copy_shape(input0_array.shape());
531 }
532
KeepDims(const Operator & op)533 bool KeepDims(const Operator& op) {
534 switch (op.type) {
535 case OperatorType::kReduceMin: // Reduction Min
536 return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
537 case OperatorType::kReduceMax: // Reduction Max
538 return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
539 case OperatorType::kSum:
540 return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
541 case OperatorType::kReduceProd:
542 return static_cast<const TensorFlowProdOperator&>(op).keep_dims;
543 case OperatorType::kMean:
544 return static_cast<const MeanOperator&>(op).keep_dims;
545 case OperatorType::kAny:
546 return static_cast<const TensorFlowAnyOperator&>(op).keep_dims;
547 default:
548 LOG(FATAL) << "Not a reduction operator!";
549 return false;
550 }
551 }
552
ProcessTensorFlowReductionOperator(Model * model,Operator * op)553 void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
554 CHECK_LE(op->inputs.size(), 2);
555 auto& output_array = model->GetArray(op->outputs[0]);
556 if (output_array.has_shape()) {
557 return;
558 }
559 const auto& input_array = model->GetArray(op->inputs[0]);
560 if (!input_array.has_shape()) {
561 return;
562 }
563 const auto& input_shape = input_array.shape();
564 const bool keep_dims = KeepDims(*op);
565 if (op->inputs.size() == 2) {
566 // There is a reduction_indices input.
567 const auto& reduction_indices_array = model->GetArray(op->inputs[1]);
568 if (!reduction_indices_array.buffer) {
569 return;
570 }
571 CHECK(reduction_indices_array.buffer->type == ArrayDataType::kInt32);
572
573 int input_rank = input_shape.dimensions_count();
574 std::set<int32> true_indices;
575 const auto& reduction_indices =
576 reduction_indices_array.GetBuffer<ArrayDataType::kInt32>().data;
577 for (size_t i = 0; i < reduction_indices.size(); ++i) {
578 const int32_t reduction_index = reduction_indices[i];
579 if (reduction_index < -input_rank || reduction_index >= input_rank) {
580 CHECK(false) << "Invalid reduction dimension " << reduction_index
581 << " for input with " << input_rank << " dimensions";
582 }
583 int32_t wrapped_index = reduction_index;
584 if (wrapped_index < 0) {
585 wrapped_index += input_rank;
586 }
587 true_indices.insert(wrapped_index);
588 }
589
590 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
591 mutable_dims->clear();
592 for (int i = 0; i < input_rank; ++i) {
593 if (true_indices.count(i) > 0) {
594 if (keep_dims) {
595 mutable_dims->emplace_back(1);
596 }
597 } else {
598 mutable_dims->emplace_back(input_shape.dims(i));
599 }
600 }
601 } else {
602 // No reduction_indices means complete reduction to a single scalar.
603 if (keep_dims) {
604 output_array.copy_shape(input_shape);
605 } else {
606 output_array.copy_shape(Shape({}));
607 }
608 }
609 }
610
ProcessSliceOperator(Model * model,SliceOperator * op)611 void ProcessSliceOperator(Model* model, SliceOperator* op) {
612 CHECK_EQ(op->inputs.size(), 3);
613 CHECK_EQ(op->outputs.size(), 1);
614
615 // Yield until the Slice params have been resolved.
616 if (op->begin.empty()) return;
617
618 // Yield until input dims have been resolved.
619 const auto& input_array = model->GetArray(op->inputs[0]);
620 if (!input_array.has_shape()) return;
621 const Shape& input_shape = input_array.shape();
622
623 auto& output_array = model->GetArray(op->outputs[0]);
624 if (output_array.has_shape()) return;
625
626 CHECK_EQ(input_shape.dims().size(), op->size.size());
627 CHECK_EQ(op->begin.size(), op->size.size());
628
629 std::vector<int> output_dims;
630 for (size_t i = 0; i < op->begin.size(); ++i) {
631 int size = op->size[i];
632 if (size == -1) {
633 size = input_array.shape().dims(i) - op->begin[i];
634 }
635 output_dims.push_back(size);
636 }
637
638 *output_array.mutable_shape()->mutable_dims() = output_dims;
639 }
640
ProcessReorderAxesOperator(Model * model,ReorderAxesOperator * op)641 void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
642 const std::string& input_name = op->inputs[0];
643 const auto& input_array = model->GetArray(input_name);
644 // Yield until input dims have been resolved.
645 if (!input_array.has_shape()) {
646 return;
647 }
648 const auto& input_shape = input_array.shape();
649 const std::string& output_name = op->outputs[0];
650 Shape* output_shape = model->GetArray(output_name).mutable_shape();
651 ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
652 output_shape);
653 }
654
ProcessConcatenationOperator(Model * model,ConcatenationOperator * op)655 void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
656 // Yield until input dims have been resolved.
657 for (const auto& input_name : op->inputs) {
658 auto& input_array = model->GetArray(input_name);
659 if (!input_array.has_shape()) {
660 return;
661 }
662 }
663 auto& output_array = model->GetArray(op->outputs[0]);
664 // Use first non-empty input as basis for output dimensions.
665 for (const auto& input_name : op->inputs) {
666 const auto& input_array = model->GetArray(input_name);
667 if (input_array.shape().dimensions_count() > 0) {
668 output_array.copy_shape(input_array.shape());
669 // Negative axis means the count starts at the back of the dims().
670 if (op->axis < 0) op->axis += input_array.shape().dims().size();
671 break;
672 }
673 }
674 // Determine the concat size, and enforce that all inputs have
675 // the same dimensions count.
676 int concat_size = 0;
677 for (const auto& input_name : op->inputs) {
678 auto& input_array = model->GetArray(input_name);
679 CHECK(input_array.has_shape());
680 if (input_array.shape().dimensions_count() == 0) {
681 continue;
682 }
683 CHECK_EQ(input_array.shape().dimensions_count(),
684 output_array.shape().dimensions_count());
685 const std::vector<int>& input_dims = input_array.shape().dims();
686 CHECK_LT(op->axis, input_dims.size());
687 concat_size += input_dims[op->axis];
688 }
689 // Write out the concat_size on the output array shape.
690 auto& output_shape = *output_array.mutable_shape();
691 auto& output_dims = *output_shape.mutable_dims();
692 CHECK_LT(op->axis, output_shape.dimensions_count());
693 output_dims[op->axis] = concat_size;
694 }
695
ProcessRangeOperator(Model * model,RangeOperator * op)696 void ProcessRangeOperator(Model* model, RangeOperator* op) {
697 CHECK_EQ(op->inputs.size(), 3);
698 const auto& start_array = model->GetArray(op->inputs[0]);
699 if (!start_array.has_shape()) {
700 // Yield until input dims have been resolved.
701 return;
702 }
703 const auto& limit_array = model->GetArray(op->inputs[1]);
704 if (!limit_array.has_shape()) {
705 return;
706 }
707 const auto& delta_array = model->GetArray(op->inputs[2]);
708 if (!delta_array.has_shape()) {
709 return;
710 }
711
712 if (!IsConstantParameterArray(*model, op->inputs[0])) {
713 // Yield until inputs are constant.
714 return;
715 }
716 if (!IsConstantParameterArray(*model, op->inputs[1])) {
717 return;
718 }
719 if (!IsConstantParameterArray(*model, op->inputs[2])) {
720 return;
721 }
722
723 const ArrayDataType& start_dtype = start_array.data_type;
724 CHECK(start_dtype == ArrayDataType::kInt32 ||
725 start_dtype == ArrayDataType::kFloat)
726 << "Range op inputs must be int32 or float.";
727 CHECK(limit_array.data_type == start_dtype)
728 << "In Range op, limit tensor must have the same data type as start "
729 "tensor.";
730 CHECK(delta_array.data_type == start_dtype)
731 << "In Range op, delta tensor must have the same data type as start "
732 "tensor.";
733 CHECK_EQ(RequiredBufferSizeForShape(start_array.shape()), 1)
734 << "Range op inputs must be scalar.";
735 CHECK_EQ(RequiredBufferSizeForShape(limit_array.shape()), 1)
736 << "Range op inputs must be scalar.";
737 CHECK_EQ(RequiredBufferSizeForShape(delta_array.shape()), 1)
738 << "Range op inputs must be scalar.";
739
740 int size = 0;
741 if (start_dtype == ArrayDataType::kInt32) {
742 size = std::floor((limit_array.GetBuffer<ArrayDataType::kInt32>().data[0] -
743 start_array.GetBuffer<ArrayDataType::kInt32>().data[0]) /
744 delta_array.GetBuffer<ArrayDataType::kInt32>().data[0]);
745 } else if (start_dtype == ArrayDataType::kFloat) {
746 size = std::floor((limit_array.GetBuffer<ArrayDataType::kFloat>().data[0] -
747 start_array.GetBuffer<ArrayDataType::kFloat>().data[0]) /
748 delta_array.GetBuffer<ArrayDataType::kFloat>().data[0]);
749 }
750
751 // Only set the output shape. Contents are set by ResolveConstantRange.
752 CHECK_EQ(op->outputs.size(), 1);
753 auto& output_array = model->GetArray(op->outputs[0]);
754 Shape* output_shape = output_array.mutable_shape();
755 output_shape->ReplaceDims({size});
756 }
757
ProcessTensorFlowSplitOperator(Model * model,TensorFlowSplitOperator * op)758 void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
759 CHECK_EQ(op->inputs.size(), 2);
760 const std::string& input_name = op->inputs[1];
761 const auto& input_array = model->GetArray(input_name);
762 // Yield until input dims have been resolved.
763 if (!input_array.has_shape()) {
764 return;
765 }
766 const Shape& input_shape = input_array.shape();
767
768 // Yield until axis is constant.
769 if (!IsConstantParameterArray(*model, op->inputs[0])) {
770 return;
771 }
772
773 const auto& axis_array = model->GetArray(op->inputs[0]);
774
775 // Yield until axis dims have been resolved.
776 if (!axis_array.has_shape()) {
777 return;
778 }
779
780 CHECK(axis_array.data_type == ArrayDataType::kInt32)
781 << "Axis array must be int32.";
782 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
783 << "Axis array must be scalar.";
784
785 int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
786 if (axis < 0) {
787 axis += input_shape.dimensions_count();
788 }
789
790 const int split_dim = input_shape.dims(axis);
791 CHECK_EQ(split_dim % op->num_split, 0);
792 const int split_depth = split_dim / op->num_split;
793
794 Shape output_shape = input_shape;
795 (*output_shape.mutable_dims())[axis] = split_depth;
796
797 CHECK_EQ(op->outputs.size(), op->num_split);
798 for (const auto& output : op->outputs) {
799 model->GetArray(output).copy_shape(output_shape);
800 }
801 }
802
ProcessTensorFlowSplitVOperator(Model * model,TensorFlowSplitVOperator * op)803 void ProcessTensorFlowSplitVOperator(Model* model,
804 TensorFlowSplitVOperator* op) {
805 CHECK_EQ(op->inputs.size(), 3);
806
807 const auto& input_array = model->GetArray(op->inputs[0]);
808 // Yield until input dims have been resolved.
809 if (!input_array.has_shape()) {
810 return;
811 }
812 const Shape& input_shape = input_array.shape();
813
814 // Yield until size_splits is constant.
815 if (!IsConstantParameterArray(*model, op->inputs[1])) {
816 return;
817 }
818 const auto& size_array = model->GetArray(op->inputs[1]);
819 // Yield until size_splits dims have been resolved.
820 if (!size_array.has_shape()) {
821 return;
822 }
823 const Shape& size_shape = size_array.shape();
824
825 CHECK(size_array.data_type == ArrayDataType::kInt32 ||
826 size_array.data_type == ArrayDataType::kInt64)
827 << "size_splits must be int32, int64";
828 CHECK_EQ(size_shape.dimensions_count(), 1) << "size_splits must be 1-D";
829
830 std::vector<int64_t> size_splits_vector;
831 if (size_array.data_type == ArrayDataType::kInt32) {
832 for (const auto each_size :
833 size_array.GetBuffer<ArrayDataType::kInt32>().data) {
834 size_splits_vector.push_back(each_size);
835 }
836 } else {
837 size_splits_vector = size_array.GetBuffer<ArrayDataType::kInt64>().data;
838 }
839
840 // Yield until axis is constant.
841 if (!IsConstantParameterArray(*model, op->inputs[2])) {
842 return;
843 }
844 const auto& axis_array = model->GetArray(op->inputs[2]);
845 // Yield until axis dims have been resolved.
846 if (!axis_array.has_shape()) {
847 return;
848 }
849
850 CHECK(axis_array.data_type == ArrayDataType::kInt32)
851 << "Axis array must be int32.";
852 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
853 << "Axis array must be scalar.";
854
855 int axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
856 if (axis < 0) {
857 axis += input_shape.dimensions_count();
858 }
859
860 CHECK_EQ(op->num_split, size_splits_vector.size());
861
862 int64_t minus_one_count = 0, size_splits_sum = 0;
863 for (auto size : size_splits_vector) {
864 if (size == -1) {
865 ++minus_one_count;
866 } else {
867 size_splits_sum += size;
868 }
869 }
870
871 const int input_size = input_shape.dims(axis);
872
873 CHECK_LE(minus_one_count, 1) << "size_splits can contain at most one -1.";
874
875 if (minus_one_count == 1) {
876 CHECK_LE(size_splits_sum, input_size);
877 auto iter =
878 std::find(size_splits_vector.begin(), size_splits_vector.end(), -1);
879 *iter = input_size - size_splits_sum;
880 } else {
881 CHECK_EQ(size_splits_sum, input_size);
882 }
883
884 CHECK_EQ(op->outputs.size(), op->num_split);
885
886 for (size_t i = 0; i < op->outputs.size(); ++i) {
887 const auto& output = op->outputs[i];
888 Shape output_shape = input_shape;
889 (*output_shape.mutable_dims())[axis] = size_splits_vector.at(i);
890 model->GetArray(output).copy_shape(output_shape);
891 }
892 }
893
ProcessAveragePoolOperator(Model * model,AveragePoolOperator * op)894 void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
895 const std::string& input_name = op->inputs[0];
896 const auto& input_array = model->GetArray(input_name);
897 // Yield until input dims have been resolved.
898 if (!input_array.has_shape()) {
899 return;
900 }
901 const auto& input_shape = input_array.shape();
902 CHECK_EQ(input_shape.dimensions_count(), 4);
903 const std::string& output_name = op->outputs[0];
904 const int output_depth = input_shape.dims(3);
905 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
906 op->stride_width, op->stride_height, 1, 1, op->padding.type,
907 model->GetArray(output_name).mutable_shape(),
908 &op->padding.GetOrCreateFixedPadding());
909 }
910
ProcessMaxPoolOperator(Model * model,MaxPoolOperator * op)911 void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
912 const std::string& input_name = op->inputs[0];
913 const auto& input_array = model->GetArray(input_name);
914 // Yield until input dims have been resolved.
915 if (!input_array.has_shape()) {
916 return;
917 }
918 const auto& input_shape = input_array.shape();
919 CHECK_EQ(input_shape.dimensions_count(), 4);
920 const std::string& output_name = op->outputs[0];
921 const int output_depth = input_shape.dims(3);
922 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
923 op->stride_width, op->stride_height, 1, 1, op->padding.type,
924 model->GetArray(output_name).mutable_shape(),
925 &op->padding.GetOrCreateFixedPadding());
926 }
927
ProcessL2PoolOperator(Model * model,L2PoolOperator * op)928 void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
929 const std::string& input_name = op->inputs[0];
930 const auto& input_array = model->GetArray(input_name);
931 // Yield until input dims have been resolved.
932 if (!input_array.has_shape()) {
933 return;
934 }
935 const auto& input_shape = input_array.shape();
936 if (input_shape.dimensions_count() < 4) {
937 LOG(FATAL) << "missing dimensions for " << input_name;
938 }
939 const std::string& output_name = op->outputs[0];
940 const int output_depth = input_shape.dims(3);
941 ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
942 op->stride_width, op->stride_height, 1, 1, op->padding.type,
943 model->GetArray(output_name).mutable_shape(),
944 &op->padding.GetOrCreateFixedPadding());
945 }
946
ProcessResizeBilinearOperator(Model * model,ResizeBilinearOperator * op)947 void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
948 CHECK_EQ(op->inputs.size(), 2);
949 CHECK_EQ(op->outputs.size(), 1);
950
951 if (!model->GetArray(op->inputs[0]).has_shape() ||
952 !model->GetArray(op->inputs[1]).has_shape()) {
953 return;
954 }
955 const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
956
957 const std::string& output_size_name = op->inputs[1];
958 const auto& output_size_array = model->GetArray(output_size_name);
959 CHECK(output_size_array.data_type == ArrayDataType::kInt32);
960 CHECK(output_size_array.has_shape());
961 const auto& output_size_shape = output_size_array.shape();
962 CHECK_EQ(output_size_shape.dimensions_count(), 1);
963 CHECK_EQ(output_size_shape.dims(0), 2);
964 if (!output_size_array.buffer) {
965 return;
966 }
967 std::vector<int32> output_shape =
968 output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
969 model->GetArray(op->outputs[0])
970 .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
971 output_shape[1], input_data_shape.dims(3)}));
972 }
973
ProcessResizeNearestNeighborOperator(Model * model,ResizeNearestNeighborOperator * op)974 void ProcessResizeNearestNeighborOperator(Model* model,
975 ResizeNearestNeighborOperator* op) {
976 CHECK_EQ(op->inputs.size(), 2);
977 CHECK_EQ(op->outputs.size(), 1);
978
979 if (!model->GetArray(op->inputs[0]).has_shape() ||
980 !model->GetArray(op->inputs[1]).has_shape()) {
981 return;
982 }
983 const auto& input_data_shape = model->GetArray(op->inputs[0]).shape();
984
985 const std::string& output_size_name = op->inputs[1];
986 const auto& output_size_array = model->GetArray(output_size_name);
987 CHECK(output_size_array.data_type == ArrayDataType::kInt32);
988 CHECK(output_size_array.has_shape());
989 const auto& output_size_shape = output_size_array.shape();
990 CHECK_EQ(output_size_shape.dimensions_count(), 1);
991 CHECK_EQ(output_size_shape.dims(0), 2);
992 if (!output_size_array.buffer) {
993 return;
994 }
995 std::vector<int32> output_shape =
996 output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
997 model->GetArray(op->outputs[0])
998 .copy_shape(Shape({input_data_shape.dims(0), output_shape[0],
999 output_shape[1], input_data_shape.dims(3)}));
1000 }
1001
ProcessLstmCellOperator(Model * model,LstmCellOperator * op)1002 void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
1003 // Only required for compact LstmCell with default NUM_INPUTS of inputs.
1004 if (op->inputs.size() != LstmCellOperator::NUM_INPUTS) return;
1005
1006 const auto& input_array =
1007 model->GetArray(op->inputs[LstmCellOperator::DATA_INPUT]);
1008 // Yield until all input dims have been resolved.
1009 if (!input_array.has_shape()) {
1010 return;
1011 }
1012 const auto& input_shape = input_array.shape();
1013 CHECK_GE(input_shape.dimensions_count(), 2);
1014
1015 const auto& prev_activ_array =
1016 model->GetArray(op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]);
1017 // Yield until all input dims have been resolved.
1018 if (!prev_activ_array.has_shape()) {
1019 return;
1020 }
1021 const auto& prev_activ_shape = prev_activ_array.shape();
1022 CHECK_GE(prev_activ_shape.dimensions_count(), 2);
1023
1024 const auto& weights_array =
1025 model->GetArray(op->inputs[LstmCellOperator::WEIGHTS_INPUT]);
1026 // Yield until weights dims have been resolved.
1027 if (!weights_array.has_shape()) {
1028 return;
1029 }
1030 const auto& weights_shape = weights_array.shape();
1031 CHECK_EQ(weights_shape.dimensions_count(), 2);
1032
1033 const auto& bias_array =
1034 model->GetArray(op->inputs[LstmCellOperator::BIASES_INPUT]);
1035 // Yield until bias dims have been resolved.
1036 if (!bias_array.has_shape()) {
1037 return;
1038 }
1039 const auto& bias_shape = bias_array.shape();
1040 CHECK_GE(bias_shape.dimensions_count(), 1);
1041
1042 const auto& prev_state_array =
1043 model->GetArray(op->inputs[LstmCellOperator::PREV_STATE_INPUT]);
1044 // Yield until all input dims have been resolved.
1045 if (!prev_state_array.has_shape()) {
1046 return;
1047 }
1048 const auto& prev_state_shape = prev_state_array.shape();
1049 CHECK_GE(prev_state_shape.dimensions_count(), 2);
1050
1051 const int fc_output_depth = weights_shape.dims(0);
1052 CHECK_EQ(fc_output_depth, bias_shape.dims(0));
1053 CHECK_EQ(fc_output_depth % 4, 0);
1054 const int depth = fc_output_depth / 4;
1055
1056 const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
1057 const int fc_input_depth = weights_shape.dims(1);
1058 CHECK_EQ(input_depth + depth, fc_input_depth);
1059 Shape output_shape(input_shape);
1060 (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;
1061
1062 // Set output dimensions
1063 model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
1064 .copy_shape(output_shape);
1065 model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
1066 .copy_shape(output_shape);
1067
1068 Shape concat_temp_shape(input_shape);
1069 (*concat_temp_shape
1070 .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
1071 fc_input_depth;
1072 model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
1073 .copy_shape(concat_temp_shape);
1074
1075 Shape activ_temp_shape(input_shape);
1076 (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
1077 fc_output_depth;
1078 model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
1079 .copy_shape(activ_temp_shape);
1080 }
1081
ProcessUnidirectionalSequenceLstmOperator(Model * model,UnidirectionalSequenceLstmOperator * op)1082 void ProcessUnidirectionalSequenceLstmOperator(
1083 Model* model, UnidirectionalSequenceLstmOperator* op) {
1084 auto& output_array = model->GetArray(op->outputs[0]);
1085 if (output_array.has_shape()) {
1086 // Shape already propagated
1087 return;
1088 }
1089
1090 if (output_array.data_type == ArrayDataType::kNone) {
1091 // Yield until the output type has been set by PropagateArrayDataTypes
1092 return;
1093 }
1094
1095 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1096 const auto& input_array = model->GetArray(op->inputs[0]);
1097
1098 constexpr int kInputActivationStateTensor = 18;
1099 constexpr int kInputCellStateTensor = 19;
1100
1101 // TFlite interpreter does not support array which is variable and contains a
1102 // buffer (see b/115961645 for more discussion).
1103 // The follow block remove buffer from the array to work around the
1104 // restriction, as a consequence, downstream applications should not
1105 // read lstm state as input to other operations.
1106 model->GetArray(op->inputs[kInputActivationStateTensor]).buffer.reset();
1107 model->GetArray(op->inputs[kInputCellStateTensor]).buffer.reset();
1108
1109 // Yield until input dims have been resolved.
1110 if (!input_array.has_shape()) {
1111 return;
1112 }
1113 const auto& input_shape = input_array.shape();
1114 const int batch_size = input_shape.dims(1);
1115 const int timestamp = input_shape.dims(0);
1116
1117 const auto& recurrent_to_output_weights_array =
1118 model->GetArray(op->inputs[8]);
1119 // Yield until input dims have been resolved.
1120 if (!recurrent_to_output_weights_array.has_shape()) {
1121 return;
1122 }
1123
1124 const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1125 const int output_size = output_weights_shape.dims(1);
1126
1127 Shape* output_shape = output_array.mutable_shape();
1128 output_shape->ReplaceDims({timestamp, batch_size, output_size});
1129 }
1130
ProcessUnidirectionalSequenceRnnOperator(Model * model,UnidirectionalSequenceRnnOperator * op)1131 void ProcessUnidirectionalSequenceRnnOperator(
1132 Model* model, UnidirectionalSequenceRnnOperator* op) {
1133 auto& output_array = model->GetArray(op->outputs[0]);
1134 if (output_array.has_shape()) {
1135 // Shape already propagated.
1136 return;
1137 }
1138
1139 if (output_array.data_type == ArrayDataType::kNone) {
1140 // Yield until the output type has been set by PropagateArrayDataTypes
1141 return;
1142 }
1143
1144 constexpr int kHiddenStateTensor = 4;
1145 // TFlite interpreter does not support array which is variable and contains a
1146 // buffer (see b/115961645 for more discussion).
1147 // The follow block remove buffer from the array to work around the
1148 // restriction, as a consequence, downstream applications should not
1149 // read lstm state as input to other operations.
1150 model->GetArray(op->inputs[kHiddenStateTensor]).buffer.reset();
1151
1152 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1153 const auto& input_array = model->GetArray(op->inputs[0]);
1154 // Yield until input dims have been resolved.
1155 if (!input_array.has_shape()) {
1156 return;
1157 }
1158 const auto& input_shape = input_array.shape();
1159 const int batch_size = input_shape.dims(1);
1160 const int timestamp = input_shape.dims(0);
1161
1162 const auto& bias_array = model->GetArray(op->inputs[3]);
1163 // Yield until input dims have been resolved.
1164 if (!bias_array.has_shape()) {
1165 return;
1166 }
1167
1168 const auto& bias_shape = bias_array.shape();
1169 const int output_size = bias_shape.dims(0);
1170
1171 Shape* output_shape = output_array.mutable_shape();
1172 output_shape->ReplaceDims({timestamp, batch_size, output_size});
1173 }
1174
ProcessBidirectionalSequenceLstmOperator(Model * model,BidirectionalSequenceLstmOperator * op)1175 void ProcessBidirectionalSequenceLstmOperator(
1176 Model* model, BidirectionalSequenceLstmOperator* op) {
1177 // We assume time major.
1178 auto& fw_output_array = model->GetArray(op->outputs[0]);
1179 auto& bw_output_array = model->GetArray(op->outputs[1]);
1180 if (fw_output_array.has_shape()) {
1181 // Shape already propagated
1182 return;
1183 }
1184
1185 if (fw_output_array.data_type == ArrayDataType::kNone) {
1186 // Yield until the output type has been set by PropagateArrayDataTypes
1187 return;
1188 }
1189
1190 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1191 const auto& input_array = model->GetArray(op->inputs[0]);
1192 // Yield until input dims have been resolved.
1193 if (!input_array.has_shape()) {
1194 return;
1195 }
1196 const auto& input_shape = input_array.shape();
1197 const int batch_size = input_shape.dims(1);
1198 const int timestamp = input_shape.dims(0);
1199
1200 constexpr int kBwRecurrentToOutputWeightsTensor = 25;
1201 const auto& recurrent_to_output_weights_array =
1202 model->GetArray(op->inputs[kBwRecurrentToOutputWeightsTensor]);
1203 // Yield until input dims have been resolved.
1204 if (!recurrent_to_output_weights_array.has_shape()) {
1205 return;
1206 }
1207
1208 constexpr int kFwInputActivationStateTensor = 35;
1209 constexpr int kFwInputCellStateTensor = 36;
1210 constexpr int kBwInputActivationStateTensor = 37;
1211 constexpr int kBwInputCellStateTensor = 38;
1212 // b(115961645): This is a hack to work around.
1213 model->GetArray(op->inputs[kFwInputActivationStateTensor]).buffer.reset();
1214 model->GetArray(op->inputs[kFwInputCellStateTensor]).buffer.reset();
1215 model->GetArray(op->inputs[kBwInputActivationStateTensor]).buffer.reset();
1216 model->GetArray(op->inputs[kBwInputCellStateTensor]).buffer.reset();
1217
1218 const auto& output_weights_shape = recurrent_to_output_weights_array.shape();
1219 const int output_size = output_weights_shape.dims(1);
1220
1221 Shape* fw_output_shape = fw_output_array.mutable_shape();
1222 if (op->merge_outputs) {
1223 fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1224 } else {
1225 fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1226 Shape* bw_output_shape = bw_output_array.mutable_shape();
1227 bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1228 }
1229 }
1230
ProcessBidirectionalSequenceRnnOperator(Model * model,BidirectionalSequenceRnnOperator * op)1231 void ProcessBidirectionalSequenceRnnOperator(
1232 Model* model, BidirectionalSequenceRnnOperator* op) {
1233 // We assume time major.
1234 auto& fw_output_array = model->GetArray(op->outputs[0]);
1235 auto& bw_output_array = model->GetArray(op->outputs[1]);
1236 if (fw_output_array.has_shape()) {
1237 // Shape already propagated
1238 return;
1239 }
1240
1241 if (fw_output_array.data_type == ArrayDataType::kNone) {
1242 // Yield until the output type has been set by PropagateArrayDataTypes
1243 return;
1244 }
1245
1246 // TODO(renjieliu): check the inputs, as well as all kinds of weights.
1247 const auto& input_array = model->GetArray(op->inputs[0]);
1248 // Yield until input dims have been resolved.
1249 if (!input_array.has_shape()) {
1250 return;
1251 }
1252 const auto& input_shape = input_array.shape();
1253 const int batch_size = input_shape.dims(1);
1254 const int timestamp = input_shape.dims(0);
1255
1256 constexpr int kFwWeightsTensor = 1;
1257 const auto& forward_weights_array =
1258 model->GetArray(op->inputs[kFwWeightsTensor]);
1259 // Yield until input dims have been resolved.
1260 if (!forward_weights_array.has_shape()) {
1261 return;
1262 }
1263
1264 constexpr int kFwHiddenStateTensor = 4;
1265 constexpr int kBwHiddenStateTensor = 8;
1266 // b(115961645): This is a hack to work around.
1267 model->GetArray(op->inputs[kFwHiddenStateTensor]).buffer.reset();
1268 model->GetArray(op->inputs[kBwHiddenStateTensor]).buffer.reset();
1269
1270 const auto& output_weights_shape = forward_weights_array.shape();
1271 const int output_size = output_weights_shape.dims(0);
1272
1273 Shape* fw_output_shape = fw_output_array.mutable_shape();
1274 if (op->merge_outputs) {
1275 fw_output_shape->ReplaceDims({timestamp, batch_size, 2 * output_size});
1276 } else {
1277 fw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1278 Shape* bw_output_shape = bw_output_array.mutable_shape();
1279 bw_output_shape->ReplaceDims({timestamp, batch_size, output_size});
1280 }
1281 }
1282
ProcessSpaceToBatchNDOperator(Model * model,SpaceToBatchNDOperator * op)1283 void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
1284 const auto& input_array = model->GetArray(op->inputs[0]);
1285 // Yield until input dims have been resolved.
1286 if (!input_array.has_shape()) {
1287 return;
1288 }
1289 const auto& input_shape = input_array.shape();
1290 // This method only handles input dimensions of 3 or 4.
1291 if (input_shape.dimensions_count() != 3 &&
1292 input_shape.dimensions_count() != 4) {
1293 return;
1294 }
1295
1296 const auto& block_shape_array = model->GetArray(op->inputs[1]);
1297 const auto& paddings_array = model->GetArray(op->inputs[2]);
1298 const auto& block_shape_array_shape = block_shape_array.shape();
1299 const auto& paddings_array_shape = paddings_array.shape();
1300 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1301 QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);
1302
1303 int spatial_dims_num = input_shape.dimensions_count() - 2;
1304 QCHECK_EQ(block_shape_array_shape.dims(0), spatial_dims_num);
1305 if (!block_shape_array.buffer) {
1306 return;
1307 }
1308 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1309 const auto& block_shape_data =
1310 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1311
1312 QCHECK_EQ(paddings_array_shape.dims(0), spatial_dims_num);
1313 QCHECK_EQ(paddings_array_shape.dims(1), 2); // Two parameters per dimension.
1314 if (!paddings_array.buffer) {
1315 return;
1316 }
1317 QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
1318 const auto& paddings_data =
1319 paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
1320
1321 Shape output_shape(input_shape);
1322 std::vector<int>* output_shape_data = output_shape.mutable_dims();
1323 int output_batch_size = input_shape.dims(0);
1324 for (int dim = 0; dim < spatial_dims_num; ++dim) {
1325 int final_dim_size = (input_shape.dims(dim + 1) + paddings_data[dim * 2] +
1326 paddings_data[dim * 2 + 1]);
1327 QCHECK_EQ(final_dim_size % block_shape_data[dim], 0);
1328 output_shape_data->at(dim + 1) = final_dim_size / block_shape_data[dim];
1329 output_batch_size *= block_shape_data[dim];
1330 }
1331
1332 output_shape_data->at(0) = output_batch_size;
1333 output_shape_data->at(input_shape.dimensions_count() - 1) =
1334 input_shape.dims(input_shape.dimensions_count() - 1);
1335
1336 model->GetArray(op->outputs[0]).copy_shape(output_shape);
1337 }
1338
ProcessBatchToSpaceNDOperator(Model * model,BatchToSpaceNDOperator * op)1339 void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
1340 const auto& input_array = model->GetArray(op->inputs[0]);
1341 // Yield until input dims have been resolved.
1342 if (!input_array.has_shape()) {
1343 return;
1344 }
1345 const auto& input_shape = input_array.shape();
1346 CHECK_GE(input_shape.dimensions_count(), 3);
1347 CHECK_LE(input_shape.dimensions_count(), 4);
1348 int spatial_dims_num = input_shape.dimensions_count() - 2;
1349
1350 const auto& block_shape_array = model->GetArray(op->inputs[1]);
1351 const auto& crops_array = model->GetArray(op->inputs[2]);
1352 const auto& block_shape_array_shape = block_shape_array.shape();
1353 const auto& crops_array_shape = crops_array.shape();
1354 QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
1355 QCHECK_EQ(crops_array_shape.dimensions_count(), 2);
1356
1357 // We only support two dimensions.
1358 QCHECK_EQ(block_shape_array_shape.dims(0), spatial_dims_num);
1359 if (!block_shape_array.buffer) {
1360 return;
1361 }
1362 QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
1363 const auto& block_shape_data =
1364 block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1365
1366 QCHECK_EQ(crops_array_shape.dims(0), spatial_dims_num);
1367 QCHECK_EQ(crops_array_shape.dims(1), 2); // Two parameters per dimension.
1368 if (!crops_array.buffer) {
1369 return;
1370 }
1371 QCHECK(crops_array.data_type == ArrayDataType::kInt32);
1372 const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
1373
1374 Shape output_shape(input_shape);
1375 std::vector<int>* output_shape_data = output_shape.mutable_dims();
1376 int output_batch_size = input_shape.dims(0);
1377 for (int dim = 0; dim < spatial_dims_num; ++dim) {
1378 // Number of batch must be multiple of (block_shape[dim]).
1379 QCHECK_EQ(output_batch_size % block_shape_data[dim], 0);
1380 output_batch_size = output_batch_size / block_shape_data[dim];
1381 output_shape_data->at(dim + 1) =
1382 input_shape.dims(dim + 1) * block_shape_data[dim] -
1383 crops_data[dim * 2] - crops_data[dim * 2 + 1];
1384 }
1385 output_shape_data->at(0) = output_batch_size;
1386 output_shape_data->at(input_shape.dimensions_count() - 1) =
1387 input_shape.dims(input_shape.dimensions_count() - 1);
1388
1389 model->GetArray(op->outputs[0]).copy_shape(output_shape);
1390 }
1391
ProcessGatherOperator(Model * model,GatherOperator * op)1392 void ProcessGatherOperator(Model* model, GatherOperator* op) {
1393 const auto& input_array = model->GetArray(op->inputs[0]);
1394 const auto& indices_array = model->GetArray(op->inputs[1]);
1395 auto& output_array = model->GetArray(op->outputs[0]);
1396
1397 // Bail if we already know the output shape.
1398 if (output_array.has_shape()) {
1399 return;
1400 }
1401
1402 // Yield until input dims have been resolved.
1403 if (!input_array.has_shape() || !indices_array.has_shape()) {
1404 return;
1405 }
1406
1407 // Yield until the axis has been resolved.
1408 if (!op->axis) {
1409 return;
1410 }
1411 int axis = op->axis.value();
1412
1413 const auto& input_shape = input_array.shape();
1414 const auto& indices_shape = indices_array.shape();
1415 QCHECK_GE(input_shape.dimensions_count(), 1);
1416 op->input_rank = input_shape.dimensions_count();
1417 QCHECK_LT(axis, op->input_rank);
1418
1419 // Copy the input dimensions to the output except for the axis dimensions
1420 // where the dimension of indices_shape is used.
1421 auto output_dims = output_array.mutable_shape()->mutable_dims();
1422 for (int dim = 0; dim < axis; ++dim) {
1423 output_dims->push_back(input_shape.dims(dim));
1424 }
1425 for (int dim = 0; dim < indices_shape.dimensions_count(); ++dim) {
1426 output_dims->push_back(indices_shape.dims(dim));
1427 }
1428 for (int dim = axis + 1; dim < input_shape.dimensions_count(); ++dim) {
1429 output_dims->push_back(input_shape.dims(dim));
1430 }
1431 }
1432
ProcessGatherNdOperator(Model * model,GatherNdOperator * op)1433 void ProcessGatherNdOperator(Model* model, GatherNdOperator* op) {
1434 const auto& input_array = model->GetArray(op->inputs[0]);
1435 const auto& indices_array = model->GetArray(op->inputs[1]);
1436 auto& output_array = model->GetArray(op->outputs[0]);
1437
1438 // Bail if we already know the output shape.
1439 if (output_array.has_shape()) {
1440 return;
1441 }
1442
1443 // Yield until input dims have been resolved.
1444 if (!input_array.has_shape() || !indices_array.has_shape()) {
1445 return;
1446 }
1447
1448 const auto& input_shape = input_array.shape();
1449 const auto& indices_shape = indices_array.shape();
1450 QCHECK_GE(input_shape.dimensions_count(), 1);
1451 QCHECK_GE(indices_shape.dimensions_count(), 1);
1452 const int indices_nd =
1453 indices_shape.dims(indices_shape.dimensions_count() - 1);
1454 QCHECK_LE(indices_nd, input_shape.dimensions_count());
1455
1456 auto output_dims = output_array.mutable_shape()->mutable_dims();
1457 for (int dim = 0; dim < indices_shape.dimensions_count() - 1; ++dim) {
1458 output_dims->push_back(indices_shape.dims(dim));
1459 }
1460 for (int dim = indices_nd; dim < input_shape.dimensions_count(); ++dim) {
1461 output_dims->push_back(input_shape.dims(dim));
1462 }
1463 }
1464
ProcessTopkV2Operator(Model * model,TopKV2Operator * op)1465 void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) {
1466 const auto& input_values = model->GetArray(op->inputs[0]);
1467 const auto& input_k = model->GetArray(op->inputs[1]);
1468 auto& output_values = model->GetArray(op->outputs[0]);
1469 auto& output_indexes = model->GetArray(op->outputs[1]);
1470
1471 // Bail if we already know the output shape.
1472 if (output_indexes.has_shape()) {
1473 QCHECK(output_values.has_shape());
1474 return;
1475 }
1476
1477 // Yield until input dims have been resolved.
1478 if (!input_values.has_shape() || !input_k.has_shape()) {
1479 return;
1480 }
1481
1482 // If the value is initialized, we can specify the last dimension, otherwise
1483 // unknown.
1484 if (input_k.buffer) {
1485 const auto& input_values_shape = input_values.shape();
1486 auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims();
1487 auto output_values_dims = output_values.mutable_shape()->mutable_dims();
1488 for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) {
1489 output_indexes_dims->push_back(input_values_shape.dims(dim));
1490 output_values_dims->push_back(input_values_shape.dims(dim));
1491 }
1492 const int32_t k_value = input_k.GetBuffer<ArrayDataType::kInt32>().data[0];
1493 output_indexes_dims->push_back(k_value);
1494 output_values_dims->push_back(k_value);
1495 }
1496 }
1497
ProcessPadOperator(Model * model,PadOperator * op)1498 void ProcessPadOperator(Model* model, PadOperator* op) {
1499 CHECK_EQ(op->inputs.size(), 2);
1500 CHECK_EQ(op->outputs.size(), 1);
1501
1502 const auto& input_array = model->GetArray(op->inputs[0]);
1503
1504 // Yield until input dims have been resolved.
1505 if (!input_array.has_shape()) return;
1506
1507 if (op->left_padding.empty()) return;
1508 CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1509
1510 auto& output_array = model->GetArray(op->outputs[0]);
1511 if (output_array.has_shape()) return;
1512
1513 Shape output_shape = input_array.shape();
1514 std::vector<int>& dims = *output_shape.mutable_dims();
1515 CHECK_EQ(op->left_padding.size(), dims.size());
1516
1517 for (size_t i = 0; i < op->left_padding.size(); ++i) {
1518 dims[i] += op->left_padding[i] + op->right_padding[i];
1519 }
1520
1521 output_array.copy_shape(output_shape);
1522 }
1523
ProcessPadV2Operator(Model * model,PadV2Operator * op)1524 void ProcessPadV2Operator(Model* model, PadV2Operator* op) {
1525 CHECK_EQ(op->inputs.size(), 3);
1526 CHECK_EQ(op->outputs.size(), 1);
1527
1528 const auto& input_array = model->GetArray(op->inputs[0]);
1529
1530 // Yield until input dims have been resolved.
1531 if (!input_array.has_shape()) return;
1532
1533 if (op->left_padding.empty()) return;
1534 CHECK_EQ(op->left_padding.size(), op->right_padding.size());
1535
1536 auto& output_array = model->GetArray(op->outputs[0]);
1537 if (output_array.has_shape()) return;
1538
1539 Shape output_shape = input_array.shape();
1540 std::vector<int>& dims = *output_shape.mutable_dims();
1541 CHECK_EQ(op->left_padding.size(), dims.size());
1542
1543 for (size_t i = 0; i < op->left_padding.size(); ++i) {
1544 dims[i] += op->left_padding[i] + op->right_padding[i];
1545 }
1546
1547 output_array.copy_shape(output_shape);
1548 }
1549
ProcessRankOperator(Model * model,TensorFlowRankOperator * op)1550 void ProcessRankOperator(Model* model, TensorFlowRankOperator* op) {
1551 CHECK_GE(op->inputs.size(), 1);
1552 CHECK_EQ(op->outputs.size(), 1);
1553 auto& output_array = model->GetArray(op->outputs[0]);
1554 if (output_array.has_shape()) {
1555 // Shape already propagated
1556 return;
1557 }
1558
1559 if (output_array.data_type == ArrayDataType::kNone) {
1560 // Yield until the output type has been set by PropagateArrayDataTypes
1561 return;
1562 }
1563
1564 const auto& input_array = model->GetArray(op->inputs[0]);
1565 if (!input_array.has_shape()) {
1566 // Yield until input dims have been resolved.
1567 return;
1568 }
1569
1570 // Only set the output shape. Array contents are set by
1571 // ResolveConstantShapeOrRank.
1572 Shape* output_shape = output_array.mutable_shape();
1573 output_shape->ReplaceDims({});
1574 }
1575
ProcessShapeOperator(Model * model,TensorFlowShapeOperator * op)1576 void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
1577 CHECK_GE(op->inputs.size(), 1);
1578 CHECK_EQ(op->outputs.size(), 1);
1579 auto& output_array = model->GetArray(op->outputs[0]);
1580 if (output_array.has_shape()) {
1581 // Shape already propagated
1582 return;
1583 }
1584
1585 if (output_array.data_type == ArrayDataType::kNone) {
1586 // Yield until the output type has been set by PropagateArrayDataTypes
1587 return;
1588 }
1589
1590 const auto& input_array = model->GetArray(op->inputs[0]);
1591 if (!input_array.has_shape()) {
1592 // Yield until input dims have been resolved.
1593 return;
1594 }
1595
1596 // Only set the output shape. Array contents are set by
1597 // ResolveConstantShapeOrRank.
1598 Shape* output_shape = output_array.mutable_shape();
1599 output_shape->ReplaceDims({input_array.shape().dimensions_count()});
1600 }
1601
ProcessPackOperator(Model * model,PackOperator * op)1602 void ProcessPackOperator(Model* model, PackOperator* op) {
1603 CHECK_GE(op->inputs.size(), 1);
1604 CHECK_EQ(op->outputs.size(), 1);
1605 auto& output_array = model->GetArray(op->outputs[0]);
1606 if (output_array.has_shape()) {
1607 // Shape already propagated
1608 return;
1609 }
1610
1611 std::unique_ptr<Shape> packed_shape;
1612 for (const auto& input : op->inputs) {
1613 const auto& input_array = model->GetArray(input);
1614 if (!input_array.has_shape()) {
1615 // Yield until all input dims have been resolved.
1616 return;
1617 }
1618
1619 Shape shape = input_array.shape();
1620 if (!packed_shape) {
1621 packed_shape = std::make_unique<Shape>(shape);
1622 } else {
1623 CHECK(*packed_shape == shape) << "All input arrays to Pack operators "
1624 "must have the same shape. Input \""
1625 << input << "\" is different.";
1626 }
1627 }
1628
1629 int axis = op->axis;
1630 if (axis < 0) {
1631 // Handle negative axis
1632 axis += packed_shape->dims().size() + 1;
1633 }
1634 packed_shape->mutable_dims()->insert(
1635 packed_shape->mutable_dims()->begin() + axis, op->inputs.size());
1636 output_array.copy_shape(*packed_shape);
1637 }
1638
ProcessStridedSliceOperator(Model * model,StridedSliceOperator * op)1639 void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
1640 CHECK_GE(op->inputs.size(), 1);
1641 CHECK_EQ(op->outputs.size(), 1);
1642 auto& output_array = model->GetArray(op->outputs[0]);
1643 if (output_array.has_shape()) {
1644 // Shape already propagated
1645 return;
1646 }
1647
1648 if (op->start_indices.empty() || op->stop_indices.empty() ||
1649 op->strides.empty()) {
1650 // ResolveStridedSliceAttributes has not run yet.
1651 return;
1652 }
1653
1654 const auto& input_array = model->GetArray(op->inputs[0]);
1655 if (!input_array.has_shape()) {
1656 // Yield until input dims have been resolved.
1657 return;
1658 }
1659
1660 if (op->ellipsis_mask != 0) {
1661 // Something like LOG_FIRST_N(WARNING, 10) would be preferable to reduce
1662 // log noise. However, the TensorFlow logging library does not appear to
1663 // support this.
1664 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1665 << "\". ellipsis_mask is not supported (mask="
1666 << op->ellipsis_mask << ")";
1667 return;
1668 }
1669 if (op->new_axis_mask != 0) {
1670 LOG(WARNING) << "Skipping StridedSlice op with output \"" << op->outputs[0]
1671 << "\". new_axis_mask is not supported (mask="
1672 << op->new_axis_mask << ")";
1673 return;
1674 }
1675
1676 int num_input_axes = input_array.shape().dimensions_count();
1677 CHECK_LE(op->start_indices.size(), num_input_axes)
1678 << "StridedSlice op with output \"" << op->outputs[0]
1679 << "\", requires no more than " << num_input_axes << " start indices";
1680 CHECK_LE(op->stop_indices.size(), num_input_axes)
1681 << "StridedSlice op with output \"" << op->outputs[0]
1682 << "\", requires no more than " << num_input_axes << " stop indices";
1683 CHECK_LE(op->strides.size(), num_input_axes)
1684 << "StridedSlice op with output \"" << op->outputs[0]
1685 << "\", requires no more than " << num_input_axes << " strides";
1686 for (size_t i = 0; i < op->strides.size(); i++) {
1687 CHECK_NE(op->strides[i], 0) << "Strides must be non-zero. Axis " << i
1688 << " has stride=" << op->strides[i] << ".";
1689 }
1690
1691 // Create output shape
1692 std::vector<int>* dims = output_array.mutable_shape()->mutable_dims();
1693
1694 // Compute output shape
1695 for (int axis = 0; axis < num_input_axes; ++axis) {
1696 const auto strided_slice_params =
1697 tflite::strided_slice::BuildStridedSliceParams(
1698 op->begin_mask, op->end_mask, op->shrink_axis_mask,
1699 op->start_indices, op->stop_indices, op->strides);
1700 int start_index = tflite::strided_slice::StartForAxis(
1701 strided_slice_params, ToRuntimeShape(input_array.shape()), axis);
1702 int stop_index = tflite::strided_slice::StopForAxis(
1703 strided_slice_params, ToRuntimeShape(input_array.shape()), axis,
1704 start_index);
1705
1706 int dim_size = std::ceil(static_cast<float>(stop_index - start_index) /
1707 op->strides[axis]);
1708
1709 CHECK_GT(dim_size, 0)
1710 << "Output size for an axis must be greater than 0. Axis " << axis
1711 << " computes to size " << dim_size
1712 << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1713 if (op->shrink_axis_mask & (1 << axis)) {
1714 CHECK_EQ(dim_size, 1)
1715 << "Output size for an axis must compute to 1 when shrinking an "
1716 "axis. Axis "
1717 << axis << " computes to size " << dim_size
1718 << " for StridedSlice op with output \"" << op->outputs[0] << "\".";
1719 } else {
1720 dims->push_back(dim_size);
1721 }
1722 }
1723 }
1724
ProcessSqueezeOperator(Model * model,SqueezeOperator * op)1725 void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
1726 CHECK_EQ(op->inputs.size(), 1);
1727 CHECK_EQ(op->outputs.size(), 1);
1728
1729 const auto& input_array = model->GetArray(op->inputs[0]);
1730
1731 // Yield until input dims have been resolved.
1732 if (!input_array.has_shape()) return;
1733
1734 auto& output_array = model->GetArray(op->outputs[0]);
1735 if (output_array.has_shape()) return;
1736
1737 const std::vector<int>& input_dims = input_array.shape().dims();
1738 std::vector<int> output_dims;
1739
1740 std::vector<int> squeeze_dims;
1741 const int input_num_dims = input_dims.size();
1742 squeeze_dims.reserve(op->squeeze_dims.size());
1743 for (int i : op->squeeze_dims) {
1744 squeeze_dims.push_back(i < 0 ? i + input_num_dims : i);
1745 }
1746 for (int i = 0; i < input_num_dims; ++i) {
1747 if (input_dims[i] != 1 ||
1748 (!squeeze_dims.empty() &&
1749 std::find(squeeze_dims.begin(), squeeze_dims.end(), i) ==
1750 squeeze_dims.end())) {
1751 output_dims.push_back(input_dims[i]);
1752 }
1753 }
1754 *output_array.mutable_shape()->mutable_dims() = output_dims;
1755 }
1756
ProcessSvdfOperator(Model * model,SvdfOperator * op)1757 void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
1758 CHECK(op->inputs.size() == 4 || op->inputs.size() == 5);
1759 const auto& input_array = model->GetArray(op->inputs[0]);
1760 if (!input_array.has_shape()) return;
1761
1762 auto& weights_feature_array = model->GetArray(op->inputs[1]);
1763 if (!weights_feature_array.has_shape()) return;
1764
1765 const auto& weights_time_array = model->GetArray(op->inputs[2]);
1766 if (!weights_time_array.has_shape()) return;
1767
1768 const bool has_bias = (op->inputs.size() == 5);
1769 if (has_bias) {
1770 const auto& bias_array = model->GetArray(op->inputs[3]);
1771 if (!bias_array.has_shape()) return;
1772 }
1773
1774 const int batch_size = input_array.shape().dims()[0];
1775 const int num_units = weights_feature_array.shape().dims()[0];
1776 const int memory_size = weights_time_array.shape().dims()[1];
1777
1778 auto& state_array = model->GetArray(op->outputs[0]);
1779 state_array.mutable_shape()->ReplaceDims(
1780 {batch_size, memory_size * num_units});
1781
1782 auto& output_array = model->GetArray(op->outputs[1]);
1783 output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
1784 }
1785
ProcessTransposeOperator(Model * model,TransposeOperator * op)1786 void ProcessTransposeOperator(Model* model, TransposeOperator* op) {
1787 auto& output_array = model->GetArray(op->outputs[0]);
1788 if (output_array.has_shape()) {
1789 // We have already run
1790 return;
1791 }
1792
1793 const auto& input_array = model->GetArray(op->inputs[0]);
1794 if (!input_array.has_shape()) {
1795 // Yield until input dims have been resolved.
1796 return;
1797 }
1798 const auto& input_shape = input_array.shape();
1799
1800 auto& perm_array = model->GetArray(op->inputs[1]);
1801 if (!perm_array.has_shape()) {
1802 // Yield until permutation shape been resolved.
1803 return;
1804 }
1805 if (!perm_array.buffer) {
1806 // Yield until the permutation is constant
1807 return;
1808 }
1809 CHECK(perm_array.data_type == ArrayDataType::kInt32)
1810 << "Transpose permutation input must be int32";
1811
1812 std::vector<int32> const& perm =
1813 perm_array.GetBuffer<ArrayDataType::kInt32>().data;
1814 CHECK_EQ(perm.size(), input_shape.dimensions_count())
1815 << "Transpose permutation input " << op->inputs[1]
1816 << " must be same length as input dimensions";
1817 std::vector<int>* output_dims = output_array.mutable_shape()->mutable_dims();
1818 for (size_t i = 0; i < perm.size(); i++) {
1819 int axis = perm[i];
1820 CHECK_GE(axis, 0);
1821 CHECK_LT(axis, input_shape.dimensions_count());
1822 output_dims->push_back(input_shape.dims(axis));
1823 }
1824 }
1825
1826 template <typename Op>
ProcessArgMinMaxOperator(Model * model,Op * op)1827 void ProcessArgMinMaxOperator(Model* model, Op* op) {
1828 CHECK_EQ(op->inputs.size(), 2);
1829 const auto& input_array = model->GetArray(op->inputs[0]);
1830 // Yield until input dims have been resolved.
1831 if (!input_array.has_shape()) {
1832 return;
1833 }
1834
1835 const Array& axis_array = model->GetArray(op->inputs[1]);
1836 // Yield until input axis array shape has been resolved.
1837 if (!axis_array.has_shape()) {
1838 return;
1839 }
1840
1841 const std::vector<int>& input_dims = input_array.shape().dims();
1842
1843 CHECK(axis_array.data_type == ArrayDataType::kInt32 ||
1844 axis_array.data_type == ArrayDataType::kInt64)
1845 << "axis_array must be int32, int64";
1846
1847 CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1)
1848 << "Axis array must be scalar.";
1849
1850 int64_t axis;
1851 if (axis_array.data_type == ArrayDataType::kInt32) {
1852 axis = axis_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1853 } else {
1854 axis = axis_array.GetBuffer<ArrayDataType::kInt64>().data[0];
1855 }
1856
1857 std::vector<int> output_dims;
1858
1859 output_dims.reserve(input_dims.size() - 1);
1860 for (size_t i = 0; i < input_dims.size(); ++i) {
1861 if (static_cast<int>(i) != axis) {
1862 output_dims.push_back(input_dims[i]);
1863 }
1864 }
1865
1866 const std::string& output_name = op->outputs[0];
1867 auto& output_array = model->GetArray(output_name);
1868 if (output_array.has_shape()) {
1869 return;
1870 }
1871 *output_array.mutable_shape()->mutable_dims() = output_dims;
1872 }
1873
ProcessSparseToDenseOperator(Model * model,SparseToDenseOperator * op)1874 void ProcessSparseToDenseOperator(Model* model, SparseToDenseOperator* op) {
1875 CHECK_EQ(op->inputs.size(), 4);
1876
1877 const Array& output_shape_array = model->GetArray(op->inputs[1]);
1878 if (!output_shape_array.has_shape()) return;
1879 CHECK_EQ(output_shape_array.shape().dimensions_count(), 1);
1880
1881 // Output should not go over four dimensions.
1882 CHECK_LE(output_shape_array.shape().dims(0), 4);
1883
1884 const std::string& output_name = op->outputs[0];
1885 Array& output_array = model->GetArray(output_name);
1886 if (output_array.has_shape()) return;
1887
1888 CHECK(output_shape_array.data_type == ArrayDataType::kInt32 ||
1889 output_shape_array.data_type == ArrayDataType::kInt64);
1890 if (output_shape_array.data_type == ArrayDataType::kInt32) {
1891 *output_array.mutable_shape()->mutable_dims() =
1892 output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
1893 } else {
1894 const std::vector<int64_t>& output_shape_data =
1895 output_shape_array.GetBuffer<ArrayDataType::kInt64>().data;
1896 // explicitly cast elements to int in order to avoid MSVC warnings about
1897 // narrowing conversion.
1898 std::transform(
1899 output_shape_data.begin(), output_shape_data.end(),
1900 std::back_inserter(*output_array.mutable_shape()->mutable_dims()),
1901 [](const int64_t dim) { return static_cast<int>(dim); });
1902 }
1903 }
1904
ProcessTileOperator(Model * model,TensorFlowTileOperator * op)1905 void ProcessTileOperator(Model* model, TensorFlowTileOperator* op) {
1906 CHECK_EQ(op->inputs.size(), 2);
1907 CHECK_EQ(op->outputs.size(), 1);
1908
1909 auto& output_array = model->GetArray(op->outputs[0]);
1910 if (output_array.has_shape()) {
1911 // We have already run.
1912 return;
1913 }
1914
1915 const auto& input_array = model->GetArray(op->inputs[0]);
1916 if (!input_array.has_shape()) {
1917 // Yield until input dims have been resolved.
1918 return;
1919 }
1920 const auto& input_shape = input_array.shape();
1921
1922 auto& multiples_array = model->GetArray(op->inputs[1]);
1923 if (!multiples_array.has_shape()) {
1924 // Yield until multiples shape been resolved.
1925 return;
1926 }
1927 if (!multiples_array.buffer) {
1928 // Yield until the multiples is constant.
1929 return;
1930 }
1931 CHECK(multiples_array.data_type == ArrayDataType::kInt32)
1932 << "Tile multiples input must be int32";
1933
1934 std::vector<int32> const& multiples =
1935 multiples_array.GetBuffer<ArrayDataType::kInt32>().data;
1936 CHECK_EQ(multiples.size(), input_shape.dimensions_count())
1937 << "Tile multiples input " << op->inputs[1]
1938 << " must be same length as input dimensions";
1939
1940 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1941 mutable_dims->resize(multiples.size());
1942 for (size_t i = 0; i < mutable_dims->size(); ++i) {
1943 (*mutable_dims)[i] = input_shape.dims(i) * multiples[i];
1944 }
1945 }
1946
ProcessOneHotOperator(Model * model,OneHotOperator * op)1947 void ProcessOneHotOperator(Model* model, OneHotOperator* op) {
1948 CHECK_EQ(op->inputs.size(), 4);
1949 CHECK_EQ(op->outputs.size(), 1);
1950 auto& output_array = model->GetArray(op->outputs[0]);
1951 if (output_array.has_shape()) {
1952 // Shape already propagated
1953 return;
1954 }
1955
1956 // Yield until indices dims have been resolved.
1957 const auto& indices_array =
1958 model->GetArray(op->inputs[OneHotOperator::INDICES_INPUT]);
1959 if (!indices_array.has_shape()) {
1960 return;
1961 }
1962
1963 // Yield until depth is constant and dims have been resolved.
1964 if (!IsConstantParameterArray(*model,
1965 op->inputs[OneHotOperator::DEPTH_INPUT])) {
1966 return;
1967 }
1968 const auto& depth_array =
1969 model->GetArray(op->inputs[OneHotOperator::DEPTH_INPUT]);
1970 if (!depth_array.has_shape()) {
1971 return;
1972 }
1973
1974 CHECK(depth_array.data_type == ArrayDataType::kInt32)
1975 << "Depth array must be int32.";
1976 CHECK_EQ(RequiredBufferSizeForShape(depth_array.shape()), 1)
1977 << "Depth array must be scalar.";
1978
1979 const int depth = depth_array.GetBuffer<ArrayDataType::kInt32>().data[0];
1980 CHECK_GE(depth, 0) << "Depth must be non-negative.";
1981
1982 const int indices_dims = indices_array.shape().dimensions_count();
1983 const int output_dims = indices_dims + 1;
1984 const int axis = op->axis == -1 ? indices_dims : op->axis;
1985 CHECK_GE(axis, 0) << "Resolved axis must be non-negative.";
1986
1987 auto* mutable_dims = output_array.mutable_shape()->mutable_dims();
1988 mutable_dims->resize(output_dims);
1989 for (int i = 0; i < output_dims; ++i) {
1990 int dim = 0;
1991 if (i < axis) {
1992 dim = indices_array.shape().dims(i);
1993 } else if (i == axis) {
1994 dim = depth;
1995 } else {
1996 dim = indices_array.shape().dims(i - 1);
1997 }
1998 (*mutable_dims)[i] = dim;
1999 }
2000 }
2001
ProcessUnpackOperator(Model * model,UnpackOperator * op)2002 void ProcessUnpackOperator(Model* model, UnpackOperator* op) {
2003 CHECK_EQ(op->inputs.size(), 1);
2004 const auto& input_array = model->GetArray(op->inputs[0]);
2005 // Yield until input dims have been resolved.
2006 if (!input_array.has_shape()) {
2007 return;
2008 }
2009
2010 const std::vector<int>& input_dims = input_array.shape().dims();
2011 std::vector<int> output_dims;
2012
2013 output_dims.reserve(input_dims.size() - 1);
2014 for (size_t i = 0; i < input_dims.size(); ++i) {
2015 if (static_cast<int>(i) != op->axis) {
2016 output_dims.push_back(input_dims[i]);
2017 }
2018 }
2019 for (const std::string& output_name : op->outputs) {
2020 auto& output_array = model->GetArray(output_name);
2021 if (output_array.has_shape()) {
2022 return;
2023 }
2024 *output_array.mutable_shape()->mutable_dims() = output_dims;
2025 }
2026 }
2027
ProcessMirrorPadOperator(Model * model,MirrorPadOperator * op)2028 void ProcessMirrorPadOperator(Model* model, MirrorPadOperator* op) {
2029 CHECK_EQ(op->inputs.size(), 2);
2030 const auto& input_array = model->GetArray(op->inputs[0]);
2031 const auto& padding_matrix = model->GetArray(op->inputs[1]);
2032
2033 // Yield until input dims have been resolved.
2034 if (!input_array.has_shape()) {
2035 return;
2036 }
2037
2038 auto& output_array = model->GetArray(op->outputs[0]);
2039 // If output already computed or padding matrix is non
2040 // const then return.
2041 if (output_array.has_shape() ||
2042 !IsConstantParameterArray(*model, op->inputs[1])) {
2043 return;
2044 }
2045 Shape output_shape = input_array.shape();
2046 std::vector<int>& dims = *output_shape.mutable_dims();
2047
2048 std::vector<int64_t> padding;
2049 if (padding_matrix.data_type == ArrayDataType::kInt32) {
2050 const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt32>().data;
2051 for (auto elem : data) {
2052 padding.push_back(static_cast<int64_t>(elem));
2053 }
2054 } else if (padding_matrix.data_type == ArrayDataType::kInt64) {
2055 const auto& data = padding_matrix.GetBuffer<ArrayDataType::kInt64>().data;
2056 for (auto elem : data) {
2057 padding.push_back(elem);
2058 }
2059 } else {
2060 CHECK(padding_matrix.data_type == ArrayDataType::kInt64 ||
2061 padding_matrix.data_type == ArrayDataType::kInt32);
2062 }
2063 CHECK_EQ(padding_matrix.shape().dimensions_count(), 2);
2064 CHECK_EQ(input_array.shape().dimensions_count(),
2065 padding_matrix.shape().dims(0));
2066 for (int i = 0; i < input_array.shape().dimensions_count(); ++i) {
2067 dims[i] += padding[i * 2] + padding[i * 2 + 1];
2068 }
2069
2070 output_array.copy_shape(output_shape);
2071 }
2072
ProcessUniqueOperator(Model * model,UniqueOperator * op)2073 void ProcessUniqueOperator(Model* model, UniqueOperator* op) {
2074 const auto& input_array = model->GetArray(op->inputs[0]);
2075 // We have 2 outputs, the shape of the index tensor, is the same size
2076 // as the input array. The unique values tensor, is unknown until runtime.
2077 CHECK_EQ(op->outputs.size(), 2);
2078 auto& idx_output_array = model->GetArray(op->outputs[1]);
2079
2080 // Yield until input dims have been resolved, or output already computed
2081 if (!input_array.has_shape() || idx_output_array.has_shape()) {
2082 return;
2083 }
2084 idx_output_array.copy_shape(input_array.shape());
2085 }
2086
ProcessMatrixDiagOperator(Model * model,MatrixDiagOperator * op)2087 void ProcessMatrixDiagOperator(Model* model, MatrixDiagOperator* op) {
2088 CHECK_EQ(op->inputs.size(), 1);
2089 CHECK_EQ(op->outputs.size(), 1);
2090 auto& input_array = model->GetArray(op->inputs[0]);
2091 auto& output_array = model->GetArray(op->outputs[0]);
2092 // The input array must have a shape in order to proceed. Also,
2093 // bail out if the output shape has already been calculated.
2094 if (!input_array.has_shape() || output_array.has_shape()) {
2095 // We have already run
2096 return;
2097 }
2098 // Get the input_shape
2099 Shape* mutable_shape = input_array.mutable_shape();
2100 std::vector<int>* dims = mutable_shape->mutable_dims();
2101 int dims_size = dims->size();
2102 // Scalars are not allowed.
2103 CHECK_GT(dims_size, 0);
2104 int last_dim = (*dims)[dims_size - 1];
2105 dims->push_back(last_dim);
2106 output_array.copy_shape(*mutable_shape);
2107 }
2108
ProcessMatrixSetDiagOperator(Model * model,MatrixSetDiagOperator * op)2109 void ProcessMatrixSetDiagOperator(Model* model, MatrixSetDiagOperator* op) {
2110 CHECK_EQ(op->inputs.size(), 2);
2111 CHECK_EQ(op->outputs.size(), 1);
2112 auto& input_array = model->GetArray(op->inputs[0]);
2113 auto& output_array = model->GetArray(op->outputs[0]);
2114 // The shape of the input array must be known because that will
2115 // be the shape of the output array.
2116 if (!input_array.has_shape() || !output_array.has_shape()) {
2117 // We have already run
2118 return;
2119 }
2120
2121 output_array.copy_shape(input_array.shape());
2122 }
2123
ProcessScatterNdOperator(Model * model,ScatterNdOperator * op)2124 void ProcessScatterNdOperator(Model* model, ScatterNdOperator* op) {
2125 CHECK_EQ(op->inputs.size(), 3);
2126 CHECK_EQ(op->outputs.size(), 1);
2127 auto& shape_array = model->GetArray(op->inputs[2]);
2128 auto& output_array = model->GetArray(op->outputs[0]);
2129
2130 if (!shape_array.has_shape()) {
2131 // Yield until dims shape been resolved.
2132 return;
2133 }
2134 if (!shape_array.buffer) {
2135 // Yield until the dims are constant
2136 return;
2137 }
2138 CHECK(shape_array.data_type == ArrayDataType::kInt32) << "dims must be int32";
2139
2140 std::vector<int32> const& dims =
2141 shape_array.GetBuffer<ArrayDataType::kInt32>().data;
2142 *(output_array.mutable_shape()->mutable_dims()) = dims;
2143 }
2144
2145 } // namespace
2146
Run(Model * model,std::size_t op_index,bool * modified)2147 ::tensorflow::Status PropagateFixedSizes::Run(Model* model,
2148 std::size_t op_index,
2149 bool* modified) {
2150 *modified = false;
2151 auto it = model->operators.begin() + op_index;
2152 auto* op = it->get();
2153 std::unordered_map<std::string, std::vector<int>> old_output_dims;
2154 for (const auto& output : op->outputs) {
2155 if (model->GetArray(output).has_shape()) {
2156 old_output_dims[output] = model->GetArray(output).shape().dims();
2157 }
2158 }
2159
2160 switch (op->type) {
2161 case OperatorType::kAbs:
2162 case OperatorType::kBatchNormalization:
2163 case OperatorType::kL2Normalization:
2164 case OperatorType::kDequantize:
2165 case OperatorType::kElu:
2166 case OperatorType::kHardSwish:
2167 case OperatorType::kRelu:
2168 case OperatorType::kRelu1:
2169 case OperatorType::kRelu6:
2170 case OperatorType::kPRelu:
2171 case OperatorType::kLeakyRelu:
2172 case OperatorType::kSoftmax:
2173 case OperatorType::kLogSoftmax:
2174 case OperatorType::kLog:
2175 case OperatorType::kLogistic:
2176 case OperatorType::kTanh:
2177 case OperatorType::kLocalResponseNormalization:
2178 case OperatorType::kIdentity:
2179 case OperatorType::kFakeQuant:
2180 case OperatorType::kNeg:
2181 case OperatorType::kRsqrt:
2182 case OperatorType::kSqrt:
2183 case OperatorType::kSquare:
2184 case OperatorType::kAll:
2185 case OperatorType::kAssert:
2186 case OperatorType::kCast:
2187 case OperatorType::kFloor:
2188 case OperatorType::kCeil:
2189 case OperatorType::kRound:
2190 case OperatorType::kExp:
2191 case OperatorType::kSin:
2192 case OperatorType::kCos:
2193 case OperatorType::kLogicalAnd:
2194 case OperatorType::kLogicalNot:
2195 case OperatorType::kLogicalOr:
2196 case OperatorType::kZerosLike:
2197 case OperatorType::kReverseV2:
2198 case OperatorType::kReverseSequence:
2199 ProcessSimpleOperator(model, op, 0);
2200 break;
2201 case OperatorType::kGather:
2202 ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
2203 break;
2204 case OperatorType::kGatherNd:
2205 ProcessGatherNdOperator(model, static_cast<GatherNdOperator*>(op));
2206 break;
2207 case OperatorType::kTopK_V2:
2208 ProcessTopkV2Operator(model, static_cast<TopKV2Operator*>(op));
2209 break;
2210 case OperatorType::kAdd:
2211 case OperatorType::kSub:
2212 case OperatorType::kMul:
2213 case OperatorType::kDiv:
2214 case OperatorType::kFloorDiv:
2215 case OperatorType::kFloorMod:
2216 case OperatorType::kLess:
2217 case OperatorType::kLessEqual:
2218 case OperatorType::kGreater:
2219 case OperatorType::kMaximum: // Element-wise Maximum
2220 case OperatorType::kMinimum: // Element-wise Minimum
2221 case OperatorType::kGreaterEqual:
2222 case OperatorType::kEqual:
2223 case OperatorType::kNotEqual:
2224 case OperatorType::kPow:
2225 case OperatorType::kSquaredDifference:
2226 ProcessSimpleBinaryOperator(model, op);
2227 break;
2228 case OperatorType::kAddN:
2229 ProcessAddNOperator(model, op);
2230 break;
2231 case OperatorType::kConv:
2232 ProcessConvOperator(model, static_cast<ConvOperator*>(op));
2233 break;
2234 case OperatorType::kTransposeConv:
2235 ProcessTransposeConvOperator(model,
2236 static_cast<TransposeConvOperator*>(op));
2237 break;
2238 case OperatorType::kDepthwiseConv:
2239 ProcessDepthwiseConvOperator(model,
2240 static_cast<DepthwiseConvOperator*>(op));
2241 break;
2242 case OperatorType::kDepthToSpace:
2243 ProcessDepthToSpaceOperator(model,
2244 static_cast<DepthToSpaceOperator*>(op));
2245 break;
2246 case OperatorType::kSpaceToDepth:
2247 ProcessSpaceToDepthOperator(model,
2248 static_cast<SpaceToDepthOperator*>(op));
2249 break;
2250 case OperatorType::kFill:
2251 CHECK_EQ(op->inputs.size(), 2);
2252 ProcessOpWithShapeInput(model, op);
2253 break;
2254 case OperatorType::kFullyConnected:
2255 ProcessFullyConnectedOperator(model,
2256 static_cast<FullyConnectedOperator*>(op));
2257 break;
2258 case OperatorType::kReshape:
2259 ProcessTensorFlowReshapeOperator(
2260 model, static_cast<TensorFlowReshapeOperator*>(op));
2261 break;
2262 case OperatorType::kAveragePool:
2263 ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
2264 break;
2265 case OperatorType::kMaxPool:
2266 ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
2267 break;
2268 case OperatorType::kL2Pool:
2269 ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
2270 break;
2271 case OperatorType::kReduceMin: // Reduction Min
2272 case OperatorType::kReduceMax: // Reduction Max
2273 case OperatorType::kSum:
2274 case OperatorType::kReduceProd:
2275 case OperatorType::kMean:
2276 case OperatorType::kAny:
2277 ProcessTensorFlowReductionOperator(model, op);
2278 break;
2279 case OperatorType::kSelect:
2280 ProcessSelectOperator(model, static_cast<SelectOperator*>(op));
2281 break;
2282 case OperatorType::kSlice:
2283 ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
2284 break;
2285
2286 case OperatorType::kSwitch:
2287 // We can't know the sizes of the outputs until we have resolved the
2288 // predicate, and once we have resolved the predicate, the whole
2289 // Switch node will get resolved away.
2290 // See ResolveTensorFlowSwitch.
2291 break;
2292 case OperatorType::kMerge:
2293 // No need to bother resolving TensorFlow Merge ops: other graph
2294 // transformations will remove them anyway.
2295 // See ResolveTensorFlowMerge.
2296 break;
2297 case OperatorType::kSplit:
2298 ProcessTensorFlowSplitOperator(model,
2299 static_cast<TensorFlowSplitOperator*>(op));
2300 break;
2301 case OperatorType::kSplitV:
2302 ProcessTensorFlowSplitVOperator(
2303 model, static_cast<TensorFlowSplitVOperator*>(op));
2304 break;
2305 case OperatorType::kSqueeze:
2306 ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
2307 break;
2308 case OperatorType::kConcat:
2309 case OperatorType::kConcatV2:
2310 // Unimplemented, hopefully another graph transformation will
2311 // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
2312 // will resolve this node to a DepthConcatenation, or else we have
2313 // a more general non-depth concatenation that will hopefully be dropped,
2314 // or else at the moment we will abort.
2315 break;
2316 case OperatorType::kExpandDims:
2317 // Yield until ExpandDims is converted to Reshape
2318 break;
2319 case OperatorType::kRange:
2320 ProcessRangeOperator(model, static_cast<RangeOperator*>(op));
2321 break;
2322 case OperatorType::kRank:
2323 ProcessRankOperator(model, static_cast<TensorFlowRankOperator*>(op));
2324 break;
2325 case OperatorType::kShape:
2326 ProcessShapeOperator(model, static_cast<TensorFlowShapeOperator*>(op));
2327 break;
2328 case OperatorType::kPack:
2329 ProcessPackOperator(model, static_cast<PackOperator*>(op));
2330 break;
2331 case OperatorType::kReorderAxes:
2332 ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
2333 break;
2334 case OperatorType::kConcatenation:
2335 ProcessConcatenationOperator(model,
2336 static_cast<ConcatenationOperator*>(op));
2337 break;
2338 case OperatorType::kResizeBilinear:
2339 ProcessResizeBilinearOperator(model,
2340 static_cast<ResizeBilinearOperator*>(op));
2341 break;
2342 case OperatorType::kResizeNearestNeighbor:
2343 ProcessResizeNearestNeighborOperator(
2344 model, static_cast<ResizeNearestNeighborOperator*>(op));
2345 break;
2346 case OperatorType::kUnidirectionalSequenceLstm:
2347 ProcessUnidirectionalSequenceLstmOperator(
2348 model, static_cast<UnidirectionalSequenceLstmOperator*>(op));
2349 break;
2350 case OperatorType::kUnidirectionalSequenceRnn:
2351 ProcessUnidirectionalSequenceRnnOperator(
2352 model, static_cast<UnidirectionalSequenceRnnOperator*>(op));
2353 break;
2354 case OperatorType::kBidirectionalSequenceLstm:
2355 ProcessBidirectionalSequenceLstmOperator(
2356 model, static_cast<BidirectionalSequenceLstmOperator*>(op));
2357 break;
2358 case OperatorType::kBidirectionalSequenceRnn:
2359 ProcessBidirectionalSequenceRnnOperator(
2360 model, static_cast<BidirectionalSequenceRnnOperator*>(op));
2361 break;
2362 case OperatorType::kLstmCell:
2363 ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
2364 break;
2365 case OperatorType::kBatchMatMul:
2366 case OperatorType::kMatMul:
2367 // MatMul operators are converted to FullyConnected, after which their
2368 // shapes are propagated.
2369 break;
2370 case OperatorType::kSpaceToBatchND:
2371 ProcessSpaceToBatchNDOperator(model,
2372 static_cast<SpaceToBatchNDOperator*>(op));
2373 break;
2374 case OperatorType::kBatchToSpaceND:
2375 ProcessBatchToSpaceNDOperator(model,
2376 static_cast<BatchToSpaceNDOperator*>(op));
2377 break;
2378 case OperatorType::kPad:
2379 ProcessPadOperator(model, static_cast<PadOperator*>(op));
2380 break;
2381 case OperatorType::kPadV2:
2382 ProcessPadV2Operator(model, static_cast<PadV2Operator*>(op));
2383 break;
2384 case OperatorType::kStridedSlice:
2385 ProcessStridedSliceOperator(model,
2386 static_cast<StridedSliceOperator*>(op));
2387 break;
2388 case OperatorType::kArgMax:
2389 ProcessArgMinMaxOperator<ArgMaxOperator>(
2390 model, static_cast<ArgMaxOperator*>(op));
2391 break;
2392 case OperatorType::kArgMin:
2393 ProcessArgMinMaxOperator<ArgMinOperator>(
2394 model, static_cast<ArgMinOperator*>(op));
2395 break;
2396 case OperatorType::kUnsupported: {
2397 const auto* unsupported_op =
2398 static_cast<TensorFlowUnsupportedOperator*>(op);
2399 // Attribute can be not specified, ignore it.
2400 if (unsupported_op->output_shapes.size() < op->outputs.size()) {
2401 return ::tensorflow::OkStatus();
2402 }
2403 for (size_t i = 0; i < op->outputs.size(); ++i) {
2404 const std::string& output = op->outputs[i];
2405 model->GetArray(output).copy_shape(unsupported_op->output_shapes.at(i));
2406 }
2407 break;
2408 }
2409 case OperatorType::kSvdf:
2410 ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
2411 break;
2412 case OperatorType::kTranspose:
2413 ProcessTransposeOperator(model, static_cast<TransposeOperator*>(op));
2414 break;
2415 case OperatorType::kDynamicPartition:
2416 case OperatorType::kDynamicStitch:
2417 // DynamicPartition/DynamicStitch are currently only supported for
2418 // transforms that remove them, so we avoid propagating shapes through
2419 // them and let things settle once they've been removed.
2420 break;
2421 case OperatorType::kRandomUniform:
2422 CHECK_EQ(op->inputs.size(), 1);
2423 ProcessOpWithShapeInput(model, op);
2424 break;
2425 case OperatorType::kSparseToDense:
2426 ProcessSparseToDenseOperator(model,
2427 static_cast<SparseToDenseOperator*>(op));
2428 break;
2429 case OperatorType::kTile:
2430 ProcessTileOperator(model, static_cast<TensorFlowTileOperator*>(op));
2431 break;
2432 break;
2433 case OperatorType::kOneHot:
2434 ProcessOneHotOperator(model, static_cast<OneHotOperator*>(op));
2435 break;
2436 case OperatorType::kUnpack:
2437 ProcessUnpackOperator(model, static_cast<UnpackOperator*>(op));
2438 break;
2439 case OperatorType::kMirrorPad:
2440 ProcessMirrorPadOperator(model, static_cast<MirrorPadOperator*>(op));
2441 break;
2442 case OperatorType::kUnique:
2443 ProcessUniqueOperator(model, static_cast<UniqueOperator*>(op));
2444 break;
2445 case OperatorType::kWhere:
2446 // The size of the output can only be known after evaluating the cond
2447 // tensor. Ignore shape propagation here and defer that to the
2448 // interpreter.
2449 break;
2450 case OperatorType::kMatrixDiag:
2451 ProcessMatrixDiagOperator(model, static_cast<MatrixDiagOperator*>(op));
2452 break;
2453 case OperatorType::kMatrixSetDiag:
2454 ProcessMatrixSetDiagOperator(model,
2455 static_cast<MatrixSetDiagOperator*>(op));
2456 break;
2457 case OperatorType::kCTCBeamSearchDecoder:
2458 // The sizes of the outputs are only known in runtime based on the input.
2459 // Ignore shape propagation here and defer that to the interpreter.
2460 break;
2461 case OperatorType::kMatrixSetDiagV2:
2462 // MatrixSetDiagV2 operators are converted to MatrixSetDiag,
2463 // after which their shapes are propagated.
2464 break;
2465 case OperatorType::kMatrixDiagV2:
2466 // MatrixDiagV2 operators are converted to MatrixDiag, after which their
2467 // shapes are propagated.
2468 break;
2469 case OperatorType::kMatrixDiagV3:
2470 // MatrixDiagV3 operators are converted to MatrixDiag, after which their
2471 // shapes are propagated.
2472 break;
2473 case OperatorType::kMatrixSetDiagV3:
2474 // MatrixSetDiagV3 operators are converted to MatrixSetDiag, after which
2475 // their shapes are propagated.
2476 break;
2477 case OperatorType::kSegmentSum:
2478 break;
2479 case OperatorType::kScatterNd:
2480 ProcessScatterNdOperator(model, static_cast<ScatterNdOperator*>(op));
2481 break;
2482 default:
2483 // Unimplemented, another graph transformation should drop it.
2484 LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
2485 }
2486
2487 // Return true if any output dim changed, false if none changed.
2488 // Assumption: no transformation clears an output shape, they only add shapes.
2489 for (const auto& output : op->outputs) {
2490 if (model->GetArray(output).has_shape() &&
2491 (old_output_dims[output] != model->GetArray(output).shape().dims())) {
2492 AddMessageF("Set shape of %s to [%s]", output,
2493 absl::StrJoin(model->GetArray(output).shape().dims(), ","));
2494 *modified = true;
2495 return ::tensorflow::OkStatus();
2496 }
2497 }
2498 return ::tensorflow::OkStatus();
2499 }
2500
2501 } // namespace toco
2502