1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <algorithm>
16 #include <string>
17
18 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
19 #include "tensorflow/lite/toco/model.h"
20 #include "tensorflow/lite/toco/tooling_util.h"
21
22 namespace toco {
23
24 namespace {
25
IsTailOfShape(const Shape & tail,const Shape & shape)26 bool IsTailOfShape(const Shape& tail, const Shape& shape) {
27 // Return true if 'tail' dimensions are the same as the ending dimensions of
28 // 'shape'.
29
30 int shape_end = shape.dimensions_count() - 1;
31 int tail_end = tail.dimensions_count() - 1;
32
33 if (tail_end > shape_end) {
34 // tail cannot be longer than shape.
35 return false;
36 }
37
38 // Walk dimensions back to front and compare
39 for (int i = 0; i <= tail_end; i++) {
40 if (shape.dims(shape_end - i) != tail.dims(tail_end - i)) {
41 return false;
42 }
43 }
44 return true;
45 }
46
47 } // namespace
48
49 // If a binary operator is doing a broadcast operation from a constant array,
50 // and the constant array shape is the tail of both the other input shape, and a
51 // subsequent reshape op's output shape, we can swap their order. Since we
52 // prefer to have reshape ops after mathematic ops, this can allow for the
53 // collapsing of some reshapes. The WaveNet model in particular benefits from
54 // this transformation.
55 //
56 // Note we are testing for one particular case of a broader set of possible
57 // binary-reshape op transformations. This transformation could be generalized.
Run(Model * model,std::size_t op_index,bool * modified)58 ::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model,
59 std::size_t op_index,
60 bool* modified) {
61 *modified = false;
62 const auto binary_it = model->operators.begin() + op_index;
63 Operator* binary_op = binary_it->get();
64 if (binary_op->type != OperatorType::kAdd &&
65 binary_op->type != OperatorType::kMul &&
66 binary_op->type != OperatorType::kSub &&
67 binary_op->type != OperatorType::kDiv &&
68 binary_op->type != OperatorType::kFloorDiv &&
69 binary_op->type != OperatorType::kFloorMod &&
70 binary_op->type != OperatorType::kMinimum &&
71 binary_op->type != OperatorType::kMaximum &&
72 binary_op->type != OperatorType::kLess &&
73 binary_op->type != OperatorType::kLessEqual &&
74 binary_op->type != OperatorType::kGreater &&
75 binary_op->type != OperatorType::kGreaterEqual) {
76 return ::tensorflow::OkStatus();
77 }
78
79 // BINARY OP INPUT CHECKS
80 CHECK_EQ(binary_op->inputs.size(), 2);
81 const bool input_is_const[2] = {
82 IsConstantParameterArray(*model, binary_op->inputs[0]),
83 IsConstantParameterArray(*model, binary_op->inputs[1]),
84 };
85 if (!input_is_const[0] && !input_is_const[1]) {
86 // To limit our scope, we require one constant input. Though there's no
87 // reason this transformation wouldn't work with all variable inputs.
88 return ::tensorflow::OkStatus();
89 }
90 if (input_is_const[0] && input_is_const[1]) {
91 // Both inputs are constants. Leave this for constants propagation.
92 return ::tensorflow::OkStatus();
93 }
94 const int constant_input_idx = input_is_const[0] ? 0 : 1;
95 const int variable_input_idx = input_is_const[0] ? 1 : 0;
96 CHECK(input_is_const[constant_input_idx]);
97 CHECK(!input_is_const[variable_input_idx]);
98
99 const auto& variable_input_array =
100 model->GetArray(binary_op->inputs[variable_input_idx]);
101 if (!variable_input_array.has_shape()) {
102 AddMessageF(
103 "Not moving %s because it's non-constant input shape is not resolved.",
104 LogName(*binary_op));
105 return ::tensorflow::OkStatus();
106 }
107 if (!IsTailOfShape(
108 model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
109 model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
110 // Constant array shape must be the latter part of the variable shape.
111 return ::tensorflow::OkStatus();
112 }
113
114 // RESHAPE OP CHECKS
115 auto reshape_it =
116 FindOpWithOutput(*model, binary_op->inputs[variable_input_idx]);
117 if (reshape_it == model->operators.end()) {
118 AddMessageF("Not moving %s because it's variable input is not connected.",
119 LogName(*binary_op));
120 return ::tensorflow::OkStatus();
121 }
122 Operator* reshape_op = reshape_it->get();
123 if (reshape_op->type != OperatorType::kReshape) {
124 AddMessageF("Not moving %s because the preceding %s is not a reshape op",
125 LogName(*binary_op), LogName(*reshape_op));
126 return ::tensorflow::OkStatus();
127 }
128 const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
129 if (!reshape_input_array.has_shape()) {
130 AddMessageF(
131 "Not moving %s because it's non-constant input shape is not resolved "
132 "yet",
133 LogName(*binary_op));
134 return ::tensorflow::OkStatus();
135 }
136 if (!IsTailOfShape(
137 model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
138 model->GetArray(reshape_op->outputs[0]).shape())) {
139 // Constant array shape must be the latter part of the binary op output
140 // shape.
141 return ::tensorflow::OkStatus();
142 }
143
144 // EXTRA CHECKS ON CONNECTING ARRAY
145 for (const std::string& output_array : model->flags.output_arrays()) {
146 if (binary_op->inputs[variable_input_idx] == output_array) {
147 AddMessageF(
148 "Not moving %s because the output of reshape op %s is an output op.",
149 LogName(*binary_op), LogName(*reshape_op));
150 return ::tensorflow::OkStatus();
151 }
152 }
153 int count_ops_consuming_output =
154 CountOpsWithInput(*model, binary_op->inputs[variable_input_idx]);
155 DCHECK_GE(count_ops_consuming_output, 1);
156 if (count_ops_consuming_output > 1) {
157 AddMessageF(
158 "Not moving %s because the output of reshape op %s is consumed by "
159 "another op",
160 LogName(*binary_op), LogName(*reshape_op));
161 return ::tensorflow::OkStatus();
162 }
163
164 // SWAP ORDER OF BINARY AND RESHAPE OPS
165 AddMessageF("Moving op %s before reshape op %s", LogName(*binary_op),
166 LogName(*reshape_op));
167
168 // Swap op input and outputs
169 std::iter_swap(reshape_op->inputs.begin(),
170 binary_op->inputs.begin() + variable_input_idx);
171 std::iter_swap(reshape_op->outputs.begin(), binary_op->outputs.begin());
172
173 // Swap operator ordering
174 std::iter_swap(binary_it, reshape_it);
175
176 // Clear binary output shape so it will be re-propagated
177 model->GetArray(binary_op->outputs[0]).clear_shape();
178
179 *modified = true;
180 return ::tensorflow::OkStatus();
181 }
182
183 } // namespace toco
184