1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/platform/logging.h"
37
38 namespace xla {
39
40 namespace {
41
42 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
43 // operations with feature_group_count > 1 into convolutions with
44 // feature_group_count = 1.
45 class ConvolutionVisitor : public DfsHloVisitorWithDefault {
46 public:
47 // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)48 Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
49 return OkStatus();
50 }
51
52 Status HandleConvolution(HloInstruction* convolution) override;
53
54 Status HandleBatchGroupCount(HloInstruction* convolution);
55
56 // Runs the visitor on a computation.
57 static bool Run(HloComputation* computation,
58 std::function<bool(HloInstruction*)> should_expand,
59 std::function<bool(HloInstruction*)> is_cost_viable,
60 bool convert_batch_groups_only, bool filter_expansion);
61
62 // Returns whether any convolution ops were rewritten.
changed() const63 const bool changed() const { return changed_; }
64
65 ~ConvolutionVisitor() override = default;
66
67 private:
ConvolutionVisitor(HloComputation * computation,std::function<bool (HloInstruction *)> should_expand,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool filter_expansion)68 explicit ConvolutionVisitor(
69 HloComputation* computation,
70 std::function<bool(HloInstruction*)> should_expand,
71 std::function<bool(HloInstruction*)> is_cost_viable,
72 bool convert_batch_groups_only, bool filter_expansion)
73 : computation_(computation),
74 filter_expansion_(filter_expansion),
75 convert_batch_groups_only_(convert_batch_groups_only),
76 should_expand_(should_expand),
77 is_cost_viable_(is_cost_viable) {}
78
79 // Current HloComputation instance the ConvolutionVisitor is traversing.
80 HloComputation* computation_;
81
82 // Whether rewrite has occurred.
83 bool changed_ = false;
84
85 // Whether filter expansion is required.
86 bool filter_expansion_;
87
88 // Decides whether to convert batch groups or feature groups.
89 bool convert_batch_groups_only_;
90
91 std::function<bool(HloInstruction*)> should_expand_;
92 std::function<bool(HloInstruction*)> is_cost_viable_;
93 };
94
Run(HloComputation * computation,std::function<bool (HloInstruction *)> should_expand,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool filter_expansion)95 bool ConvolutionVisitor::Run(
96 HloComputation* computation,
97 std::function<bool(HloInstruction*)> should_expand,
98 std::function<bool(HloInstruction*)> is_cost_viable,
99 bool convert_batch_groups_only, bool filter_expansion) {
100 ConvolutionVisitor visitor(computation, should_expand, is_cost_viable,
101 convert_batch_groups_only, filter_expansion);
102 TF_CHECK_OK(computation->Accept(&visitor));
103 return visitor.changed_;
104 }
105
ExpandedFilterShape(const Shape & shape,int64_t group_count,int64_t input_feature_dim)106 Shape ExpandedFilterShape(const Shape& shape, int64_t group_count,
107 int64_t input_feature_dim) {
108 int64_t num_dims = shape.dimensions_size();
109 CHECK_GE(num_dims, 2);
110 Shape expanded_shape = shape;
111 expanded_shape.set_dimensions(
112 input_feature_dim, shape.dimensions(input_feature_dim) * group_count);
113 return expanded_shape;
114 }
115
116 // Returns a vector with 'group_count' many groups, where the i-th group
117 // consists of 'group_size' times the value i.
GetMaskIds(int64_t group_size,int64_t group_count)118 std::vector<int32_t> GetMaskIds(int64_t group_size, int64_t group_count) {
119 std::vector<int32_t> values;
120 values.reserve(group_count * group_size);
121 for (int i = 0; i < group_count; ++i) {
122 for (int j = 0; j < group_size; ++j) {
123 values.push_back(i);
124 }
125 }
126 return values;
127 }
128
129 // Create a mask for grouped convolution that will make a normal convolution
130 // produce the same results as a grouped convolution. For a [2, 1, 6]
131 // filter this returns a [2, 3, 6] mask
132 // 1 1 0 0 0 0
133 // 0 0 1 1 0 0
134 // 0 0 0 0 1 1
135 //
136 // 1 1 0 0 0 0
137 // 0 0 1 1 0 0
138 // 0 0 0 0 1 1
139 //
140 // The first step is to create a rank 1 constant:
141 // 0 1 2
142 //
143 // This is broadcasted to
144 // 0 0 0 0 0 0
145 // 1 1 1 1 1 1
146 // 2 2 2 2 2 2
147 //
148 // 0 0 0 0 0 0
149 // 1 1 1 1 1 1
150 // 2 2 2 2 2 2
151 //
152 // Then we create another rank 1 constant
153 // 0 0 1 1 2 2
154 //
155 // This is broadcasted to
156 // 0 0 1 1 2 2
157 // 0 0 1 1 2 2
158 // 0 0 1 1 2 2
159 //
160 // 0 0 1 1 2 2
161 // 0 0 1 1 2 2
162 // 0 0 1 1 2 2
163 //
164 // Finally we use the Eq op of these two broadcasted constants and get the
165 // desired mask.
GetExpandedFilterMask(const Shape & filter_shape,int64_t kernel_input_feature_dim,int64_t kernel_output_feature_dim,int64_t group_count,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)166 HloInstruction* GetExpandedFilterMask(
167 const Shape& filter_shape, int64_t kernel_input_feature_dim,
168 int64_t kernel_output_feature_dim, int64_t group_count,
169 const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
170 add_instruction) {
171 Shape expanded_filter_shape =
172 ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim);
173 Shape mask_shape =
174 ShapeUtil::MakeShape(S32, expanded_filter_shape.dimensions());
175 int64_t output_feature = filter_shape.dimensions(kernel_output_feature_dim);
176 int64_t group_size = filter_shape.dimensions(kernel_input_feature_dim);
177
178 // Create a 'input_feature' sized linspace and 'output_feature' sized linspace
179 // that will be broadcasted into perpendicular dimensions and compared.
180 const std::vector<int32_t> input_feature_filter_mask =
181 GetMaskIds(group_size, group_count);
182 const std::vector<int32_t> output_feature_filter_mask =
183 GetMaskIds(output_feature / group_count, group_count);
184 auto mask1 = add_instruction(HloInstruction::CreateConstant(
185 LiteralUtil::CreateR1<int32_t>(input_feature_filter_mask)));
186 auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast(
187 mask_shape, mask1, {kernel_input_feature_dim}));
188 auto mask2 = add_instruction(HloInstruction::CreateConstant(
189 LiteralUtil::CreateR1<int32_t>(output_feature_filter_mask)));
190 auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast(
191 mask_shape, mask2, {kernel_output_feature_dim}));
192
193 // Compare the broadcasted output feature linspace to the input feature
194 // linspace to create a diagonal predicate.
195 Shape predicate_shape =
196 ShapeUtil::MakeShape(PRED, expanded_filter_shape.dimensions());
197 return add_instruction(HloInstruction::CreateCompare(
198 predicate_shape, broadcasted_mask1, broadcasted_mask2,
199 ComparisonDirection::kEq));
200 }
201
202 // This function handles batch_group_counts which are relevant only for
203 // depthwise backprop filter convolutions.
HandleBatchGroupCount(HloInstruction * convolution)204 Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
205 auto dim_numbers = convolution->convolution_dimension_numbers();
206 auto activation = convolution->mutable_operand(0);
207 auto filter = convolution->mutable_operand(1);
208 int64_t batch_group_count = convolution->batch_group_count();
209
210 if (batch_group_count == 1 ||
211 (should_expand_ && !should_expand_(convolution))) {
212 return OkStatus();
213 }
214
215 VLOG(2) << "Dealing with batch_group_count " << batch_group_count
216 << " for convolution " << convolution->ToString() << "\n";
217
218 auto add = [&](std::unique_ptr<HloInstruction> inst) {
219 return computation_->AddInstruction(std::move(inst));
220 };
221
222 int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
223 const int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
224
225 int64_t output_batch_dimension = dim_numbers.output_batch_dimension();
226 int64_t output_feature_dimension = dim_numbers.output_feature_dimension();
227
228 const int64_t kernel_input_feature_dimension =
229 dim_numbers.kernel_input_feature_dimension();
230 const int64_t kernel_output_feature_dimension =
231 dim_numbers.kernel_output_feature_dimension();
232
233 const int64_t input_batch =
234 activation->shape().dimensions(input_batch_dimension);
235 const int64_t output_feature =
236 filter->shape().dimensions(kernel_output_feature_dimension);
237
238 if (output_feature != batch_group_count || input_batch != batch_group_count) {
239 // Insert a spatial dimension to the activation before the input batch
240 // dimension to represent the batch group.
241 std::vector<int64_t> input_sizes(activation->shape().dimensions().begin(),
242 activation->shape().dimensions().end());
243 input_sizes[input_batch_dimension] /= batch_group_count;
244 input_sizes.insert(input_sizes.begin() + input_batch_dimension,
245 batch_group_count);
246 activation = MakeReshapeHlo(input_sizes, activation).ValueOrDie();
247 for (auto& d : *dim_numbers.mutable_input_spatial_dimensions()) {
248 if (d > input_batch_dimension) {
249 ++d;
250 }
251 }
252 dim_numbers.add_input_spatial_dimensions(input_batch_dimension);
253 dim_numbers.set_input_batch_dimension(input_batch_dimension + 1);
254 if (input_feature_dimension > input_batch_dimension) {
255 dim_numbers.set_input_feature_dimension(input_feature_dimension + 1);
256 }
257
258 // Insert a spatial dimension to the kernel before the output feature
259 // dimension to represent the batch group.
260 std::vector<int64_t> kernel_sizes(filter->shape().dimensions().begin(),
261 filter->shape().dimensions().end());
262 kernel_sizes[kernel_output_feature_dimension] /= batch_group_count;
263 kernel_sizes.insert(kernel_sizes.begin() + kernel_output_feature_dimension,
264 batch_group_count);
265 filter = MakeReshapeHlo(kernel_sizes, filter).ValueOrDie();
266 for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
267 if (d > kernel_output_feature_dimension) {
268 ++d;
269 }
270 }
271 dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension);
272 dim_numbers.set_kernel_output_feature_dimension(
273 kernel_output_feature_dimension + 1);
274 if (kernel_input_feature_dimension > kernel_output_feature_dimension) {
275 dim_numbers.set_kernel_input_feature_dimension(
276 kernel_input_feature_dimension + 1);
277 }
278
279 // Insert a spatial dimension to the output before the output feature
280 // dimension to represent the batch group.
281 for (auto& d : *dim_numbers.mutable_output_spatial_dimensions()) {
282 if (d > output_feature_dimension) {
283 ++d;
284 }
285 }
286 dim_numbers.add_output_spatial_dimensions(output_feature_dimension);
287 dim_numbers.set_output_feature_dimension(output_feature_dimension + 1);
288 if (output_batch_dimension > output_feature_dimension) {
289 dim_numbers.set_output_batch_dimension(output_batch_dimension + 1);
290 }
291
292 // To represent a batch group count of 3 you can slide a 3 wide window
293 // [X Y Z]
294 // across [A 0 0 B 0 0 C] with stride 2 to produce
295 // [AX+0Y+0Z 0X+BY+0Z 0X+0Y+CZ] -> [AX BY CZ] which will behave the same as
296 // a batch group count.
297 Window window = convolution->window();
298 auto window_dim = window.add_dimensions();
299 window_dim->set_base_dilation(batch_group_count);
300 window_dim->set_size(batch_group_count);
301 window_dim->set_stride(batch_group_count - 1);
302 window_dim->set_padding_low(0);
303 window_dim->set_padding_high(0);
304 window_dim->set_window_reversal(false);
305 window_dim->set_window_dilation(1);
306 HloInstruction* new_convolution =
307 MakeConvolveHlo(
308 activation, filter, convolution->feature_group_count(),
309 /*batch_group_count=*/1, window, dim_numbers,
310 convolution->precision_config(),
311 /*preferred_element_type=*/convolution->shape().element_type())
312 .ValueOrDie();
313 convolution->SetupDerivedInstruction(new_convolution);
314 TF_CHECK_OK(computation_->ReplaceInstruction(
315 convolution,
316 MakeReshapeHlo(convolution->shape(), new_convolution).ValueOrDie()));
317 changed_ = true;
318 return OkStatus();
319 }
320
321 VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
322 const bool cost_too_high = !is_cost_viable_(convolution);
323 if (cost_too_high || filter_expansion_) {
324 // We first obtain the expanded the filter (which is the convolution
325 // output). The batch dimension is the expanded one (which originally
326 // represents kernel input feature dimension). We mask the filter to zero
327 // out the expanded regions. Next we reduce the filter in the batch
328 // dimension to obtain the original filter size.
329
330 HloInstruction* filter_mask =
331 GetExpandedFilterMask(convolution->shape(), output_batch_dimension,
332 output_feature_dimension, batch_group_count, add);
333 auto expanded_filter_shape = ExpandedFilterShape(
334 convolution->shape(), batch_group_count, output_batch_dimension);
335
336 VLOG(2) << "output_batch_dimension " << output_batch_dimension;
337 VLOG(2) << "New output shape of convolution "
338 << expanded_filter_shape.ToString();
339
340 auto new_convolution = add(HloInstruction::CreateConvolve(
341 expanded_filter_shape, activation, filter,
342 /*feature_group_count=*/1, /*batch_group_count=*/1,
343 convolution->window(), dim_numbers, convolution->precision_config()));
344
345 VLOG(2) << "Expanded convolution " << new_convolution->ToString();
346
347 auto zero = add(HloInstruction::CreateConstant(
348 LiteralUtil::Zero(expanded_filter_shape.element_type())));
349 auto zero_filter =
350 add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
351
352 auto new_filter = add(HloInstruction::CreateTernary(
353 expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution,
354 zero_filter));
355
356 PrimitiveType reduce_type = new_filter->shape().element_type();
357 auto reduce_window_shape = new_convolution->shape();
358 reduce_window_shape.set_dimensions(output_batch_dimension, 1);
359
360 // Ensure that data input to reduce window uses at least 32 bits.
361 if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) {
362 reduce_type = F32;
363 reduce_window_shape.set_element_type(F32);
364 Shape convert_shape = new_filter->shape();
365 convert_shape.set_element_type(F32);
366 new_filter =
367 add(HloInstruction::CreateConvert(convert_shape, new_filter));
368 }
369
370 auto zero_literal = LiteralUtil::Zero(reduce_type);
371 auto zero_scalar =
372 add(HloInstruction::CreateConstant(std::move(zero_literal)));
373
374 auto reduce_function = [&]() -> HloComputation* {
375 HloComputation::Builder b("add_computation");
376 Shape shape = ShapeUtil::MakeShape(reduce_type, {});
377 auto lhs =
378 b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
379 auto rhs =
380 b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
381 auto scalar_op = b.AddInstruction(
382 HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs));
383 return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
384 };
385
386 // Create the reduce window.
387 Window window;
388 for (int64_t i = 0; i < new_convolution->shape().dimensions_size(); ++i) {
389 auto* dim = window.add_dimensions();
390 dim->set_padding_low(0);
391 dim->set_padding_high(0);
392 dim->set_window_dilation(1);
393 dim->set_base_dilation(1);
394 if (i == output_batch_dimension) {
395 dim->set_stride(batch_group_count);
396 dim->set_size(batch_group_count);
397 } else {
398 dim->set_stride(1);
399 dim->set_size(1);
400 }
401 }
402 auto reduce_window = add(HloInstruction::CreateReduceWindow(
403 reduce_window_shape, new_filter, zero_scalar, window,
404 reduce_function()));
405
406 Shape convert_back_shape = reduce_window->shape();
407 convert_back_shape.set_element_type(activation->shape().element_type());
408
409 // Convert reduced data back to the original data type.
410 auto reduce_window_converted =
411 HloInstruction::CreateConvert(convert_back_shape, reduce_window);
412
413 TF_CHECK_OK(computation_->ReplaceWithNewInstruction(
414 convolution, std::move(reduce_window_converted)));
415 changed_ = true;
416 }
417
418 return OkStatus();
419 }
420
HandleConvolution(HloInstruction * convolution)421 Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
422 if (convert_batch_groups_only_) {
423 return HandleBatchGroupCount(convolution);
424 }
425
426 auto add = [&](std::unique_ptr<HloInstruction> inst) {
427 return computation_->AddInstruction(std::move(inst));
428 };
429
430 int64_t group_count = convolution->feature_group_count();
431 if (group_count == 1 || (should_expand_ && !should_expand_(convolution))) {
432 return OkStatus();
433 }
434
435 changed_ = true;
436 ConvolutionDimensionNumbers dim_numbers =
437 convolution->convolution_dimension_numbers();
438 auto filter = convolution->mutable_operand(1);
439 int64_t kernel_input_feature_dim =
440 dim_numbers.kernel_input_feature_dimension();
441 int64_t group_size = filter->shape().dimensions(kernel_input_feature_dim);
442 int64_t kernel_output_feature_dim =
443 dim_numbers.kernel_output_feature_dimension();
444 auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count,
445 kernel_input_feature_dim);
446 HloInstruction* filter_mask =
447 GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim,
448 kernel_output_feature_dim, group_count, add);
449 HloInstruction* expanded_filter;
450
451 if (group_size == 1) {
452 bool depthwise_separable =
453 (group_count == filter->shape().dimensions(kernel_output_feature_dim));
454 // If the code generator handles depthwise separable convolutions
455 // inherently, then no filter expansion is needed.
456 if (!filter_expansion_ && depthwise_separable) {
457 changed_ = false;
458 return OkStatus();
459 }
460 VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
461 // We want to repeat 'filter' in the 'input_feature_dim' dimension
462 // 'group_count' times.
463 if (!is_cost_viable_(convolution) || filter_expansion_) {
464 Shape reshaped_filter_shape =
465 ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
466 auto reshaped_filter =
467 add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
468 std::vector<int64_t> broadcast_dims;
469 for (int64_t i = 0; i < filter->shape().dimensions_size(); ++i) {
470 if (i == kernel_input_feature_dim) {
471 continue;
472 }
473 broadcast_dims.push_back(i);
474 }
475 expanded_filter = add(HloInstruction::CreateBroadcast(
476 expanded_filter_shape, reshaped_filter, broadcast_dims));
477
478 auto zero = add(HloInstruction::CreateConstant(
479 LiteralUtil::Zero(expanded_filter_shape.element_type())));
480 auto zero_filter =
481 add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
482 auto new_filter = add(HloInstruction::CreateTernary(
483 expanded_filter_shape, HloOpcode::kSelect, filter_mask,
484 expanded_filter, zero_filter));
485
486 auto new_convolution = HloInstruction::CreateConvolve(
487 convolution->shape(), convolution->mutable_operand(0), new_filter,
488 /*feature_group_count=*/1, /*batch_group_count=*/1,
489 convolution->window(), dim_numbers, convolution->precision_config());
490 return computation_->ReplaceWithNewInstruction(
491 convolution, std::move(new_convolution));
492 }
493 // Add a spatial dimension to emulate a larger output feature dimension
494 // to avoid creating a convolution with group_count = 1.
495 std::vector<int64_t> new_filter_dimension;
496 new_filter_dimension.reserve(filter->shape().rank() + 1);
497 const int64_t depthwise_multiplier =
498 filter->shape().dimensions(kernel_output_feature_dim) / group_count;
499 // Split the kernel output feature dimension into group count and
500 // depthwise mutilipler.
501 for (int64_t i = 0; i < filter->shape().rank(); ++i) {
502 if (i == kernel_output_feature_dim) {
503 new_filter_dimension.push_back(group_count);
504 new_filter_dimension.push_back(depthwise_multiplier);
505 } else {
506 new_filter_dimension.push_back(filter->shape().dimensions(i));
507 }
508 }
509 if (kernel_input_feature_dim > kernel_output_feature_dim) {
510 dim_numbers.set_kernel_input_feature_dimension(kernel_input_feature_dim +
511 1);
512 }
513 for (auto& dim : *dim_numbers.mutable_kernel_spatial_dimensions()) {
514 if (dim > kernel_output_feature_dim) {
515 ++dim;
516 }
517 }
518 dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dim + 1);
519 HloInstruction* new_filter =
520 computation_->AddInstruction(HloInstruction::CreateReshape(
521 ShapeUtil::MakeShape(filter->shape().element_type(),
522 new_filter_dimension),
523 filter));
524
525 auto new_activation_shape = convolution->operand(0)->shape();
526 dim_numbers.add_input_spatial_dimensions(new_activation_shape.rank());
527
528 // Create and activations spatial dimension of size 1 with a reversed
529 // window and high and low padding equal to the depthwise_multiplier -1.
530 // This emulates a larger output feature dimension with an extra spatial
531 // dimension.
532 ShapeUtil::AppendMajorDimension(1, &new_activation_shape);
533 HloInstruction* new_activation =
534 computation_->AddInstruction(HloInstruction::CreateReshape(
535 new_activation_shape, convolution->mutable_operand(0)));
536 auto new_window = convolution->window();
537 auto new_dim = new_window.add_dimensions();
538 new_dim->set_size(depthwise_multiplier);
539 new_dim->set_window_reversal(true);
540 new_dim->set_padding_low(depthwise_multiplier - 1);
541 new_dim->set_padding_high(depthwise_multiplier - 1);
542 new_dim->set_stride(1);
543 new_dim->set_window_dilation(1);
544 new_dim->set_base_dilation(1);
545
546 // Split the output feature dimension into and output feature of group
547 // count and depthwise multipler as an output spatial dimension.
548 std::vector<int64_t> new_output_dimension;
549 new_output_dimension.reserve(convolution->shape().rank() + 1);
550 for (int64_t i = 0; i < convolution->shape().rank(); ++i) {
551 if (i == dim_numbers.output_feature_dimension()) {
552 new_output_dimension.push_back(group_count);
553 new_output_dimension.push_back(depthwise_multiplier);
554 } else {
555 new_output_dimension.push_back(convolution->shape().dimensions(i));
556 }
557 }
558 if (dim_numbers.output_batch_dimension() >
559 dim_numbers.output_feature_dimension()) {
560 dim_numbers.set_output_batch_dimension(
561 dim_numbers.output_batch_dimension() + 1);
562 }
563 for (auto& dim : *dim_numbers.mutable_output_spatial_dimensions()) {
564 if (dim > dim_numbers.output_feature_dimension()) {
565 ++dim;
566 }
567 }
568 dim_numbers.add_output_spatial_dimensions(
569 dim_numbers.output_feature_dimension() + 1);
570 auto new_convolution_output_shape = ShapeUtil::MakeShape(
571 convolution->shape().element_type(), new_output_dimension);
572 HloInstruction* new_convolution =
573 computation_->AddInstruction(HloInstruction::CreateConvolve(
574 new_convolution_output_shape, new_activation, new_filter,
575 /*feature_group_count=*/group_count, /*batch_group_count=*/1,
576 new_window, dim_numbers, convolution->precision_config()));
577 return computation_->ReplaceWithNewInstruction(
578 convolution,
579 HloInstruction::CreateReshape(convolution->shape(), new_convolution));
580 }
581
582 // Implement general grouped convolution using an extra spatial dimension to
583 // represent the feature group count.
584 //
585 // Insert a spatial dimension to the input before the input feature
586 // dimension to represent the feature group.
587 HloInstruction* activation = convolution->mutable_operand(0);
588 std::vector<int64_t> input_sizes(activation->shape().dimensions().begin(),
589 activation->shape().dimensions().end());
590 const int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
591 input_sizes[input_feature_dimension] /= group_count;
592 input_sizes.insert(input_sizes.begin() + input_feature_dimension,
593 group_count);
594 activation = MakeReshapeHlo(input_sizes, activation).ValueOrDie();
595 for (auto& d : *dim_numbers.mutable_input_spatial_dimensions()) {
596 if (d > input_feature_dimension) {
597 ++d;
598 }
599 }
600 dim_numbers.add_input_spatial_dimensions(input_feature_dimension);
601 dim_numbers.set_input_feature_dimension(input_feature_dimension + 1);
602 if (dim_numbers.input_batch_dimension() > input_feature_dimension) {
603 dim_numbers.set_input_batch_dimension(dim_numbers.input_batch_dimension() +
604 1);
605 }
606
607 // Insert a spatial dimension to the kernel before the output feature
608 // dimension to represent the feature group.
609 std::vector<int64_t> kernel_sizes(filter->shape().dimensions().begin(),
610 filter->shape().dimensions().end());
611 const int64_t kernel_output_feature_dimension =
612 dim_numbers.kernel_output_feature_dimension();
613 kernel_sizes[kernel_output_feature_dimension] /= group_count;
614 kernel_sizes.insert(kernel_sizes.begin() + kernel_output_feature_dimension,
615 group_count);
616 filter = MakeReshapeHlo(kernel_sizes, filter).ValueOrDie();
617 for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
618 if (d > kernel_output_feature_dimension) {
619 ++d;
620 }
621 }
622 dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension);
623 dim_numbers.set_kernel_output_feature_dimension(
624 kernel_output_feature_dimension + 1);
625 if (dim_numbers.kernel_input_feature_dimension() >
626 kernel_output_feature_dimension) {
627 dim_numbers.set_kernel_input_feature_dimension(
628 dim_numbers.kernel_input_feature_dimension() + 1);
629 }
630
631 // Insert a spatial dimension to the output before the output feature
632 // dimension to represent the feature group.
633 const int64_t output_feature_dimension =
634 dim_numbers.output_feature_dimension();
635 for (auto& d : *dim_numbers.mutable_output_spatial_dimensions()) {
636 if (d > output_feature_dimension) {
637 ++d;
638 }
639 }
640 dim_numbers.add_output_spatial_dimensions(output_feature_dimension);
641 dim_numbers.set_output_feature_dimension(output_feature_dimension + 1);
642 if (dim_numbers.output_batch_dimension() > output_feature_dimension) {
643 dim_numbers.set_output_batch_dimension(
644 dim_numbers.output_batch_dimension() + 1);
645 }
646
647 // To represent a feature group count of 3 you can slide a 3 wide window
648 // [X Y Z]
649 // across [A 0 0 B 0 0 C] with stride 2 to produce
650 // [AX+0Y+0Z 0X+BY+0Z 0X+0Y+CZ] -> [AX BY CZ] which will behave the same as
651 // a batch group count.
652 Window window = convolution->window();
653 auto window_dim = window.add_dimensions();
654 window_dim->set_base_dilation(group_count);
655 window_dim->set_size(group_count);
656 window_dim->set_stride(group_count - 1);
657 window_dim->set_padding_low(0);
658 window_dim->set_padding_high(0);
659 window_dim->set_window_reversal(false);
660 window_dim->set_window_dilation(1);
661 HloInstruction* new_convolution =
662 MakeConvolveHlo(
663 activation, filter, /*feature_group_count=*/1,
664 /*batch_group_count=*/1, window, dim_numbers,
665 convolution->precision_config(),
666 /*preferred_element_type=*/convolution->shape().element_type())
667 .ValueOrDie();
668 convolution->SetupDerivedInstruction(new_convolution);
669 changed_ = true;
670 return computation_->ReplaceInstruction(
671 convolution,
672 MakeReshapeHlo(convolution->shape(), new_convolution).ValueOrDie());
673 }
674
675 } // namespace
676
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)677 StatusOr<bool> ConvolutionGroupConverter::Run(
678 HloModule* module,
679 const absl::flat_hash_set<absl::string_view>& execution_threads) {
680 XLA_VLOG_LINES(
681 2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString());
682 bool changed = false;
683 for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
684 if (ConvolutionVisitor::Run(comp, should_expand_, is_cost_viable_,
685 convert_batch_groups_only_,
686 filter_expansion_)) {
687 changed = true;
688 }
689 }
690 XLA_VLOG_LINES(
691 2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString());
692 return changed;
693 }
694
695 } // namespace xla
696