1 /*
2 * Copyright © 2018 Intel Corporation
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 */
23
24 #include "nir.h"
25 #include "nir_builder.h"
26 #include "nir_deref.h"
27
28 struct match_node {
29 /* Note: these fields are only valid for leaf nodes */
30
31 unsigned next_array_idx;
32 int src_wildcard_idx;
33 nir_deref_path first_src_path;
34
35 /* The index of the first read of the source path that's part of the copy
36 * we're matching. If the last write to the source path is after this, we
37 * would get a different result from reading it at the end and we can't
38 * emit the copy.
39 */
40 unsigned first_src_read;
41
42 /* The last time there was a write to this node. */
43 unsigned last_overwritten;
44
45 /* The last time there was a write to this node which successfully advanced
46 * next_array_idx. This helps us catch any intervening aliased writes.
47 */
48 unsigned last_successful_write;
49
50 unsigned num_children;
51 struct match_node *children[];
52 };
53
54 struct match_state {
55 /* Map from nir_variable * -> match_node */
56 struct hash_table *var_nodes;
57 /* Map from cast nir_deref_instr * -> match_node */
58 struct hash_table *cast_nodes;
59
60 unsigned cur_instr;
61
62 nir_builder builder;
63
64 void *dead_ctx;
65 };
66
67 static struct match_node *
create_match_node(const struct glsl_type * type,struct match_state * state)68 create_match_node(const struct glsl_type *type, struct match_state *state)
69 {
70 unsigned num_children = 0;
71 if (glsl_type_is_array_or_matrix(type)) {
72 /* One for wildcards */
73 num_children = glsl_get_length(type) + 1;
74 } else if (glsl_type_is_struct_or_ifc(type)) {
75 num_children = glsl_get_length(type);
76 }
77
78 struct match_node *node = rzalloc_size(state->dead_ctx,
79 sizeof(struct match_node) +
80 num_children * sizeof(struct match_node *));
81 node->num_children = num_children;
82 node->src_wildcard_idx = -1;
83 node->first_src_read = UINT32_MAX;
84 return node;
85 }
86
87 static struct match_node *
node_for_deref(nir_deref_instr * instr,struct match_node * parent,struct match_state * state)88 node_for_deref(nir_deref_instr *instr, struct match_node *parent,
89 struct match_state *state)
90 {
91 unsigned idx;
92 switch (instr->deref_type) {
93 case nir_deref_type_var: {
94 struct hash_entry *entry =
95 _mesa_hash_table_search(state->var_nodes, instr->var);
96 if (entry) {
97 return entry->data;
98 } else {
99 struct match_node *node = create_match_node(instr->type, state);
100 _mesa_hash_table_insert(state->var_nodes, instr->var, node);
101 return node;
102 }
103 }
104
105 case nir_deref_type_cast: {
106 struct hash_entry *entry =
107 _mesa_hash_table_search(state->cast_nodes, instr);
108 if (entry) {
109 return entry->data;
110 } else {
111 struct match_node *node = create_match_node(instr->type, state);
112 _mesa_hash_table_insert(state->cast_nodes, instr, node);
113 return node;
114 }
115 }
116
117 case nir_deref_type_array_wildcard:
118 idx = parent->num_children - 1;
119 break;
120
121 case nir_deref_type_array:
122 if (nir_src_is_const(instr->arr.index)) {
123 idx = nir_src_as_uint(instr->arr.index);
124 assert(idx < parent->num_children - 1);
125 } else {
126 idx = parent->num_children - 1;
127 }
128 break;
129
130 case nir_deref_type_struct:
131 idx = instr->strct.index;
132 break;
133
134 default:
135 unreachable("bad deref type");
136 }
137
138 assert(idx < parent->num_children);
139 if (parent->children[idx]) {
140 return parent->children[idx];
141 } else {
142 struct match_node *node = create_match_node(instr->type, state);
143 parent->children[idx] = node;
144 return node;
145 }
146 }
147
148 static struct match_node *
node_for_wildcard(const struct glsl_type * type,struct match_node * parent,struct match_state * state)149 node_for_wildcard(const struct glsl_type *type, struct match_node *parent,
150 struct match_state *state)
151 {
152 assert(glsl_type_is_array_or_matrix(type));
153 unsigned idx = glsl_get_length(type);
154
155 if (parent->children[idx]) {
156 return parent->children[idx];
157 } else {
158 struct match_node *node =
159 create_match_node(glsl_get_array_element(type), state);
160 parent->children[idx] = node;
161 return node;
162 }
163 }
164
165 static struct match_node *
node_for_path(nir_deref_path * path,struct match_state * state)166 node_for_path(nir_deref_path *path, struct match_state *state)
167 {
168 struct match_node *node = NULL;
169 for (nir_deref_instr **instr = path->path; *instr; instr++)
170 node = node_for_deref(*instr, node, state);
171
172 return node;
173 }
174
175 static struct match_node *
node_for_path_with_wildcard(nir_deref_path * path,unsigned wildcard_idx,struct match_state * state)176 node_for_path_with_wildcard(nir_deref_path *path, unsigned wildcard_idx,
177 struct match_state *state)
178 {
179 struct match_node *node = NULL;
180 unsigned idx = 0;
181 for (nir_deref_instr **instr = path->path; *instr; instr++, idx++) {
182 if (idx == wildcard_idx)
183 node = node_for_wildcard((*(instr - 1))->type, node, state);
184 else
185 node = node_for_deref(*instr, node, state);
186 }
187
188 return node;
189 }
190
191 typedef void (*match_cb)(struct match_node *, struct match_state *);
192
193 static void
_foreach_child(match_cb cb,struct match_node * node,struct match_state * state)194 _foreach_child(match_cb cb, struct match_node *node, struct match_state *state)
195 {
196 if (node->num_children == 0) {
197 cb(node, state);
198 } else {
199 for (unsigned i = 0; i < node->num_children; i++) {
200 if (node->children[i])
201 _foreach_child(cb, node->children[i], state);
202 }
203 }
204 }
205
206 static void
_foreach_aliasing(nir_deref_instr ** deref,match_cb cb,struct match_node * node,struct match_state * state)207 _foreach_aliasing(nir_deref_instr **deref, match_cb cb,
208 struct match_node *node, struct match_state *state)
209 {
210 if (*deref == NULL) {
211 cb(node, state);
212 return;
213 }
214
215 switch ((*deref)->deref_type) {
216 case nir_deref_type_struct: {
217 struct match_node *child = node->children[(*deref)->strct.index];
218 if (child)
219 _foreach_aliasing(deref + 1, cb, child, state);
220 return;
221 }
222
223 case nir_deref_type_array:
224 case nir_deref_type_array_wildcard: {
225 if ((*deref)->deref_type == nir_deref_type_array_wildcard ||
226 !nir_src_is_const((*deref)->arr.index)) {
227 /* This access may touch any index, so we have to visit all of
228 * them.
229 */
230 for (unsigned i = 0; i < node->num_children; i++) {
231 if (node->children[i])
232 _foreach_aliasing(deref + 1, cb, node->children[i], state);
233 }
234 } else {
235 /* Visit the wildcard entry if any */
236 if (node->children[node->num_children - 1]) {
237 _foreach_aliasing(deref + 1, cb,
238 node->children[node->num_children - 1], state);
239 }
240
241 unsigned index = nir_src_as_uint((*deref)->arr.index);
242 /* Check that the index is in-bounds */
243 if (index < node->num_children - 1 && node->children[index])
244 _foreach_aliasing(deref + 1, cb, node->children[index], state);
245 }
246 return;
247 }
248
249 case nir_deref_type_cast:
250 _foreach_child(cb, node, state);
251 return;
252
253 default:
254 unreachable("bad deref type");
255 }
256 }
257
258 /* Given a deref path, find all the leaf deref nodes that alias it. */
259
260 static void
foreach_aliasing_node(nir_deref_path * path,match_cb cb,struct match_state * state)261 foreach_aliasing_node(nir_deref_path *path,
262 match_cb cb,
263 struct match_state *state)
264 {
265 if (path->path[0]->deref_type == nir_deref_type_var) {
266 struct hash_entry *entry = _mesa_hash_table_search(state->var_nodes,
267 path->path[0]->var);
268 if (entry)
269 _foreach_aliasing(&path->path[1], cb, entry->data, state);
270
271 hash_table_foreach(state->cast_nodes, entry)
272 _foreach_child(cb, entry->data, state);
273 } else {
274 /* Casts automatically alias anything that isn't a cast */
275 assert(path->path[0]->deref_type == nir_deref_type_cast);
276 hash_table_foreach(state->var_nodes, entry)
277 _foreach_child(cb, entry->data, state);
278
279 /* Casts alias other casts if the casts are different or if they're the
280 * same and the path from the cast may alias as per the usual rules.
281 */
282 hash_table_foreach(state->cast_nodes, entry) {
283 const nir_deref_instr *cast = entry->key;
284 assert(cast->deref_type == nir_deref_type_cast);
285 if (cast == path->path[0])
286 _foreach_aliasing(&path->path[1], cb, entry->data, state);
287 else
288 _foreach_child(cb, entry->data, state);
289 }
290 }
291 }
292
293 static nir_deref_instr *
build_wildcard_deref(nir_builder * b,nir_deref_path * path,unsigned wildcard_idx)294 build_wildcard_deref(nir_builder *b, nir_deref_path *path,
295 unsigned wildcard_idx)
296 {
297 assert(path->path[wildcard_idx]->deref_type == nir_deref_type_array);
298
299 nir_deref_instr *tail =
300 nir_build_deref_array_wildcard(b, path->path[wildcard_idx - 1]);
301
302 for (unsigned i = wildcard_idx + 1; path->path[i]; i++)
303 tail = nir_build_deref_follower(b, tail, path->path[i]);
304
305 return tail;
306 }
307
308 static void
clobber(struct match_node * node,struct match_state * state)309 clobber(struct match_node *node, struct match_state *state)
310 {
311 node->last_overwritten = state->cur_instr;
312 }
313
314 static bool
try_match_deref(nir_deref_path * base_path,int * path_array_idx,nir_deref_path * deref_path,int arr_idx,nir_deref_instr * dst)315 try_match_deref(nir_deref_path *base_path, int *path_array_idx,
316 nir_deref_path *deref_path, int arr_idx,
317 nir_deref_instr *dst)
318 {
319 for (int i = 0;; i++) {
320 nir_deref_instr *b = base_path->path[i];
321 nir_deref_instr *d = deref_path->path[i];
322 /* They have to be the same length */
323 if ((b == NULL) != (d == NULL))
324 return false;
325
326 if (b == NULL)
327 break;
328
329 /* This can happen if one is a deref_array and the other a wildcard */
330 if (b->deref_type != d->deref_type)
331 return false;
332 ;
333
334 switch (b->deref_type) {
335 case nir_deref_type_var:
336 if (b->var != d->var)
337 return false;
338 continue;
339
340 case nir_deref_type_array: {
341 const bool const_b_idx = nir_src_is_const(b->arr.index);
342 const bool const_d_idx = nir_src_is_const(d->arr.index);
343 const unsigned b_idx = const_b_idx ? nir_src_as_uint(b->arr.index) : 0;
344 const unsigned d_idx = const_d_idx ? nir_src_as_uint(d->arr.index) : 0;
345
346 /* If we don't have an index into the path yet or if this entry in
347 * the path is at the array index, see if this is a candidate. We're
348 * looking for an index which is zero in the base deref and arr_idx
349 * in the search deref and has a matching array size.
350 */
351 if ((*path_array_idx < 0 || *path_array_idx == i) &&
352 const_b_idx && b_idx == 0 &&
353 const_d_idx && d_idx == arr_idx &&
354 glsl_get_length(nir_deref_instr_parent(b)->type) ==
355 glsl_get_length(nir_deref_instr_parent(dst)->type)) {
356 *path_array_idx = i;
357 continue;
358 }
359
360 /* We're at the array index but not a candidate */
361 if (*path_array_idx == i)
362 return false;
363
364 /* If we're not the path array index, we must match exactly. We
365 * could probably just compare SSA values and trust in copy
366 * propagation but doing it ourselves means this pass can run a bit
367 * earlier.
368 */
369 if (b->arr.index.ssa == d->arr.index.ssa ||
370 (const_b_idx && const_d_idx && b_idx == d_idx))
371 continue;
372
373 return false;
374 }
375
376 case nir_deref_type_array_wildcard:
377 continue;
378
379 case nir_deref_type_struct:
380 if (b->strct.index != d->strct.index)
381 return false;
382 continue;
383
384 default:
385 unreachable("Invalid deref type in a path");
386 }
387 }
388
389 /* If we got here without failing, we've matched. However, it isn't an
390 * array match unless we found an altered array index.
391 */
392 return *path_array_idx > 0;
393 }
394
395 static void
handle_read(nir_deref_instr * src,struct match_state * state)396 handle_read(nir_deref_instr *src, struct match_state *state)
397 {
398 /* We only need to create an entry for sources that might be used to form
399 * an array copy. Hence no indirects or indexing into a vector.
400 */
401 if (nir_deref_instr_has_indirect(src) ||
402 nir_deref_instr_is_known_out_of_bounds(src) ||
403 (src->deref_type == nir_deref_type_array &&
404 glsl_type_is_vector(nir_src_as_deref(src->parent)->type)))
405 return;
406
407 nir_deref_path src_path;
408 nir_deref_path_init(&src_path, src, state->dead_ctx);
409
410 /* Create a node for this source if it doesn't exist. The point of this is
411 * to know which nodes aliasing a given store we actually need to care
412 * about, to avoid creating an excessive amount of nodes.
413 */
414 node_for_path(&src_path, state);
415 }
416
417 /* The core implementation, which is used for both copies and writes. Return
418 * true if a copy is created.
419 */
420 static bool
handle_write(nir_deref_instr * dst,nir_deref_instr * src,unsigned write_index,unsigned read_index,struct match_state * state)421 handle_write(nir_deref_instr *dst, nir_deref_instr *src,
422 unsigned write_index, unsigned read_index,
423 struct match_state *state)
424 {
425 nir_builder *b = &state->builder;
426
427 nir_deref_path dst_path;
428 nir_deref_path_init(&dst_path, dst, state->dead_ctx);
429
430 unsigned idx = 0;
431 for (nir_deref_instr **instr = dst_path.path; *instr; instr++, idx++) {
432 if ((*instr)->deref_type != nir_deref_type_array)
433 continue;
434
435 /* Get the entry where the index is replaced by a wildcard, so that we
436 * hopefully can keep matching an array copy.
437 */
438 struct match_node *dst_node =
439 node_for_path_with_wildcard(&dst_path, idx, state);
440
441 if (!src)
442 goto reset;
443
444 if (nir_src_as_uint((*instr)->arr.index) != dst_node->next_array_idx)
445 goto reset;
446
447 if (dst_node->next_array_idx == 0) {
448 /* At this point there may be multiple source indices which are zero,
449 * so we can't pin down the actual source index. Just store it and
450 * move on.
451 */
452 nir_deref_path_init(&dst_node->first_src_path, src, state->dead_ctx);
453 } else {
454 nir_deref_path src_path;
455 nir_deref_path_init(&src_path, src, state->dead_ctx);
456 bool result = try_match_deref(&dst_node->first_src_path,
457 &dst_node->src_wildcard_idx,
458 &src_path, dst_node->next_array_idx,
459 *instr);
460 nir_deref_path_finish(&src_path);
461 if (!result)
462 goto reset;
463 }
464
465 /* Check if an aliasing write clobbered the array after the last normal
466 * write. For example, with a sequence like this:
467 *
468 * dst[0][*] = src[0][*];
469 * dst[0][0] = 0; // invalidates the array copy dst[*][*] = src[*][*]
470 * dst[1][*] = src[1][*];
471 *
472 * Note that the second write wouldn't reset the entry for dst[*][*]
473 * by itself, but it'll be caught by this check when processing the
474 * third copy.
475 */
476 if (dst_node->last_successful_write < dst_node->last_overwritten)
477 goto reset;
478
479 dst_node->last_successful_write = write_index;
480
481 /* In this case we've successfully processed an array element. Check if
482 * this is the last, so that we can emit an array copy.
483 */
484 dst_node->next_array_idx++;
485 dst_node->first_src_read = MIN2(dst_node->first_src_read, read_index);
486 if (dst_node->next_array_idx > 1 &&
487 dst_node->next_array_idx == glsl_get_length((*(instr - 1))->type)) {
488 /* Make sure that nothing was overwritten. */
489 struct match_node *src_node =
490 node_for_path_with_wildcard(&dst_node->first_src_path,
491 dst_node->src_wildcard_idx,
492 state);
493
494 if (src_node->last_overwritten <= dst_node->first_src_read) {
495 nir_copy_deref(b, build_wildcard_deref(b, &dst_path, idx),
496 build_wildcard_deref(b, &dst_node->first_src_path,
497 dst_node->src_wildcard_idx));
498 foreach_aliasing_node(&dst_path, clobber, state);
499 return true;
500 }
501 } else {
502 continue;
503 }
504
505 reset:
506 dst_node->next_array_idx = 0;
507 dst_node->src_wildcard_idx = -1;
508 dst_node->last_successful_write = 0;
509 dst_node->first_src_read = UINT32_MAX;
510 }
511
512 /* Mark everything aliasing dst_path as clobbered. This needs to happen
513 * last since in the loop above we need to know what last clobbered
514 * dst_node and this overwrites that.
515 */
516 foreach_aliasing_node(&dst_path, clobber, state);
517
518 return false;
519 }
520
521 static bool
opt_find_array_copies_block(nir_builder * b,nir_block * block,struct match_state * state)522 opt_find_array_copies_block(nir_builder *b, nir_block *block,
523 struct match_state *state)
524 {
525 bool progress = false;
526
527 unsigned next_index = 0;
528
529 _mesa_hash_table_clear(state->var_nodes, NULL);
530 _mesa_hash_table_clear(state->cast_nodes, NULL);
531
532 nir_foreach_instr(instr, block) {
533 if (instr->type != nir_instr_type_intrinsic)
534 continue;
535
536 /* Index the instructions before we do anything else. */
537 instr->index = next_index++;
538
539 /* Save the index of this instruction */
540 state->cur_instr = instr->index;
541
542 nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
543
544 if (intrin->intrinsic == nir_intrinsic_load_deref) {
545 handle_read(nir_src_as_deref(intrin->src[0]), state);
546 continue;
547 }
548
549 if (intrin->intrinsic != nir_intrinsic_copy_deref &&
550 intrin->intrinsic != nir_intrinsic_store_deref)
551 continue;
552
553 nir_deref_instr *dst_deref = nir_src_as_deref(intrin->src[0]);
554
555 /* The destination must be local. If we see a non-local store, we
556 * continue on because it won't affect local stores or read-only
557 * variables.
558 */
559 if (!nir_deref_mode_may_be(dst_deref, nir_var_function_temp))
560 continue;
561
562 if (!nir_deref_mode_must_be(dst_deref, nir_var_function_temp)) {
563 /* This only happens if we have something that might be a local store
564 * but we don't know. In this case, clear everything.
565 */
566 nir_deref_path dst_path;
567 nir_deref_path_init(&dst_path, dst_deref, state->dead_ctx);
568 foreach_aliasing_node(&dst_path, clobber, state);
569 continue;
570 }
571
572 /* If there are any known out-of-bounds writes, then we can just skip
573 * this write as it's undefined and won't contribute to building up an
574 * array copy anyways.
575 */
576 if (nir_deref_instr_is_known_out_of_bounds(dst_deref))
577 continue;
578
579 nir_deref_instr *src_deref;
580 unsigned load_index = 0;
581 if (intrin->intrinsic == nir_intrinsic_copy_deref) {
582 src_deref = nir_src_as_deref(intrin->src[1]);
583 load_index = intrin->instr.index;
584 } else {
585 assert(intrin->intrinsic == nir_intrinsic_store_deref);
586 nir_intrinsic_instr *load = nir_src_as_intrinsic(intrin->src[1]);
587 if (load == NULL || load->intrinsic != nir_intrinsic_load_deref) {
588 src_deref = NULL;
589 } else {
590 src_deref = nir_src_as_deref(load->src[0]);
591 load_index = load->instr.index;
592 }
593
594 if (nir_intrinsic_write_mask(intrin) !=
595 (1 << glsl_get_components(dst_deref->type)) - 1) {
596 src_deref = NULL;
597 }
598 }
599
600 /* The source must be either local or something that's guaranteed to be
601 * read-only.
602 */
603 if (src_deref &&
604 !nir_deref_mode_must_be(src_deref, nir_var_function_temp |
605 nir_var_read_only_modes)) {
606 src_deref = NULL;
607 }
608
609 /* There must be no indirects in the source or destination and no known
610 * out-of-bounds accesses in the source, and the copy must be fully
611 * qualified, or else we can't build up the array copy. We handled
612 * out-of-bounds accesses to the dest above. The types must match, since
613 * copy_deref currently can't bitcast mismatched deref types.
614 */
615 if (src_deref &&
616 (nir_deref_instr_has_indirect(src_deref) ||
617 nir_deref_instr_is_known_out_of_bounds(src_deref) ||
618 nir_deref_instr_has_indirect(dst_deref) ||
619 !glsl_type_is_vector_or_scalar(src_deref->type) ||
620 glsl_get_bare_type(src_deref->type) !=
621 glsl_get_bare_type(dst_deref->type))) {
622 src_deref = NULL;
623 }
624
625 state->builder.cursor = nir_after_instr(instr);
626 progress |= handle_write(dst_deref, src_deref, instr->index,
627 load_index, state);
628 }
629
630 return progress;
631 }
632
633 static bool
opt_find_array_copies_impl(nir_function_impl * impl)634 opt_find_array_copies_impl(nir_function_impl *impl)
635 {
636 nir_builder b = nir_builder_create(impl);
637
638 bool progress = false;
639
640 struct match_state s;
641 s.dead_ctx = ralloc_context(NULL);
642 s.var_nodes = _mesa_pointer_hash_table_create(s.dead_ctx);
643 s.cast_nodes = _mesa_pointer_hash_table_create(s.dead_ctx);
644 s.builder = nir_builder_create(impl);
645
646 nir_foreach_block(block, impl) {
647 if (opt_find_array_copies_block(&b, block, &s))
648 progress = true;
649 }
650
651 ralloc_free(s.dead_ctx);
652
653 if (progress) {
654 nir_metadata_preserve(impl, nir_metadata_control_flow);
655 } else {
656 nir_metadata_preserve(impl, nir_metadata_all);
657 }
658
659 return progress;
660 }
661
662 /**
663 * This peephole optimization looks for a series of load/store_deref or
664 * copy_deref instructions that copy an array from one variable to another and
665 * turns it into a copy_deref that copies the entire array. The pattern it
666 * looks for is extremely specific but it's good enough to pick up on the
667 * input array copies in DXVK and should also be able to pick up the sequence
668 * generated by spirv_to_nir for a OpLoad of a large composite followed by
669 * OpStore.
670 *
671 * TODO: Support out-of-order copies.
672 */
673 bool
nir_opt_find_array_copies(nir_shader * shader)674 nir_opt_find_array_copies(nir_shader *shader)
675 {
676 bool progress = false;
677
678 nir_foreach_function_impl(impl, shader) {
679 if (opt_find_array_copies_impl(impl))
680 progress = true;
681 }
682
683 return progress;
684 }
685