1 //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation file for the abstraction of a tensor type, and JSON loading
10 // utils.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/Config/config.h"
14
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/TensorSpec.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/JSON.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/raw_ostream.h"
22 #include <array>
23 #include <cassert>
24 #include <numeric>
25
26 using namespace llvm;
27
28 namespace llvm {
29
30 #define TFUTILS_GETDATATYPE_IMPL(T, E) \
31 template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
32
SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)33 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
34
35 #undef TFUTILS_GETDATATYPE_IMPL
36
37 static std::array<std::string, static_cast<size_t>(TensorType::Total)>
38 TensorTypeNames{"INVALID",
39 #define TFUTILS_GETNAME_IMPL(T, _) #T,
40 SUPPORTED_TENSOR_TYPES(TFUTILS_GETNAME_IMPL)
41 #undef TFUTILS_GETNAME_IMPL
42 };
43
toString(TensorType TT)44 StringRef toString(TensorType TT) {
45 return TensorTypeNames[static_cast<size_t>(TT)];
46 }
47
toJSON(json::OStream & OS) const48 void TensorSpec::toJSON(json::OStream &OS) const {
49 OS.object([&]() {
50 OS.attribute("name", name());
51 OS.attribute("type", toString(type()));
52 OS.attribute("port", port());
53 OS.attributeArray("shape", [&]() {
54 for (size_t D : shape())
55 OS.value(static_cast<int64_t>(D));
56 });
57 });
58 }
59
TensorSpec(const std::string & Name,int Port,TensorType Type,size_t ElementSize,const std::vector<int64_t> & Shape)60 TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
61 size_t ElementSize, const std::vector<int64_t> &Shape)
62 : Name(Name), Port(Port), Type(Type), Shape(Shape),
63 ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
64 std::multiplies<int64_t>())),
65 ElementSize(ElementSize) {}
66
getTensorSpecFromJSON(LLVMContext & Ctx,const json::Value & Value)67 std::optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
68 const json::Value &Value) {
69 auto EmitError =
70 [&](const llvm::Twine &Message) -> std::optional<TensorSpec> {
71 std::string S;
72 llvm::raw_string_ostream OS(S);
73 OS << Value;
74 Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
75 return std::nullopt;
76 };
77 // FIXME: accept a Path as a parameter, and use it for error reporting.
78 json::Path::Root Root("tensor_spec");
79 json::ObjectMapper Mapper(Value, Root);
80 if (!Mapper)
81 return EmitError("Value is not a dict");
82
83 std::string TensorName;
84 int TensorPort = -1;
85 std::string TensorType;
86 std::vector<int64_t> TensorShape;
87
88 if (!Mapper.map<std::string>("name", TensorName))
89 return EmitError("'name' property not present or not a string");
90 if (!Mapper.map<std::string>("type", TensorType))
91 return EmitError("'type' property not present or not a string");
92 if (!Mapper.map<int>("port", TensorPort))
93 return EmitError("'port' property not present or not an int");
94 if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
95 return EmitError("'shape' property not present or not an int array");
96
97 #define PARSE_TYPE(T, E) \
98 if (TensorType == #T) \
99 return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
100 SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
101 #undef PARSE_TYPE
102 return std::nullopt;
103 }
104
105 } // namespace llvm
106