xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cudnn_vectorize_convolutions.h"
17 
18 #include <optional>
19 #include <vector>
20 
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/service/gpu/cudnn_support_utils.h"
23 #include "tensorflow/compiler/xla/service/gpu/stream_executor_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 // Finds convolutions that this pass may be able to transform, namely int8_t
31 // cudnn forward or forward-bias-activation convolutions
32 //
33 // cudnn as of v8.2 supports the following data type combinations for forward
34 // and forward-bias-activation convolutions.  We have to make sure we only
35 // vectorize to one of these supported configs.
36 //
37 //   in       out
38 //   int8x1   int8x1
39 //   int8x1   float
40 //   int8x1   int32_t
41 //
42 //   int8x4   int8x4
43 //   int8x4   float
44 //
45 //   int8x32  int8x32
46 //
47 // https://docs.nvidia.com/deeplearning/cudnn/api/index.html#cudnnConvolutionForward
48 //
49 // For now we restrict ourselves to only the int8xN -> int8xN cases.  We could
50 // allow the int8x4 -> float case in the future if desirable.
GetRelevantConvs(HloComputation * comp)51 static std::vector<HloCustomCallInstruction*> GetRelevantConvs(
52     HloComputation* comp) {
53   std::vector<HloCustomCallInstruction*> convs;
54   for (HloInstruction* instr : comp->instructions()) {
55     if (instr->opcode() != HloOpcode::kCustomCall ||
56         (instr->custom_call_target() != kCudnnConvForwardCallTarget &&
57          instr->custom_call_target() !=
58              kCudnnConvBiasActivationForwardCallTarget) ||
59         instr->operand_count() < 2) {
60       continue;
61     }
62 
63     PrimitiveType input_ty = instr->operand(0)->shape().element_type();
64     PrimitiveType output_ty = instr->shape().tuple_shapes(0).element_type();
65     if (input_ty == output_ty && (input_ty == S8 || input_ty == U8)) {
66       convs.push_back(Cast<HloCustomCallInstruction>(instr));
67     }
68   }
69   return convs;
70 }
71 
72 // Converts an XlaBuilder into an HloComputation in the same module as
73 // `sibling_computation`.
74 //
75 // Yes, we serialize/deserialize as a proto.  :)
BuilderToHloComputation(XlaBuilder & b,XlaOp root,HloComputation * sibling_computation)76 static StatusOr<HloComputation*> BuilderToHloComputation(
77     XlaBuilder& b, XlaOp root, HloComputation* sibling_computation) {
78   TF_ASSIGN_OR_RETURN(XlaComputation comp, b.Build(root));
79   TF_ASSIGN_OR_RETURN(ProgramShape program_shape, comp.GetProgramShape());
80   HloModuleConfig config(program_shape);
81   TF_ASSIGN_OR_RETURN(auto new_module,
82                       HloModule::CreateFromProto(comp.proto(), config));
83 
84   HloModule* dest_module = sibling_computation->parent();
85   HloCloneContext context(dest_module);
86   return dest_module->DeepCloneComputation(new_module->entry_computation(),
87                                            &context);
88 }
89 
90 // Reshapes `instr` so that it has an extra dimension of size `vect_size` right
91 // after `dim`.
SplitAtDim(XlaOp instr,int64_t dim,int64_t vect_size)92 static XlaOp SplitAtDim(XlaOp instr, int64_t dim, int64_t vect_size) {
93   XlaBuilder& b = *instr.builder();
94   Shape shape = b.GetShape(instr).ValueOrDie();
95   DimensionVector new_dims(shape.dimensions().begin(),
96                            shape.dimensions().end());
97   CHECK_EQ(new_dims[dim] % vect_size, 0);
98   new_dims[dim] /= vect_size;
99   new_dims.insert(new_dims.begin() + dim + 1, vect_size);
100   return Reshape(instr, new_dims);
101 }
102 
103 // Reshapes `shape` so that there's an extra dimension of size `vect_size` right
104 // after `dim`.
105 //
106 // For example given shape=s8[10, 32, 20], dim=1, vect_size=4, returns
107 // s8[10, 8, 4, 20].
SplitShapeAtDim(Shape shape,int64_t dim,int64_t vect_size)108 static Shape SplitShapeAtDim(Shape shape, int64_t dim, int64_t vect_size) {
109   DimensionVector new_dims(shape.dimensions().begin(),
110                            shape.dimensions().end());
111   CHECK_EQ(new_dims[dim] % vect_size, 0);
112   new_dims[dim] /= vect_size;
113   new_dims.insert(new_dims.begin() + dim + 1, vect_size);
114   return ShapeUtil::MakeShape(shape.element_type(), new_dims);
115 }
116 
117 // Transposes dimension `src` to right before `dst`.
MoveDim(XlaOp instr,int64_t src,int64_t dst)118 static XlaOp MoveDim(XlaOp instr, int64_t src, int64_t dst) {
119   XlaBuilder& b = *instr.builder();
120   int64_t rank = b.GetShape(instr)->dimensions_size();
121 
122   DimensionVector idxs(rank);
123   absl::c_iota(idxs, 0);
124   if (src < dst) {
125     idxs.insert(idxs.begin() + dst, src);
126     idxs.erase(idxs.begin() + src);
127   } else {
128     idxs.erase(idxs.begin() + src);
129     idxs.insert(idxs.begin() + dst, src);
130   }
131   return Transpose(instr, idxs);
132 }
133 
134 // Reshapes instr so that dimension `vect_dim` has size `vect_size`, by stealing
135 // elements from `dim`.
136 //
137 // Requires that this is possible without merging and re-splitting the two
138 // dimensions.  I.e. there should be some amount of dim that we can "split off"
139 // and add to vect_dim to get it to have size vect_size.
RevectorizeInstr(XlaOp instr,int64_t dim,int64_t vect_dim,int64_t vect_size)140 static XlaOp RevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
141                               int64_t vect_size) {
142   XlaBuilder& b = *instr.builder();
143   Shape shape = b.GetShape(instr).ValueOrDie();
144   auto size = [&](int64_t d) { return shape.dimensions(d); };
145 
146   CHECK_LE(size(vect_dim), vect_size);
147   CHECK_EQ(vect_size % size(vect_dim), 0);
148 
149   int64_t split_factor = vect_size / size(vect_dim);
150   CHECK_EQ(size(dim) % split_factor, 0);
151 
152   // Split dim into [C, split_factor].
153   instr = SplitAtDim(instr, dim, split_factor);
154 
155   // SplitAtDim may have added a dimension before vect_dim.
156   if (vect_dim > dim) {
157     vect_dim++;
158   }
159 
160   // Move the split_factor dimension to right before vect_dim.
161   instr = MoveDim(instr, dim + 1, vect_dim);
162 
163   // Moving the split_factor dimension may have *removed* a dimension before
164   // vect_dim.
165   if (vect_dim > dim) {
166     vect_dim--;
167   }
168 
169   // Collapse the split_factor dimension into vect_dim.
170   return Collapse(instr, {vect_dim, vect_dim + 1});
171 }
172 
173 // Inverse of RevectorizeInstr.  Reshapes instr so that dimension `vect_dim` has
174 // size `vect_size`, moving excess elements into `dim`.
UnrevectorizeInstr(XlaOp instr,int64_t dim,int64_t vect_dim,int64_t orig_vect_size)175 static XlaOp UnrevectorizeInstr(XlaOp instr, int64_t dim, int64_t vect_dim,
176                                 int64_t orig_vect_size) {
177   XlaBuilder& b = *instr.builder();
178   Shape shape = b.GetShape(instr).ValueOrDie();
179   auto size = [&](int64_t d) { return shape.dimensions(d); };
180 
181   CHECK_GE(size(vect_dim), orig_vect_size);
182   CHECK_EQ(size(vect_dim) % orig_vect_size, 0);
183 
184   // Split vect_dim into [C, orig_vect_size].
185   instr = SplitAtDim(instr, vect_dim, orig_vect_size);
186 
187   // SplitAtDim may have added a dimension before dim.
188   if (dim > vect_dim) {
189     dim++;
190   }
191 
192   // Move the `C` dimension to right after `dim`.  Take into account that
193   // SplitAtDim may have added a dimension before dim.
194   instr = MoveDim(instr, vect_dim, dim + 1);
195 
196   // MoveDim may have *removed* a dimension before dim.
197   if (dim > vect_dim) {
198     dim--;
199   }
200 
201   // Collapse the `C` and `dim` dimensions.
202   return Collapse(instr, {dim, dim + 1});
203 }
204 
205 // Adds a vectorized-feature dimension to dnums right after the current feature
206 // dimension.
207 //
208 // ConvolutionDimensionNumbers doesn't represent the vectorized-feature
209 // dimension explicitly, because the whole concept of a vectorized-feature
210 // dimension is specific to cudnn.  Rather, the vectorized-feature dimension is
211 // implicit; it's the first dimension that *doesn't* appear in the dnums.
212 //
213 // This function "makes room" in dnums for the new vectorized dimension by
214 // incrementing any dimensions which appear after the feature dim.  The implicit
215 // vector dim is then in this "empty" spot.
VectorizeDnums(ConvolutionDimensionNumbers dnums)216 static ConvolutionDimensionNumbers VectorizeDnums(
217     ConvolutionDimensionNumbers dnums) {
218   int64_t input_vect_dim = dnums.input_feature_dimension();
219   if (dnums.input_batch_dimension() > input_vect_dim) {
220     dnums.set_input_batch_dimension(dnums.input_batch_dimension() + 1);
221   }
222   for (int64_t& d : *dnums.mutable_input_spatial_dimensions()) {
223     if (d > input_vect_dim) {
224       ++d;
225     }
226   }
227 
228   int64_t kernel_vect_dim = dnums.kernel_input_feature_dimension();
229   if (dnums.kernel_output_feature_dimension() > kernel_vect_dim) {
230     dnums.set_kernel_output_feature_dimension(
231         dnums.kernel_output_feature_dimension() + 1);
232   }
233   for (int64_t& d : *dnums.mutable_kernel_spatial_dimensions()) {
234     if (d > kernel_vect_dim) {
235       ++d;
236     }
237   }
238 
239   int64_t output_vect_dim = dnums.output_feature_dimension();
240   if (dnums.output_batch_dimension() > output_vect_dim) {
241     dnums.set_output_batch_dimension(dnums.output_batch_dimension() + 1);
242   }
243   for (int64_t& d : *dnums.mutable_output_spatial_dimensions()) {
244     if (d > output_vect_dim) {
245       ++d;
246     }
247   }
248 
249   return dnums;
250 }
251 
252 // Tries to vectorize an already-vectorized convolution.
253 //
254 // That is, given a convolution of shape [N, C/k, H, W, k], changes it to have
255 // shape [N, C/vect_size, H, W, vect_size].  Similarly changes the filter from
256 // [H, W, I/k, O] to [H, W, I/vect_size, vect_size, O].
257 //
258 // (The dimensions can appear in any order; which is N/C/etc is determined by
259 // the convolutions' dnums.)
TryRevectorizeConv(const se::CudaComputeCapability & compute_capability,HloCustomCallInstruction * conv,int vect_size)260 static StatusOr<bool> TryRevectorizeConv(
261     const se::CudaComputeCapability& compute_capability,
262     HloCustomCallInstruction* conv, int vect_size) {
263   const Shape& input_shape = conv->operand(0)->shape();
264   const Shape& kernel_shape = conv->operand(1)->shape();
265   const Shape& output_shape = conv->shape().tuple_shapes(0);
266   const auto& dnums = conv->convolution_dimension_numbers();
267 
268   // Find the vectorized-features dim in the input/kernel/output.
269   std::optional<int64_t> input_vect_dim;
270   std::optional<int64_t> kernel_vect_dim;
271   std::optional<int64_t> output_vect_dim;
272   std::tie(input_vect_dim, kernel_vect_dim, output_vect_dim) =
273       FindVectorizedFeatureDims(dnums, input_shape, kernel_shape, output_shape);
274 
275   if (!input_vect_dim.has_value() || !kernel_vect_dim.has_value() ||
276       !output_vect_dim.has_value()) {
277     return false;
278   }
279 
280   int64_t input_feat_size =
281       input_shape.dimensions(dnums.input_feature_dimension());
282   int64_t output_feat_size =
283       output_shape.dimensions(dnums.output_feature_dimension());
284   int64_t input_vect_size = input_shape.dimensions(*input_vect_dim);
285   int64_t output_vect_size = output_shape.dimensions(*output_vect_dim);
286   if (vect_size % input_vect_size != 0 || vect_size % output_vect_size != 0 ||
287       input_feat_size % (vect_size / input_vect_size) != 0 ||
288       output_feat_size % (vect_size / output_vect_size) != 0) {
289     return false;
290   }
291 
292   // If this is an integer convolution check that we only vectorize when cuDNN
293   // supports the vectorized implementation.
294   if (primitive_util::IsIntegralType(input_shape.element_type())) {
295     TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
296                         CudnnSupportsOptimizedIntegerConvolution(
297                             compute_capability, *conv, vect_size));
298     if (!supported_target_vectorization) {
299       VLOG(3) << "Skipping re-vectorization of conv to vector size: "
300               << vect_size << ": " << conv->ToString();
301       return false;
302     }
303   }
304 
305   VLOG(1) << "Re-vectorizing conv channels from "
306           << input_shape.dimensions(*input_vect_dim) << " to " << vect_size
307           << ": " << conv->ToString();
308 
309   // We use XlaBuilder because it's a lot easier to get these tricky
310   // reshape/transposes correct using that API.
311   XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
312   b.SetOpMetadata(conv->metadata());
313 
314   absl::InlinedVector<XlaOp, 4> new_operands = {
315       RevectorizeInstr(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
316                        dnums.input_feature_dimension(), *input_vect_dim,
317                        vect_size),
318       RevectorizeInstr(Parameter(&b, 1, conv->operand(1)->shape(), "filter"),
319                        dnums.kernel_input_feature_dimension(), *kernel_vect_dim,
320                        vect_size),
321   };
322   if (conv->operand_count() > 2) {
323     // Bias, if present.  This is passed through unmodified.
324     new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
325   }
326   if (conv->operand_count() > 3) {
327     new_operands.push_back(RevectorizeInstr(
328         Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
329         dnums.input_feature_dimension(), *input_vect_dim, vect_size));
330   }
331 
332   if (conv->operand_count() > 4) {
333     return InvalidArgument(
334         "Don't understand a conv with more than 4 arguments: %s",
335         conv->ToString());
336   }
337 
338   // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
339   // value in the tuple represents the convolution's scratch memory.
340   DimensionVector new_output_dims(output_shape.dimensions().begin(),
341                                   output_shape.dimensions().end());
342   new_output_dims[dnums.output_feature_dimension()] /=
343       (vect_size / output_vect_size);
344   new_output_dims[*output_vect_dim] = vect_size;
345   XlaOp new_conv = CustomCallWithConvDnums(
346       &b, conv->custom_call_target(), new_operands,
347       ShapeUtil::MakeTupleShape(
348           {ShapeUtil::MakeShape(output_shape.element_type(), new_output_dims),
349            ShapeUtil::MakeShape(U8, {0})}),
350       /*operand_shapes_with_layout=*/{},
351       /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
352       /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
353       /*window=*/conv->window(),
354       /*dnums=*/conv->convolution_dimension_numbers());
355 
356   XlaOp new_conv_result = GetTupleElement(new_conv, 0);
357   XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
358 
359   XlaOp new_conv_result_unrevectorized = UnrevectorizeInstr(
360       new_conv_result, dnums.output_feature_dimension(), *output_vect_dim,
361       /*orig_vect_size=*/output_shape.dimensions(*output_vect_dim));
362 
363   TF_ASSIGN_OR_RETURN(
364       HloComputation * new_conv_comp,
365       BuilderToHloComputation(
366           b, Tuple(&b, {new_conv_result_unrevectorized, new_conv_scratch}),
367           conv->parent()));
368 
369   // Set the name on the new conv.  This is purely cosmetic, but we attempt to
370   // preserve e.g. "cudnn-conv.42" instead of "custom-call.42".
371   auto new_conv_comp_instrs = new_conv_comp->instructions();
372   auto new_conv_it =
373       absl::c_find_if(new_conv_comp_instrs, [](HloInstruction* instr) {
374         return instr->opcode() == HloOpcode::kCustomCall;
375       });
376   if (new_conv_it != new_conv_comp_instrs.end()) {
377     new_conv_comp->parent()->SetAndUniquifyInstrName(*new_conv_it,
378                                                      conv->name());
379   }
380 
381   // Replace the old conv with a call to the computation we just created.
382   VLOG(1) << "Re-vectorized conv to " << new_conv_comp->ToString();
383   TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
384       conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
385                                        new_conv_comp)));
386 
387   return true;
388 }
389 
390 // Tries to vectorize a convolution.
391 //
392 // Given a convolution of dimensions [N, C, H, W], tries to convert it to have
393 // shape [N, C/vect_size, H, W, vect_size].  Similarly, given a kernel of shape
394 // [H, W, I, O], tries to conver it to [H, W, I/vect_size, vect_size, O].
395 //
396 // This requires that C be a multiple of vect_size.  CudnnPadForConvolutions can
397 // add padding to make this true.
TryVectorizeConv(const se::CudaComputeCapability & compute_capability,HloCustomCallInstruction * conv,int64_t vect_size)398 static StatusOr<bool> TryVectorizeConv(
399     const se::CudaComputeCapability& compute_capability,
400     HloCustomCallInstruction* conv, int64_t vect_size) {
401   const Shape& input_shape = conv->operand(0)->shape();
402   const Shape& output_shape = conv->shape().tuple_shapes(0);
403   const auto& dnums = conv->convolution_dimension_numbers();
404   int64_t in_channels = input_shape.dimensions(dnums.input_feature_dimension());
405   int64_t out_channels =
406       output_shape.dimensions(dnums.output_feature_dimension());
407 
408   if (in_channels % vect_size != 0 || out_channels % vect_size != 0) {
409     return false;
410   }
411 
412   if (input_shape.dimensions_size() >
413       2 + dnums.input_spatial_dimensions_size()) {
414     // Conv already has an extra dimension, which we assume is the vectorized
415     // features dim.
416     return false;
417   }
418 
419   // If this is an integer convolution check that we only vectorize when cuDNN
420   // supports the vectorized implementation.
421   if (primitive_util::IsIntegralType(input_shape.element_type())) {
422     TF_ASSIGN_OR_RETURN(bool supported_target_vectorization,
423                         CudnnSupportsOptimizedIntegerConvolution(
424                             compute_capability, *conv, vect_size));
425     if (!supported_target_vectorization) {
426       VLOG(3) << "Skipping vectorization of conv to vector size: " << vect_size
427               << ": " << conv->ToString();
428       return false;
429     }
430   }
431 
432   VLOG(1) << "Vectorizing conv channels by " << vect_size << ": "
433           << conv->ToString();
434 
435   // We use XlaBuilder because it's a lot easier to get these tricky
436   // reshape/transposes correct using that API.
437   XlaBuilder b(absl::StrCat(conv->name(), ".revectorized"));
438   b.SetOpMetadata(conv->metadata());
439 
440   absl::InlinedVector<XlaOp, 4> new_operands = {
441       SplitAtDim(Parameter(&b, 0, conv->operand(0)->shape(), "input"),
442                  dnums.input_feature_dimension(), vect_size),
443       SplitAtDim(Parameter(&b, 1, conv->operand(1)->shape(), "filter"),
444                  dnums.kernel_input_feature_dimension(), vect_size),
445   };
446   if (conv->operand_count() > 2) {
447     // Bias, if present.  This is passed through unmodified.
448     new_operands.push_back(Parameter(&b, 2, conv->operand(2)->shape(), "bias"));
449   }
450   if (conv->operand_count() > 3) {
451     // Handle side input, which has same shape as the input.
452     new_operands.push_back(
453         SplitAtDim(Parameter(&b, 3, conv->operand(3)->shape(), "side_input"),
454                    dnums.input_feature_dimension(), vect_size));
455   }
456   if (conv->operand_count() > 4) {
457     return InvalidArgument(
458         "Don't understand a conv with more than 4 arguments: %s",
459         conv->ToString());
460   }
461 
462   // The custom-call returns a tuple (new_output_shape, u8[0]), where the second
463   // value in the tuple represents the convolution's scratch memory.
464   Shape new_output_shape = SplitShapeAtDim(
465       output_shape, dnums.output_feature_dimension(), vect_size);
466   XlaOp new_conv = CustomCallWithConvDnums(
467       &b, conv->custom_call_target(), new_operands,
468       ShapeUtil::MakeTupleShape(
469           {new_output_shape, ShapeUtil::MakeShape(U8, {0})}),
470       /*operand_shapes_with_layout=*/{},
471       /*opaque=*/conv->raw_backend_config_string(), /*has_side_effect=*/false,
472       /*output_operand_aliasing=*/{}, /*literal=*/nullptr,
473       /*window=*/conv->window(),
474       /*dnums=*/VectorizeDnums(dnums));
475 
476   XlaOp new_conv_result = GetTupleElement(new_conv, 0);
477   XlaOp new_conv_scratch = GetTupleElement(new_conv, 1);
478 
479   // Reshape back to the original shape.
480   XlaOp conv_result_collapsed = Collapse(
481       new_conv_result,
482       {dnums.output_feature_dimension(), dnums.output_feature_dimension() + 1});
483 
484   TF_ASSIGN_OR_RETURN(
485       HloComputation * new_conv_comp,
486       BuilderToHloComputation(
487           b, Tuple(&b, {conv_result_collapsed, new_conv_scratch}),
488           conv->parent()));
489 
490   // Create a tuple and replace the old conv with it!
491   VLOG(1) << "Vectorized conv to: " << new_conv_comp->ToString();
492   TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction(
493       conv, HloInstruction::CreateCall(conv->shape(), conv->operands(),
494                                        new_conv_comp)));
495   return true;
496 }
497 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)498 StatusOr<bool> CudnnVectorizeConvolutions::Run(
499     HloModule* module,
500     const absl::flat_hash_set<absl::string_view>& execution_threads) {
501   bool changed = false;
502   for (HloComputation* comp :
503        module->MakeNonfusionComputations(execution_threads)) {
504     for (HloCustomCallInstruction* conv : GetRelevantConvs(comp)) {
505       // Try to (re)vectorize to int8x32 if this is an sm75+ GPU.  If we can't,
506       // fall back to int8x4.
507       bool local_changed = false;
508       if (compute_capability_.IsAtLeast(7, 5)) {
509         TF_ASSIGN_OR_RETURN(local_changed,
510                             TryRevectorizeConv(compute_capability_, conv, 32));
511         if (!local_changed) {
512           TF_ASSIGN_OR_RETURN(local_changed,
513                               TryVectorizeConv(compute_capability_, conv, 32));
514         }
515       }
516       if (!local_changed) {
517         TF_ASSIGN_OR_RETURN(local_changed,
518                             TryVectorizeConv(compute_capability_, conv, 4));
519       }
520       changed |= local_changed;
521     }
522   }
523   return changed;
524 }
525 
526 }  // namespace gpu
527 }  // namespace xla
528