xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This is an explorative prototype emitter for convolution using MLIR.
17 // This prototype is still under construction.
18 // TODO(timshen): Fix the documentation once it's implemented.
19 //
20 // Goals:
21 // * Autotune-able tiling.
22 // * Autotune-able memory accesses.
23 // * Autotune-able lowering logic (from a portable program to thread-oriented
24 //   CUDA program).
25 // * Use milr::AffineExpr to analyze all accesses. It aims to algorithmically
26 //   find memory access strategies for given input layouts and tiling configs.
27 
28 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter.h"
29 
30 #include "absl/types/span.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/raw_ostream.h"
33 #include "mlir/Dialect/Affine/IR/AffineOps.h"  // from @llvm-project
34 #include "mlir/Dialect/Affine/LoopUtils.h"  // from @llvm-project
35 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
36 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
37 #include "mlir/Dialect/MemRef/IR/MemRef.h"  // from @llvm-project
38 #include "mlir/IR/AffineExpr.h"  // from @llvm-project
39 #include "mlir/IR/AffineMap.h"  // from @llvm-project
40 #include "mlir/IR/Builders.h"  // from @llvm-project
41 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
42 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/xla/experimental/conv_emitter/conv_emitter_transforms.h"
44 #include "tensorflow/compiler/xla/permutation_util.h"
45 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
46 #include "tensorflow/compiler/xla/window_util.h"
47 
48 namespace xla {
49 namespace experimental {
50 namespace {
51 
52 using mlir::OpBuilder;
53 
54 // Various extracted information for input shapes.
55 struct ShapeInfo {
56   // Buffer dimensions in the order of NCHW.
57   std::vector<int64_t> nchw_dimensions;
58 
59   // Buffer dimensions in the order of major to minor;
60   std::vector<int64_t> physical_dimensions;
61 
62   // The affine map that takes NCHW indices, and maps to the physical order.
63   mlir::AffineMap affine_map;
64 
65   mlir::Type element_type;
66 };
67 
GetShapeInfo(const Shape & shape,int64_t n_dim,int64_t c_dim,absl::Span<const tensorflow::protobuf_int64> spatial_dims,mlir::Builder builder)68 ShapeInfo GetShapeInfo(
69     const Shape& shape, int64_t n_dim, int64_t c_dim,
70     absl::Span<const tensorflow::protobuf_int64> spatial_dims,
71     mlir::Builder builder) {
72   ShapeInfo shape_info;
73 
74   std::vector<int64_t> physical_to_logical(
75       shape.layout().minor_to_major().rbegin(),
76       shape.layout().minor_to_major().rend());
77 
78   std::vector<int64_t> nchw_to_logical;
79 
80   nchw_to_logical.push_back(n_dim);
81   nchw_to_logical.push_back(c_dim);
82   for (int64_t dim : spatial_dims) {
83     nchw_to_logical.push_back(dim);
84   }
85 
86   for (int64_t dim : nchw_to_logical) {
87     shape_info.nchw_dimensions.push_back(shape.dimensions(dim));
88   }
89 
90   for (int64_t dim : physical_to_logical) {
91     shape_info.physical_dimensions.push_back(shape.dimensions(dim));
92   }
93 
94   std::vector<mlir::AffineExpr> affine_exprs;
95   // We want physical to nchw order.
96   for (int64_t dim : ComposePermutations(InversePermutation(nchw_to_logical),
97                                          physical_to_logical)) {
98     affine_exprs.push_back(builder.getAffineDimExpr(dim));
99   }
100 
101   shape_info.affine_map = mlir::AffineMap::get(
102       /*dimCount=*/2 + spatial_dims.size(), /*symbolCount=*/0, affine_exprs,
103       builder.getContext());
104 
105   shape_info.element_type = [&] {
106     switch (shape.element_type()) {
107       case xla::F16:
108         return builder.getF16Type();
109       case xla::F32:
110         return builder.getF32Type();
111       default:
112         break;
113     }
114     CHECK(false);
115   }();
116 
117   return shape_info;
118 }
119 
SetMemRef(mlir::Operation * op,mlir::Value memref)120 void SetMemRef(mlir::Operation* op, mlir::Value memref) {
121   if (auto load = mlir::dyn_cast<mlir::AffineLoadOp>(op)) {
122     load.setMemRef(memref);
123   } else if (auto store = mlir::dyn_cast<mlir::AffineStoreOp>(op)) {
124     store.setMemRef(memref);
125   } else {
126     CHECK(false);
127   }
128 }
129 
130 // Hoist operations out of `where`. [begin_op, end_op) must be the first
131 // operations of their parent loop, and `where` must be an ancestor of that
132 // parent loop.
133 //
134 // It always preserves the semantics of the program, therefore it may modify the
135 // hoisted operations or add extra loops at the hoisted place.
HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,llvm::iplist<mlir::Operation>::iterator end_op,mlir::AffineForOp where)136 mlir::Operation* HoistAndFix(llvm::iplist<mlir::Operation>::iterator begin_op,
137                              llvm::iplist<mlir::Operation>::iterator end_op,
138                              mlir::AffineForOp where) {
139   // All loops to hoist through.
140   llvm::SmallVector<mlir::AffineForOp, 4> ancestors;
141   getPerfectlyNestedLoops(ancestors, where);
142   {
143     int i;
144     for (i = 0; i < ancestors.size(); i++) {
145       if (&ancestors[i].getBody()->front() == &*begin_op) {
146         break;
147       }
148     }
149     CHECK(i < ancestors.size());
150     ancestors.resize(i + 1);
151   }
152 
153   std::vector<int64_t> ancestor_dimensions;
154   for (auto ancestor : ancestors) {
155     CHECK(IsSimpleLoop(ancestor));
156     ancestor_dimensions.push_back(
157         ancestor.getUpperBoundMap().getSingleConstantResult());
158   }
159 
160   if (auto alloc = mlir::dyn_cast<mlir::memref::AllocOp>(begin_op)) {
161     CHECK(std::next(begin_op) == end_op)
162         << "alloc() needs to be hoisted by its own";
163 
164     OpBuilder builder(where);
165     mlir::MemRefType type = alloc.getType();
166     CHECK(type.getLayout().isIdentity());
167     ancestor_dimensions.insert(ancestor_dimensions.end(),
168                                type.getShape().begin(), type.getShape().end());
169     mlir::MemRefType new_type =
170         mlir::MemRefType::get(ancestor_dimensions, type.getElementType());
171     auto new_alloc = builder.create<mlir::memref::AllocOp>(
172         builder.getUnknownLoc(), new_type);
173 
174     std::vector<mlir::Value> indvars;
175     for (auto ancestor : ancestors) {
176       indvars.push_back(ancestor.getInductionVar());
177     }
178     for (auto& use : llvm::make_early_inc_range(alloc.getResult().getUses())) {
179       mlir::Operation* owner = use.getOwner();
180       BoundAffineMap affine_map = GetBoundAffineMapFrom(owner);
181       affine_map.operands.insert(affine_map.operands.begin(), indvars.begin(),
182                                  indvars.end());
183       CHECK(affine_map.affine_map.isIdentity());
184       affine_map.affine_map = mlir::AffineMap::getMultiDimIdentityMap(
185           affine_map.operands.size(), builder.getContext());
186 
187       mlir::Operation* new_op =
188           CloneWithNewAffineMap(owner, affine_map, OpBuilder(owner));
189       SetMemRef(new_op, new_alloc);
190       owner->replaceAllUsesWith(new_op);
191       owner->erase();
192     }
193     alloc.erase();
194     return new_alloc;
195   }
196 
197   const bool any_op_is_loop_variant = [&] {
198     for (mlir::Operation& op : llvm::make_range(begin_op, end_op)) {
199       if (mlir::isa<mlir::AffineForOp, mlir::AffineStoreOp>(op)) {
200         return true;
201       }
202     }
203     return false;
204   }();
205 
206   if (any_op_is_loop_variant) {
207     auto builder = OpBuilder(where);
208     std::vector<mlir::AffineForOp> new_loops;
209     for (auto dim : ancestor_dimensions) {
210       auto where =
211           builder.create<mlir::AffineForOp>(builder.getUnknownLoc(), 0, dim);
212       new_loops.push_back(where);
213       builder = OpBuilder::atBlockTerminator(where.getBody());
214     }
215     for (mlir::Operation& op :
216          llvm::make_early_inc_range(llvm::make_range(begin_op, end_op))) {
217       op.moveBefore(&new_loops.back().getBody()->back());
218     }
219     CHECK_EQ(ancestors.size(), new_loops.size());
220     for (int i = 0; i < ancestors.size(); i++) {
221       replaceAllUsesInRegionWith(ancestors[i].getInductionVar(),
222                                  new_loops[i].getInductionVar(),
223                                  new_loops.back().getRegion());
224     }
225     return new_loops.front();
226   }
227   CHECK(false);
228 }
229 
HoistAndFix(mlir::Operation * op,mlir::AffineForOp where)230 mlir::Operation* HoistAndFix(mlir::Operation* op, mlir::AffineForOp where) {
231   return HoistAndFix(op->getIterator(), std::next(op->getIterator()), where);
232 }
233 
234 struct InitialMlirConvAnchors {
235   std::vector<mlir::AffineForOp> cartesian_product_loops;
236   std::vector<mlir::AffineForOp> reduction_loops;
237   mlir::memref::AllocOp output_acc;
238 };
239 
240 // Return the following IR with the anchors set to corresponding operations.
241 //   for (cartesian loops...) {
242 //     %output_acc = alloc() : memref(f32)
243 //     output_acc[] = 0
244 //     for (reduction loops...) {
245 //       output_acc[] += input[...] * filter[...]
246 //     }
247 //     output[...] = output_acc[]
248 //   }
CreateNaiveMlirConv(mlir::Value input,mlir::Value filter,mlir::Value output,const ShapeInfo & input_shape_info,const ShapeInfo & filter_shape_info,const ShapeInfo & output_shape_info,const Window & window,OpBuilder builder)249 StatusOr<InitialMlirConvAnchors> CreateNaiveMlirConv(
250     mlir::Value input, mlir::Value filter, mlir::Value output,
251     const ShapeInfo& input_shape_info, const ShapeInfo& filter_shape_info,
252     const ShapeInfo& output_shape_info, const Window& window,
253     OpBuilder builder) {
254   CHECK(input_shape_info.element_type == builder.getF16Type());
255   CHECK(filter_shape_info.element_type == builder.getF16Type());
256   CHECK(output_shape_info.element_type == builder.getF16Type());
257 
258   auto location = mlir::UnknownLoc::get(builder.getContext());
259 
260   std::vector<mlir::AffineForOp> cartesian_product_loops =
261       CreateNestedSimpleLoops(output_shape_info.nchw_dimensions, builder);
262 
263   builder =
264       OpBuilder::atBlockTerminator(cartesian_product_loops.back().getBody());
265 
266   auto output_acc = builder.create<mlir::memref::AllocOp>(
267       location, mlir::MemRefType::get({}, builder.getF32Type()));
268 
269   builder.create<mlir::AffineStoreOp>(
270       location,
271       builder.create<mlir::arith::ConstantOp>(
272           location, mlir::FloatAttr::get(builder.getF32Type(), 0)),
273       output_acc, llvm::ArrayRef<mlir::Value>());
274 
275   std::vector<mlir::AffineForOp> reduction_loops;
276   reduction_loops = CreateNestedSimpleLoops(
277       absl::MakeSpan(filter_shape_info.nchw_dimensions).subspan(1), builder);
278 
279   mlir::AffineForOp loop_n = cartesian_product_loops[0];
280   mlir::AffineForOp loop_o = cartesian_product_loops[1];
281   mlir::AffineForOp loop_c = reduction_loops[0];
282 
283   std::vector<mlir::Value> output_spatial_indvars;
284   for (auto loop : absl::MakeSpan(cartesian_product_loops).subspan(2)) {
285     output_spatial_indvars.push_back(loop.getInductionVar());
286   }
287   std::vector<mlir::Value> filter_spatial_indvars;
288   for (auto loop : absl::MakeSpan(reduction_loops).subspan(1)) {
289     filter_spatial_indvars.push_back(loop.getInductionVar());
290   }
291   int num_spatial_dims = output_spatial_indvars.size();
292   CHECK_EQ(num_spatial_dims, filter_spatial_indvars.size());
293 
294   builder = OpBuilder::atBlockTerminator(reduction_loops.back().getBody());
295 
296   mlir::Value loaded_input = [&] {
297     std::vector<mlir::AffineExpr> input_indices;
298     input_indices.push_back(builder.getAffineDimExpr(0));
299     input_indices.push_back(builder.getAffineDimExpr(1));
300 
301     // For spatial dimensions, generate input_index * stride + filter_index -
302     // left_pad
303     //
304     // TODO(timshen): guard out-of-bound loads and stores brought by padding.
305     for (int i = 0; i < num_spatial_dims; i++) {
306       const WindowDimension& window_dim = window.dimensions(i);
307       input_indices.push_back(
308           builder.getAffineDimExpr(i + 2) * window_dim.stride() +
309           builder.getAffineDimExpr(2 + num_spatial_dims + i) -
310           window_dim.padding_low());
311     }
312     std::vector<mlir::Value> input_vars;
313     input_vars.push_back(loop_n.getInductionVar());
314     input_vars.push_back(loop_c.getInductionVar());
315     input_vars.insert(input_vars.end(), output_spatial_indvars.begin(),
316                       output_spatial_indvars.end());
317     input_vars.insert(input_vars.end(), filter_spatial_indvars.begin(),
318                       filter_spatial_indvars.end());
319 
320     return builder.create<mlir::arith::ExtFOp>(
321         location, builder.getF32Type(),
322         builder.createOrFold<mlir::AffineLoadOp>(
323             location, input,
324             mlir::AffineMap(input_shape_info.affine_map)
325                 .compose(mlir::AffineMap::get(
326                     /*dimCount=*/2 + num_spatial_dims * 2,
327                     /*symbolCount=*/0, input_indices, builder.getContext())),
328             input_vars));
329   }();
330 
331   mlir::Value loaded_filter = [&] {
332     std::vector<mlir::Value> filter_vars;
333     filter_vars.push_back(loop_o.getInductionVar());
334     filter_vars.push_back(loop_c.getInductionVar());
335     filter_vars.insert(filter_vars.end(), filter_spatial_indvars.begin(),
336                        filter_spatial_indvars.end());
337 
338     return builder.create<mlir::arith::ExtFOp>(
339         location, builder.getF32Type(),
340         builder.createOrFold<mlir::AffineLoadOp>(
341             location, filter, filter_shape_info.affine_map, filter_vars));
342   }();
343 
344   auto accum_load_op =
345       builder.createOrFold<mlir::AffineLoadOp>(location, output_acc);
346   builder.createOrFold<mlir::AffineStoreOp>(
347       location,
348       builder.create<mlir::arith::AddFOp>(
349           location, accum_load_op,
350           builder.create<mlir::arith::MulFOp>(location, loaded_input,
351                                               loaded_filter)),
352       output_acc, llvm::ArrayRef<mlir::Value>());
353 
354   builder.setInsertionPointAfter(reduction_loops[0]);
355   {
356     std::vector<mlir::Value> output_vars;
357     output_vars.push_back(loop_n.getInductionVar());
358     output_vars.push_back(loop_o.getInductionVar());
359     output_vars.insert(output_vars.end(), output_spatial_indvars.begin(),
360                        output_spatial_indvars.end());
361     builder.createOrFold<mlir::AffineStoreOp>(
362         location,
363         builder.create<mlir::arith::TruncFOp>(
364             location, builder.getF16Type(),
365             builder.createOrFold<mlir::AffineLoadOp>(location, output_acc)),
366         output, output_shape_info.affine_map, output_vars);
367   }
368 
369   return InitialMlirConvAnchors{cartesian_product_loops, reduction_loops,
370                                 output_acc};
371 }
372 
373 // Contains the following pattern with anchors:
374 //   for (cartesian loops...) {
375 //     %output_acc = alloc() : memref(..., f32)
376 //     for (reduction loops...) {
377 //       for (tiled cartesian loops...) {
378 //         output_acc[...] = 0
379 //       }
380 //       for (tiled cartesian loops...) {
381 //         for (reduction loops...) {
382 //           output_acc[] += input[...] * filter[...]
383 //         }
384 //       }
385 //       for (tiled cartesian loops...) {
386 //         output[...] = output_acc[...]
387 //       }
388 //     }
389 //   }
390 struct TransformedMlirConvAnchors {
391   std::vector<mlir::AffineForOp> cartesian_product_loops;
392   std::vector<mlir::AffineForOp> reduction_loops;
393 };
394 
TransformMlirConv(InitialMlirConvAnchors anchors)395 StatusOr<TransformedMlirConvAnchors> TransformMlirConv(
396     InitialMlirConvAnchors anchors) {
397   std::vector<mlir::AffineForOp> cartesian_product_loops =
398       anchors.cartesian_product_loops;
399   std::vector<mlir::AffineForOp> reduction_loops = anchors.reduction_loops;
400   mlir::memref::AllocOp output_acc = anchors.output_acc;
401 
402   // TODO(timshen): consider using pattern matchers for transformations
403   //
404   // Initial form:
405   //   for (cartesian loops...) {
406   //     %output_acc = alloc() : memref(f32)
407   //     output_acc[] = 0
408   //     for (reduction loops...) {
409   //       output_acc[] += input[...] * filter[...]
410   //     }
411   //     output[...] = output_acc[]
412   //   }
413 
414   // Tile cartesian loops to:
415   //   for (cartesian loops...) {
416   //     for (tiled cartesian loops...) {
417   //       %output_acc = alloc() : memref(f32)
418   //       output_acc[] = 0
419   //       for (reduction loops...) {
420   //         output_acc[] += input[...] * filter[...]
421   //       }
422   //       output[...] = output_acc[]
423   //     }
424   //   }
425   TileLoop(reduction_loops[0], 4, reduction_loops.back());
426 
427   std::vector<mlir::AffineForOp> tiled_cartesian_loops;
428   tiled_cartesian_loops.push_back(
429       TileLoop(cartesian_product_loops[1], 32, cartesian_product_loops.back()));
430 
431   tiled_cartesian_loops.push_back(TileLoop(cartesian_product_loops.back(), 16,
432                                            tiled_cartesian_loops.back()));
433 
434   // Two hoist operations to interleave the allocation, computation, and
435   // writebacks to output_acc:
436   // After first hoist:
437   //   for (cartesian loops...) {
438   //     %output_acc = alloc() : memref(..., f32)
439   //     for (tiled cartesian loops...) {
440   //       output_acc[...] = 0
441   //       for (reduction loops...) {
442   //         output_acc[...] += input[...] * filter[...]
443   //       }
444   //       output[...] = output_acc[...]
445   //     }
446   //   }
447   output_acc = llvm::cast<mlir::memref::AllocOp>(
448       HoistAndFix(output_acc, tiled_cartesian_loops.front()));
449 
450   // Hoist everything before reduction loops (aka zero initializations of
451   // output_acc):
452   //   for (cartesian loops...) {
453   //     %output_acc = alloc() : memref(..., f32)
454   //     for (tiled cartesian loops...) {
455   //       output_acc[...] = 0
456   //     }
457   //     for (tiled cartesian loops...) {
458   //       for (reduction loops...) {
459   //         output_acc[...] += input[...] * filter[...]
460   //       }
461   //       output[...] = output_acc[...]
462   //     }
463   //   }
464   HoistAndFix(tiled_cartesian_loops.back().getBody()->begin(),
465               reduction_loops.front().getOperation()->getIterator(),
466               tiled_cartesian_loops.front());
467 
468   // Now hoist all reduction loops outside of tiled cartesian loops.
469   // Notice that HoistAndFix automatically add a new set of tiled cartesian
470   // loops for hoisted reduction loops to keep the semantics correct.
471   //
472   // After second hoist:
473   //   for (cartesian loops...) {
474   //     %output_acc = alloc() : memref(..., f32)
475   //     for (tiled cartesian loops...) {
476   //       output_acc[...] = 0
477   //     }
478   //     for (tiled cartesian loops...) {
479   //       for (reduction loops...) {
480   //         output_acc[] += input[...] * filter[...]
481   //       }
482   //     }  // compute loop
483   //     for (tiled cartesian loops...) {
484   //       output[...] = output_acc[...]
485   //     }
486   //   }
487   {
488     auto compute_loop = llvm::cast<mlir::AffineForOp>(
489         HoistAndFix(reduction_loops.front(), tiled_cartesian_loops[0]));
490 
491     // Fix tiled_cartesian_loops to make them point to the tiled compute loops,
492     // not the writeback loops to output buffer.
493     llvm::SmallVector<mlir::AffineForOp, 4> all_loops;
494     getPerfectlyNestedLoops(all_loops, compute_loop);
495     absl::c_copy_n(all_loops, tiled_cartesian_loops.size(),
496                    tiled_cartesian_loops.data());
497   }
498 
499   // After exchanging tiled cartesian compute loops with reduction loops:
500   //   for (cartesian loops...) {
501   //     %output_acc = alloc() : memref(..., f32)
502   //     for (tiled cartesian loops...) {
503   //       output_acc[...] = 0
504   //     }
505   //     for (reduction loops...) {
506   //       for (tiled cartesian loops...) {
507   //         output_acc[] += input[...] * filter[...]
508   //       }
509   //     }
510   //     for (tiled cartesian loops...) {
511   //       output[...] = output_acc[...]
512   //     }
513   //   }
514   //
515   // ...so that later tiled cartesian loops (with computations in it) can be
516   // replaced by CUDA MMA instructions.
517   {
518     std::vector<mlir::AffineForOp> loops;
519     loops.insert(loops.end(), tiled_cartesian_loops.begin(),
520                  tiled_cartesian_loops.end());
521     loops.insert(loops.end(), reduction_loops.begin(), reduction_loops.end());
522     SinkPerfectlyNestedLoops(loops, tiled_cartesian_loops.size());
523   }
524   return TransformedMlirConvAnchors{cartesian_product_loops, reduction_loops};
525 }
526 
527 }  // namespace
528 
EmitConvolutionForwardAsMlir(HloInstruction * conv,absl::string_view function_name,mlir::MLIRContext * context)529 StatusOr<mlir::func::FuncOp> EmitConvolutionForwardAsMlir(
530     HloInstruction* conv, absl::string_view function_name,
531     mlir::MLIRContext* context) {
532   OpBuilder builder(context);
533 
534   const auto& dim_nums = conv->convolution_dimension_numbers();
535   ShapeInfo input_shape_info =
536       GetShapeInfo(conv->operand(0)->shape(), dim_nums.input_batch_dimension(),
537                    dim_nums.input_feature_dimension(),
538                    dim_nums.input_spatial_dimensions(), builder);
539 
540   ShapeInfo filter_shape_info = GetShapeInfo(
541       conv->operand(1)->shape(), dim_nums.kernel_output_feature_dimension(),
542       dim_nums.kernel_input_feature_dimension(),
543       dim_nums.kernel_spatial_dimensions(), builder);
544 
545   ShapeInfo output_shape_info = GetShapeInfo(
546       conv->shape().tuple_shapes(0), dim_nums.output_batch_dimension(),
547       dim_nums.output_feature_dimension(), dim_nums.output_spatial_dimensions(),
548       builder);
549 
550   auto function = mlir::func::FuncOp::create(
551       mlir::UnknownLoc::get(builder.getContext()),
552       llvm_ir::AsStringRef(function_name),
553       builder.getFunctionType(
554           {mlir::MemRefType::get(output_shape_info.physical_dimensions,
555                                  output_shape_info.element_type,
556                                  mlir::AffineMap()),
557            mlir::MemRefType::get(input_shape_info.physical_dimensions,
558                                  input_shape_info.element_type,
559                                  mlir::AffineMap()),
560            mlir::MemRefType::get(filter_shape_info.physical_dimensions,
561                                  filter_shape_info.element_type,
562                                  mlir::AffineMap())},
563           {}));
564 
565   auto* entry_block = function.addEntryBlock();
566   builder.setInsertionPointToStart(entry_block);
567   builder.create<mlir::func::ReturnOp>(builder.getUnknownLoc());
568   builder.setInsertionPointToStart(entry_block);
569 
570   mlir::Value input = entry_block->getArgument(1);
571   mlir::Value filter = entry_block->getArgument(2);
572   mlir::Value output = entry_block->getArgument(0);
573 
574   TF_RETURN_IF_ERROR(ConvIsImplemented(conv));
575 
576   TF_ASSIGN_OR_RETURN(
577       InitialMlirConvAnchors initial_anchors,
578       CreateNaiveMlirConv(input, filter, output, input_shape_info,
579                           filter_shape_info, output_shape_info, conv->window(),
580                           builder));
581 
582   TF_ASSIGN_OR_RETURN(TransformedMlirConvAnchors transformed_anchors,
583                       TransformMlirConv(initial_anchors));
584 
585   // TODO(timshen): Implement a transformation that collects loads to a given
586   // buffer, create a local alloc() for the accessed part, redirects all loads
587   // and stores to that local alloc(), and create code to initialize /
588   // writeback the local alloc() if needed.
589 
590   // TODO(timshen): Implement CUDA-specific lowering.
591 
592   return function;
593 }
594 
ConvIsImplemented(const HloInstruction * conv)595 Status ConvIsImplemented(const HloInstruction* conv) {
596   if (conv->feature_group_count() != 1 || conv->batch_group_count() != 1) {
597     return Unimplemented("group count is not implemented.");
598   }
599   if (window_util::HasWindowReversal(conv->window())) {
600     return Unimplemented("Window reversal is not implemented.");
601   }
602   if (window_util::HasDilation(conv->window())) {
603     return Unimplemented("Dilation is not implemented.");
604   }
605   return ::tensorflow::OkStatus();
606 }
607 
608 }  // namespace experimental
609 }  // namespace xla
610