1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This is an explorative prototype emitter for convolution using MLIR.
17 // This prototype is still under construction.
18 // TODO(timshen): Fix the documentation once it's implemented.
19 //
20 // Goals:
21 // * Autotune-able tiling.
22 // * Autotune-able memory accesses.
23 // * Autotune-able lowering logic (from a portable program to thread-oriented
24 // CUDA program).
25 // * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically
26 // find memory access strategies for given input layouts and tiling configs.
27
28 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h"
29
30 #include "absl/types/span.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project
34 #include "mlir/Dialect/Affine/LoopUtils.h" // from @llvm-project
35 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
36 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
37 #include "mlir/Dialect/MemRef/IR/MemRef.h" // from @llvm-project
38 #include "mlir/IR/AffineExpr.h" // from @llvm-project
39 #include "mlir/IR/AffineMap.h" // from @llvm-project
40 #include "mlir/IR/Builders.h" // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
42 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
44 #include "tensorflow/compiler/xla/permutation_util.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47
48 namespace xla {
49 namespace experimental {
50 namespace {
51
52 using mlir::OpBuilder;
53
54 // Various extracted information for input shapes.
55 struct ShapeInfo {
56 // Buffer dimensions in the order of NCHW.
57 std::vector<int64_t> nchw_dimensions;
58
59 // Buffer dimensions in the order of major to minor;
60 std::vector<int64_t> physical_dimensions;
61
62 // The affine map that takes NCHW indices, and maps to the physical order.
63 mlir::AffineMap affine_map;
64
65 mlir::Type element_type;
66 };
67
GetShapeInfo(const Shape & shape,int64_t n_dim,int64_t c_dim,absl::Span<const tensorflow::protobuf_int64> spatial_dims,mlir::Builder builder)68 ShapeInfo GetShapeInfo(
69 const Shape& shape, int64_t n_dim, int64_t c_dim,
70 absl::Span<const tensorflow::protobuf_int64> spatial_dims,
71 mlir::Builder builder) {
72 ShapeInfo shape_info;
73
74 std::vector<int64_t> physical_to_logical(
75 shape.layout().minor_to_major().rbegin(),
76 shape.layout().minor_to_major().rend());
77
78 std::vector<int64_t> nchw_to_logical;
79
80 nchw_to_logical.push_back(n_dim);
81 nchw_to_logical.push_back(c_dim);
82 for (int64_t dim : spatial_dims) {
83 nchw_to_logical.push_back(dim);
84 }
85
86 for (int64_t dim : nchw_to_logical) {
87 shape_info.nchw_dimensions.push_back(shape.dimensions(dim));
88 }
89
90 for (int64_t dim : physical_to_logical) {
91 shape_info.physical_dimensions.push_back(shape.dimensions(dim));
92 }
93
94 std::vector<mlir::AffineExpr> affine_exprs;
95 // We want physical to nchw order.
96 for (int64_t dim : ComposePermutations(InversePermutation(nchw_to_logical),
97 physical_to_logical)) {
98 affine_exprs.push_back(builder.getAffineDimExpr(dim));
99 }
100
101 shape_info.affine_map = mlir::AffineMap::get(
102 /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs,
103 builder.getContext());
104
105 shape_info.element_type = [&] {
106 switch (shape.element_type()) {
107 case xla::F16:
108 return builder.getF16Type();
109 case xla::F32:
110 return builder.getF32Type();
111 default:
112 break;
113 }
114 CHECK(false);
115 }();
116
117 return shape_info;
118 }
119
SetMemRef(mlir::Operation * op,mlir::Value memref)120 void SetMemRef(mlir::Operation* op, mlir::Value memref) {
121 if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
122 load.setMemRef(memref);
123 } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
124 store.setMemRef(memref);
125 } else {
126 CHECK(false);
127 }
128 }
129
130 // Hoist operations out of `where`. [begin_op, end_op) must be the first
131 // operations of their parent loop, and `where` must be an ancestor of that
132 // parent loop.
133 //
134 // It always preserves the semantics of the program, therefore it may modify the
135 // hoisted operations or add extra loops at the hoisted place.
HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,llvm::iplist<mlir::Operation>::iterator end_op,mlir::AffineForOp where)136 mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
137 llvm::iplist<mlir::Operation>::iterator end_op,
138 mlir::AffineForOp where) {
139 // All loops to hoist through.
140 llvm::SmallVector<mlir::AffineForOp, 4> ancestors;
141 getPerfectlyNestedLoops(ancestors, where);
142 {
143 int i;
144 for (i = 0; i < ancestors.size(); i++) {
145 if (&ancestors[i].getBody()->front() == &*begin_op) {
146 break;
147 }
148 }
149 CHECK(i < ancestors.size());
150 ancestors.resize(i + 1);
151 }
152
153 std::vector<int64_t> ancestor_dimensions;
154 for (auto ancestor : ancestors) {
155 CHECK(IsSimpleLoop(ancestor));
156 ancestor_dimensions.push_back(
157 ancestor.getUpperBoundMap().getSingleConstantResult());
158 }
159
160 if (auto alloc = mlir::dyn_cast<mlir::memref::AllocOp>(begin_op)) {
161 CHECK(std::next(begin_op) == end_op)
162 << "alloc() needs to be hoisted by its own";
163
164 OpBuilder builder(where);
165 mlir::MemRefType type = alloc.getType();
166 CHECK(type.getLayout().isIdentity());
167 ancestor_dimensions.insert(ancestor_dimensions.end(),
168 type.getShape().begin(), type.getShape().end());
169 mlir::MemRefType new_type =
170 mlir::MemRefType::get(ancestor_dimensions, type.getElementType());
171 auto new_alloc = builder.create<mlir::memref::AllocOp>(
172 builder.getUnknownLoc(), new_type);
173
174 std::vector<mlir::Value> indvars;
175 for (auto ancestor : ancestors) {
176 indvars.push_back(ancestor.getInductionVar());
177 }
178 for (auto& use : llvm::make_early_inc_range(alloc.getResult().getUses())) {
179 mlir::Operation* owner = use.getOwner();
180 BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
181 affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
182 indvars.end());
183 CHECK(affine_map.affine_map.isIdentity());
184 affine_map.affine_map = mlir::AffineMap::getMultiDimIdentityMap(
185 affine_map.operands.size(), builder.getContext());
186
187 mlir::Operation* new_op =
188 CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner));
189 SetMemRef(new_op, new_alloc);
190 owner->replaceAllUsesWith(new_op);
191 owner->erase();
192 }
193 alloc.erase();
194 return new_alloc;
195 }
196
197 const bool any_op_is_loop_variant = [&] {
198 for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
199 if (mlir::isa<mlir::AffineForOp, mlir::AffineStoreOp>(op)) {
200 return true;
201 }
202 }
203 return false;
204 }();
205
206 if (any_op_is_loop_variant) {
207 auto builder = OpBuilder(where);
208 std::vector<mlir::AffineForOp> new_loops;
209 for (auto dim : ancestor_dimensions) {
210 auto where =
211 builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
212 new_loops.push_back(where);
213 builder = OpBuilder::atBlockTerminator(where.getBody());
214 }
215 for (mlir::Operation& op :
216 llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) {
217 op.moveBefore(&new_loops.back().getBody()->back());
218 }
219 CHECK_EQ(ancestors.size(), new_loops.size());
220 for (int i = 0; i < ancestors.size(); i++) {
221 replaceAllUsesInRegionWith(ancestors[i].getInductionVar(),
222 new_loops[i].getInductionVar(),
223 new_loops.back().getRegion());
224 }
225 return new_loops.front();
226 }
227 CHECK(false);
228 }
229
HoistAndFix(mlir::Operation * op,mlir::AffineForOp where)230 mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
231 return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
232 }
233
234 struct InitialMlirConvAnchors {
235 std::vector<mlir::AffineForOp> cartesian_product_loops;
236 std::vector<mlir::AffineForOp> reduction_loops;
237 mlir::memref::AllocOp output_acc;
238 };
239
240 // Return the following IR with the anchors set to corresponding operations.
241 // for (cartesian loops...) {
242 // %output_acc = alloc() : memref(f32)
243 // output_acc[] = 0
244 // for (reduction loops...) {
245 // output_acc[] += input[...] * filter[...]
246 // }
247 // output[...] = output_acc[]
248 // }
CreateNaiveMlirConv(mlir::Value input,mlir::Value filter,mlir::Value output,const ShapeInfo & input_shape_info,const ShapeInfo & filter_shape_info,const ShapeInfo & output_shape_info,const Window & window,OpBuilder builder)249 StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
250 mlir::Value input, mlir::Value filter, mlir::Value output,
251 const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info,
252 const ShapeInfo& output_shape_info, const Window& window,
253 OpBuilder builder) {
254 CHECK(input_shape_info.element_type == builder.getF16Type());
255 CHECK(filter_shape_info.element_type == builder.getF16Type());
256 CHECK(output_shape_info.element_type == builder.getF16Type());
257
258 auto location = mlir::UnknownLoc::get(builder.getContext());
259
260 std::vector<mlir::AffineForOp> cartesian_product_loops =
261 CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder);
262
263 builder =
264 OpBuilder::atBlockTerminator(cartesian_product_loops.back().getBody());
265
266 auto output_acc = builder.create<mlir::memref::AllocOp>(
267 location, mlir::MemRefType::get({}, builder.getF32Type()));
268
269 builder.create<mlir::AffineStoreOp>(
270 location,
271 builder.create<mlir::arith::ConstantOp>(
272 location, mlir::FloatAttr::get(builder.getF32Type(), 0)),
273 output_acc, llvm::ArrayRef<mlir::Value>());
274
275 std::vector<mlir::AffineForOp> reduction_loops;
276 reduction_loops = CreateNestedSimpleLoops(
277 absl::MakeSpan(filter_shape_info.nchw_dimensions).subspan(1), builder);
278
279 mlir::AffineForOp loop_n = cartesian_product_loops[0];
280 mlir::AffineForOp loop_o = cartesian_product_loops[1];
281 mlir::AffineForOp loop_c = reduction_loops[0];
282
283 std::vector<mlir::Value> output_spatial_indvars;
284 for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) {
285 output_spatial_indvars.push_back(loop.getInductionVar());
286 }
287 std::vector<mlir::Value> filter_spatial_indvars;
288 for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) {
289 filter_spatial_indvars.push_back(loop.getInductionVar());
290 }
291 int num_spatial_dims = output_spatial_indvars.size();
292 CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size());
293
294 builder = OpBuilder::atBlockTerminator(reduction_loops.back().getBody());
295
296 mlir::Value loaded_input = [&] {
297 std::vector<mlir::AffineExpr> input_indices;
298 input_indices.push_back(builder.getAffineDimExpr(0));
299 input_indices.push_back(builder.getAffineDimExpr(1));
300
301 // For spatial dimensions, generate input_index * stride + filter_index -
302 // left_pad
303 //
304 // TODO(timshen): guard out-of-bound loads and stores brought by padding.
305 for (int i = 0; i < num_spatial_dims; i++) {
306 const WindowDimension& window_dim = window.dimensions(i);
307 input_indices.push_back(
308 builder.getAffineDimExpr(i + 2) * window_dim.stride() +
309 builder.getAffineDimExpr(2 + num_spatial_dims + i) -
310 window_dim.padding_low());
311 }
312 std::vector<mlir::Value> input_vars;
313 input_vars.push_back(loop_n.getInductionVar());
314 input_vars.push_back(loop_c.getInductionVar());
315 input_vars.insert(input_vars.end(), output_spatial_indvars.begin(),
316 output_spatial_indvars.end());
317 input_vars.insert(input_vars.end(), filter_spatial_indvars.begin(),
318 filter_spatial_indvars.end());
319
320 return builder.create<mlir::arith::ExtFOp>(
321 location, builder.getF32Type(),
322 builder.createOrFold<mlir::AffineLoadOp>(
323 location, input,
324 mlir::AffineMap(input_shape_info.affine_map)
325 .compose(mlir::AffineMap::get(
326 /*dimCount=*/2 + num_spatial_dims * 2,
327 /*symbolCount=*/0, input_indices, builder.getContext())),
328 input_vars));
329 }();
330
331 mlir::Value loaded_filter = [&] {
332 std::vector<mlir::Value> filter_vars;
333 filter_vars.push_back(loop_o.getInductionVar());
334 filter_vars.push_back(loop_c.getInductionVar());
335 filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(),
336 filter_spatial_indvars.end());
337
338 return builder.create<mlir::arith::ExtFOp>(
339 location, builder.getF32Type(),
340 builder.createOrFold<mlir::AffineLoadOp>(
341 location, filter, filter_shape_info.affine_map, filter_vars));
342 }();
343
344 auto accum_load_op =
345 builder.createOrFold<mlir::AffineLoadOp>(location, output_acc);
346 builder.createOrFold<mlir::AffineStoreOp>(
347 location,
348 builder.create<mlir::arith::AddFOp>(
349 location, accum_load_op,
350 builder.create<mlir::arith::MulFOp>(location, loaded_input,
351 loaded_filter)),
352 output_acc, llvm::ArrayRef<mlir::Value>());
353
354 builder.setInsertionPointAfter(reduction_loops[0]);
355 {
356 std::vector<mlir::Value> output_vars;
357 output_vars.push_back(loop_n.getInductionVar());
358 output_vars.push_back(loop_o.getInductionVar());
359 output_vars.insert(output_vars.end(), output_spatial_indvars.begin(),
360 output_spatial_indvars.end());
361 builder.createOrFold<mlir::AffineStoreOp>(
362 location,
363 builder.create<mlir::arith::TruncFOp>(
364 location, builder.getF16Type(),
365 builder.createOrFold<mlir::AffineLoadOp>(location, output_acc)),
366 output, output_shape_info.affine_map, output_vars);
367 }
368
369 return InitialMlirConvAnchors{cartesian_product_loops, reduction_loops,
370 output_acc};
371 }
372
373 // Contains the following pattern with anchors:
374 // for (cartesian loops...) {
375 // %output_acc = alloc() : memref(..., f32)
376 // for (reduction loops...) {
377 // for (tiled cartesian loops...) {
378 // output_acc[...] = 0
379 // }
380 // for (tiled cartesian loops...) {
381 // for (reduction loops...) {
382 // output_acc[] += input[...] * filter[...]
383 // }
384 // }
385 // for (tiled cartesian loops...) {
386 // output[...] = output_acc[...]
387 // }
388 // }
389 // }
390 struct TransformedMlirConvAnchors {
391 std::vector<mlir::AffineForOp> cartesian_product_loops;
392 std::vector<mlir::AffineForOp> reduction_loops;
393 };
394
TransformMlirConv(InitialMlirConvAnchors anchors)395 StatusOr<TransformedMlirConvAnchors> TransformMlirConv(
396 InitialMlirConvAnchors anchors) {
397 std::vector<mlir::AffineForOp> cartesian_product_loops =
398 anchors.cartesian_product_loops;
399 std::vector<mlir::AffineForOp> reduction_loops = anchors.reduction_loops;
400 mlir::memref::AllocOp output_acc = anchors.output_acc;
401
402 // TODO(timshen): consider using pattern matchers for transformations
403 //
404 // Initial form:
405 // for (cartesian loops...) {
406 // %output_acc = alloc() : memref(f32)
407 // output_acc[] = 0
408 // for (reduction loops...) {
409 // output_acc[] += input[...] * filter[...]
410 // }
411 // output[...] = output_acc[]
412 // }
413
414 // Tile cartesian loops to:
415 // for (cartesian loops...) {
416 // for (tiled cartesian loops...) {
417 // %output_acc = alloc() : memref(f32)
418 // output_acc[] = 0
419 // for (reduction loops...) {
420 // output_acc[] += input[...] * filter[...]
421 // }
422 // output[...] = output_acc[]
423 // }
424 // }
425 TileLoop(reduction_loops[0], 4, reduction_loops.back());
426
427 std::vector<mlir::AffineForOp> tiled_cartesian_loops;
428 tiled_cartesian_loops.push_back(
429 TileLoop(cartesian_product_loops[1], 32, cartesian_product_loops.back()));
430
431 tiled_cartesian_loops.push_back(TileLoop(cartesian_product_loops.back(), 16,
432 tiled_cartesian_loops.back()));
433
434 // Two hoist operations to interleave the allocation, computation, and
435 // writebacks to output_acc:
436 // After first hoist:
437 // for (cartesian loops...) {
438 // %output_acc = alloc() : memref(..., f32)
439 // for (tiled cartesian loops...) {
440 // output_acc[...] = 0
441 // for (reduction loops...) {
442 // output_acc[...] += input[...] * filter[...]
443 // }
444 // output[...] = output_acc[...]
445 // }
446 // }
447 output_acc = llvm::cast<mlir::memref::AllocOp>(
448 HoistAndFix(output_acc, tiled_cartesian_loops.front()));
449
450 // Hoist everything before reduction loops (aka zero initializations of
451 // output_acc):
452 // for (cartesian loops...) {
453 // %output_acc = alloc() : memref(..., f32)
454 // for (tiled cartesian loops...) {
455 // output_acc[...] = 0
456 // }
457 // for (tiled cartesian loops...) {
458 // for (reduction loops...) {
459 // output_acc[...] += input[...] * filter[...]
460 // }
461 // output[...] = output_acc[...]
462 // }
463 // }
464 HoistAndFix(tiled_cartesian_loops.back().getBody()->begin(),
465 reduction_loops.front().getOperation()->getIterator(),
466 tiled_cartesian_loops.front());
467
468 // Now hoist all reduction loops outside of tiled cartesian loops.
469 // Notice that HoistAndFix automatically add a new set of tiled cartesian
470 // loops for hoisted reduction loops to keep the semantics correct.
471 //
472 // After second hoist:
473 // for (cartesian loops...) {
474 // %output_acc = alloc() : memref(..., f32)
475 // for (tiled cartesian loops...) {
476 // output_acc[...] = 0
477 // }
478 // for (tiled cartesian loops...) {
479 // for (reduction loops...) {
480 // output_acc[] += input[...] * filter[...]
481 // }
482 // } // compute loop
483 // for (tiled cartesian loops...) {
484 // output[...] = output_acc[...]
485 // }
486 // }
487 {
488 auto compute_loop = llvm::cast<mlir::AffineForOp>(
489 HoistAndFix(reduction_loops.front(), tiled_cartesian_loops[0]));
490
491 // Fix tiled_cartesian_loops to make them point to the tiled compute loops,
492 // not the writeback loops to output buffer.
493 llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
494 getPerfectlyNestedLoops(all_loops, compute_loop);
495 absl::c_copy_n(all_loops, tiled_cartesian_loops.size(),
496 tiled_cartesian_loops.data());
497 }
498
499 // After exchanging tiled cartesian compute loops with reduction loops:
500 // for (cartesian loops...) {
501 // %output_acc = alloc() : memref(..., f32)
502 // for (tiled cartesian loops...) {
503 // output_acc[...] = 0
504 // }
505 // for (reduction loops...) {
506 // for (tiled cartesian loops...) {
507 // output_acc[] += input[...] * filter[...]
508 // }
509 // }
510 // for (tiled cartesian loops...) {
511 // output[...] = output_acc[...]
512 // }
513 // }
514 //
515 // ...so that later tiled cartesian loops (with computations in it) can be
516 // replaced by CUDA MMA instructions.
517 {
518 std::vector<mlir::AffineForOp> loops;
519 loops.insert(loops.end(), tiled_cartesian_loops.begin(),
520 tiled_cartesian_loops.end());
521 loops.insert(loops.end(), reduction_loops.begin(), reduction_loops.end());
522 SinkPerfectlyNestedLoops(loops, tiled_cartesian_loops.size());
523 }
524 return TransformedMlirConvAnchors{cartesian_product_loops, reduction_loops};
525 }
526
527 } // namespace
528
EmitConvolutionForwardAsMlir(HloInstruction * conv,absl::string_view function_name,mlir::MLIRContext * context)529 StatusOr<mlir::func::FuncOp> EmitConvolutionForwardAsMlir(
530 HloInstruction* conv, absl::string_view function_name,
531 mlir::MLIRContext* context) {
532 OpBuilder builder(context);
533
534 const auto& dim_nums = conv->convolution_dimension_numbers();
535 ShapeInfo input_shape_info =
536 GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(),
537 dim_nums.input_feature_dimension(),
538 dim_nums.input_spatial_dimensions(), builder);
539
540 ShapeInfo filter_shape_info = GetShapeInfo(
541 conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(),
542 dim_nums.kernel_input_feature_dimension(),
543 dim_nums.kernel_spatial_dimensions(), builder);
544
545 ShapeInfo output_shape_info = GetShapeInfo(
546 conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(),
547 dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(),
548 builder);
549
550 auto function = mlir::func::FuncOp::create(
551 mlir::UnknownLoc::get(builder.getContext()),
552 llvm_ir::AsStringRef(function_name),
553 builder.getFunctionType(
554 {mlir::MemRefType::get(output_shape_info.physical_dimensions,
555 output_shape_info.element_type,
556 mlir::AffineMap()),
557 mlir::MemRefType::get(input_shape_info.physical_dimensions,
558 input_shape_info.element_type,
559 mlir::AffineMap()),
560 mlir::MemRefType::get(filter_shape_info.physical_dimensions,
561 filter_shape_info.element_type,
562 mlir::AffineMap())},
563 {}));
564
565 auto* entry_block = function.addEntryBlock();
566 builder.setInsertionPointToStart(entry_block);
567 builder.create<mlir::func::ReturnOp>(builder.getUnknownLoc());
568 builder.setInsertionPointToStart(entry_block);
569
570 mlir::Value input = entry_block->getArgument(1);
571 mlir::Value filter = entry_block->getArgument(2);
572 mlir::Value output = entry_block->getArgument(0);
573
574 TF_RETURN_IF_ERROR(ConvIsImplemented(conv));
575
576 TF_ASSIGN_OR_RETURN(
577 InitialMlirConvAnchors initial_anchors,
578 CreateNaiveMlirConv(input, filter, output, input_shape_info,
579 filter_shape_info, output_shape_info, conv->window(),
580 builder));
581
582 TF_ASSIGN_OR_RETURN(TransformedMlirConvAnchors transformed_anchors,
583 TransformMlirConv(initial_anchors));
584
585 // TODO(timshen): Implement a transformation that collects loads to a given
586 // buffer, create a local alloc() for the accessed part, redirects all loads
587 // and stores to that local alloc(), and create code to initialize /
588 // writeback the local alloc() if needed.
589
590 // TODO(timshen): Implement CUDA-specific lowering.
591
592 return function;
593 }
594
ConvIsImplemented(const HloInstruction * conv)595 Status ConvIsImplemented(const HloInstruction* conv) {
596 if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
597 return Unimplemented("group count is not implemented.");
598 }
599 if (window_util::HasWindowReversal(conv->window())) {
600 return Unimplemented("Window reversal is not implemented.");
601 }
602 if (window_util::HasDilation(conv->window())) {
603 return Unimplemented("Dilation is not implemented.");
604 }
605 return ::tensorflow::OkStatus();
606 }
607
608 } // namespace experimental
609 } // namespace xla
610