1 #include <c10/core/ScalarType.h> 2 #import <nlohmann/json.hpp> 3 4 #include <string> 5 6 namespace torch { 7 namespace jit { 8 namespace mobile { 9 namespace coreml { 10 11 struct TensorSpec { 12 std::string name = ""; 13 c10::ScalarType dtype = c10::ScalarType::Float; 14 }; 15 scalar_type(const std::string & type_string)16static inline c10::ScalarType scalar_type(const std::string& type_string) { 17 if (type_string == "0") { 18 return c10::ScalarType::Float; 19 } else if (type_string == "1") { 20 return c10::ScalarType::Double; 21 } else if (type_string == "2") { 22 return c10::ScalarType::Int; 23 } else if (type_string == "3") { 24 return c10::ScalarType::Long; 25 } 26 return c10::ScalarType::Undefined; 27 } 28 29 } // namespace coreml 30 } // namespace mobile 31 } // namespace jit 32 } // namespace torch 33