1 #pragma once 2 3 #include <ATen/core/Tensor.h> 4 #include <ATen/core/IListRef.h> 5 6 namespace at::native { 7 8 struct ResultTypeState { 9 c10::ScalarType dimResult = ScalarType::Undefined; 10 c10::ScalarType wrappedResult = ScalarType::Undefined; 11 c10::ScalarType zeroResult = ScalarType::Undefined; 12 }; 13 14 TORCH_API ResultTypeState update_result_type_state(const Tensor& tensor, const ResultTypeState& in_state); 15 TORCH_API ResultTypeState update_result_type_state(const Scalar& scalar, const ResultTypeState& in_state); 16 TORCH_API ScalarType result_type(const ResultTypeState& state); 17 18 TORCH_API ScalarType result_type(ITensorListRef tensors); 19 20 } // namespace at::native 21