xref: /aosp_15_r20/external/angle/third_party/spirv-tools/src/source/val/validate_tensor_layout.cpp (revision 8975f5c5ed3d1c378011245431ada316dfb6f244)
1 // Copyright (c) 2024 NVIDIA Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Validate instructions that manipulate tensor layout and view objects
16 
17 #include "source/opcode.h"
18 #include "source/spirv_target_env.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22 
23 namespace spvtools {
24 namespace val {
25 namespace {
26 
ValidateTensorLayoutResultTypeNV(ValidationState_t & _,const Instruction * inst)27 spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _,
28                                               const Instruction* inst) {
29   const auto result_type_index = 0;
30   const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
31   const auto result_type = _.FindDef(result_type_id);
32 
33   if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) {
34     return _.diag(SPV_ERROR_INVALID_ID, inst)
35            << spvOpcodeString(inst->opcode()) << " Result Type <id> "
36            << _.getIdName(result_type_id) << " is not a tensor layout type.";
37   }
38   return SPV_SUCCESS;
39 }
40 
ValidateTensorViewResultTypeNV(ValidationState_t & _,const Instruction * inst)41 spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _,
42                                             const Instruction* inst) {
43   const auto result_type_index = 0;
44   const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
45   const auto result_type = _.FindDef(result_type_id);
46 
47   if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) {
48     return _.diag(SPV_ERROR_INVALID_ID, inst)
49            << spvOpcodeString(inst->opcode()) << " Result Type <id> "
50            << _.getIdName(result_type_id) << " is not a tensor view type.";
51   }
52   return SPV_SUCCESS;
53 }
54 
ValidateCreateTensorLayoutNV(ValidationState_t & _,const Instruction * inst)55 spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _,
56                                           const Instruction* inst) {
57   if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
58 
59   return SPV_SUCCESS;
60 }
61 
ValidateCreateTensorViewNV(ValidationState_t & _,const Instruction * inst)62 spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _,
63                                         const Instruction* inst) {
64   if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
65 
66   return SPV_SUCCESS;
67 }
68 
69 enum ExpectedNumValues {
70   DIM,
71   DIMx2,
72   ONE,
73   FOUR,
74 };
75 
ValidateTensorTypeWithDimValuesNV(ValidationState_t & _,const Instruction * inst,ExpectedNumValues expected,bool is_view)76 spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _,
77                                                const Instruction* inst,
78                                                ExpectedNumValues expected,
79                                                bool is_view) {
80   std::string type_str;
81   if (is_view) {
82     if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
83     type_str = "TensorView";
84   } else {
85     if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
86     type_str = "TensorLayout";
87   }
88 
89   const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
90   const auto tensor_id = inst->GetOperandAs<uint32_t>(2);
91   const auto tensor = _.FindDef(tensor_id);
92   if (!tensor || result_type_id != tensor->type_id()) {
93     return _.diag(SPV_ERROR_INVALID_ID, inst)
94            << spvOpcodeString(inst->opcode()) << " Result Type <id> "
95            << _.getIdName(result_type_id) << " does not match " << type_str
96            << " type.";
97   }
98 
99   const auto num_values = inst->operands().size() - 3;
100 
101   const auto result_type = _.FindDef(result_type_id);
102   const auto dim_index = 1;
103   const auto dim_id = result_type->GetOperandAs<uint32_t>(dim_index);
104   uint64_t dim_value;
105   if (_.EvalConstantValUint64(dim_id, &dim_value)) {
106     uint64_t expected_num_values = 0;
107     switch (expected) {
108       case DIM:
109         expected_num_values = dim_value;
110         break;
111       case DIMx2:
112         expected_num_values = dim_value * 2;
113         break;
114       case ONE:
115         expected_num_values = 1;
116         break;
117       case FOUR:
118         expected_num_values = 4;
119         break;
120     }
121 
122     if (num_values != expected_num_values) {
123       return _.diag(SPV_ERROR_INVALID_ID, inst)
124              << spvOpcodeString(inst->opcode())
125              << " unexpected number of operands.";
126     }
127   }
128 
129   for (uint32_t i = 0; i < num_values; ++i) {
130     const auto val_id = inst->GetOperandAs<uint32_t>(i + 3);
131     const auto val = _.FindDef(val_id);
132     if (!val || !_.IsIntScalarType(val->type_id()) ||
133         _.GetBitWidth(val->type_id()) != 32) {
134       return _.diag(SPV_ERROR_INVALID_ID, inst)
135              << spvOpcodeString(inst->opcode()) << " operand <id> "
136              << _.getIdName(val_id) << " is not a 32-bit integer.";
137     }
138   }
139 
140   return SPV_SUCCESS;
141 }
142 
143 }  // namespace
144 
TensorLayoutPass(ValidationState_t & _,const Instruction * inst)145 spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) {
146   switch (inst->opcode()) {
147     case spv::Op::OpCreateTensorLayoutNV:
148       if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error;
149       break;
150     case spv::Op::OpCreateTensorViewNV:
151       if (auto error = ValidateCreateTensorViewNV(_, inst)) return error;
152       break;
153     case spv::Op::OpTensorLayoutSetBlockSizeNV:
154     case spv::Op::OpTensorLayoutSetDimensionNV:
155     case spv::Op::OpTensorLayoutSetStrideNV:
156       if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false))
157         return error;
158       break;
159     case spv::Op::OpTensorLayoutSliceNV:
160       if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false))
161         return error;
162       break;
163     case spv::Op::OpTensorLayoutSetClampValueNV:
164       if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false))
165         return error;
166       break;
167     case spv::Op::OpTensorViewSetDimensionNV:
168     case spv::Op::OpTensorViewSetStrideNV:
169       if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true))
170         return error;
171       break;
172     case spv::Op::OpTensorViewSetClipNV:
173       if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true))
174         return error;
175       break;
176     default:
177       break;
178   }
179 
180   return SPV_SUCCESS;
181 }
182 
183 }  // namespace val
184 }  // namespace spvtools
185