1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
6
7 namespace torch::jit {
8
9 namespace onnx {
10 using namespace ::c10::onnx;
11 }
12
13 namespace {
14 const int ONNX_OPSET_14 = 14;
15
16 static const std::unordered_map<c10::ScalarType, int, ScalarTypeHashFunction>
17 scalarTypeToONNXTypeMap = {
18 {c10::kFloat, 1},
19 {c10::kByte, 2},
20 {c10::kChar, 3},
21 {c10::kShort, 5},
22 {c10::kInt, 6},
23 {c10::kLong, 7},
24 {c10::kBool, 9},
25 {c10::kHalf, 10},
26 {c10::kDouble, 11},
27 {c10::kQInt8, 12},
28 {c10::kQUInt8, 13},
29 {c10::kQInt32, 14},
30 {c10::kBFloat16, 15},
31 {c10::kFloat8_e4m3fn, 16},
32 {c10::kFloat8_e5m2, 17},
33 {c10::kFloat8_e4m3fnuz, 18},
34 {c10::kFloat8_e5m2fnuz, 19},
35 };
36
ScalarTypeToONNXType(const c10::ScalarType & st)37 static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) {
38 int64_t onnx_type = -1;
39 const auto it = scalarTypeToONNXTypeMap.find(st);
40 if (it != scalarTypeToONNXTypeMap.end()) {
41 onnx_type = it->second;
42 }
43 return onnx_type;
44 }
45
46 // For these operators, all inputs and outputs share the same scalar type.
47 // There is no operator-wise special case handling needed.
48 static const std::unordered_set<NodeKind> standardOps = {
49 onnx::Add,
50 onnx::Concat,
51 onnx::Div,
52 onnx::Gemm,
53 onnx::Min,
54 onnx::Max,
55 onnx::Mod,
56 onnx::Mul,
57 onnx::Pow,
58 onnx::Sub,
59 onnx::MatMul,
60 onnx::Conv,
61 };
62
63 // For these operators, all inputs share the same scalar type.
64 // The output scalar type is always Bool.
65 static const std::unordered_set<NodeKind> comparisonOps = {
66 onnx::Greater,
67 onnx::Less,
68 onnx::Equal,
69 onnx::GreaterOrEqual,
70 onnx::LessOrEqual,
71 };
72
73 static const std::unordered_set<NodeKind> selectorOps = {onnx::Where};
74
IsStandardOp(const NodeKind & nkind)75 static bool IsStandardOp(const NodeKind& nkind) {
76 return standardOps.find(nkind) != standardOps.end();
77 }
78
IsComparisonOp(const NodeKind & nkind)79 static bool IsComparisonOp(const NodeKind& nkind) {
80 return comparisonOps.find(nkind) != comparisonOps.end();
81 }
82
IsSelectorOp(const NodeKind & nkind)83 static bool IsSelectorOp(const NodeKind& nkind) {
84 return selectorOps.find(nkind) != selectorOps.end();
85 }
86
CreateProfiledTensorTypeWithScalarType(const TensorTypePtr & typePtr,const c10::ScalarType & scalar_type)87 static TensorTypePtr CreateProfiledTensorTypeWithScalarType(
88 const TensorTypePtr& typePtr,
89 const c10::ScalarType& scalar_type) {
90 TORCH_INTERNAL_ASSERT(typePtr != nullptr);
91 return typePtr->withScalarType({scalar_type});
92 }
93
IsImplicitCastSupported(const NodeKind & nodeKind)94 static bool IsImplicitCastSupported(const NodeKind& nodeKind) {
95 return IsStandardOp(nodeKind) || IsComparisonOp(nodeKind) ||
96 IsSelectorOp(nodeKind);
97 }
98
PromoteScalarTypes(const std::vector<c10::ScalarType> & types)99 static std::optional<c10::ScalarType> PromoteScalarTypes(
100 const std::vector<c10::ScalarType>& types) {
101 if (types.empty()) {
102 return std::nullopt;
103 }
104 auto st = types[0];
105 for (const auto i : c10::irange(1, types.size())) {
106 st = c10::promoteTypes(st, types[i]);
107 }
108 return st;
109 }
110
111 // Type promotion between scalars and tensors
112 // per logic here
113 // https://pytorch.org/docs/main/tensor_attributes.html#tensor-attributes
PromoteScalarTypesWithCategory(const std::vector<c10::ScalarType> & typesFromTensors,const std::vector<c10::ScalarType> & typesFromScalars)114 static std::optional<c10::ScalarType> PromoteScalarTypesWithCategory(
115 const std::vector<c10::ScalarType>& typesFromTensors,
116 const std::vector<c10::ScalarType>& typesFromScalars) {
117 auto typeFromTensor = PromoteScalarTypes(typesFromTensors);
118 auto typeFromScalar = PromoteScalarTypes(typesFromScalars);
119
120 auto getTypeCategory = [](c10::ScalarType t) {
121 if (c10::kBool == t) {
122 return 1;
123 }
124 if (c10::isIntegralType(t, /*includeBool=*/false)) {
125 return 2;
126 }
127 if (c10::isFloatingType(t)) {
128 return 3;
129 }
130 return 0;
131 };
132
133 if (std::nullopt == typeFromScalar) {
134 return typeFromTensor;
135 } else if (std::nullopt == typeFromTensor) {
136 return typeFromScalar;
137 }
138
139 auto typeCategoryFromTensor = getTypeCategory(typeFromTensor.value());
140 auto typeCategoryFromScalar = getTypeCategory(typeFromScalar.value());
141
142 if (typeCategoryFromScalar > typeCategoryFromTensor) {
143 return typeFromScalar;
144 }
145 return typeFromTensor;
146 }
147
InferExpectedScalarType(const Node * n)148 static std::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
149 std::vector<c10::ScalarType> typesFromTensors;
150 std::vector<c10::ScalarType> typesFromScalars;
151
152 auto get_scalar_type =
153 [](const Value* input) -> std::optional<at::ScalarType> {
154 if (auto* tensor_type = input->type()->castRaw<TensorType>()) {
155 return tensor_type->scalarType();
156 }
157 return std::nullopt;
158 };
159 auto emplace_type_from_scalar =
160 [&typesFromTensors, &typesFromScalars](at::ScalarType scalar_type) {
161 // Mimic PyTorch scalar type promotion logic
162 // from https://github.com/pytorch/pytorch/issues/9515
163 // Quoting:
164 // A Tensor is a considered a "wrapped number" if it is
165 // auto-wrapped from a C++ or Python number type. Integer types are
166 // wrapped as 0-dim int64 tensors and floating-point types are
167 // wrapped as 0-dim double tensors.
168 auto default_scalar_type =
169 at::typeMetaToScalarType(at::get_default_dtype());
170 switch (scalar_type) {
171 case at::kDouble:
172 case at::kFloat:
173 // floating-point numbers wrapped as float32/float64 tensors are
174 // considered to have default type, instead of double.
175 typesFromScalars.emplace_back(default_scalar_type);
176 break;
177 case at::kLong:
178 case at::kBool:
179 // bool and integer numbers remain the same type.
180 typesFromScalars.emplace_back(scalar_type);
181 break;
182 default:
183 // other types are not from wrapped numbers,
184 // track them as types from tensors.
185 typesFromTensors.emplace_back(scalar_type);
186 break;
187 }
188 };
189
190 size_t input_idx = 0;
191 std::for_each(
192 n->inputs().begin(), n->inputs().end(), [&](const Value* input) {
193 // We skip the 'condition' input (i.e., the first input) in case of
194 // onnx::Where operator.
195 if (IsSelectorOp(n->kind()) && input_idx == 0) {
196 input_idx++;
197 return;
198 }
199
200 auto nkind = input->node()->kind();
201 if (nkind == onnx::Gather &&
202 input->node()->input(0)->node()->kind() == onnx::Shape) {
203 // This is a special pattern generated by code like `dim_size =
204 // x.size(0)`. It gets converted to the below ONNX IR graph
205 // %1 : Long() = onnx::Constant[value={0}]()
206 // %2 : Tensor = onnx::Shape(%x)
207 // %dim_size : Long() = onnx::Gather(%2, %1)
208 // `dim_size` is treated in PyTorch as Scalar.
209 // However, in the ONNX IR graph, it is an output of onnx::Gather,
210 // which is by default considered as a tensor.
211 typesFromScalars.emplace_back(c10::kLong);
212 } else if (nkind == onnx::Constant) {
213 auto tensor = input->node()->t(attr::value);
214 auto rank = tensor.dim();
215 auto scalar_type = tensor.scalar_type();
216
217 if (rank == 0) {
218 emplace_type_from_scalar(scalar_type);
219 } else {
220 typesFromTensors.emplace_back(scalar_type);
221 }
222 } else if (auto scalar_type = get_scalar_type(input)) {
223 auto tensor_type = input->type()->castRaw<TensorType>();
224 // get_scalar_type returns non-null value already guarantees
225 // that the input has a valid tensor_type.
226 TORCH_INTERNAL_ASSERT(nullptr != tensor_type);
227 // ONNX model track shape related computes that were done in pytorch
228 // by python numbers as tensor computes. This is the only way for ONNX
229 // to track them properly since ONNX only has tensor type, otherwise
230 // the computation result will be tracked statically as constant, and
231 // the model won't work for another input that differs in shape.
232
233 // Now for type promotion logic, scalars should be treated differently
234 // with tensors. More info regarding type promotion logic commented at
235 // `emplace_type_from_scalar`. Here we filter out rank 0 tensors and
236 // run it with `emplace_type_from_scalar` to determine if they are
237 // considered scalars for type promotion.
238
239 // NOTE that this might introduce regression that a REAL 0-rank tensor
240 // is now being recognized as scalar. The downside is the model will
241 // drop in accuracy for these cases as certain computations will
242 // happen in lower precision data types.
243 auto rank = tensor_type->dim();
244 if (rank && rank.value() == 0) {
245 emplace_type_from_scalar(scalar_type.value());
246 } else {
247 typesFromTensors.emplace_back(scalar_type.value());
248 }
249
250 input_idx++;
251 }
252 });
253
254 std::optional<c10::ScalarType> st = std::nullopt;
255 const auto output_st = get_scalar_type(n->output());
256
257 if (IsComparisonOp(n->kind())) {
258 // For comparison ops, always promote scalar type to highest among inputs,
259 // regardless if that input is a tensor or scalar.
260 typesFromScalars.insert(
261 typesFromScalars.end(),
262 typesFromTensors.begin(),
263 typesFromTensors.end());
264 st = PromoteScalarTypes(typesFromScalars);
265 } else {
266 if (output_st) {
267 // If output scalar type is available, use that.
268 st = output_st;
269 } else {
270 // PyTorch now does implicit type promotion regardless whether the inputs
271 // are tensors or scalars. (Previously only scalars support implicit
272 // casting).
273 // Per logic here
274 // https://pytorch.org/docs/main/tensor_attributes.html#tensor-attributes
275 st = PromoteScalarTypesWithCategory(typesFromTensors, typesFromScalars);
276 }
277 }
278
279 return st;
280 }
281
LowPrecisionCastForStandardOps(const Node * n,const c10::ScalarType & scalar_type)282 static std::optional<c10::ScalarType> LowPrecisionCastForStandardOps(
283 const Node* n,
284 const c10::ScalarType& scalar_type) {
285 // Some of standardOps do not support uint8\int8\int16 type for ONNX
286 // opset version < 14.
287 // Fix in this ONNX PR:
288 // https://github.com/onnx/onnx/pull/3334
289 if (n->kind() != onnx::Gemm && IsStandardOp(n->kind()) &&
290 (scalar_type == c10::kByte || scalar_type == c10::kChar ||
291 scalar_type == c10::kShort)) {
292 return c10::kLong;
293 }
294 return scalar_type;
295 }
296
UpdateScalarTypeForInputs(Node * n,const c10::ScalarType & scalar_type)297 static void UpdateScalarTypeForInputs(
298 Node* n,
299 const c10::ScalarType& scalar_type) {
300 const int64_t onnx_type = ScalarTypeToONNXType(scalar_type);
301 if (onnx_type < 0) {
302 TORCH_WARN(
303 "ONNX Scalar Type Analysis - Scalar type: ",
304 c10::toString(scalar_type),
305 " of input tensor in operator: ",
306 n->kind().toDisplayString(),
307 " not supported in ONNX. ");
308 return;
309 }
310
311 size_t input_idx = 0;
312 for (auto input : n->inputs()) {
313 auto input_tensor_type = input->type()->cast<TensorType>();
314 auto input_scalar_type =
315 input_tensor_type ? input_tensor_type->scalarType() : std::nullopt;
316
317 // We skip the 'condition' input (i.e., the first input) in case of
318 // onnx:Where operator.
319 if (IsSelectorOp(n->kind()) && input_idx == 0) {
320 input_idx++;
321 continue;
322 }
323
324 if ((input->node()->kind() == onnx::Constant) ||
325 (input_scalar_type && (*input_scalar_type != scalar_type))) {
326 if (input->node()->kind() == onnx::Constant) {
327 // Fix up the scalar directly instead of inserting a cast operator.
328 // TODO: Keep only the else branch once constant_folding is enabled by
329 // default.
330 at::Tensor val = input->node()->t(attr::value);
331 at::Tensor new_val = val.to(scalar_type);
332 Node* const_node = n->owningGraph()->create(onnx::Constant);
333 const_node->t_(attr::value, new_val);
334 const_node->insertBefore(n);
335 const_node->output()->setType(TensorType::create(new_val));
336 const_node->copyMetadata(n);
337 n->replaceInputWith(input, const_node->output());
338 } else {
339 Node* cast_node = n->owningGraph()->create(onnx::Cast);
340 cast_node->addInput(input);
341 cast_node->i_(attr::to, onnx_type);
342 cast_node->insertBefore(n);
343 cast_node->output()->setType(CreateProfiledTensorTypeWithScalarType(
344 input_tensor_type, scalar_type));
345 cast_node->copyMetadata(n);
346 n->replaceInputWith(input, cast_node->output());
347 }
348 }
349
350 input_idx++;
351 }
352 }
353
UpdateScalarTypeForOutput(Node * n,const c10::ScalarType & scalar_type)354 static void UpdateScalarTypeForOutput(
355 Node* n,
356 const c10::ScalarType& scalar_type) {
357 if (auto output_tensor_type = n->output()->type()->cast<TensorType>()) {
358 n->output()->setType(CreateProfiledTensorTypeWithScalarType(
359 output_tensor_type, scalar_type));
360 }
361 }
362
RecoverScalarTypeForOutput(Value * out,const c10::ScalarType & scalar_type)363 static void RecoverScalarTypeForOutput(
364 Value* out,
365 const c10::ScalarType& scalar_type) {
366 Node* n = out->node();
367 TORCH_INTERNAL_ASSERT(nullptr != n);
368 const int64_t onnx_type = ScalarTypeToONNXType(scalar_type);
369 Node* cast_node = n->owningGraph()->create(onnx::Cast, 1);
370 cast_node->addInput(out);
371 cast_node->i_(attr::to, onnx_type);
372 cast_node->insertAfter(n);
373 cast_node->copyMetadata(n);
374 out->replaceAllUsesAfterNodeWith(cast_node, cast_node->output());
375 }
376
377 // This example error found when exports transfo_xl model using add op in uint8
378 // type, as below:
379 // if self.same_length:
380 // all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
381 // mask_len = klen - self.mem_len
382 // if mask_len > 0:
383 // mask_shift_len = qlen - mask_len
384 // else:
385 // mask_shift_len = qlen
386 // dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones,
387 // -mask_shift_len))[:, :, None] # -1
388 //
389 // `all_ones is` an uint8 tensor, But the calculation of `dec_attn_mask` using
390 // add(+) op to get the uint8 result. Reference Link:
391 // https://github.com/huggingface/transformers/blob/b020a736c374460af1b34267283f957988350630/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L936
LowPrecisionCastNodeForStandardOps(Node * n,int opset_version)392 static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) {
393 TORCH_INTERNAL_ASSERT(n->outputs().size() == 1);
394 if (n->output()->type()->cast<TensorType>() == nullptr ||
395 n->output()->type()->cast<TensorType>()->scalarType() == std::nullopt) {
396 // skip LowPrecisionCast if op output type is null.
397 return;
398 }
399 auto output_scalar_type =
400 n->output()->type()->cast<TensorType>()->scalarType().value();
401 for (size_t i = 0; i < n->inputs().size(); ++i) {
402 if (n->input(i)->type()->cast<TensorType>() == nullptr ||
403 n->input(i)->type()->cast<TensorType>()->scalarType() == std::nullopt) {
404 // skip LowPrecisionCast if any op input type node is null.
405 return;
406 }
407 auto input_tensor_type =
408 n->input(i)->type()->cast<TensorType>()->scalarType().value();
409 TORCH_INTERNAL_ASSERT(output_scalar_type == input_tensor_type);
410 }
411
412 // The LowPrecision problem will be fixed in ONNX opset 14.
413 if (opset_version < ONNX_OPSET_14) {
414 auto expected_scalar_type_cast =
415 LowPrecisionCastForStandardOps(n, output_scalar_type);
416 UpdateScalarTypeForInputs(n, *expected_scalar_type_cast);
417 if (output_scalar_type != *expected_scalar_type_cast) {
418 // If input type is changed, convert it to the original type.
419 RecoverScalarTypeForOutput(n->output(), output_scalar_type);
420 }
421 }
422 }
423
ImplicitCastNodeForONNX(Node * n)424 static void ImplicitCastNodeForONNX(Node* n) {
425 if (IsImplicitCastSupported(n->kind())) {
426 auto expected_scalar_type = InferExpectedScalarType(n);
427 if (expected_scalar_type) {
428 UpdateScalarTypeForInputs(n, *expected_scalar_type);
429 if (!IsComparisonOp(n->kind())) {
430 UpdateScalarTypeForOutput(n, *expected_scalar_type);
431 }
432 }
433 }
434 }
435
ImplicitCastForONNX(Block * block)436 static void ImplicitCastForONNX(Block* block) {
437 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
438 for (auto sub : it->blocks()) {
439 ImplicitCastForONNX(sub);
440 }
441
442 ImplicitCastNodeForONNX(*it);
443 }
444 EliminateDeadCode(
445 block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
446 }
447
LowPrecisionCastForStandardOpsONNX(Block * block,int opset_version)448 static void LowPrecisionCastForStandardOpsONNX(
449 Block* block,
450 int opset_version) {
451 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
452 for (auto sub : it->blocks()) {
453 LowPrecisionCastForStandardOpsONNX(sub, opset_version);
454 }
455
456 if (IsStandardOp(it->kind())) {
457 LowPrecisionCastNodeForStandardOps(*it, opset_version);
458 }
459 }
460 EliminateDeadCode(
461 block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
462 }
463 } // anonymous namespace
464
ScalarTypeAnalysisForONNX(const std::shared_ptr<Graph> & graph,bool lowprecision_cast,int opset_version)465 void ScalarTypeAnalysisForONNX(
466 const std::shared_ptr<Graph>& graph,
467 bool lowprecision_cast,
468 int opset_version) {
469 GRAPH_DUMP("Before ScalarTypeAnalysisForONNX: ", graph);
470 ImplicitCastForONNX(graph->block());
471 if (lowprecision_cast) {
472 LowPrecisionCastForStandardOpsONNX(graph->block(), opset_version);
473 }
474 GRAPH_DUMP("After ScalarTypeAnalysisForONNX: ", graph);
475 }
476
ScalarTypeAnalysisNodeForONNX(Node * n)477 void ScalarTypeAnalysisNodeForONNX(Node* n) {
478 ImplicitCastNodeForONNX(n);
479 }
480
481 } // namespace torch::jit
482