1 // Copyright (c) 2024 NVIDIA Corporation
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Validate instructions that manipulate tensor layout and view objects
16
17 #include "source/opcode.h"
18 #include "source/spirv_target_env.h"
19 #include "source/val/instruction.h"
20 #include "source/val/validate.h"
21 #include "source/val/validation_state.h"
22
23 namespace spvtools {
24 namespace val {
25 namespace {
26
ValidateTensorLayoutResultTypeNV(ValidationState_t & _,const Instruction * inst)27 spv_result_t ValidateTensorLayoutResultTypeNV(ValidationState_t& _,
28 const Instruction* inst) {
29 const auto result_type_index = 0;
30 const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
31 const auto result_type = _.FindDef(result_type_id);
32
33 if (!result_type || spv::Op::OpTypeTensorLayoutNV != result_type->opcode()) {
34 return _.diag(SPV_ERROR_INVALID_ID, inst)
35 << spvOpcodeString(inst->opcode()) << " Result Type <id> "
36 << _.getIdName(result_type_id) << " is not a tensor layout type.";
37 }
38 return SPV_SUCCESS;
39 }
40
ValidateTensorViewResultTypeNV(ValidationState_t & _,const Instruction * inst)41 spv_result_t ValidateTensorViewResultTypeNV(ValidationState_t& _,
42 const Instruction* inst) {
43 const auto result_type_index = 0;
44 const auto result_type_id = inst->GetOperandAs<uint32_t>(result_type_index);
45 const auto result_type = _.FindDef(result_type_id);
46
47 if (!result_type || spv::Op::OpTypeTensorViewNV != result_type->opcode()) {
48 return _.diag(SPV_ERROR_INVALID_ID, inst)
49 << spvOpcodeString(inst->opcode()) << " Result Type <id> "
50 << _.getIdName(result_type_id) << " is not a tensor view type.";
51 }
52 return SPV_SUCCESS;
53 }
54
ValidateCreateTensorLayoutNV(ValidationState_t & _,const Instruction * inst)55 spv_result_t ValidateCreateTensorLayoutNV(ValidationState_t& _,
56 const Instruction* inst) {
57 if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
58
59 return SPV_SUCCESS;
60 }
61
ValidateCreateTensorViewNV(ValidationState_t & _,const Instruction * inst)62 spv_result_t ValidateCreateTensorViewNV(ValidationState_t& _,
63 const Instruction* inst) {
64 if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
65
66 return SPV_SUCCESS;
67 }
68
69 enum ExpectedNumValues {
70 DIM,
71 DIMx2,
72 ONE,
73 FOUR,
74 };
75
ValidateTensorTypeWithDimValuesNV(ValidationState_t & _,const Instruction * inst,ExpectedNumValues expected,bool is_view)76 spv_result_t ValidateTensorTypeWithDimValuesNV(ValidationState_t& _,
77 const Instruction* inst,
78 ExpectedNumValues expected,
79 bool is_view) {
80 std::string type_str;
81 if (is_view) {
82 if (auto error = ValidateTensorViewResultTypeNV(_, inst)) return error;
83 type_str = "TensorView";
84 } else {
85 if (auto error = ValidateTensorLayoutResultTypeNV(_, inst)) return error;
86 type_str = "TensorLayout";
87 }
88
89 const auto result_type_id = inst->GetOperandAs<uint32_t>(0);
90 const auto tensor_id = inst->GetOperandAs<uint32_t>(2);
91 const auto tensor = _.FindDef(tensor_id);
92 if (!tensor || result_type_id != tensor->type_id()) {
93 return _.diag(SPV_ERROR_INVALID_ID, inst)
94 << spvOpcodeString(inst->opcode()) << " Result Type <id> "
95 << _.getIdName(result_type_id) << " does not match " << type_str
96 << " type.";
97 }
98
99 const auto num_values = inst->operands().size() - 3;
100
101 const auto result_type = _.FindDef(result_type_id);
102 const auto dim_index = 1;
103 const auto dim_id = result_type->GetOperandAs<uint32_t>(dim_index);
104 uint64_t dim_value;
105 if (_.EvalConstantValUint64(dim_id, &dim_value)) {
106 uint64_t expected_num_values = 0;
107 switch (expected) {
108 case DIM:
109 expected_num_values = dim_value;
110 break;
111 case DIMx2:
112 expected_num_values = dim_value * 2;
113 break;
114 case ONE:
115 expected_num_values = 1;
116 break;
117 case FOUR:
118 expected_num_values = 4;
119 break;
120 }
121
122 if (num_values != expected_num_values) {
123 return _.diag(SPV_ERROR_INVALID_ID, inst)
124 << spvOpcodeString(inst->opcode())
125 << " unexpected number of operands.";
126 }
127 }
128
129 for (uint32_t i = 0; i < num_values; ++i) {
130 const auto val_id = inst->GetOperandAs<uint32_t>(i + 3);
131 const auto val = _.FindDef(val_id);
132 if (!val || !_.IsIntScalarType(val->type_id()) ||
133 _.GetBitWidth(val->type_id()) != 32) {
134 return _.diag(SPV_ERROR_INVALID_ID, inst)
135 << spvOpcodeString(inst->opcode()) << " operand <id> "
136 << _.getIdName(val_id) << " is not a 32-bit integer.";
137 }
138 }
139
140 return SPV_SUCCESS;
141 }
142
143 } // namespace
144
TensorLayoutPass(ValidationState_t & _,const Instruction * inst)145 spv_result_t TensorLayoutPass(ValidationState_t& _, const Instruction* inst) {
146 switch (inst->opcode()) {
147 case spv::Op::OpCreateTensorLayoutNV:
148 if (auto error = ValidateCreateTensorLayoutNV(_, inst)) return error;
149 break;
150 case spv::Op::OpCreateTensorViewNV:
151 if (auto error = ValidateCreateTensorViewNV(_, inst)) return error;
152 break;
153 case spv::Op::OpTensorLayoutSetBlockSizeNV:
154 case spv::Op::OpTensorLayoutSetDimensionNV:
155 case spv::Op::OpTensorLayoutSetStrideNV:
156 if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, false))
157 return error;
158 break;
159 case spv::Op::OpTensorLayoutSliceNV:
160 if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIMx2, false))
161 return error;
162 break;
163 case spv::Op::OpTensorLayoutSetClampValueNV:
164 if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, ONE, false))
165 return error;
166 break;
167 case spv::Op::OpTensorViewSetDimensionNV:
168 case spv::Op::OpTensorViewSetStrideNV:
169 if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, DIM, true))
170 return error;
171 break;
172 case spv::Op::OpTensorViewSetClipNV:
173 if (auto error = ValidateTensorTypeWithDimValuesNV(_, inst, FOUR, true))
174 return error;
175 break;
176 default:
177 break;
178 }
179
180 return SPV_SUCCESS;
181 }
182
183 } // namespace val
184 } // namespace spvtools
185