xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/cublas_pad_for_gemms.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/cublas_pad_for_gemms.h"
17 
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/window_util.h"
25 
26 namespace xla {
27 namespace gpu {
28 
PadForGemm(HloDotInstruction * dot,PrimitiveType datatype,int pad_to_multiple_of)29 static StatusOr<bool> PadForGemm(HloDotInstruction* dot, PrimitiveType datatype,
30                                  int pad_to_multiple_of) {
31   auto* lhs = dot->mutable_operand(0);
32   auto* rhs = dot->mutable_operand(1);
33 
34   Shape lshape = lhs->shape();
35   Shape rshape = rhs->shape();
36   Shape result_shape = dot->shape();
37 
38   if (lshape.element_type() != datatype || rshape.element_type() != datatype) {
39     return false;
40   }
41 
42   auto pad_dim = [&](Shape& s, int dim) {
43     s.set_dimensions(dim,
44                      RoundUpTo<int64_t>(s.dimensions(dim), pad_to_multiple_of));
45   };
46 
47   auto pad_matrix_dims = [&pad_dim](Shape s) {
48     // Since the dot instruction is canonicalized, the last two dimensions for
49     // each operand represent non-batch dimensions, and the others are the same
50     // for both operands and correspond to batch dimensions.
51     pad_dim(s, s.rank() - 2);
52     pad_dim(s, s.rank() - 1);
53     return s;
54   };
55 
56   Shape new_lshape = pad_matrix_dims(lshape);
57   Shape new_rshape = pad_matrix_dims(rshape);
58   Shape new_result_shape = pad_matrix_dims(result_shape);
59 
60   if (new_lshape == lshape && new_rshape == rshape) {
61     return false;
62   }
63 
64   VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
65   VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
66           << new_result_shape;
67 
68   auto create_padding_config = [](Shape& shape, Shape& new_shape) {
69     PaddingConfig padding_config;
70     for (int i = 0; i < shape.rank(); ++i) {
71       auto dimension = padding_config.add_dimensions();
72       dimension->set_edge_padding_high(new_shape.dimensions()[i] -
73                                        shape.dimensions()[i]);
74       dimension->set_edge_padding_low(0);
75       dimension->set_interior_padding(0);
76     }
77     return padding_config;
78   };
79 
80   auto l_padding_config = create_padding_config(lshape, new_lshape);
81   auto r_padding_config = create_padding_config(rshape, new_rshape);
82 
83   HloComputation* parent = dot->parent();
84 
85   HloInstruction* zero_float = parent->AddInstruction(
86       HloInstruction::CreateConstant(LiteralUtil::Zero(datatype)));
87   zero_float->set_metadata(dot->metadata());
88 
89   HloInstruction* lpad = parent->AddInstruction(
90       HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
91   lpad->set_metadata(dot->metadata());
92 
93   HloInstruction* rpad = parent->AddInstruction(
94       HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
95   rpad->set_metadata(dot->metadata());
96 
97   HloInstruction* new_dot = parent->AddInstruction(
98       dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
99 
100   std::vector<int64_t> start_indices(result_shape.rank(), 0);
101   std::vector<int64_t> strides(result_shape.rank(), 1);
102   HloInstruction* slice = parent->AddInstruction(
103       HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
104                                   result_shape.dimensions(), strides));
105   slice->set_metadata(dot->metadata());
106 
107   bool is_root = dot->user_count() == 0;
108 
109   TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
110 
111   if (is_root) {
112     parent->set_root_instruction(slice);
113   }
114 
115   return true;
116 }
117 
118 namespace {
119 
120 // We need this check because PadForGemm works in the assumption that
121 // the dot instruction is canonicalized.
CheckCanonical(HloDotInstruction * dot)122 bool CheckCanonical(HloDotInstruction* dot) {
123   auto dimension_numbers = dot->dot_dimension_numbers();
124 
125   if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
126           dot->operand(0)->shape().rank() ||
127       dimension_numbers.rhs_batch_dimensions_size() + 2 !=
128           dot->operand(1)->shape().rank()) {
129     LOG(ERROR) << "Dot is not canonical: Expected all dimensions but 2 to be "
130                   "batch_dimensions.";
131     return false;
132   }
133 
134   std::vector<int64_t> canonical_batch_dims(
135       dimension_numbers.lhs_batch_dimensions_size());
136   absl::c_iota(canonical_batch_dims, 0);
137   if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
138                      canonical_batch_dims) ||
139       !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
140                      canonical_batch_dims)) {
141     LOG(ERROR) << "Dot is not canonical: Expected batch dimensions to be all "
142                   "dimensions except for the last 2 ones.";
143     return false;
144   }
145 
146   return true;
147 }
148 
149 }  // namespace
150 
GetRelevantDots(HloComputation * comp,PrimitiveType datatype)151 static std::vector<HloDotInstruction*> GetRelevantDots(HloComputation* comp,
152                                                        PrimitiveType datatype) {
153   std::vector<HloDotInstruction*> gemms;
154 
155   for (HloInstruction* instr : comp->instructions()) {
156     if (IsMatrixMultiplication(*instr)) {
157       HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
158       if (instr->operand(0)->shape().element_type() == datatype &&
159           CheckCanonical(dot)) {
160         gemms.push_back(dot);
161       }
162     }
163   }
164   return gemms;
165 }
166 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)167 StatusOr<bool> CublasPadForGemms::Run(
168     HloModule* module,
169     const absl::flat_hash_set<absl::string_view>& execution_threads) {
170   bool changed = false;
171   for (HloComputation* comp :
172        module->MakeNonfusionComputations(execution_threads)) {
173     for (HloDotInstruction* dot : GetRelevantDots(comp, datatype_)) {
174       TF_ASSIGN_OR_RETURN(bool result,
175                           PadForGemm(dot, datatype_, pad_to_multiple_of_));
176       changed |= result;
177     }
178   }
179   return changed;
180 }
181 
182 }  // namespace gpu
183 }  // namespace xla
184