xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/convolution_group_converter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/convolution_group_converter.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/literal_util.h"
24 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
25 #include "tensorflow/compiler/xla/service/hlo_computation.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/status_macros.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/platform/logging.h"
37 
38 namespace xla {
39 
40 namespace {
41 
42 // ConvolutionVisitor traverses the HLO computation and rewrites Convolution
43 // operations with feature_group_count > 1 into convolutions with
44 // feature_group_count = 1.
45 class ConvolutionVisitor : public DfsHloVisitorWithDefault {
46  public:
47   // Default visitor action is to do nothing and return OK.
DefaultAction(HloInstruction *)48   Status DefaultAction(HloInstruction* /*hlo_instruction*/) override {
49     return OkStatus();
50   }
51 
52   Status HandleConvolution(HloInstruction* convolution) override;
53 
54   Status HandleBatchGroupCount(HloInstruction* convolution);
55 
56   // Runs the visitor on a computation.
57   static bool Run(HloComputation* computation,
58                   std::function<bool(HloInstruction*)> should_expand,
59                   std::function<bool(HloInstruction*)> is_cost_viable,
60                   bool convert_batch_groups_only, bool filter_expansion);
61 
62   // Returns whether any convolution ops were rewritten.
changed() const63   const bool changed() const { return changed_; }
64 
65   ~ConvolutionVisitor() override = default;
66 
67  private:
ConvolutionVisitor(HloComputation * computation,std::function<bool (HloInstruction *)> should_expand,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool filter_expansion)68   explicit ConvolutionVisitor(
69       HloComputation* computation,
70       std::function<bool(HloInstruction*)> should_expand,
71       std::function<bool(HloInstruction*)> is_cost_viable,
72       bool convert_batch_groups_only, bool filter_expansion)
73       : computation_(computation),
74         filter_expansion_(filter_expansion),
75         convert_batch_groups_only_(convert_batch_groups_only),
76         should_expand_(should_expand),
77         is_cost_viable_(is_cost_viable) {}
78 
79   // Current HloComputation instance the ConvolutionVisitor is traversing.
80   HloComputation* computation_;
81 
82   // Whether rewrite has occurred.
83   bool changed_ = false;
84 
85   // Whether filter expansion is required.
86   bool filter_expansion_;
87 
88   // Decides whether to convert batch groups or feature groups.
89   bool convert_batch_groups_only_;
90 
91   std::function<bool(HloInstruction*)> should_expand_;
92   std::function<bool(HloInstruction*)> is_cost_viable_;
93 };
94 
Run(HloComputation * computation,std::function<bool (HloInstruction *)> should_expand,std::function<bool (HloInstruction *)> is_cost_viable,bool convert_batch_groups_only,bool filter_expansion)95 bool ConvolutionVisitor::Run(
96     HloComputation* computation,
97     std::function<bool(HloInstruction*)> should_expand,
98     std::function<bool(HloInstruction*)> is_cost_viable,
99     bool convert_batch_groups_only, bool filter_expansion) {
100   ConvolutionVisitor visitor(computation, should_expand, is_cost_viable,
101                              convert_batch_groups_only, filter_expansion);
102   TF_CHECK_OK(computation->Accept(&visitor));
103   return visitor.changed_;
104 }
105 
ExpandedFilterShape(const Shape & shape,int64_t group_count,int64_t input_feature_dim)106 Shape ExpandedFilterShape(const Shape& shape, int64_t group_count,
107                           int64_t input_feature_dim) {
108   int64_t num_dims = shape.dimensions_size();
109   CHECK_GE(num_dims, 2);
110   Shape expanded_shape = shape;
111   expanded_shape.set_dimensions(
112       input_feature_dim, shape.dimensions(input_feature_dim) * group_count);
113   return expanded_shape;
114 }
115 
116 // Returns a vector with 'group_count' many groups, where the i-th group
117 // consists of 'group_size' times the value i.
GetMaskIds(int64_t group_size,int64_t group_count)118 std::vector<int32_t> GetMaskIds(int64_t group_size, int64_t group_count) {
119   std::vector<int32_t> values;
120   values.reserve(group_count * group_size);
121   for (int i = 0; i < group_count; ++i) {
122     for (int j = 0; j < group_size; ++j) {
123       values.push_back(i);
124     }
125   }
126   return values;
127 }
128 
129 // Create a mask for grouped convolution that will make a normal convolution
130 // produce the same results as a grouped convolution. For a [2, 1, 6]
131 // filter this returns a [2, 3, 6] mask
132 //   1 1 0 0 0 0
133 //   0 0 1 1 0 0
134 //   0 0 0 0 1 1
135 //
136 //   1 1 0 0 0 0
137 //   0 0 1 1 0 0
138 //   0 0 0 0 1 1
139 //
140 // The first step is to create a rank 1 constant:
141 //   0 1 2
142 //
143 // This is broadcasted to
144 //   0 0 0 0 0 0
145 //   1 1 1 1 1 1
146 //   2 2 2 2 2 2
147 //
148 //   0 0 0 0 0 0
149 //   1 1 1 1 1 1
150 //   2 2 2 2 2 2
151 //
152 // Then we create another rank 1 constant
153 //   0 0 1 1 2 2
154 //
155 // This is broadcasted to
156 //   0 0 1 1 2 2
157 //   0 0 1 1 2 2
158 //   0 0 1 1 2 2
159 //
160 //   0 0 1 1 2 2
161 //   0 0 1 1 2 2
162 //   0 0 1 1 2 2
163 //
164 // Finally we use the Eq op of these two broadcasted constants and get the
165 // desired mask.
GetExpandedFilterMask(const Shape & filter_shape,int64_t kernel_input_feature_dim,int64_t kernel_output_feature_dim,int64_t group_count,const std::function<HloInstruction * (std::unique_ptr<HloInstruction>)> & add_instruction)166 HloInstruction* GetExpandedFilterMask(
167     const Shape& filter_shape, int64_t kernel_input_feature_dim,
168     int64_t kernel_output_feature_dim, int64_t group_count,
169     const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
170         add_instruction) {
171   Shape expanded_filter_shape =
172       ExpandedFilterShape(filter_shape, group_count, kernel_input_feature_dim);
173   Shape mask_shape =
174       ShapeUtil::MakeShape(S32, expanded_filter_shape.dimensions());
175   int64_t output_feature = filter_shape.dimensions(kernel_output_feature_dim);
176   int64_t group_size = filter_shape.dimensions(kernel_input_feature_dim);
177 
178   // Create a 'input_feature' sized linspace and 'output_feature' sized linspace
179   // that will be broadcasted into perpendicular dimensions and compared.
180   const std::vector<int32_t> input_feature_filter_mask =
181       GetMaskIds(group_size, group_count);
182   const std::vector<int32_t> output_feature_filter_mask =
183       GetMaskIds(output_feature / group_count, group_count);
184   auto mask1 = add_instruction(HloInstruction::CreateConstant(
185       LiteralUtil::CreateR1<int32_t>(input_feature_filter_mask)));
186   auto broadcasted_mask1 = add_instruction(HloInstruction::CreateBroadcast(
187       mask_shape, mask1, {kernel_input_feature_dim}));
188   auto mask2 = add_instruction(HloInstruction::CreateConstant(
189       LiteralUtil::CreateR1<int32_t>(output_feature_filter_mask)));
190   auto broadcasted_mask2 = add_instruction(HloInstruction::CreateBroadcast(
191       mask_shape, mask2, {kernel_output_feature_dim}));
192 
193   // Compare the broadcasted output feature linspace to the input feature
194   // linspace to create a diagonal predicate.
195   Shape predicate_shape =
196       ShapeUtil::MakeShape(PRED, expanded_filter_shape.dimensions());
197   return add_instruction(HloInstruction::CreateCompare(
198       predicate_shape, broadcasted_mask1, broadcasted_mask2,
199       ComparisonDirection::kEq));
200 }
201 
202 // This function handles batch_group_counts which are relevant only for
203 // depthwise backprop filter convolutions.
HandleBatchGroupCount(HloInstruction * convolution)204 Status ConvolutionVisitor::HandleBatchGroupCount(HloInstruction* convolution) {
205   auto dim_numbers = convolution->convolution_dimension_numbers();
206   auto activation = convolution->mutable_operand(0);
207   auto filter = convolution->mutable_operand(1);
208   int64_t batch_group_count = convolution->batch_group_count();
209 
210   if (batch_group_count == 1 ||
211       (should_expand_ && !should_expand_(convolution))) {
212     return OkStatus();
213   }
214 
215   VLOG(2) << "Dealing with batch_group_count " << batch_group_count
216           << " for convolution " << convolution->ToString() << "\n";
217 
218   auto add = [&](std::unique_ptr<HloInstruction> inst) {
219     return computation_->AddInstruction(std::move(inst));
220   };
221 
222   int64_t input_batch_dimension = dim_numbers.input_batch_dimension();
223   const int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
224 
225   int64_t output_batch_dimension = dim_numbers.output_batch_dimension();
226   int64_t output_feature_dimension = dim_numbers.output_feature_dimension();
227 
228   const int64_t kernel_input_feature_dimension =
229       dim_numbers.kernel_input_feature_dimension();
230   const int64_t kernel_output_feature_dimension =
231       dim_numbers.kernel_output_feature_dimension();
232 
233   const int64_t input_batch =
234       activation->shape().dimensions(input_batch_dimension);
235   const int64_t output_feature =
236       filter->shape().dimensions(kernel_output_feature_dimension);
237 
238   if (output_feature != batch_group_count || input_batch != batch_group_count) {
239     // Insert a spatial dimension to the activation before the input batch
240     // dimension to represent the batch group.
241     std::vector<int64_t> input_sizes(activation->shape().dimensions().begin(),
242                                      activation->shape().dimensions().end());
243     input_sizes[input_batch_dimension] /= batch_group_count;
244     input_sizes.insert(input_sizes.begin() + input_batch_dimension,
245                        batch_group_count);
246     activation = MakeReshapeHlo(input_sizes, activation).ValueOrDie();
247     for (auto& d : *dim_numbers.mutable_input_spatial_dimensions()) {
248       if (d > input_batch_dimension) {
249         ++d;
250       }
251     }
252     dim_numbers.add_input_spatial_dimensions(input_batch_dimension);
253     dim_numbers.set_input_batch_dimension(input_batch_dimension + 1);
254     if (input_feature_dimension > input_batch_dimension) {
255       dim_numbers.set_input_feature_dimension(input_feature_dimension + 1);
256     }
257 
258     // Insert a spatial dimension to the kernel before the output feature
259     // dimension to represent the batch group.
260     std::vector<int64_t> kernel_sizes(filter->shape().dimensions().begin(),
261                                       filter->shape().dimensions().end());
262     kernel_sizes[kernel_output_feature_dimension] /= batch_group_count;
263     kernel_sizes.insert(kernel_sizes.begin() + kernel_output_feature_dimension,
264                         batch_group_count);
265     filter = MakeReshapeHlo(kernel_sizes, filter).ValueOrDie();
266     for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
267       if (d > kernel_output_feature_dimension) {
268         ++d;
269       }
270     }
271     dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension);
272     dim_numbers.set_kernel_output_feature_dimension(
273         kernel_output_feature_dimension + 1);
274     if (kernel_input_feature_dimension > kernel_output_feature_dimension) {
275       dim_numbers.set_kernel_input_feature_dimension(
276           kernel_input_feature_dimension + 1);
277     }
278 
279     // Insert a spatial dimension to the output before the output feature
280     // dimension to represent the batch group.
281     for (auto& d : *dim_numbers.mutable_output_spatial_dimensions()) {
282       if (d > output_feature_dimension) {
283         ++d;
284       }
285     }
286     dim_numbers.add_output_spatial_dimensions(output_feature_dimension);
287     dim_numbers.set_output_feature_dimension(output_feature_dimension + 1);
288     if (output_batch_dimension > output_feature_dimension) {
289       dim_numbers.set_output_batch_dimension(output_batch_dimension + 1);
290     }
291 
292     // To represent a batch group count of 3 you can slide a 3 wide window
293     // [X Y Z]
294     // across [A 0 0 B 0 0 C] with stride 2 to produce
295     // [AX+0Y+0Z 0X+BY+0Z 0X+0Y+CZ] -> [AX BY CZ] which will behave the same as
296     // a batch group count.
297     Window window = convolution->window();
298     auto window_dim = window.add_dimensions();
299     window_dim->set_base_dilation(batch_group_count);
300     window_dim->set_size(batch_group_count);
301     window_dim->set_stride(batch_group_count - 1);
302     window_dim->set_padding_low(0);
303     window_dim->set_padding_high(0);
304     window_dim->set_window_reversal(false);
305     window_dim->set_window_dilation(1);
306     HloInstruction* new_convolution =
307         MakeConvolveHlo(
308             activation, filter, convolution->feature_group_count(),
309             /*batch_group_count=*/1, window, dim_numbers,
310             convolution->precision_config(),
311             /*preferred_element_type=*/convolution->shape().element_type())
312             .ValueOrDie();
313     convolution->SetupDerivedInstruction(new_convolution);
314     TF_CHECK_OK(computation_->ReplaceInstruction(
315         convolution,
316         MakeReshapeHlo(convolution->shape(), new_convolution).ValueOrDie()));
317     changed_ = true;
318     return OkStatus();
319   }
320 
321   VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
322   const bool cost_too_high = !is_cost_viable_(convolution);
323   if (cost_too_high || filter_expansion_) {
324     // We first obtain the expanded the filter (which is the convolution
325     // output). The batch dimension is the expanded one (which originally
326     // represents kernel input feature dimension). We mask the filter to zero
327     // out the expanded regions. Next we reduce the filter in the batch
328     // dimension to obtain the original filter size.
329 
330     HloInstruction* filter_mask =
331         GetExpandedFilterMask(convolution->shape(), output_batch_dimension,
332                               output_feature_dimension, batch_group_count, add);
333     auto expanded_filter_shape = ExpandedFilterShape(
334         convolution->shape(), batch_group_count, output_batch_dimension);
335 
336     VLOG(2) << "output_batch_dimension " << output_batch_dimension;
337     VLOG(2) << "New output shape of convolution "
338             << expanded_filter_shape.ToString();
339 
340     auto new_convolution = add(HloInstruction::CreateConvolve(
341         expanded_filter_shape, activation, filter,
342         /*feature_group_count=*/1, /*batch_group_count=*/1,
343         convolution->window(), dim_numbers, convolution->precision_config()));
344 
345     VLOG(2) << "Expanded convolution " << new_convolution->ToString();
346 
347     auto zero = add(HloInstruction::CreateConstant(
348         LiteralUtil::Zero(expanded_filter_shape.element_type())));
349     auto zero_filter =
350         add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
351 
352     auto new_filter = add(HloInstruction::CreateTernary(
353         expanded_filter_shape, HloOpcode::kSelect, filter_mask, new_convolution,
354         zero_filter));
355 
356     PrimitiveType reduce_type = new_filter->shape().element_type();
357     auto reduce_window_shape = new_convolution->shape();
358     reduce_window_shape.set_dimensions(output_batch_dimension, 1);
359 
360     // Ensure that data input to reduce window uses at least 32 bits.
361     if (primitive_util::BitWidth(reduce_type) < primitive_util::BitWidth(F32)) {
362       reduce_type = F32;
363       reduce_window_shape.set_element_type(F32);
364       Shape convert_shape = new_filter->shape();
365       convert_shape.set_element_type(F32);
366       new_filter =
367           add(HloInstruction::CreateConvert(convert_shape, new_filter));
368     }
369 
370     auto zero_literal = LiteralUtil::Zero(reduce_type);
371     auto zero_scalar =
372         add(HloInstruction::CreateConstant(std::move(zero_literal)));
373 
374     auto reduce_function = [&]() -> HloComputation* {
375       HloComputation::Builder b("add_computation");
376       Shape shape = ShapeUtil::MakeShape(reduce_type, {});
377       auto lhs =
378           b.AddInstruction(HloInstruction::CreateParameter(0, shape, "lhs"));
379       auto rhs =
380           b.AddInstruction(HloInstruction::CreateParameter(1, shape, "rhs"));
381       auto scalar_op = b.AddInstruction(
382           HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs, rhs));
383       return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op));
384     };
385 
386     // Create the reduce window.
387     Window window;
388     for (int64_t i = 0; i < new_convolution->shape().dimensions_size(); ++i) {
389       auto* dim = window.add_dimensions();
390       dim->set_padding_low(0);
391       dim->set_padding_high(0);
392       dim->set_window_dilation(1);
393       dim->set_base_dilation(1);
394       if (i == output_batch_dimension) {
395         dim->set_stride(batch_group_count);
396         dim->set_size(batch_group_count);
397       } else {
398         dim->set_stride(1);
399         dim->set_size(1);
400       }
401     }
402     auto reduce_window = add(HloInstruction::CreateReduceWindow(
403         reduce_window_shape, new_filter, zero_scalar, window,
404         reduce_function()));
405 
406     Shape convert_back_shape = reduce_window->shape();
407     convert_back_shape.set_element_type(activation->shape().element_type());
408 
409     // Convert reduced data back to the original data type.
410     auto reduce_window_converted =
411         HloInstruction::CreateConvert(convert_back_shape, reduce_window);
412 
413     TF_CHECK_OK(computation_->ReplaceWithNewInstruction(
414         convolution, std::move(reduce_window_converted)));
415     changed_ = true;
416   }
417 
418   return OkStatus();
419 }
420 
HandleConvolution(HloInstruction * convolution)421 Status ConvolutionVisitor::HandleConvolution(HloInstruction* convolution) {
422   if (convert_batch_groups_only_) {
423     return HandleBatchGroupCount(convolution);
424   }
425 
426   auto add = [&](std::unique_ptr<HloInstruction> inst) {
427     return computation_->AddInstruction(std::move(inst));
428   };
429 
430   int64_t group_count = convolution->feature_group_count();
431   if (group_count == 1 || (should_expand_ && !should_expand_(convolution))) {
432     return OkStatus();
433   }
434 
435   changed_ = true;
436   ConvolutionDimensionNumbers dim_numbers =
437       convolution->convolution_dimension_numbers();
438   auto filter = convolution->mutable_operand(1);
439   int64_t kernel_input_feature_dim =
440       dim_numbers.kernel_input_feature_dimension();
441   int64_t group_size = filter->shape().dimensions(kernel_input_feature_dim);
442   int64_t kernel_output_feature_dim =
443       dim_numbers.kernel_output_feature_dimension();
444   auto expanded_filter_shape = ExpandedFilterShape(filter->shape(), group_count,
445                                                    kernel_input_feature_dim);
446   HloInstruction* filter_mask =
447       GetExpandedFilterMask(filter->shape(), kernel_input_feature_dim,
448                             kernel_output_feature_dim, group_count, add);
449   HloInstruction* expanded_filter;
450 
451   if (group_size == 1) {
452     bool depthwise_separable =
453         (group_count == filter->shape().dimensions(kernel_output_feature_dim));
454     // If the code generator handles depthwise separable convolutions
455     // inherently, then no filter expansion is needed.
456     if (!filter_expansion_ && depthwise_separable) {
457       changed_ = false;
458       return OkStatus();
459     }
460     VLOG(2) << "is_cost_viable_ " << is_cost_viable_(convolution);
461     // We want to repeat 'filter' in the 'input_feature_dim' dimension
462     // 'group_count' times.
463     if (!is_cost_viable_(convolution) || filter_expansion_) {
464       Shape reshaped_filter_shape =
465           ShapeUtil::DeleteDimension(kernel_input_feature_dim, filter->shape());
466       auto reshaped_filter =
467           add(HloInstruction::CreateReshape(reshaped_filter_shape, filter));
468       std::vector<int64_t> broadcast_dims;
469       for (int64_t i = 0; i < filter->shape().dimensions_size(); ++i) {
470         if (i == kernel_input_feature_dim) {
471           continue;
472         }
473         broadcast_dims.push_back(i);
474       }
475       expanded_filter = add(HloInstruction::CreateBroadcast(
476           expanded_filter_shape, reshaped_filter, broadcast_dims));
477 
478       auto zero = add(HloInstruction::CreateConstant(
479           LiteralUtil::Zero(expanded_filter_shape.element_type())));
480       auto zero_filter =
481           add(HloInstruction::CreateBroadcast(expanded_filter_shape, zero, {}));
482       auto new_filter = add(HloInstruction::CreateTernary(
483           expanded_filter_shape, HloOpcode::kSelect, filter_mask,
484           expanded_filter, zero_filter));
485 
486       auto new_convolution = HloInstruction::CreateConvolve(
487           convolution->shape(), convolution->mutable_operand(0), new_filter,
488           /*feature_group_count=*/1, /*batch_group_count=*/1,
489           convolution->window(), dim_numbers, convolution->precision_config());
490       return computation_->ReplaceWithNewInstruction(
491           convolution, std::move(new_convolution));
492     }
493     // Add a spatial dimension to emulate a larger output feature dimension
494     // to avoid creating a convolution with group_count = 1.
495     std::vector<int64_t> new_filter_dimension;
496     new_filter_dimension.reserve(filter->shape().rank() + 1);
497     const int64_t depthwise_multiplier =
498         filter->shape().dimensions(kernel_output_feature_dim) / group_count;
499     // Split the kernel output feature dimension into group count and
500     // depthwise mutilipler.
501     for (int64_t i = 0; i < filter->shape().rank(); ++i) {
502       if (i == kernel_output_feature_dim) {
503         new_filter_dimension.push_back(group_count);
504         new_filter_dimension.push_back(depthwise_multiplier);
505       } else {
506         new_filter_dimension.push_back(filter->shape().dimensions(i));
507       }
508     }
509     if (kernel_input_feature_dim > kernel_output_feature_dim) {
510       dim_numbers.set_kernel_input_feature_dimension(kernel_input_feature_dim +
511                                                      1);
512     }
513     for (auto& dim : *dim_numbers.mutable_kernel_spatial_dimensions()) {
514       if (dim > kernel_output_feature_dim) {
515         ++dim;
516       }
517     }
518     dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dim + 1);
519     HloInstruction* new_filter =
520         computation_->AddInstruction(HloInstruction::CreateReshape(
521             ShapeUtil::MakeShape(filter->shape().element_type(),
522                                  new_filter_dimension),
523             filter));
524 
525     auto new_activation_shape = convolution->operand(0)->shape();
526     dim_numbers.add_input_spatial_dimensions(new_activation_shape.rank());
527 
528     // Create and activations spatial dimension of size 1 with a reversed
529     // window and high and low padding equal to the depthwise_multiplier -1.
530     // This emulates a larger output feature dimension with an extra spatial
531     // dimension.
532     ShapeUtil::AppendMajorDimension(1, &new_activation_shape);
533     HloInstruction* new_activation =
534         computation_->AddInstruction(HloInstruction::CreateReshape(
535             new_activation_shape, convolution->mutable_operand(0)));
536     auto new_window = convolution->window();
537     auto new_dim = new_window.add_dimensions();
538     new_dim->set_size(depthwise_multiplier);
539     new_dim->set_window_reversal(true);
540     new_dim->set_padding_low(depthwise_multiplier - 1);
541     new_dim->set_padding_high(depthwise_multiplier - 1);
542     new_dim->set_stride(1);
543     new_dim->set_window_dilation(1);
544     new_dim->set_base_dilation(1);
545 
546     // Split the output feature dimension into and output feature of group
547     // count and depthwise multipler as an output spatial dimension.
548     std::vector<int64_t> new_output_dimension;
549     new_output_dimension.reserve(convolution->shape().rank() + 1);
550     for (int64_t i = 0; i < convolution->shape().rank(); ++i) {
551       if (i == dim_numbers.output_feature_dimension()) {
552         new_output_dimension.push_back(group_count);
553         new_output_dimension.push_back(depthwise_multiplier);
554       } else {
555         new_output_dimension.push_back(convolution->shape().dimensions(i));
556       }
557     }
558     if (dim_numbers.output_batch_dimension() >
559         dim_numbers.output_feature_dimension()) {
560       dim_numbers.set_output_batch_dimension(
561           dim_numbers.output_batch_dimension() + 1);
562     }
563     for (auto& dim : *dim_numbers.mutable_output_spatial_dimensions()) {
564       if (dim > dim_numbers.output_feature_dimension()) {
565         ++dim;
566       }
567     }
568     dim_numbers.add_output_spatial_dimensions(
569         dim_numbers.output_feature_dimension() + 1);
570     auto new_convolution_output_shape = ShapeUtil::MakeShape(
571         convolution->shape().element_type(), new_output_dimension);
572     HloInstruction* new_convolution =
573         computation_->AddInstruction(HloInstruction::CreateConvolve(
574             new_convolution_output_shape, new_activation, new_filter,
575             /*feature_group_count=*/group_count, /*batch_group_count=*/1,
576             new_window, dim_numbers, convolution->precision_config()));
577     return computation_->ReplaceWithNewInstruction(
578         convolution,
579         HloInstruction::CreateReshape(convolution->shape(), new_convolution));
580   }
581 
582   // Implement general grouped convolution using an extra spatial dimension to
583   // represent the feature group count.
584   //
585   // Insert a spatial dimension to the input before the input feature
586   // dimension to represent the feature group.
587   HloInstruction* activation = convolution->mutable_operand(0);
588   std::vector<int64_t> input_sizes(activation->shape().dimensions().begin(),
589                                    activation->shape().dimensions().end());
590   const int64_t input_feature_dimension = dim_numbers.input_feature_dimension();
591   input_sizes[input_feature_dimension] /= group_count;
592   input_sizes.insert(input_sizes.begin() + input_feature_dimension,
593                      group_count);
594   activation = MakeReshapeHlo(input_sizes, activation).ValueOrDie();
595   for (auto& d : *dim_numbers.mutable_input_spatial_dimensions()) {
596     if (d > input_feature_dimension) {
597       ++d;
598     }
599   }
600   dim_numbers.add_input_spatial_dimensions(input_feature_dimension);
601   dim_numbers.set_input_feature_dimension(input_feature_dimension + 1);
602   if (dim_numbers.input_batch_dimension() > input_feature_dimension) {
603     dim_numbers.set_input_batch_dimension(dim_numbers.input_batch_dimension() +
604                                           1);
605   }
606 
607   // Insert a spatial dimension to the kernel before the output feature
608   // dimension to represent the feature group.
609   std::vector<int64_t> kernel_sizes(filter->shape().dimensions().begin(),
610                                     filter->shape().dimensions().end());
611   const int64_t kernel_output_feature_dimension =
612       dim_numbers.kernel_output_feature_dimension();
613   kernel_sizes[kernel_output_feature_dimension] /= group_count;
614   kernel_sizes.insert(kernel_sizes.begin() + kernel_output_feature_dimension,
615                       group_count);
616   filter = MakeReshapeHlo(kernel_sizes, filter).ValueOrDie();
617   for (auto& d : *dim_numbers.mutable_kernel_spatial_dimensions()) {
618     if (d > kernel_output_feature_dimension) {
619       ++d;
620     }
621   }
622   dim_numbers.add_kernel_spatial_dimensions(kernel_output_feature_dimension);
623   dim_numbers.set_kernel_output_feature_dimension(
624       kernel_output_feature_dimension + 1);
625   if (dim_numbers.kernel_input_feature_dimension() >
626       kernel_output_feature_dimension) {
627     dim_numbers.set_kernel_input_feature_dimension(
628         dim_numbers.kernel_input_feature_dimension() + 1);
629   }
630 
631   // Insert a spatial dimension to the output before the output feature
632   // dimension to represent the feature group.
633   const int64_t output_feature_dimension =
634       dim_numbers.output_feature_dimension();
635   for (auto& d : *dim_numbers.mutable_output_spatial_dimensions()) {
636     if (d > output_feature_dimension) {
637       ++d;
638     }
639   }
640   dim_numbers.add_output_spatial_dimensions(output_feature_dimension);
641   dim_numbers.set_output_feature_dimension(output_feature_dimension + 1);
642   if (dim_numbers.output_batch_dimension() > output_feature_dimension) {
643     dim_numbers.set_output_batch_dimension(
644         dim_numbers.output_batch_dimension() + 1);
645   }
646 
647   // To represent a feature group count of 3 you can  slide a 3 wide window
648   // [X Y Z]
649   // across [A 0 0 B 0 0 C] with stride 2 to produce
650   // [AX+0Y+0Z 0X+BY+0Z 0X+0Y+CZ] -> [AX BY CZ] which will behave the same as
651   // a batch group count.
652   Window window = convolution->window();
653   auto window_dim = window.add_dimensions();
654   window_dim->set_base_dilation(group_count);
655   window_dim->set_size(group_count);
656   window_dim->set_stride(group_count - 1);
657   window_dim->set_padding_low(0);
658   window_dim->set_padding_high(0);
659   window_dim->set_window_reversal(false);
660   window_dim->set_window_dilation(1);
661   HloInstruction* new_convolution =
662       MakeConvolveHlo(
663           activation, filter, /*feature_group_count=*/1,
664           /*batch_group_count=*/1, window, dim_numbers,
665           convolution->precision_config(),
666           /*preferred_element_type=*/convolution->shape().element_type())
667           .ValueOrDie();
668   convolution->SetupDerivedInstruction(new_convolution);
669   changed_ = true;
670   return computation_->ReplaceInstruction(
671       convolution,
672       MakeReshapeHlo(convolution->shape(), new_convolution).ValueOrDie());
673 }
674 
675 }  // namespace
676 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)677 StatusOr<bool> ConvolutionGroupConverter::Run(
678     HloModule* module,
679     const absl::flat_hash_set<absl::string_view>& execution_threads) {
680   XLA_VLOG_LINES(
681       2, "ConvolutionGroupConverter::Run(), before:\n" + module->ToString());
682   bool changed = false;
683   for (auto* comp : module->MakeNonfusionComputations(execution_threads)) {
684     if (ConvolutionVisitor::Run(comp, should_expand_, is_cost_viable_,
685                                 convert_batch_groups_only_,
686                                 filter_expansion_)) {
687       changed = true;
688     }
689   }
690   XLA_VLOG_LINES(
691       2, "ConvolutionGroupConverter::Run(), after:\n" + module->ToString());
692   return changed;
693 }
694 
695 }  // namespace xla
696