1 // Copyright (c) 2019 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <string>
16 #include <tuple>
17 
18 #include "gmock/gmock.h"
19 #include "test/unit_spirv.h"
20 #include "test/val/val_fixtures.h"
21 
22 namespace spvtools {
23 namespace val {
24 namespace {
25 
26 using ::testing::Combine;
27 using ::testing::HasSubstr;
28 using ::testing::Values;
29 
30 using ValidateFunctionCall = spvtest::ValidateBase<std::string>;
31 
GenerateShader(const std::string & storage_class,const std::string & capabilities,const std::string & extensions)32 std::string GenerateShader(const std::string& storage_class,
33                            const std::string& capabilities,
34                            const std::string& extensions) {
35   std::string spirv = R"(
36 OpCapability Shader
37 OpCapability Linkage
38 OpCapability AtomicStorage
39 )" + capabilities + R"(
40 OpExtension "SPV_KHR_storage_buffer_storage_class"
41 )" +
42                       extensions + R"(
43 OpMemoryModel Logical GLSL450
44 OpName %var "var"
45 %void = OpTypeVoid
46 %int = OpTypeInt 32 0
47 %ptr = OpTypePointer )" + storage_class + R"( %int
48 %caller_ty = OpTypeFunction %void
49 %callee_ty = OpTypeFunction %void %ptr
50 )";
51 
52   if (storage_class != "Function") {
53     spirv += "%var = OpVariable %ptr " + storage_class;
54   }
55 
56   spirv += R"(
57 %caller = OpFunction %void None %caller_ty
58 %1 = OpLabel
59 )";
60 
61   if (storage_class == "Function") {
62     spirv += "%var = OpVariable %ptr Function";
63   }
64 
65   spirv += R"(
66 %call = OpFunctionCall %void %callee %var
67 OpReturn
68 OpFunctionEnd
69 %callee = OpFunction %void None %callee_ty
70 %param = OpFunctionParameter %ptr
71 %2 = OpLabel
72 OpReturn
73 OpFunctionEnd
74 )";
75 
76   return spirv;
77 }
78 
GenerateShaderParameter(const std::string & storage_class,const std::string & capabilities,const std::string & extensions)79 std::string GenerateShaderParameter(const std::string& storage_class,
80                                     const std::string& capabilities,
81                                     const std::string& extensions) {
82   std::string spirv = R"(
83 OpCapability Shader
84 OpCapability Linkage
85 OpCapability AtomicStorage
86 )" + capabilities + R"(
87 OpExtension "SPV_KHR_storage_buffer_storage_class"
88 )" +
89                       extensions + R"(
90 OpMemoryModel Logical GLSL450
91 OpName %p "p"
92 %void = OpTypeVoid
93 %int = OpTypeInt 32 0
94 %ptr = OpTypePointer )" + storage_class + R"( %int
95 %func_ty = OpTypeFunction %void %ptr
96 %caller = OpFunction %void None %func_ty
97 %p = OpFunctionParameter %ptr
98 %1 = OpLabel
99 %call = OpFunctionCall %void %callee %p
100 OpReturn
101 OpFunctionEnd
102 %callee = OpFunction %void None %func_ty
103 %param = OpFunctionParameter %ptr
104 %2 = OpLabel
105 OpReturn
106 OpFunctionEnd
107 )";
108 
109   return spirv;
110 }
111 
GenerateShaderAccessChain(const std::string & storage_class,const std::string & capabilities,const std::string & extensions)112 std::string GenerateShaderAccessChain(const std::string& storage_class,
113                                       const std::string& capabilities,
114                                       const std::string& extensions) {
115   std::string spirv = R"(
116 OpCapability Shader
117 OpCapability Linkage
118 OpCapability AtomicStorage
119 )" + capabilities + R"(
120 OpExtension "SPV_KHR_storage_buffer_storage_class"
121 )" +
122                       extensions + R"(
123 OpMemoryModel Logical GLSL450
124 OpName %var "var"
125 OpName %gep "gep"
126 %void = OpTypeVoid
127 %int = OpTypeInt 32 0
128 %int2 = OpTypeVector %int 2
129 %int_0 = OpConstant %int 0
130 %ptr = OpTypePointer )" + storage_class + R"( %int2
131 %ptr2 = OpTypePointer )" +
132                       storage_class + R"( %int
133 %caller_ty = OpTypeFunction %void
134 %callee_ty = OpTypeFunction %void %ptr2
135 )";
136 
137   if (storage_class != "Function") {
138     spirv += "%var = OpVariable %ptr " + storage_class;
139   }
140 
141   spirv += R"(
142 %caller = OpFunction %void None %caller_ty
143 %1 = OpLabel
144 )";
145 
146   if (storage_class == "Function") {
147     spirv += "%var = OpVariable %ptr Function";
148   }
149 
150   spirv += R"(
151 %gep = OpAccessChain %ptr2 %var %int_0
152 %call = OpFunctionCall %void %callee %gep
153 OpReturn
154 OpFunctionEnd
155 %callee = OpFunction %void None %callee_ty
156 %param = OpFunctionParameter %ptr2
157 %2 = OpLabel
158 OpReturn
159 OpFunctionEnd
160 )";
161 
162   return spirv;
163 }
164 
TEST_P(ValidateFunctionCall,VariableNoVariablePointers)165 TEST_P(ValidateFunctionCall, VariableNoVariablePointers) {
166   const std::string storage_class = GetParam();
167 
168   std::string spirv = GenerateShader(storage_class, "", "");
169 
170   const std::vector<std::string> valid_storage_classes = {
171       "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
172   bool valid =
173       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
174                 storage_class) != valid_storage_classes.end();
175 
176   CompileSuccessfully(spirv);
177   if (valid) {
178     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
179   } else {
180     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
181     if (storage_class == "StorageBuffer") {
182       EXPECT_THAT(
183           getDiagnosticString(),
184           HasSubstr("StorageBuffer pointer operand '1[%var]' requires a "
185                     "variable pointers capability"));
186     } else {
187       EXPECT_THAT(
188           getDiagnosticString(),
189           HasSubstr("Invalid storage class for pointer operand '1[%var]'"));
190     }
191   }
192 }
193 
TEST_P(ValidateFunctionCall,VariableVariablePointersStorageClass)194 TEST_P(ValidateFunctionCall, VariableVariablePointersStorageClass) {
195   const std::string storage_class = GetParam();
196 
197   std::string spirv = GenerateShader(
198       storage_class, "OpCapability VariablePointersStorageBuffer",
199       "OpExtension \"SPV_KHR_variable_pointers\"");
200 
201   const std::vector<std::string> valid_storage_classes = {
202       "UniformConstant", "Function",      "Private",
203       "Workgroup",       "StorageBuffer", "AtomicCounter"};
204   bool valid =
205       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
206                 storage_class) != valid_storage_classes.end();
207 
208   CompileSuccessfully(spirv);
209   if (valid) {
210     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
211   } else {
212     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
213     EXPECT_THAT(
214         getDiagnosticString(),
215         HasSubstr("Invalid storage class for pointer operand '1[%var]'"));
216   }
217 }
218 
TEST_P(ValidateFunctionCall,VariableVariablePointers)219 TEST_P(ValidateFunctionCall, VariableVariablePointers) {
220   const std::string storage_class = GetParam();
221 
222   std::string spirv =
223       GenerateShader(storage_class, "OpCapability VariablePointers",
224                      "OpExtension \"SPV_KHR_variable_pointers\"");
225 
226   const std::vector<std::string> valid_storage_classes = {
227       "UniformConstant", "Function",      "Private",
228       "Workgroup",       "StorageBuffer", "AtomicCounter"};
229   bool valid =
230       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
231                 storage_class) != valid_storage_classes.end();
232 
233   CompileSuccessfully(spirv);
234   if (valid) {
235     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
236   } else {
237     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
238     EXPECT_THAT(
239         getDiagnosticString(),
240         HasSubstr("Invalid storage class for pointer operand '1[%var]'"));
241   }
242 }
243 
TEST_P(ValidateFunctionCall,ParameterNoVariablePointers)244 TEST_P(ValidateFunctionCall, ParameterNoVariablePointers) {
245   const std::string storage_class = GetParam();
246 
247   std::string spirv = GenerateShaderParameter(storage_class, "", "");
248 
249   const std::vector<std::string> valid_storage_classes = {
250       "UniformConstant", "Function", "Private", "Workgroup", "AtomicCounter"};
251   bool valid =
252       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
253                 storage_class) != valid_storage_classes.end();
254 
255   CompileSuccessfully(spirv);
256   if (valid) {
257     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
258   } else {
259     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
260     if (storage_class == "StorageBuffer") {
261       EXPECT_THAT(getDiagnosticString(),
262                   HasSubstr("StorageBuffer pointer operand '1[%p]' requires a "
263                             "variable pointers capability"));
264     } else {
265       EXPECT_THAT(
266           getDiagnosticString(),
267           HasSubstr("Invalid storage class for pointer operand '1[%p]'"));
268     }
269   }
270 }
271 
TEST_P(ValidateFunctionCall,ParameterVariablePointersStorageBuffer)272 TEST_P(ValidateFunctionCall, ParameterVariablePointersStorageBuffer) {
273   const std::string storage_class = GetParam();
274 
275   std::string spirv = GenerateShaderParameter(
276       storage_class, "OpCapability VariablePointersStorageBuffer",
277       "OpExtension \"SPV_KHR_variable_pointers\"");
278 
279   const std::vector<std::string> valid_storage_classes = {
280       "UniformConstant", "Function",      "Private",
281       "Workgroup",       "StorageBuffer", "AtomicCounter"};
282   bool valid =
283       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
284                 storage_class) != valid_storage_classes.end();
285 
286   CompileSuccessfully(spirv);
287   if (valid) {
288     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
289   } else {
290     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
291     EXPECT_THAT(getDiagnosticString(),
292                 HasSubstr("Invalid storage class for pointer operand '1[%p]'"));
293   }
294 }
295 
TEST_P(ValidateFunctionCall,ParameterVariablePointers)296 TEST_P(ValidateFunctionCall, ParameterVariablePointers) {
297   const std::string storage_class = GetParam();
298 
299   std::string spirv =
300       GenerateShaderParameter(storage_class, "OpCapability VariablePointers",
301                               "OpExtension \"SPV_KHR_variable_pointers\"");
302 
303   const std::vector<std::string> valid_storage_classes = {
304       "UniformConstant", "Function",      "Private",
305       "Workgroup",       "StorageBuffer", "AtomicCounter"};
306   bool valid =
307       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
308                 storage_class) != valid_storage_classes.end();
309 
310   CompileSuccessfully(spirv);
311   if (valid) {
312     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
313   } else {
314     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
315     EXPECT_THAT(getDiagnosticString(),
316                 HasSubstr("Invalid storage class for pointer operand '1[%p]'"));
317   }
318 }
319 
TEST_P(ValidateFunctionCall,NonMemoryObjectDeclarationNoVariablePointers)320 TEST_P(ValidateFunctionCall, NonMemoryObjectDeclarationNoVariablePointers) {
321   const std::string storage_class = GetParam();
322 
323   std::string spirv = GenerateShaderAccessChain(storage_class, "", "");
324 
325   const std::vector<std::string> valid_storage_classes = {
326       "Function", "Private", "Workgroup", "AtomicCounter"};
327   bool valid_sc =
328       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
329                 storage_class) != valid_storage_classes.end();
330 
331   CompileSuccessfully(spirv);
332   spv_result_t expected_result =
333       storage_class == "UniformConstant" ? SPV_SUCCESS : SPV_ERROR_INVALID_ID;
334   EXPECT_EQ(expected_result, ValidateInstructions());
335   if (valid_sc) {
336     EXPECT_THAT(
337         getDiagnosticString(),
338         HasSubstr(
339             "Pointer operand '2[%gep]' must be a memory object declaration"));
340   } else {
341     if (storage_class == "StorageBuffer") {
342       EXPECT_THAT(
343           getDiagnosticString(),
344           HasSubstr("StorageBuffer pointer operand '2[%gep]' requires a "
345                     "variable pointers capability"));
346     } else if (storage_class != "UniformConstant") {
347       EXPECT_THAT(
348           getDiagnosticString(),
349           HasSubstr("Invalid storage class for pointer operand '2[%gep]'"));
350     }
351   }
352 }
353 
TEST_P(ValidateFunctionCall,NonMemoryObjectDeclarationVariablePointersStorageBuffer)354 TEST_P(ValidateFunctionCall,
355        NonMemoryObjectDeclarationVariablePointersStorageBuffer) {
356   const std::string storage_class = GetParam();
357 
358   std::string spirv = GenerateShaderAccessChain(
359       storage_class, "OpCapability VariablePointersStorageBuffer",
360       "OpExtension \"SPV_KHR_variable_pointers\"");
361 
362   const std::vector<std::string> valid_storage_classes = {
363       "Function", "Private", "Workgroup", "StorageBuffer", "AtomicCounter"};
364   bool valid_sc =
365       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
366                 storage_class) != valid_storage_classes.end();
367   bool validate =
368       storage_class == "StorageBuffer" || storage_class == "UniformConstant";
369 
370   CompileSuccessfully(spirv);
371   if (validate) {
372     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
373   } else {
374     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
375     if (valid_sc) {
376       EXPECT_THAT(
377           getDiagnosticString(),
378           HasSubstr(
379               "Pointer operand '2[%gep]' must be a memory object declaration"));
380     } else {
381       EXPECT_THAT(
382           getDiagnosticString(),
383           HasSubstr("Invalid storage class for pointer operand '2[%gep]'"));
384     }
385   }
386 }
387 
TEST_P(ValidateFunctionCall,NonMemoryObjectDeclarationVariablePointers)388 TEST_P(ValidateFunctionCall, NonMemoryObjectDeclarationVariablePointers) {
389   const std::string storage_class = GetParam();
390 
391   std::string spirv =
392       GenerateShaderAccessChain(storage_class, "OpCapability VariablePointers",
393                                 "OpExtension \"SPV_KHR_variable_pointers\"");
394 
395   const std::vector<std::string> valid_storage_classes = {
396       "Function", "Private", "Workgroup", "StorageBuffer", "AtomicCounter"};
397   bool valid_sc =
398       std::find(valid_storage_classes.begin(), valid_storage_classes.end(),
399                 storage_class) != valid_storage_classes.end();
400   bool validate = storage_class == "StorageBuffer" ||
401                   storage_class == "Workgroup" ||
402                   storage_class == "UniformConstant";
403 
404   CompileSuccessfully(spirv);
405   if (validate) {
406     EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
407   } else {
408     EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
409     if (valid_sc) {
410       EXPECT_THAT(
411           getDiagnosticString(),
412           HasSubstr(
413               "Pointer operand '2[%gep]' must be a memory object declaration"));
414     } else {
415       EXPECT_THAT(
416           getDiagnosticString(),
417           HasSubstr("Invalid storage class for pointer operand '2[%gep]'"));
418     }
419   }
420 }
421 
TEST_F(ValidateFunctionCall,LogicallyMatchingPointers)422 TEST_F(ValidateFunctionCall, LogicallyMatchingPointers) {
423   std::string spirv =
424       R"(
425                OpCapability Shader
426                OpMemoryModel Logical GLSL450
427                OpEntryPoint GLCompute %1 "main"
428                OpExecutionMode %1 LocalSize 1 1 1
429                OpSource HLSL 600
430                OpDecorate %2 DescriptorSet 0
431                OpDecorate %2 Binding 0
432                OpMemberDecorate %_struct_3 0 Offset 0
433                OpDecorate %_runtimearr__struct_3 ArrayStride 4
434                OpMemberDecorate %_struct_5 0 Offset 0
435                OpDecorate %_struct_5 BufferBlock
436         %int = OpTypeInt 32 1
437       %int_0 = OpConstant %int 0
438        %uint = OpTypeInt 32 0
439      %uint_0 = OpConstant %uint 0
440   %_struct_3 = OpTypeStruct %int
441 %_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3
442   %_struct_5 = OpTypeStruct %_runtimearr__struct_3
443 %_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5
444        %void = OpTypeVoid
445          %14 = OpTypeFunction %void
446  %_struct_15 = OpTypeStruct %int
447 %_ptr_Function__struct_15 = OpTypePointer Function %_struct_15
448 %_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3
449          %18 = OpTypeFunction %void %_ptr_Function__struct_15
450           %2 = OpVariable %_ptr_Uniform__struct_5 Uniform
451           %1 = OpFunction %void None %14
452          %19 = OpLabel
453          %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0
454          %21 = OpFunctionCall %void %22 %20
455                OpReturn
456                OpFunctionEnd
457          %22 = OpFunction %void None %18
458          %23 = OpFunctionParameter %_ptr_Function__struct_15
459          %24 = OpLabel
460                OpReturn
461                OpFunctionEnd
462 )";
463   CompileSuccessfully(spirv);
464   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
465   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
466 }
467 
TEST_F(ValidateFunctionCall,LogicallyMatchingPointersNestedStruct)468 TEST_F(ValidateFunctionCall, LogicallyMatchingPointersNestedStruct) {
469   std::string spirv =
470       R"(
471                OpCapability Shader
472                OpMemoryModel Logical GLSL450
473                OpEntryPoint GLCompute %1 "main"
474                OpExecutionMode %1 LocalSize 1 1 1
475                OpSource HLSL 600
476                OpDecorate %2 DescriptorSet 0
477                OpDecorate %2 Binding 0
478                OpMemberDecorate %_struct_3 0 Offset 0
479                OpMemberDecorate %_struct_4 0 Offset 0
480                OpDecorate %_runtimearr__struct_4 ArrayStride 4
481                OpMemberDecorate %_struct_6 0 Offset 0
482                OpDecorate %_struct_6 BufferBlock
483         %int = OpTypeInt 32 1
484       %int_0 = OpConstant %int 0
485        %uint = OpTypeInt 32 0
486      %uint_0 = OpConstant %uint 0
487   %_struct_3 = OpTypeStruct %int
488   %_struct_4 = OpTypeStruct %_struct_3
489 %_runtimearr__struct_4 = OpTypeRuntimeArray %_struct_4
490   %_struct_6 = OpTypeStruct %_runtimearr__struct_4
491 %_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6
492        %void = OpTypeVoid
493          %13 = OpTypeFunction %void
494  %_struct_14 = OpTypeStruct %int
495  %_struct_15 = OpTypeStruct %_struct_14
496 %_ptr_Function__struct_15 = OpTypePointer Function %_struct_15
497 %_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4
498          %18 = OpTypeFunction %void %_ptr_Function__struct_15
499           %2 = OpVariable %_ptr_Uniform__struct_6 Uniform
500           %1 = OpFunction %void None %13
501          %19 = OpLabel
502          %20 = OpVariable %_ptr_Function__struct_15 Function
503          %21 = OpAccessChain %_ptr_Uniform__struct_4 %2 %int_0 %uint_0
504          %22 = OpFunctionCall %void %23 %21
505                OpReturn
506                OpFunctionEnd
507          %23 = OpFunction %void None %18
508          %24 = OpFunctionParameter %_ptr_Function__struct_15
509          %25 = OpLabel
510                OpReturn
511                OpFunctionEnd
512 )";
513 
514   CompileSuccessfully(spirv);
515   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
516   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
517 }
518 
TEST_F(ValidateFunctionCall,LogicallyMatchingPointersNestedArray)519 TEST_F(ValidateFunctionCall, LogicallyMatchingPointersNestedArray) {
520   std::string spirv =
521       R"(
522               OpCapability Shader
523                OpMemoryModel Logical GLSL450
524                OpEntryPoint GLCompute %1 "main"
525                OpExecutionMode %1 LocalSize 1 1 1
526                OpSource HLSL 600
527                OpDecorate %2 DescriptorSet 0
528                OpDecorate %2 Binding 0
529                OpDecorate %_arr_int_uint_10 ArrayStride 4
530                OpMemberDecorate %_struct_4 0 Offset 0
531                OpDecorate %_runtimearr__struct_4 ArrayStride 40
532                OpMemberDecorate %_struct_6 0 Offset 0
533                OpDecorate %_struct_6 BufferBlock
534         %int = OpTypeInt 32 1
535       %int_0 = OpConstant %int 0
536        %uint = OpTypeInt 32 0
537      %uint_0 = OpConstant %uint 0
538     %uint_10 = OpConstant %uint 10
539 %_arr_int_uint_10 = OpTypeArray %int %uint_10
540   %_struct_4 = OpTypeStruct %_arr_int_uint_10
541 %_runtimearr__struct_4 = OpTypeRuntimeArray %_struct_4
542   %_struct_6 = OpTypeStruct %_runtimearr__struct_4
543 %_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6
544        %void = OpTypeVoid
545          %14 = OpTypeFunction %void
546 %_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4
547 %_arr_int_uint_10_0 = OpTypeArray %int %uint_10
548  %_struct_17 = OpTypeStruct %_arr_int_uint_10_0
549 %_ptr_Function__struct_17 = OpTypePointer Function %_struct_17
550          %19 = OpTypeFunction %void %_ptr_Function__struct_17
551           %2 = OpVariable %_ptr_Uniform__struct_6 Uniform
552           %1 = OpFunction %void None %14
553          %20 = OpLabel
554          %21 = OpAccessChain %_ptr_Uniform__struct_4 %2 %int_0 %uint_0
555          %22 = OpFunctionCall %void %23 %21
556                OpReturn
557                OpFunctionEnd
558          %23 = OpFunction %void None %19
559          %24 = OpFunctionParameter %_ptr_Function__struct_17
560          %25 = OpLabel
561                OpReturn
562                OpFunctionEnd
563 )";
564 
565   CompileSuccessfully(spirv);
566   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
567   EXPECT_EQ(SPV_SUCCESS, ValidateInstructions());
568 }
569 
TEST_F(ValidateFunctionCall,LogicallyMismatchedPointersMissingMember)570 TEST_F(ValidateFunctionCall, LogicallyMismatchedPointersMissingMember) {
571   //  Validation should fail because the formal parameter type has two members,
572   //  while the actual parameter only has 1.
573   std::string spirv =
574       R"(
575                OpCapability Shader
576                OpMemoryModel Logical GLSL450
577                OpEntryPoint GLCompute %1 "main"
578                OpExecutionMode %1 LocalSize 1 1 1
579                OpSource HLSL 600
580                OpDecorate %2 DescriptorSet 0
581                OpDecorate %2 Binding 0
582                OpMemberDecorate %_struct_3 0 Offset 0
583                OpDecorate %_runtimearr__struct_3 ArrayStride 4
584                OpMemberDecorate %_struct_5 0 Offset 0
585                OpDecorate %_struct_5 BufferBlock
586         %int = OpTypeInt 32 1
587       %int_0 = OpConstant %int 0
588        %uint = OpTypeInt 32 0
589      %uint_0 = OpConstant %uint 0
590   %_struct_3 = OpTypeStruct %int
591 %_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3
592   %_struct_5 = OpTypeStruct %_runtimearr__struct_3
593 %_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5
594        %void = OpTypeVoid
595          %14 = OpTypeFunction %void
596  %_struct_15 = OpTypeStruct %int %int
597 %_ptr_Function__struct_15 = OpTypePointer Function %_struct_15
598 %_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3
599          %18 = OpTypeFunction %void %_ptr_Function__struct_15
600           %2 = OpVariable %_ptr_Uniform__struct_5 Uniform
601           %1 = OpFunction %void None %14
602          %19 = OpLabel
603          %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0
604          %21 = OpFunctionCall %void %22 %20
605                OpReturn
606                OpFunctionEnd
607          %22 = OpFunction %void None %18
608          %23 = OpFunctionParameter %_ptr_Function__struct_15
609          %24 = OpLabel
610                OpReturn
611                OpFunctionEnd
612 )";
613 
614   CompileSuccessfully(spirv);
615   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
616   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
617   EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument <id>"));
618   EXPECT_THAT(getDiagnosticString(),
619               HasSubstr("type does not match Function <id>"));
620 }
621 
TEST_F(ValidateFunctionCall,LogicallyMismatchedPointersDifferentMemberType)622 TEST_F(ValidateFunctionCall, LogicallyMismatchedPointersDifferentMemberType) {
623   //  Validation should fail because the formal parameter has a member that is
624   // a different type than the actual parameter.
625   std::string spirv =
626       R"(
627                OpCapability Shader
628                OpMemoryModel Logical GLSL450
629                OpEntryPoint GLCompute %1 "main"
630                OpExecutionMode %1 LocalSize 1 1 1
631                OpSource HLSL 600
632                OpDecorate %2 DescriptorSet 0
633                OpDecorate %2 Binding 0
634                OpMemberDecorate %_struct_3 0 Offset 0
635                OpDecorate %_runtimearr__struct_3 ArrayStride 4
636                OpMemberDecorate %_struct_5 0 Offset 0
637                OpDecorate %_struct_5 BufferBlock
638         %int = OpTypeInt 32 1
639       %int_0 = OpConstant %int 0
640        %uint = OpTypeInt 32 0
641      %uint_0 = OpConstant %uint 0
642   %_struct_3 = OpTypeStruct %uint
643 %_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3
644   %_struct_5 = OpTypeStruct %_runtimearr__struct_3
645 %_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5
646        %void = OpTypeVoid
647          %14 = OpTypeFunction %void
648  %_struct_15 = OpTypeStruct %int
649 %_ptr_Function__struct_15 = OpTypePointer Function %_struct_15
650 %_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3
651          %18 = OpTypeFunction %void %_ptr_Function__struct_15
652           %2 = OpVariable %_ptr_Uniform__struct_5 Uniform
653           %1 = OpFunction %void None %14
654          %19 = OpLabel
655          %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0
656          %21 = OpFunctionCall %void %22 %20
657                OpReturn
658                OpFunctionEnd
659          %22 = OpFunction %void None %18
660          %23 = OpFunctionParameter %_ptr_Function__struct_15
661          %24 = OpLabel
662                OpReturn
663                OpFunctionEnd
664 )";
665 
666   CompileSuccessfully(spirv);
667   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
668   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
669   EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument <id>"));
670   EXPECT_THAT(getDiagnosticString(),
671               HasSubstr("type does not match Function <id>"));
672 }
673 
TEST_F(ValidateFunctionCall,LogicallyMismatchedPointersIncompatableDecorations)674 TEST_F(ValidateFunctionCall,
675        LogicallyMismatchedPointersIncompatableDecorations) {
676   //  Validation should fail because the formal parameter has an incompatible
677   //  decoration.
678   std::string spirv =
679       R"(
680                OpCapability Shader
681                OpMemoryModel Logical GLSL450
682                OpEntryPoint GLCompute %1 "main"
683                OpExecutionMode %1 LocalSize 1 1 1
684                OpSource HLSL 600
685                OpDecorate %2 DescriptorSet 0
686                OpDecorate %2 Binding 0
687                OpMemberDecorate %_struct_3 0 Offset 0
688                OpDecorate %_runtimearr__struct_3 ArrayStride 4
689                OpMemberDecorate %_struct_5 0 Offset 0
690                OpDecorate %_struct_5 Block
691                OpMemberDecorate %_struct_15 0 NonWritable
692         %int = OpTypeInt 32 1
693       %int_0 = OpConstant %int 0
694        %uint = OpTypeInt 32 0
695      %uint_0 = OpConstant %uint 0
696   %_struct_3 = OpTypeStruct %int
697 %_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3
698   %_struct_5 = OpTypeStruct %_runtimearr__struct_3
699 %_ptr_StorageBuffer__struct_5 = OpTypePointer StorageBuffer %_struct_5
700        %void = OpTypeVoid
701          %14 = OpTypeFunction %void
702  %_struct_15 = OpTypeStruct %int
703 %_ptr_Function__struct_15 = OpTypePointer Function %_struct_15
704 %_ptr_StorageBuffer__struct_3 = OpTypePointer StorageBuffer %_struct_3
705          %18 = OpTypeFunction %void %_ptr_Function__struct_15
706           %2 = OpVariable %_ptr_StorageBuffer__struct_5 StorageBuffer
707           %1 = OpFunction %void None %14
708          %19 = OpLabel
709          %20 = OpAccessChain %_ptr_StorageBuffer__struct_3 %2 %int_0 %uint_0
710          %21 = OpFunctionCall %void %22 %20
711                OpReturn
712                OpFunctionEnd
713          %22 = OpFunction %void None %18
714          %23 = OpFunctionParameter %_ptr_Function__struct_15
715          %24 = OpLabel
716                OpReturn
717                OpFunctionEnd
718 )";
719 
720   CompileSuccessfully(spirv, SPV_ENV_UNIVERSAL_1_4);
721   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
722   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions(SPV_ENV_UNIVERSAL_1_4));
723   EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument <id>"));
724   EXPECT_THAT(getDiagnosticString(),
725               HasSubstr("type does not match Function <id>"));
726 }
727 
TEST_F(ValidateFunctionCall,LogicallyMismatchedPointersIncompatableDecorations2)728 TEST_F(ValidateFunctionCall,
729        LogicallyMismatchedPointersIncompatableDecorations2) {
730   //  Validation should fail because the formal parameter has an incompatible
731   //  decoration.
732   std::string spirv =
733       R"(
734                OpCapability Shader
735                OpMemoryModel Logical GLSL450
736                OpEntryPoint GLCompute %1 "main"
737                OpExecutionMode %1 LocalSize 1 1 1
738                OpSource HLSL 600
739                OpDecorate %2 DescriptorSet 0
740                OpDecorate %2 Binding 0
741                OpMemberDecorate %_struct_3 0 Offset 0
742                OpDecorate %_runtimearr__struct_3 ArrayStride 4
743                OpMemberDecorate %_struct_5 0 Offset 0
744                OpDecorate %_struct_5 BufferBlock
745                OpDecorate %_ptr_Uniform__struct_3 ArrayStride 4
746                OpDecorate %_ptr_Uniform__struct_3_0 ArrayStride 8
747         %int = OpTypeInt 32 1
748       %int_0 = OpConstant %int 0
749        %uint = OpTypeInt 32 0
750      %uint_0 = OpConstant %uint 0
751   %_struct_3 = OpTypeStruct %int
752 %_runtimearr__struct_3 = OpTypeRuntimeArray %_struct_3
753   %_struct_5 = OpTypeStruct %_runtimearr__struct_3
754 %_ptr_Uniform__struct_5 = OpTypePointer Uniform %_struct_5
755        %void = OpTypeVoid
756          %14 = OpTypeFunction %void
757 %_ptr_Uniform__struct_3 = OpTypePointer Uniform %_struct_3
758 %_ptr_Uniform__struct_3_0 = OpTypePointer Uniform %_struct_3
759          %18 = OpTypeFunction %void %_ptr_Uniform__struct_3_0
760           %2 = OpVariable %_ptr_Uniform__struct_5 Uniform
761           %1 = OpFunction %void None %14
762          %19 = OpLabel
763          %20 = OpAccessChain %_ptr_Uniform__struct_3 %2 %int_0 %uint_0
764          %21 = OpFunctionCall %void %22 %20
765                OpReturn
766                OpFunctionEnd
767          %22 = OpFunction %void None %18
768          %23 = OpFunctionParameter %_ptr_Uniform__struct_3_0
769          %24 = OpLabel
770                OpReturn
771                OpFunctionEnd
772 )";
773 
774   CompileSuccessfully(spirv);
775   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
776   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
777   EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument <id>"));
778   EXPECT_THAT(getDiagnosticString(),
779               HasSubstr("type does not match Function <id>"));
780 }
781 
TEST_F(ValidateFunctionCall,LogicallyMismatchedPointersArraySize)782 TEST_F(ValidateFunctionCall, LogicallyMismatchedPointersArraySize) {
783   //  Validation should fail because the formal parameter array has a different
784   // number of element than the actual parameter.
785   std::string spirv =
786       R"(
787                OpCapability Shader
788                OpMemoryModel Logical GLSL450
789                OpEntryPoint GLCompute %1 "main"
790                OpExecutionMode %1 LocalSize 1 1 1
791                OpSource HLSL 600
792                OpDecorate %2 DescriptorSet 0
793                OpDecorate %2 Binding 0
794                OpDecorate %_arr_int_uint_10 ArrayStride 4
795                OpMemberDecorate %_struct_4 0 Offset 0
796                OpDecorate %_runtimearr__struct_4 ArrayStride 40
797                OpMemberDecorate %_struct_6 0 Offset 0
798                OpDecorate %_struct_6 BufferBlock
799         %int = OpTypeInt 32 1
800       %int_0 = OpConstant %int 0
801        %uint = OpTypeInt 32 0
802      %uint_0 = OpConstant %uint 0
803     %uint_5 = OpConstant %uint 5
804     %uint_10 = OpConstant %uint 10
805 %_arr_int_uint_10 = OpTypeArray %int %uint_10
806   %_struct_4 = OpTypeStruct %_arr_int_uint_10
807 %_runtimearr__struct_4 = OpTypeRuntimeArray %_struct_4
808   %_struct_6 = OpTypeStruct %_runtimearr__struct_4
809 %_ptr_Uniform__struct_6 = OpTypePointer Uniform %_struct_6
810        %void = OpTypeVoid
811          %14 = OpTypeFunction %void
812 %_ptr_Uniform__struct_4 = OpTypePointer Uniform %_struct_4
813 %_arr_int_uint_5 = OpTypeArray %int %uint_5
814  %_struct_17 = OpTypeStruct %_arr_int_uint_5
815 %_ptr_Function__struct_17 = OpTypePointer Function %_struct_17
816          %19 = OpTypeFunction %void %_ptr_Function__struct_17
817           %2 = OpVariable %_ptr_Uniform__struct_6 Uniform
818           %1 = OpFunction %void None %14
819          %20 = OpLabel
820          %21 = OpAccessChain %_ptr_Uniform__struct_4 %2 %int_0 %uint_0
821          %22 = OpFunctionCall %void %23 %21
822                OpReturn
823                OpFunctionEnd
824          %23 = OpFunction %void None %19
825          %24 = OpFunctionParameter %_ptr_Function__struct_17
826          %25 = OpLabel
827                OpReturn
828                OpFunctionEnd
829 )";
830 
831   CompileSuccessfully(spirv);
832   spvValidatorOptionsSetBeforeHlslLegalization(getValidatorOptions(), true);
833   EXPECT_EQ(SPV_ERROR_INVALID_ID, ValidateInstructions());
834   EXPECT_THAT(getDiagnosticString(), HasSubstr("OpFunctionCall Argument <id>"));
835   EXPECT_THAT(getDiagnosticString(),
836               HasSubstr("type does not match Function <id>"));
837 }
838 
839 INSTANTIATE_TEST_SUITE_P(StorageClass, ValidateFunctionCall,
840                          Values("UniformConstant", "Input", "Uniform", "Output",
841                                 "Workgroup", "Private", "Function",
842                                 "PushConstant", "Image", "StorageBuffer",
843                                 "AtomicCounter"));
844 }  // namespace
845 }  // namespace val
846 }  // namespace spvtools
847