1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION highp 12 13layout(std430) buffer; 14 15${layout_declare_tensor(0, "w", "t_out", "float", "buffer")} 16${layout_declare_tensor(1, "r", "t_mat1", "float", "buffer")} 17${layout_declare_tensor(2, "r", "t_mat2", "float", "buffer")} 18${layout_declare_ubo(3, "ivec4", "out_sizes")} 19${layout_declare_ubo(4, "ivec4", "out_strides")} 20${layout_declare_ubo(5, "ivec4", "mat1_sizes")} 21${layout_declare_ubo(6, "ivec4", "mat1_strides")} 22${layout_declare_ubo(7, "ivec4", "mat2_sizes")} 23${layout_declare_ubo(8, "ivec4", "mat2_strides")} 24 25layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 26 27void main() { 28 const ivec2 out_idx = ivec2(gl_GlobalInvocationID.x, gl_GlobalInvocationID.y); 29 if (any(greaterThanEqual(out_idx, out_sizes.xy))) { 30 return; 31 } 32 33 // Initial idx for mat1 is (0, out_idx.y) 34 int mat1_id = out_idx.y * mat1_strides.y; 35 // Initial idx for mat2 is (out_idx.x, 0) 36 int mat2_id = out_idx.x * mat2_strides.x; 37 38 float sum = 0.0; 39 for (int i = 0; i < mat1_sizes.x; ++i) { 40 sum += t_mat1[mat1_id] * t_mat2[mat2_id]; 41 42 mat1_id += mat1_strides.x; 43 mat2_id += mat2_strides.y; 44 } 45 46 const int out_id = out_idx.x * out_strides.x + out_idx.y * out_strides.y; 47 t_out[out_id] = sum; 48} 49