xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/irange.h>
2 #include <torch/csrc/jit/jit_log.h>
3 #include <torch/csrc/jit/passes/dead_code_elimination.h>
4 #include <torch/csrc/jit/passes/onnx/helper.h>
5 #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
6 
7 namespace torch::jit {
8 
9 namespace onnx {
10 using namespace ::c10::onnx;
11 }
12 
13 namespace {
14 const int ONNX_OPSET_14 = 14;
15 
16 static const std::unordered_map<c10::ScalarType, int, ScalarTypeHashFunction>
17     scalarTypeToONNXTypeMap = {
18         {c10::kFloat, 1},
19         {c10::kByte, 2},
20         {c10::kChar, 3},
21         {c10::kShort, 5},
22         {c10::kInt, 6},
23         {c10::kLong, 7},
24         {c10::kBool, 9},
25         {c10::kHalf, 10},
26         {c10::kDouble, 11},
27         {c10::kQInt8, 12},
28         {c10::kQUInt8, 13},
29         {c10::kQInt32, 14},
30         {c10::kBFloat16, 15},
31         {c10::kFloat8_e4m3fn, 16},
32         {c10::kFloat8_e5m2, 17},
33         {c10::kFloat8_e4m3fnuz, 18},
34         {c10::kFloat8_e5m2fnuz, 19},
35 };
36 
ScalarTypeToONNXType(const c10::ScalarType & st)37 static int64_t ScalarTypeToONNXType(const c10::ScalarType& st) {
38   int64_t onnx_type = -1;
39   const auto it = scalarTypeToONNXTypeMap.find(st);
40   if (it != scalarTypeToONNXTypeMap.end()) {
41     onnx_type = it->second;
42   }
43   return onnx_type;
44 }
45 
46 // For these operators, all inputs and outputs share the same scalar type.
47 // There is no operator-wise special case handling needed.
48 static const std::unordered_set<NodeKind> standardOps = {
49     onnx::Add,
50     onnx::Concat,
51     onnx::Div,
52     onnx::Gemm,
53     onnx::Min,
54     onnx::Max,
55     onnx::Mod,
56     onnx::Mul,
57     onnx::Pow,
58     onnx::Sub,
59     onnx::MatMul,
60     onnx::Conv,
61 };
62 
63 // For these operators, all inputs share the same scalar type.
64 // The output scalar type is always Bool.
65 static const std::unordered_set<NodeKind> comparisonOps = {
66     onnx::Greater,
67     onnx::Less,
68     onnx::Equal,
69     onnx::GreaterOrEqual,
70     onnx::LessOrEqual,
71 };
72 
73 static const std::unordered_set<NodeKind> selectorOps = {onnx::Where};
74 
IsStandardOp(const NodeKind & nkind)75 static bool IsStandardOp(const NodeKind& nkind) {
76   return standardOps.find(nkind) != standardOps.end();
77 }
78 
IsComparisonOp(const NodeKind & nkind)79 static bool IsComparisonOp(const NodeKind& nkind) {
80   return comparisonOps.find(nkind) != comparisonOps.end();
81 }
82 
IsSelectorOp(const NodeKind & nkind)83 static bool IsSelectorOp(const NodeKind& nkind) {
84   return selectorOps.find(nkind) != selectorOps.end();
85 }
86 
CreateProfiledTensorTypeWithScalarType(const TensorTypePtr & typePtr,const c10::ScalarType & scalar_type)87 static TensorTypePtr CreateProfiledTensorTypeWithScalarType(
88     const TensorTypePtr& typePtr,
89     const c10::ScalarType& scalar_type) {
90   TORCH_INTERNAL_ASSERT(typePtr != nullptr);
91   return typePtr->withScalarType({scalar_type});
92 }
93 
IsImplicitCastSupported(const NodeKind & nodeKind)94 static bool IsImplicitCastSupported(const NodeKind& nodeKind) {
95   return IsStandardOp(nodeKind) || IsComparisonOp(nodeKind) ||
96       IsSelectorOp(nodeKind);
97 }
98 
PromoteScalarTypes(const std::vector<c10::ScalarType> & types)99 static std::optional<c10::ScalarType> PromoteScalarTypes(
100     const std::vector<c10::ScalarType>& types) {
101   if (types.empty()) {
102     return std::nullopt;
103   }
104   auto st = types[0];
105   for (const auto i : c10::irange(1, types.size())) {
106     st = c10::promoteTypes(st, types[i]);
107   }
108   return st;
109 }
110 
111 // Type promotion between scalars and tensors
112 // per logic here
113 // https://pytorch.org/docs/main/tensor_attributes.html#tensor-attributes
PromoteScalarTypesWithCategory(const std::vector<c10::ScalarType> & typesFromTensors,const std::vector<c10::ScalarType> & typesFromScalars)114 static std::optional<c10::ScalarType> PromoteScalarTypesWithCategory(
115     const std::vector<c10::ScalarType>& typesFromTensors,
116     const std::vector<c10::ScalarType>& typesFromScalars) {
117   auto typeFromTensor = PromoteScalarTypes(typesFromTensors);
118   auto typeFromScalar = PromoteScalarTypes(typesFromScalars);
119 
120   auto getTypeCategory = [](c10::ScalarType t) {
121     if (c10::kBool == t) {
122       return 1;
123     }
124     if (c10::isIntegralType(t, /*includeBool=*/false)) {
125       return 2;
126     }
127     if (c10::isFloatingType(t)) {
128       return 3;
129     }
130     return 0;
131   };
132 
133   if (std::nullopt == typeFromScalar) {
134     return typeFromTensor;
135   } else if (std::nullopt == typeFromTensor) {
136     return typeFromScalar;
137   }
138 
139   auto typeCategoryFromTensor = getTypeCategory(typeFromTensor.value());
140   auto typeCategoryFromScalar = getTypeCategory(typeFromScalar.value());
141 
142   if (typeCategoryFromScalar > typeCategoryFromTensor) {
143     return typeFromScalar;
144   }
145   return typeFromTensor;
146 }
147 
InferExpectedScalarType(const Node * n)148 static std::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
149   std::vector<c10::ScalarType> typesFromTensors;
150   std::vector<c10::ScalarType> typesFromScalars;
151 
152   auto get_scalar_type =
153       [](const Value* input) -> std::optional<at::ScalarType> {
154     if (auto* tensor_type = input->type()->castRaw<TensorType>()) {
155       return tensor_type->scalarType();
156     }
157     return std::nullopt;
158   };
159   auto emplace_type_from_scalar =
160       [&typesFromTensors, &typesFromScalars](at::ScalarType scalar_type) {
161         // Mimic PyTorch scalar type promotion logic
162         // from https://github.com/pytorch/pytorch/issues/9515
163         // Quoting:
164         //    A Tensor is a considered a "wrapped number" if it is
165         //    auto-wrapped from a C++ or Python number type. Integer types are
166         //    wrapped as 0-dim int64 tensors and floating-point types are
167         //    wrapped as 0-dim double tensors.
168         auto default_scalar_type =
169             at::typeMetaToScalarType(at::get_default_dtype());
170         switch (scalar_type) {
171           case at::kDouble:
172           case at::kFloat:
173             // floating-point numbers wrapped as float32/float64 tensors are
174             // considered to have default type, instead of double.
175             typesFromScalars.emplace_back(default_scalar_type);
176             break;
177           case at::kLong:
178           case at::kBool:
179             // bool and integer numbers remain the same type.
180             typesFromScalars.emplace_back(scalar_type);
181             break;
182           default:
183             // other types are not from wrapped numbers,
184             // track them as types from tensors.
185             typesFromTensors.emplace_back(scalar_type);
186             break;
187         }
188       };
189 
190   size_t input_idx = 0;
191   std::for_each(
192       n->inputs().begin(), n->inputs().end(), [&](const Value* input) {
193         // We skip the 'condition' input (i.e., the first input) in case of
194         // onnx::Where operator.
195         if (IsSelectorOp(n->kind()) && input_idx == 0) {
196           input_idx++;
197           return;
198         }
199 
200         auto nkind = input->node()->kind();
201         if (nkind == onnx::Gather &&
202             input->node()->input(0)->node()->kind() == onnx::Shape) {
203           // This is a special pattern generated by code like `dim_size =
204           // x.size(0)`. It gets converted to the below ONNX IR graph
205           //    %1 : Long() = onnx::Constant[value={0}]()
206           //    %2 : Tensor = onnx::Shape(%x)
207           //    %dim_size : Long() = onnx::Gather(%2, %1)
208           // `dim_size` is treated in PyTorch as Scalar.
209           // However, in the ONNX IR graph, it is an output of onnx::Gather,
210           // which is by default considered as a tensor.
211           typesFromScalars.emplace_back(c10::kLong);
212         } else if (nkind == onnx::Constant) {
213           auto tensor = input->node()->t(attr::value);
214           auto rank = tensor.dim();
215           auto scalar_type = tensor.scalar_type();
216 
217           if (rank == 0) {
218             emplace_type_from_scalar(scalar_type);
219           } else {
220             typesFromTensors.emplace_back(scalar_type);
221           }
222         } else if (auto scalar_type = get_scalar_type(input)) {
223           auto tensor_type = input->type()->castRaw<TensorType>();
224           // get_scalar_type returns non-null value already guarantees
225           // that the input has a valid tensor_type.
226           TORCH_INTERNAL_ASSERT(nullptr != tensor_type);
227           // ONNX model track shape related computes that were done in pytorch
228           // by python numbers as tensor computes. This is the only way for ONNX
229           // to track them properly since ONNX only has tensor type, otherwise
230           // the computation result will be tracked statically as constant, and
231           // the model won't work for another input that differs in shape.
232 
233           // Now for type promotion logic, scalars should be treated differently
234           // with tensors. More info regarding type promotion logic commented at
235           // `emplace_type_from_scalar`. Here we filter out rank 0 tensors and
236           // run it with `emplace_type_from_scalar` to determine if they are
237           // considered scalars for type promotion.
238 
239           // NOTE that this might introduce regression that a REAL 0-rank tensor
240           // is now being recognized as scalar. The downside is the model will
241           // drop in accuracy for these cases as certain computations will
242           // happen in lower precision data types.
243           auto rank = tensor_type->dim();
244           if (rank && rank.value() == 0) {
245             emplace_type_from_scalar(scalar_type.value());
246           } else {
247             typesFromTensors.emplace_back(scalar_type.value());
248           }
249 
250           input_idx++;
251         }
252       });
253 
254   std::optional<c10::ScalarType> st = std::nullopt;
255   const auto output_st = get_scalar_type(n->output());
256 
257   if (IsComparisonOp(n->kind())) {
258     // For comparison ops, always promote scalar type to highest among inputs,
259     // regardless if that input is a tensor or scalar.
260     typesFromScalars.insert(
261         typesFromScalars.end(),
262         typesFromTensors.begin(),
263         typesFromTensors.end());
264     st = PromoteScalarTypes(typesFromScalars);
265   } else {
266     if (output_st) {
267       // If output scalar type is available, use that.
268       st = output_st;
269     } else {
270       // PyTorch now does implicit type promotion regardless whether the inputs
271       // are tensors or scalars. (Previously only scalars support implicit
272       // casting).
273       // Per logic here
274       // https://pytorch.org/docs/main/tensor_attributes.html#tensor-attributes
275       st = PromoteScalarTypesWithCategory(typesFromTensors, typesFromScalars);
276     }
277   }
278 
279   return st;
280 }
281 
LowPrecisionCastForStandardOps(const Node * n,const c10::ScalarType & scalar_type)282 static std::optional<c10::ScalarType> LowPrecisionCastForStandardOps(
283     const Node* n,
284     const c10::ScalarType& scalar_type) {
285   // Some of standardOps do not support uint8\int8\int16 type for ONNX
286   // opset version < 14.
287   // Fix in this ONNX PR:
288   // https://github.com/onnx/onnx/pull/3334
289   if (n->kind() != onnx::Gemm && IsStandardOp(n->kind()) &&
290       (scalar_type == c10::kByte || scalar_type == c10::kChar ||
291        scalar_type == c10::kShort)) {
292     return c10::kLong;
293   }
294   return scalar_type;
295 }
296 
UpdateScalarTypeForInputs(Node * n,const c10::ScalarType & scalar_type)297 static void UpdateScalarTypeForInputs(
298     Node* n,
299     const c10::ScalarType& scalar_type) {
300   const int64_t onnx_type = ScalarTypeToONNXType(scalar_type);
301   if (onnx_type < 0) {
302     TORCH_WARN(
303         "ONNX Scalar Type Analysis - Scalar type: ",
304         c10::toString(scalar_type),
305         " of input tensor in operator: ",
306         n->kind().toDisplayString(),
307         " not supported in ONNX. ");
308     return;
309   }
310 
311   size_t input_idx = 0;
312   for (auto input : n->inputs()) {
313     auto input_tensor_type = input->type()->cast<TensorType>();
314     auto input_scalar_type =
315         input_tensor_type ? input_tensor_type->scalarType() : std::nullopt;
316 
317     // We skip the 'condition' input (i.e., the first input) in case of
318     // onnx:Where operator.
319     if (IsSelectorOp(n->kind()) && input_idx == 0) {
320       input_idx++;
321       continue;
322     }
323 
324     if ((input->node()->kind() == onnx::Constant) ||
325         (input_scalar_type && (*input_scalar_type != scalar_type))) {
326       if (input->node()->kind() == onnx::Constant) {
327         // Fix up the scalar directly instead of inserting a cast operator.
328         // TODO: Keep only the else branch once constant_folding is enabled by
329         // default.
330         at::Tensor val = input->node()->t(attr::value);
331         at::Tensor new_val = val.to(scalar_type);
332         Node* const_node = n->owningGraph()->create(onnx::Constant);
333         const_node->t_(attr::value, new_val);
334         const_node->insertBefore(n);
335         const_node->output()->setType(TensorType::create(new_val));
336         const_node->copyMetadata(n);
337         n->replaceInputWith(input, const_node->output());
338       } else {
339         Node* cast_node = n->owningGraph()->create(onnx::Cast);
340         cast_node->addInput(input);
341         cast_node->i_(attr::to, onnx_type);
342         cast_node->insertBefore(n);
343         cast_node->output()->setType(CreateProfiledTensorTypeWithScalarType(
344             input_tensor_type, scalar_type));
345         cast_node->copyMetadata(n);
346         n->replaceInputWith(input, cast_node->output());
347       }
348     }
349 
350     input_idx++;
351   }
352 }
353 
UpdateScalarTypeForOutput(Node * n,const c10::ScalarType & scalar_type)354 static void UpdateScalarTypeForOutput(
355     Node* n,
356     const c10::ScalarType& scalar_type) {
357   if (auto output_tensor_type = n->output()->type()->cast<TensorType>()) {
358     n->output()->setType(CreateProfiledTensorTypeWithScalarType(
359         output_tensor_type, scalar_type));
360   }
361 }
362 
RecoverScalarTypeForOutput(Value * out,const c10::ScalarType & scalar_type)363 static void RecoverScalarTypeForOutput(
364     Value* out,
365     const c10::ScalarType& scalar_type) {
366   Node* n = out->node();
367   TORCH_INTERNAL_ASSERT(nullptr != n);
368   const int64_t onnx_type = ScalarTypeToONNXType(scalar_type);
369   Node* cast_node = n->owningGraph()->create(onnx::Cast, 1);
370   cast_node->addInput(out);
371   cast_node->i_(attr::to, onnx_type);
372   cast_node->insertAfter(n);
373   cast_node->copyMetadata(n);
374   out->replaceAllUsesAfterNodeWith(cast_node, cast_node->output());
375 }
376 
377 // This example error found when exports transfo_xl model using add op in uint8
378 // type, as below:
379 // if self.same_length:
380 //     all_ones = word_emb.new_ones((qlen, klen), dtype=torch.uint8)
381 //     mask_len = klen - self.mem_len
382 //     if mask_len > 0:
383 //         mask_shift_len = qlen - mask_len
384 //     else:
385 //         mask_shift_len = qlen
386 //     dec_attn_mask = (torch.triu(all_ones, 1 + mlen) + torch.tril(all_ones,
387 //     -mask_shift_len))[:, :, None]  # -1
388 //
389 // `all_ones is` an uint8 tensor, But the calculation of `dec_attn_mask` using
390 // add(+) op to get the uint8 result. Reference Link:
391 // https://github.com/huggingface/transformers/blob/b020a736c374460af1b34267283f957988350630/src/transformers/models/transfo_xl/modeling_transfo_xl.py#L936
LowPrecisionCastNodeForStandardOps(Node * n,int opset_version)392 static void LowPrecisionCastNodeForStandardOps(Node* n, int opset_version) {
393   TORCH_INTERNAL_ASSERT(n->outputs().size() == 1);
394   if (n->output()->type()->cast<TensorType>() == nullptr ||
395       n->output()->type()->cast<TensorType>()->scalarType() == std::nullopt) {
396     // skip LowPrecisionCast if op output type is null.
397     return;
398   }
399   auto output_scalar_type =
400       n->output()->type()->cast<TensorType>()->scalarType().value();
401   for (size_t i = 0; i < n->inputs().size(); ++i) {
402     if (n->input(i)->type()->cast<TensorType>() == nullptr ||
403         n->input(i)->type()->cast<TensorType>()->scalarType() == std::nullopt) {
404       // skip LowPrecisionCast if any op input type node is null.
405       return;
406     }
407     auto input_tensor_type =
408         n->input(i)->type()->cast<TensorType>()->scalarType().value();
409     TORCH_INTERNAL_ASSERT(output_scalar_type == input_tensor_type);
410   }
411 
412   // The LowPrecision problem will be fixed in ONNX opset 14.
413   if (opset_version < ONNX_OPSET_14) {
414     auto expected_scalar_type_cast =
415         LowPrecisionCastForStandardOps(n, output_scalar_type);
416     UpdateScalarTypeForInputs(n, *expected_scalar_type_cast);
417     if (output_scalar_type != *expected_scalar_type_cast) {
418       // If input type is changed, convert it to the original type.
419       RecoverScalarTypeForOutput(n->output(), output_scalar_type);
420     }
421   }
422 }
423 
ImplicitCastNodeForONNX(Node * n)424 static void ImplicitCastNodeForONNX(Node* n) {
425   if (IsImplicitCastSupported(n->kind())) {
426     auto expected_scalar_type = InferExpectedScalarType(n);
427     if (expected_scalar_type) {
428       UpdateScalarTypeForInputs(n, *expected_scalar_type);
429       if (!IsComparisonOp(n->kind())) {
430         UpdateScalarTypeForOutput(n, *expected_scalar_type);
431       }
432     }
433   }
434 }
435 
ImplicitCastForONNX(Block * block)436 static void ImplicitCastForONNX(Block* block) {
437   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
438     for (auto sub : it->blocks()) {
439       ImplicitCastForONNX(sub);
440     }
441 
442     ImplicitCastNodeForONNX(*it);
443   }
444   EliminateDeadCode(
445       block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
446 }
447 
LowPrecisionCastForStandardOpsONNX(Block * block,int opset_version)448 static void LowPrecisionCastForStandardOpsONNX(
449     Block* block,
450     int opset_version) {
451   for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
452     for (auto sub : it->blocks()) {
453       LowPrecisionCastForStandardOpsONNX(sub, opset_version);
454     }
455 
456     if (IsStandardOp(it->kind())) {
457       LowPrecisionCastNodeForStandardOps(*it, opset_version);
458     }
459   }
460   EliminateDeadCode(
461       block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
462 }
463 } // anonymous namespace
464 
ScalarTypeAnalysisForONNX(const std::shared_ptr<Graph> & graph,bool lowprecision_cast,int opset_version)465 void ScalarTypeAnalysisForONNX(
466     const std::shared_ptr<Graph>& graph,
467     bool lowprecision_cast,
468     int opset_version) {
469   GRAPH_DUMP("Before ScalarTypeAnalysisForONNX: ", graph);
470   ImplicitCastForONNX(graph->block());
471   if (lowprecision_cast) {
472     LowPrecisionCastForStandardOpsONNX(graph->block(), opset_version);
473   }
474   GRAPH_DUMP("After ScalarTypeAnalysisForONNX: ", graph);
475 }
476 
ScalarTypeAnalysisNodeForONNX(Node * n)477 void ScalarTypeAnalysisNodeForONNX(Node* n) {
478   ImplicitCastNodeForONNX(n);
479 }
480 
481 } // namespace torch::jit
482