1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
3 #include <iostream>
4 #include <utility>
5
6 namespace torch::inductor {
7
TensorMetadata(const at::Tensor & src_tensor)8 TensorMetadata::TensorMetadata(const at::Tensor& src_tensor)
9 : is_symbolic_(false),
10 dtype_(src_tensor.scalar_type()),
11 device_(src_tensor.device()),
12 dispatch_key_set_(src_tensor.key_set()),
13 sizes_(src_tensor.sizes().vec()),
14 strides_(src_tensor.strides().vec()),
15 requires_grad_(src_tensor.requires_grad()) {}
16
TensorMetadata(bool is_symbolic,c10::ScalarType dtype,c10::Device device,c10::DispatchKeySet dispatch_key_set,std::vector<int64_t> sizes,std::vector<int64_t> strides,bool requires_grad)17 TensorMetadata::TensorMetadata(
18 bool is_symbolic,
19 c10::ScalarType dtype,
20 c10::Device device,
21 c10::DispatchKeySet dispatch_key_set,
22 std::vector<int64_t> sizes,
23 std::vector<int64_t> strides,
24 bool requires_grad)
25 : is_symbolic_(is_symbolic),
26 dtype_(dtype),
27 device_(device),
28 dispatch_key_set_(dispatch_key_set),
29 sizes_(std::move(sizes)),
30 strides_(std::move(strides)),
31 requires_grad_(requires_grad) {
32 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
33 !is_symbolic_, "Not support symbolic shape now");
34 }
35
build_guard(const torch::dynamo::LocalState & local_state)36 void TensorMetadata::build_guard(const torch::dynamo::LocalState& local_state) {
37 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
38 !is_symbolic_, "Not support symbolic shape now");
39 std::vector<std::optional<c10::SymInt>> sym_sizes;
40 std::vector<std::optional<c10::SymInt>> sym_strides;
41 std::transform(
42 sizes_.begin(),
43 sizes_.end(),
44 std::back_inserter(sym_sizes),
45 [](int64_t size) { return std::optional<c10::SymInt>(size); });
46 std::transform(
47 strides_.begin(),
48 strides_.end(),
49 std::back_inserter(sym_strides),
50 [](int64_t stride) { return std::optional<c10::SymInt>(stride); });
51 tensor_check_ = torch::dynamo::TensorCheck(
52 local_state,
53 nullptr,
54 dispatch_key_set_,
55 dtype_,
56 device_.index(),
57 requires_grad_,
58 sym_sizes,
59 sym_strides);
60 }
61
operator ==(const TensorMetadata & other) const62 bool TensorMetadata::operator==(const TensorMetadata& other) const {
63 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
64 !is_symbolic_, "Not support symbolic shape now");
65
66 if (tensor_check_.has_value()) {
67 auto sizes_ = c10::IntArrayRef(other.sizes_);
68 auto strides_ = c10::IntArrayRef(other.strides_);
69 auto sym_sizes = c10::SymIntArrayRef(
70 reinterpret_cast<const c10::SymInt*>(sizes_.data()), sizes_.size());
71 auto sym_strides = c10::SymIntArrayRef(
72 reinterpret_cast<const c10::SymInt*>(strides_.data()), strides_.size());
73
74 torch::dynamo::LocalState local_state;
75 local_state.overrideDispatchKeySet(dispatch_key_set_);
76 auto _tensor_check = tensor_check_.value();
77 auto res = _tensor_check.check(
78 local_state,
79 other.dispatch_key_set_,
80 other.dtype_,
81 other.device_,
82 sym_sizes,
83 sym_strides,
84 other.requires_grad_ /* Should we need to care about grad requirement?*/);
85 return res;
86 } else {
87 return this->is_symbolic_ == other.is_symbolic_ &&
88 this->dtype_ == other.dtype_ && this->device_ == other.device_ &&
89 this->dispatch_key_set_ == other.dispatch_key_set_ &&
90 this->requires_grad_ == other.requires_grad_ &&
91 this->sizes_ == other.sizes_ && this->strides_ == other.strides_;
92 }
93 }
94
operator <<(std::ostream & stream,const TensorMetadata & tensor_metadata)95 std::ostream& operator<<(
96 std::ostream& stream,
97 const TensorMetadata& tensor_metadata) {
98 stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << '\n';
99 stream << "dtype_: " << tensor_metadata.dtype_ << '\n';
100 stream << "device_: " << tensor_metadata.device_ << '\n';
101 stream << "sizes_: ";
102 for (const auto& size : tensor_metadata.sizes_) {
103 stream << size << " ";
104 }
105 stream << '\n';
106 stream << "strides_: ";
107 for (const auto& stride : tensor_metadata.strides_) {
108 stream << stride << " ";
109 }
110
111 stream << "requires_grad_: " << tensor_metadata.requires_grad_ << '\n';
112 stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ << '\n';
113 stream << "tensor_check_: " << tensor_metadata.tensor_check_.has_value()
114 << '\n';
115 stream << '\n';
116 return stream;
117 }
118
ParameterMetadata(TensorMetadata tensor_metadata,uint64_t input_order)119 ParameterMetadata::ParameterMetadata(
120 TensorMetadata tensor_metadata,
121 uint64_t input_order)
122 : tag_(TENSOR), value_(tensor_metadata), order_(input_order) {}
123
ParameterMetadata(const at::Tensor & tensor,uint64_t input_order)124 ParameterMetadata::ParameterMetadata(
125 const at::Tensor& tensor,
126 uint64_t input_order)
127 : tag_(TENSOR), order_(input_order) {
128 value_ = TensorMetadata(tensor);
129 }
130
ParameterMetadata(const std::vector<TensorMetadata> & tensor_metadata_list,uint64_t input_order)131 ParameterMetadata::ParameterMetadata(
132 const std::vector<TensorMetadata>& tensor_metadata_list,
133 uint64_t input_order)
134 : tag_(TENSOR_LIST), value_(tensor_metadata_list), order_(input_order) {}
135
ParameterMetadata(const std::vector<at::Tensor> & tensor_list,uint64_t input_order)136 ParameterMetadata::ParameterMetadata(
137 const std::vector<at::Tensor>& tensor_list,
138 uint64_t input_order)
139 : tag_(TENSOR_LIST), order_(input_order) {
140 std::vector<TensorMetadata> tensor_metadata_list;
141 tensor_metadata_list.reserve(tensor_list.size());
142 for (const auto& tensor : tensor_list) {
143 tensor_metadata_list.emplace_back(tensor);
144 }
145 value_ = tensor_metadata_list;
146 }
147
ParameterMetadata(const c10::Scalar & scalar,uint64_t input_order)148 ParameterMetadata::ParameterMetadata(
149 const c10::Scalar& scalar,
150 uint64_t input_order)
151 : tag_(SCALAR), value_(scalar), order_(input_order) {}
152
ParameterMetadata(const std::string & str,uint64_t input_order)153 ParameterMetadata::ParameterMetadata(
154 const std::string& str,
155 uint64_t input_order)
156 : tag_(STRING), value_(str), order_(input_order) {}
157
ParameterMetadata(const c10::Device & device,uint64_t input_order)158 ParameterMetadata::ParameterMetadata(
159 const c10::Device& device,
160 uint64_t input_order)
161 : tag_(DEVICE), value_(device), order_(input_order) {}
162
operator ==(const ParameterMetadata & other) const163 bool ParameterMetadata::operator==(const ParameterMetadata& other) const {
164 // Same type
165 if (tag_ != other.tag_) {
166 return false;
167 }
168
169 // Same order of the input parameters
170 if (order_ != other.order_) {
171 return false;
172 }
173
174 switch (tag_) {
175 case TENSOR:
176 return std::get<TensorMetadata>(value_) ==
177 std::get<TensorMetadata>(other.value_);
178 case TENSOR_LIST:
179 return std::get<std::vector<TensorMetadata>>(value_) ==
180 std::get<std::vector<TensorMetadata>>(other.value_);
181 case SCALAR:
182 TORCH_INTERNAL_ASSERT(
183 std::get<c10::Scalar>(other.value_).isFloatingPoint() ||
184 std::get<c10::Scalar>(other.value_).isIntegral(true /*includeBool*/));
185 return equal_to(std::get<c10::Scalar>(other.value_));
186 case STRING:
187 return std::get<std::string>(value_) ==
188 std::get<std::string>(other.value_);
189 case DEVICE:
190 return std::get<c10::Device>(value_) ==
191 std::get<c10::Device>(other.value_);
192 default:
193 return false;
194 }
195 }
196
equal_to(const c10::Scalar & scalar) const197 bool ParameterMetadata::equal_to(const c10::Scalar& scalar) const {
198 TORCH_INTERNAL_ASSERT(scalar.isFloatingPoint() || scalar.isIntegral(true));
199 if (tag_ != SCALAR) {
200 return false;
201 }
202
203 auto self_scalar = std::get<c10::Scalar>(value_);
204 if (scalar.isFloatingPoint() && self_scalar.isFloatingPoint()) {
205 return self_scalar.toDouble() == scalar.toDouble();
206 } else if (scalar.isIntegral(true) && self_scalar.isIntegral(true)) {
207 return self_scalar.toInt() == scalar.toInt();
208 }
209
210 return false;
211 }
212
213 } // namespace torch::inductor
214 #endif
215