1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/cublas_pad_for_gemms.h"
17
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
23 #include "tensorflow/compiler/xla/util.h"
24 #include "tensorflow/compiler/xla/window_util.h"
25
26 namespace xla {
27 namespace gpu {
28
PadForGemm(HloDotInstruction * dot,PrimitiveType datatype,int pad_to_multiple_of)29 static StatusOr<bool> PadForGemm(HloDotInstruction* dot, PrimitiveType datatype,
30 int pad_to_multiple_of) {
31 auto* lhs = dot->mutable_operand(0);
32 auto* rhs = dot->mutable_operand(1);
33
34 Shape lshape = lhs->shape();
35 Shape rshape = rhs->shape();
36 Shape result_shape = dot->shape();
37
38 if (lshape.element_type() != datatype || rshape.element_type() != datatype) {
39 return false;
40 }
41
42 auto pad_dim = [&](Shape& s, int dim) {
43 s.set_dimensions(dim,
44 RoundUpTo<int64_t>(s.dimensions(dim), pad_to_multiple_of));
45 };
46
47 auto pad_matrix_dims = [&pad_dim](Shape s) {
48 // Since the dot instruction is canonicalized, the last two dimensions for
49 // each operand represent non-batch dimensions, and the others are the same
50 // for both operands and correspond to batch dimensions.
51 pad_dim(s, s.rank() - 2);
52 pad_dim(s, s.rank() - 1);
53 return s;
54 };
55
56 Shape new_lshape = pad_matrix_dims(lshape);
57 Shape new_rshape = pad_matrix_dims(rshape);
58 Shape new_result_shape = pad_matrix_dims(result_shape);
59
60 if (new_lshape == lshape && new_rshape == rshape) {
61 return false;
62 }
63
64 VLOG(3) << "old shape: " << lshape << " " << rshape << " " << result_shape;
65 VLOG(3) << "new shape: " << new_lshape << " " << new_rshape << " "
66 << new_result_shape;
67
68 auto create_padding_config = [](Shape& shape, Shape& new_shape) {
69 PaddingConfig padding_config;
70 for (int i = 0; i < shape.rank(); ++i) {
71 auto dimension = padding_config.add_dimensions();
72 dimension->set_edge_padding_high(new_shape.dimensions()[i] -
73 shape.dimensions()[i]);
74 dimension->set_edge_padding_low(0);
75 dimension->set_interior_padding(0);
76 }
77 return padding_config;
78 };
79
80 auto l_padding_config = create_padding_config(lshape, new_lshape);
81 auto r_padding_config = create_padding_config(rshape, new_rshape);
82
83 HloComputation* parent = dot->parent();
84
85 HloInstruction* zero_float = parent->AddInstruction(
86 HloInstruction::CreateConstant(LiteralUtil::Zero(datatype)));
87 zero_float->set_metadata(dot->metadata());
88
89 HloInstruction* lpad = parent->AddInstruction(
90 HloInstruction::CreatePad(new_lshape, lhs, zero_float, l_padding_config));
91 lpad->set_metadata(dot->metadata());
92
93 HloInstruction* rpad = parent->AddInstruction(
94 HloInstruction::CreatePad(new_rshape, rhs, zero_float, r_padding_config));
95 rpad->set_metadata(dot->metadata());
96
97 HloInstruction* new_dot = parent->AddInstruction(
98 dot->CloneWithNewOperands(new_result_shape, {lpad, rpad}));
99
100 std::vector<int64_t> start_indices(result_shape.rank(), 0);
101 std::vector<int64_t> strides(result_shape.rank(), 1);
102 HloInstruction* slice = parent->AddInstruction(
103 HloInstruction::CreateSlice(result_shape, new_dot, start_indices,
104 result_shape.dimensions(), strides));
105 slice->set_metadata(dot->metadata());
106
107 bool is_root = dot->user_count() == 0;
108
109 TF_CHECK_OK(parent->ReplaceInstruction(dot, slice));
110
111 if (is_root) {
112 parent->set_root_instruction(slice);
113 }
114
115 return true;
116 }
117
118 namespace {
119
120 // We need this check because PadForGemm works in the assumption that
121 // the dot instruction is canonicalized.
CheckCanonical(HloDotInstruction * dot)122 bool CheckCanonical(HloDotInstruction* dot) {
123 auto dimension_numbers = dot->dot_dimension_numbers();
124
125 if (dimension_numbers.lhs_batch_dimensions_size() + 2 !=
126 dot->operand(0)->shape().rank() ||
127 dimension_numbers.rhs_batch_dimensions_size() + 2 !=
128 dot->operand(1)->shape().rank()) {
129 LOG(ERROR) << "Dot is not canonical: Expected all dimensions but 2 to be "
130 "batch_dimensions.";
131 return false;
132 }
133
134 std::vector<int64_t> canonical_batch_dims(
135 dimension_numbers.lhs_batch_dimensions_size());
136 absl::c_iota(canonical_batch_dims, 0);
137 if (!absl::c_equal(dimension_numbers.lhs_batch_dimensions(),
138 canonical_batch_dims) ||
139 !absl::c_equal(dimension_numbers.rhs_batch_dimensions(),
140 canonical_batch_dims)) {
141 LOG(ERROR) << "Dot is not canonical: Expected batch dimensions to be all "
142 "dimensions except for the last 2 ones.";
143 return false;
144 }
145
146 return true;
147 }
148
149 } // namespace
150
GetRelevantDots(HloComputation * comp,PrimitiveType datatype)151 static std::vector<HloDotInstruction*> GetRelevantDots(HloComputation* comp,
152 PrimitiveType datatype) {
153 std::vector<HloDotInstruction*> gemms;
154
155 for (HloInstruction* instr : comp->instructions()) {
156 if (IsMatrixMultiplication(*instr)) {
157 HloDotInstruction* dot = Cast<HloDotInstruction>(instr);
158 if (instr->operand(0)->shape().element_type() == datatype &&
159 CheckCanonical(dot)) {
160 gemms.push_back(dot);
161 }
162 }
163 }
164 return gemms;
165 }
166
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)167 StatusOr<bool> CublasPadForGemms::Run(
168 HloModule* module,
169 const absl::flat_hash_set<absl::string_view>& execution_threads) {
170 bool changed = false;
171 for (HloComputation* comp :
172 module->MakeNonfusionComputations(execution_threads)) {
173 for (HloDotInstruction* dot : GetRelevantDots(comp, datatype_)) {
174 TF_ASSIGN_OR_RETURN(bool result,
175 PadForGemm(dot, datatype_, pad_to_multiple_of_));
176 changed |= result;
177 }
178 }
179 return changed;
180 }
181
182 } // namespace gpu
183 } // namespace xla
184