xref: /aosp_15_r20/external/executorch/backends/vulkan/test/glsl/reference_matmul.glsl (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION highp
12
13layout(std430) buffer;
14
15${layout_declare_tensor(0, "w", "t_out", "float", "buffer")}
16${layout_declare_tensor(1, "r", "t_mat1", "float", "buffer")}
17${layout_declare_tensor(2, "r", "t_mat2", "float", "buffer")}
18${layout_declare_ubo(3, "ivec4", "out_sizes")}
19${layout_declare_ubo(4, "ivec4", "out_strides")}
20${layout_declare_ubo(5, "ivec4", "mat1_sizes")}
21${layout_declare_ubo(6, "ivec4", "mat1_strides")}
22${layout_declare_ubo(7, "ivec4", "mat2_sizes")}
23${layout_declare_ubo(8, "ivec4", "mat2_strides")}
24
25layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
26
27void main() {
28  const ivec2 out_idx = ivec2(gl_GlobalInvocationID.x, gl_GlobalInvocationID.y);
29  if (any(greaterThanEqual(out_idx, out_sizes.xy))) {
30    return;
31  }
32
33  // Initial idx for mat1 is (0, out_idx.y)
34  int mat1_id = out_idx.y * mat1_strides.y;
35  // Initial idx for mat2 is (out_idx.x, 0)
36  int mat2_id = out_idx.x * mat2_strides.x;
37
38  float sum = 0.0;
39  for (int i = 0; i < mat1_sizes.x; ++i) {
40    sum += t_mat1[mat1_id] * t_mat2[mat2_id];
41
42    mat1_id += mat1_strides.x;
43    mat2_id += mat2_strides.y;
44  }
45
46  const int out_id = out_idx.x * out_strides.x + out_idx.y * out_strides.y;
47  t_out[out_id] = sum;
48}
49