1# Convolution Emitter
2
3## Context
4
5This is a doc that describes a set of patches that are still under review.
6TODO(timshen): Change once all patches are checked in.
7
8The convolution emitter is a prototype with the following goals:
9
10*   The top priority is performance.
11*   It supports arbitrarily sophisticated layouts.
12*   It supports platform-specific high-performance instructions.
13*   It is as portable as possible.
14*   It enables fusion support in the future.
15
16## Current Design
17
18### Overview
19
20The prototype consists of the following components:
21
22*   The emitter currently focuses on NVIDIA Volta architecture and N(C/4)HW4
23    layout.
24*   An MLIR-based emitter. It takes a set of tuning parameters and a convolution
25    configuration, then produces a NVVM device function.
26*   An autotuner, which generates tuning parameters given a convolution
27    configuration.
28*   A test framework, which executes the generated device function with random
29    inputs, and compares the result against cuDNN.
30
31### The Emitter - Naive Implementation
32
33The emitter starts with a hand-built, naive implementation that looks like
34following Resnet first layer convolution (pseudo code):
35
36```mlir
37func @Conv(%input : memref<128x1x224x224xvector<4xf16>>,
38           %filter : memref<64x1x7x7xvector<4xf16>>,
39           %output : memref<128x64x224x224xf16>) {
40  affine.parallel (%n, %o, %oh, %ow) = 0 to 128, 0 to 64, 0 to 112, 0 to 112 {
41    %acc = alloc() : memref<f32>
42    affine.store 0, %acc[]
43    affine.for (%c, %fh, %fw) = 0 to 1, 0 to 7, 0 to 7 {
44      %a = affine.padded.load %input[%n, %c, %oh * 2 + %fh - 3, %ow * 2 + %fw - 3]
45      %b = affine.load %filter[%o, %c, %fh, %fw]
46      %c = affine.load %acc[]
47      %d = std.fpext %a to vector<4xf32>
48      %e = std.fpext %b to vector<4xf32>
49      %f = std.multiply %d, %e
50      %g = "reduce" %f
51      %v = %g + %c
52      affine.store %v, %acc[]
53    }
54    %c = affine.load %acc[]
55    affine.store %acc, %output[%n, %o, %oh, %ow]
56  }
57}
58```
59
60A few extensions are used in the example above:
61
62*   affine.padded.load allows out-of-bounds access, in which case the result is
63    always 0.
64*   The "reduce" operation produces the sum of elements in a vector.
65
66Also notice that the input element type is vector<4xf16> only because the
67current implementation does so. A MemRef with <...x4xf16> should work as well,
68given the alignment properly aligned to at least 8 (usually 16).
69
70Then the emitter does a few semantic preserving transformations to work the code
71towards PTX's structure.
72
73### The Emitter - Tiling
74
75The following is the naive code after loop tiling:
76
77```mlir
78func @Conv(%input : memref<128x1x224x224xvector<4xf16>>,
79           %filter : memref<64x1x7x7xvector<4xf16>>,
80           %output : memref<128x64x224x224xf16>) {
81  affine.parallel (%n0, %o0, %oh0, %ow0) = 0 to 128, 0 to 1, 0 to 7, 0 to 7 {
82    affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
83      %acc = alloc() : memref<f32>
84      affine.store 0, %acc[]
85      affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 {
86        affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 {
87          %a = affine.padded.load %input[
88              %n0 * 1 + %n1,
89              %c0 * 1 + %c1,
90              (%oh0 * 16 + %oh1) * 2 + %fh0 * 7 + %fh1 - 3,
91              (%ow0 * 16 + %ow1) * 2 + %fw0 * 7 + %fw1 - 3]
92          %b = affine.load %filter[
93              %o0 * 64 + %o1,
94              %c0 * 1 + %c1,
95              %fh0 * 7 + %fh1,
96              %fw0 * 7 + %fw1]
97          %old = affine.load %acc[]
98          %d = std.fpext %a to vector<4xf32>
99          %e = std.fpext %b to vector<4xf32>
100          %f = std.multiply %d, %e
101          %g = "reduce" %f
102          %new = %g + %old
103          affine.store %new, %acc[]
104        }
105      }
106      %v = affine.load %acc[]
107      affine.store %v, %output[
108          %n0 * 1 + %n1,
109          %o0 * 64 + %o1,
110          %oh0 * 16 + %oh1,
111          %ow0 * 16 + %ow1]
112    } { ptx_block }
113  } { ptx_grid }
114}
115```
116
117The motivation is obvious - we need to decide which loops are parallelized on
118the compute units in the PTX architecture. The `ptx_grid` and `ptx_block`
119directs that the loop should be parallelized on a grid / a block, respectively.
120
121Also notice that to keep the code pattern clean and neat, tiling is implemented
122in the following way. Defining "simple loop" as a loop with lower bound 0, and
123step 1, the tiling:
124
125*   only takes simple loops.
126*   only produces simple loops.
127*   no extra operation is generated. All altered index calculations are done in
128    each user AffineMaps.
129
130The contracting dimensions (%c, %fh, %fw) are also tiled for once. The
131significance will be seen later in shared memory promotion.
132
133### The Emitter - Splitting
134
135This step splits the body of the (%n1, %o1, %oh1, %ow1) loop into several parts:
136
137*   The code that sets the accumulators to 0.
138*   The actual convolution computation code.
139*   The code that writes back accumulators to the %output buffer.
140
141This transformation "vectorizes" the accumulator accordingly as the `alloc()`
142gets hoisted out of the `affine.parallel` op.
143
144After splitting:
145
146```mlir
147func @Conv(%input : memref<128x1x224x224xvector<4xf16>>,
148           %filter : memref<64x1x7x7xvector<4xf16>>,
149           %output : memref<128x64x224x224xf16>) {
150  affine.parallel (%n0, %o0, %oh0, %ow0) = 0 to 128, 0 to 1, 0 to 7, 0 to 7 {
151    %acc = alloc() : memref<1x64x16x16xf32>
152    affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
153      affine.store 0, %acc[%n1, %o1, %oh1, %ow1]
154    } { ptx_block }
155    affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
156      affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 {
157        affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 {
158          %a = affine.padded.load %input[
159              %n0 * 1 + %n1,
160              %c0 * 1 + %c1,
161              (%oh0 * 16 + %oh1) * 2 + %fh0 * 7 + %fh1 - 3,
162              (%ow0 * 16 + %ow1) * 2 + %fw0 * 7 + %fw1 - 3]
163          %b = affine.load %filter[
164              %o0 * 64 + %o1,
165              %c0 * 1 + %c1,
166              %fh0 * 7 + %fh1,
167              %fw0 * 7 + %fw1]
168          %old = affine.load %acc[%n1, %o1, %oh1, %ow1]
169          %d = std.fpext %a to vector<4xf32>
170          %e = std.fpext %b to vector<4xf32>
171          %f = std.multiply %d, %e
172          %g = "reduce" %f
173          %new = %g + %old
174          affine.store %new, %acc[%n1, %o1, %oh1, %ow1]
175        }
176      }
177    } { ptx_block }
178    affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
179      %v = affine.load %acc[%n1, %o1, %oh1, %ow1]
180      affine.store %v, %output[
181          %n0 * 1 + %n1,
182          %o0 * 64 + %o1,
183          %oh0 * 16 + %oh1,
184          %ow0 * 16 + %ow1]
185    } { ptx_block }
186  } { ptx_grid }
187}
188```
189
190To prepare for the next transformations, we'd also like to sink the (%n1, %o1,
191%oh1, %ow1), as (%c0, %fh0, %fw0) is not interesting.
192
193```
194affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
195  affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 {
196    affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 {
197      ...
198    }
199  }
200} { ptx_block }
201
202=>
203
204affine.for (%c0, %fh0, %fw0) = 0 to 1, 0 to 1, 0 to 1 {
205  affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 {
206    affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
207      ...
208    } { ptx_block }
209  }
210}
211```
212
213### The Emitter - Shared Memory Promotion
214
215This transformation is done by `affineDataCopyGenerate`, which does precise
216calculation on how much memory is transferred for a load operation.
217
218After calculating the sizes of the shared memory buffer (`%promoted_input` and
219`%promoted_filter`), the transformation also creates loads and stores to
220pre-fetch data from global memory (`%input`, `%filter`) to the promoted, shared
221memory.
222
223```mlir
224// Before
225affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 {
226  affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
227    %a = affine.padded.load %input[
228        %n0 * 1 + %n1,
229        %c0 * 1 + %c1,
230        (%oh0 * 16 + %oh1) * 2 + %fh0 * 7 + %fh1 - 3,
231        (%ow0 * 16 + %ow1) * 2 + %fw0 * 7 + %fw1 - 3]
232    %b = affine.load %filter[
233        %o0 * 64 + %o1,
234        %c0 * 1 + %c1,
235        %fh0 * 7 + %fh1,
236        %fw0 * 7 + %fw1]
237    %old = affine.load %acc[%n1, %o1, %oh1, %ow1]
238    %d = std.fpext %a to vector<4xf32>
239    %e = std.fpext %b to vector<4xf32>
240    %f = std.multiply %d, %e
241    %g = "reduce" %f
242    %new = %g + %old
243    affine.store %new, %acc[%n1, %o1, %oh1, %ow1]
244  } { ptx_block }
245}
246```
247
248```mlir
249// After
250
251%promoted_input = alloc() : memref<1x1x37x37, memory_space = 3>
252%promoted_filter = alloc() : memref<64x1x7x7, memory_space = 3>
253affine.parallel (%i0, %i1, %i2, %i3) = 0 to 1, 0 to 1, 0 to 37, 0 to 37 {
254  %v = affine.padded.load %input[
255      %n0 * 1 + %i0,
256      %c0 * 1 + %i1,
257      (%oh0 * 16) * 2 + %fh0 * 7 + %i2 - 3,
258      (%ow0 * 16) * 2 + %fw0 * 7 + %i3 - 3]
259  affine.store %v, %promoted_input[%i0, %i1, %i2, %i3]
260} { ptx_block }
261affine.parallel (%i0, %i1, %i2, %i3) = 0 to 64, 0 to 1, 0 to 7, 0 to 7 {
262  %v = affine.load %filter[
263      %o0 * 64 + %i0,
264      %c0 * 1 + %i1,
265      %fh0 * 7 + %i2,
266      %fw0 * 7 + %i3]
267  affine.store %v, %promoted_filter[%i0, %i1, %i2, %i3]
268} { ptx_block }
269affine.for (%c1, %fh1, %fw1) = 0 to 1, 0 to 7, 0 to 7 {
270  affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
271    %a = affine.load %promoted_input[%n1, %c1, %oh1 * 2 + %fh1, %ow1 * 2 + %fw1]
272    %b = affine.load %promoted_filter[%o1, %c1, %fh1, %fw1]
273    %old = affine.load %acc[%n1, %o1, %oh1, %ow1]
274    %d = std.fpext %a to vector<4xf32>
275    %e = std.fpext %b to vector<4xf32>
276    %f = std.multiply %d, %e
277    %g = "reduce" %f
278    %new = %g + %old
279    affine.store %new, %acc[%n1, %o1, %oh1, %ow1]
280  } { ptx_block }
281}
282```
283
284### The Emitter - Volta MMA Instruction
285
286This transformation turns the inner loop:
287
288```mlir
289affine.parallel (%n1, %o1, %oh1, %ow1) = 0 to 1, 0 to 64, 0 to 16, 0 to 16 {
290  %a = affine.load %promoted_input[%n1, %c1, %oh1 * 2 + %fh1, %ow1 * 2 + %fw1]
291  %b = affine.load %promoted_filter[%o1, %c1, %fh1, %fw1]
292  %old = affine.load %acc[%n1, %o1, %oh1, %ow1]
293  %d = std.fpext %a to vector<4xf32>
294  %e = std.fpext %b to vector<4xf32>
295  %f = std.multiply %d, %e
296  %g = "reduce" %f
297  %new = %g + %old
298  affine.store %new, %acc[%n1, %o1, %oh1, %ow1]
299} { ptx_block }
300```
301
302to multiple Volta mma.sync instructions. The result is not shown here, because
303the prototype currently only hacks it up to achieve benchmark goals.
304
305### The Autotuner
306
307As shown above, many parameters dictate how a naive implementation is
308transformed. For now, the parameters are all tile sizes. On the top of the
309emitter, the prototype includes a simple autotuner that enumerates all good
310combinations of tile sizes and invoke the emitter with each of the combinations.
311With the assistance of in-process benchmarking, the autotuner is able to pick
312the best set of parameters.
313
314## Future Improvements
315
316*   Explore Linalg/Vector for a higher-level naive implementation. MMA
317    instruction handling would be much easier with high-level functional
318    constructs.
319*   Explore other layouts. The current layout corresponds to NVIDIA
320    `CUDNN_TENSOR_NCHW_VECT_C` but for fp16s.
321*   Iron out GPU dialect related lowering. Annotations like `ptx_grid` and
322    `ptx_block` should be generalized to more architectures.
323*   Speed up autotuning through more pruning.
324*   Support dynamic shapes.
325