xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/coreml/objc/PTMCoreMLTensorSpec.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/ScalarType.h>
2 #import <nlohmann/json.hpp>
3 
4 #include <string>
5 
6 namespace torch {
7 namespace jit {
8 namespace mobile {
9 namespace coreml {
10 
11 struct TensorSpec {
12   std::string name = "";
13   c10::ScalarType dtype = c10::ScalarType::Float;
14 };
15 
scalar_type(const std::string & type_string)16 static inline c10::ScalarType scalar_type(const std::string& type_string) {
17   if (type_string == "0") {
18     return c10::ScalarType::Float;
19   } else if (type_string == "1") {
20     return c10::ScalarType::Double;
21   } else if (type_string == "2") {
22     return c10::ScalarType::Int;
23   } else if (type_string == "3") {
24     return c10::ScalarType::Long;
25   }
26   return c10::ScalarType::Undefined;
27 }
28 
29 } // namespace coreml
30 } // namespace mobile
31 } // namespace jit
32 } // namespace torch
33