1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3  Licensed under the Apache License, Version 2.0 (the "License");
4  you may not use this file except in compliance with the License.
5  You may obtain a copy of the License at
6 
7      http://www.apache.org/licenses/LICENSE-2.0
8 
9  Unless required by applicable law or agreed to in writing, software
10  distributed under the License is distributed on an "AS IS" BASIS,
11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  See the License for the specific language governing permissions and
13  limitations under the License.
14  ==============================================================================*/
15 #include <algorithm>
16 #include <string>
17 
18 #include "tensorflow/lite/toco/graph_transformations/graph_transformations.h"
19 #include "tensorflow/lite/toco/model.h"
20 #include "tensorflow/lite/toco/tooling_util.h"
21 
22 namespace toco {
23 
24 namespace {
25 
IsTailOfShape(const Shape & tail,const Shape & shape)26 bool IsTailOfShape(const Shape& tail, const Shape& shape) {
27   // Return true if 'tail' dimensions are the same as the ending dimensions of
28   // 'shape'.
29 
30   int shape_end = shape.dimensions_count() - 1;
31   int tail_end = tail.dimensions_count() - 1;
32 
33   if (tail_end > shape_end) {
34     // tail cannot be longer than shape.
35     return false;
36   }
37 
38   // Walk dimensions back to front and compare
39   for (int i = 0; i <= tail_end; i++) {
40     if (shape.dims(shape_end - i) != tail.dims(tail_end - i)) {
41       return false;
42     }
43   }
44   return true;
45 }
46 
47 }  // namespace
48 
49 // If a binary operator is doing a broadcast operation from a constant array,
50 // and the constant array shape is the tail of both the other input shape, and a
51 // subsequent reshape op's output shape, we can swap their order. Since we
52 // prefer to have reshape ops after mathematic ops, this can allow for the
53 // collapsing of some reshapes. The WaveNet model in particular benefits from
54 // this transformation.
55 //
56 // Note we are testing for one particular case of a broader set of possible
57 // binary-reshape op transformations. This transformation could be generalized.
Run(Model * model,std::size_t op_index,bool * modified)58 ::tensorflow::Status MoveBinaryOperatorBeforeReshape::Run(Model* model,
59                                                           std::size_t op_index,
60                                                           bool* modified) {
61   *modified = false;
62   const auto binary_it = model->operators.begin() + op_index;
63   Operator* binary_op = binary_it->get();
64   if (binary_op->type != OperatorType::kAdd &&
65       binary_op->type != OperatorType::kMul &&
66       binary_op->type != OperatorType::kSub &&
67       binary_op->type != OperatorType::kDiv &&
68       binary_op->type != OperatorType::kFloorDiv &&
69       binary_op->type != OperatorType::kFloorMod &&
70       binary_op->type != OperatorType::kMinimum &&
71       binary_op->type != OperatorType::kMaximum &&
72       binary_op->type != OperatorType::kLess &&
73       binary_op->type != OperatorType::kLessEqual &&
74       binary_op->type != OperatorType::kGreater &&
75       binary_op->type != OperatorType::kGreaterEqual) {
76     return ::tensorflow::OkStatus();
77   }
78 
79   // BINARY OP INPUT CHECKS
80   CHECK_EQ(binary_op->inputs.size(), 2);
81   const bool input_is_const[2] = {
82       IsConstantParameterArray(*model, binary_op->inputs[0]),
83       IsConstantParameterArray(*model, binary_op->inputs[1]),
84   };
85   if (!input_is_const[0] && !input_is_const[1]) {
86     // To limit our scope, we require one constant input. Though there's no
87     // reason this transformation wouldn't work with all variable inputs.
88     return ::tensorflow::OkStatus();
89   }
90   if (input_is_const[0] && input_is_const[1]) {
91     // Both inputs are constants. Leave this for constants propagation.
92     return ::tensorflow::OkStatus();
93   }
94   const int constant_input_idx = input_is_const[0] ? 0 : 1;
95   const int variable_input_idx = input_is_const[0] ? 1 : 0;
96   CHECK(input_is_const[constant_input_idx]);
97   CHECK(!input_is_const[variable_input_idx]);
98 
99   const auto& variable_input_array =
100       model->GetArray(binary_op->inputs[variable_input_idx]);
101   if (!variable_input_array.has_shape()) {
102     AddMessageF(
103         "Not moving %s because it's non-constant input shape is not resolved.",
104         LogName(*binary_op));
105     return ::tensorflow::OkStatus();
106   }
107   if (!IsTailOfShape(
108           model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
109           model->GetArray(binary_op->inputs[variable_input_idx]).shape())) {
110     // Constant array shape must be the latter part of the variable shape.
111     return ::tensorflow::OkStatus();
112   }
113 
114   // RESHAPE OP CHECKS
115   auto reshape_it =
116       FindOpWithOutput(*model, binary_op->inputs[variable_input_idx]);
117   if (reshape_it == model->operators.end()) {
118     AddMessageF("Not moving %s because it's variable input is not connected.",
119                 LogName(*binary_op));
120     return ::tensorflow::OkStatus();
121   }
122   Operator* reshape_op = reshape_it->get();
123   if (reshape_op->type != OperatorType::kReshape) {
124     AddMessageF("Not moving %s because the preceding %s is not a reshape op",
125                 LogName(*binary_op), LogName(*reshape_op));
126     return ::tensorflow::OkStatus();
127   }
128   const auto& reshape_input_array = model->GetArray(reshape_op->inputs[0]);
129   if (!reshape_input_array.has_shape()) {
130     AddMessageF(
131         "Not moving %s because it's non-constant input shape is not resolved "
132         "yet",
133         LogName(*binary_op));
134     return ::tensorflow::OkStatus();
135   }
136   if (!IsTailOfShape(
137           model->GetArray(binary_op->inputs[constant_input_idx]).shape(),
138           model->GetArray(reshape_op->outputs[0]).shape())) {
139     // Constant array shape must be the latter part of the binary op output
140     // shape.
141     return ::tensorflow::OkStatus();
142   }
143 
144   // EXTRA CHECKS ON CONNECTING ARRAY
145   for (const std::string& output_array : model->flags.output_arrays()) {
146     if (binary_op->inputs[variable_input_idx] == output_array) {
147       AddMessageF(
148           "Not moving %s because the output of reshape op %s is an output op.",
149           LogName(*binary_op), LogName(*reshape_op));
150       return ::tensorflow::OkStatus();
151     }
152   }
153   int count_ops_consuming_output =
154       CountOpsWithInput(*model, binary_op->inputs[variable_input_idx]);
155   DCHECK_GE(count_ops_consuming_output, 1);
156   if (count_ops_consuming_output > 1) {
157     AddMessageF(
158         "Not moving %s because the output of reshape op %s is consumed by "
159         "another op",
160         LogName(*binary_op), LogName(*reshape_op));
161     return ::tensorflow::OkStatus();
162   }
163 
164   // SWAP ORDER OF BINARY AND RESHAPE OPS
165   AddMessageF("Moving op %s before reshape op %s", LogName(*binary_op),
166               LogName(*reshape_op));
167 
168   // Swap op input and outputs
169   std::iter_swap(reshape_op->inputs.begin(),
170                  binary_op->inputs.begin() + variable_input_idx);
171   std::iter_swap(reshape_op->outputs.begin(), binary_op->outputs.begin());
172 
173   // Swap operator ordering
174   std::iter_swap(binary_it, reshape_it);
175 
176   // Clear binary output shape so it will be re-propagated
177   model->GetArray(binary_op->outputs[0]).clear_shape();
178 
179   *modified = true;
180   return ::tensorflow::OkStatus();
181 }
182 
183 }  // namespace toco
184