xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_eager/kernel_meta_info.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #if !defined(C10_MOBILE) && !defined(ANDROID)
2 #include <torch/csrc/inductor/aoti_eager/kernel_meta_info.h>
3 #include <iostream>
4 #include <utility>
5 
6 namespace torch::inductor {
7 
TensorMetadata(const at::Tensor & src_tensor)8 TensorMetadata::TensorMetadata(const at::Tensor& src_tensor)
9     : is_symbolic_(false),
10       dtype_(src_tensor.scalar_type()),
11       device_(src_tensor.device()),
12       dispatch_key_set_(src_tensor.key_set()),
13       sizes_(src_tensor.sizes().vec()),
14       strides_(src_tensor.strides().vec()),
15       requires_grad_(src_tensor.requires_grad()) {}
16 
TensorMetadata(bool is_symbolic,c10::ScalarType dtype,c10::Device device,c10::DispatchKeySet dispatch_key_set,std::vector<int64_t> sizes,std::vector<int64_t> strides,bool requires_grad)17 TensorMetadata::TensorMetadata(
18     bool is_symbolic,
19     c10::ScalarType dtype,
20     c10::Device device,
21     c10::DispatchKeySet dispatch_key_set,
22     std::vector<int64_t> sizes,
23     std::vector<int64_t> strides,
24     bool requires_grad)
25     : is_symbolic_(is_symbolic),
26       dtype_(dtype),
27       device_(device),
28       dispatch_key_set_(dispatch_key_set),
29       sizes_(std::move(sizes)),
30       strides_(std::move(strides)),
31       requires_grad_(requires_grad) {
32   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
33       !is_symbolic_, "Not support symbolic shape now");
34 }
35 
build_guard(const torch::dynamo::LocalState & local_state)36 void TensorMetadata::build_guard(const torch::dynamo::LocalState& local_state) {
37   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
38       !is_symbolic_, "Not support symbolic shape now");
39   std::vector<std::optional<c10::SymInt>> sym_sizes;
40   std::vector<std::optional<c10::SymInt>> sym_strides;
41   std::transform(
42       sizes_.begin(),
43       sizes_.end(),
44       std::back_inserter(sym_sizes),
45       [](int64_t size) { return std::optional<c10::SymInt>(size); });
46   std::transform(
47       strides_.begin(),
48       strides_.end(),
49       std::back_inserter(sym_strides),
50       [](int64_t stride) { return std::optional<c10::SymInt>(stride); });
51   tensor_check_ = torch::dynamo::TensorCheck(
52       local_state,
53       nullptr,
54       dispatch_key_set_,
55       dtype_,
56       device_.index(),
57       requires_grad_,
58       sym_sizes,
59       sym_strides);
60 }
61 
operator ==(const TensorMetadata & other) const62 bool TensorMetadata::operator==(const TensorMetadata& other) const {
63   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
64       !is_symbolic_, "Not support symbolic shape now");
65 
66   if (tensor_check_.has_value()) {
67     auto sizes_ = c10::IntArrayRef(other.sizes_);
68     auto strides_ = c10::IntArrayRef(other.strides_);
69     auto sym_sizes = c10::SymIntArrayRef(
70         reinterpret_cast<const c10::SymInt*>(sizes_.data()), sizes_.size());
71     auto sym_strides = c10::SymIntArrayRef(
72         reinterpret_cast<const c10::SymInt*>(strides_.data()), strides_.size());
73 
74     torch::dynamo::LocalState local_state;
75     local_state.overrideDispatchKeySet(dispatch_key_set_);
76     auto _tensor_check = tensor_check_.value();
77     auto res = _tensor_check.check(
78         local_state,
79         other.dispatch_key_set_,
80         other.dtype_,
81         other.device_,
82         sym_sizes,
83         sym_strides,
84         other.requires_grad_ /* Should we need to care about grad requirement?*/);
85     return res;
86   } else {
87     return this->is_symbolic_ == other.is_symbolic_ &&
88         this->dtype_ == other.dtype_ && this->device_ == other.device_ &&
89         this->dispatch_key_set_ == other.dispatch_key_set_ &&
90         this->requires_grad_ == other.requires_grad_ &&
91         this->sizes_ == other.sizes_ && this->strides_ == other.strides_;
92   }
93 }
94 
operator <<(std::ostream & stream,const TensorMetadata & tensor_metadata)95 std::ostream& operator<<(
96     std::ostream& stream,
97     const TensorMetadata& tensor_metadata) {
98   stream << "is_symbolic_: " << tensor_metadata.is_symbolic_ << '\n';
99   stream << "dtype_: " << tensor_metadata.dtype_ << '\n';
100   stream << "device_: " << tensor_metadata.device_ << '\n';
101   stream << "sizes_: ";
102   for (const auto& size : tensor_metadata.sizes_) {
103     stream << size << " ";
104   }
105   stream << '\n';
106   stream << "strides_: ";
107   for (const auto& stride : tensor_metadata.strides_) {
108     stream << stride << " ";
109   }
110 
111   stream << "requires_grad_: " << tensor_metadata.requires_grad_ << '\n';
112   stream << "dispatch_key_set_: " << tensor_metadata.dispatch_key_set_ << '\n';
113   stream << "tensor_check_: " << tensor_metadata.tensor_check_.has_value()
114          << '\n';
115   stream << '\n';
116   return stream;
117 }
118 
ParameterMetadata(TensorMetadata tensor_metadata,uint64_t input_order)119 ParameterMetadata::ParameterMetadata(
120     TensorMetadata tensor_metadata,
121     uint64_t input_order)
122     : tag_(TENSOR), value_(tensor_metadata), order_(input_order) {}
123 
ParameterMetadata(const at::Tensor & tensor,uint64_t input_order)124 ParameterMetadata::ParameterMetadata(
125     const at::Tensor& tensor,
126     uint64_t input_order)
127     : tag_(TENSOR), order_(input_order) {
128   value_ = TensorMetadata(tensor);
129 }
130 
ParameterMetadata(const std::vector<TensorMetadata> & tensor_metadata_list,uint64_t input_order)131 ParameterMetadata::ParameterMetadata(
132     const std::vector<TensorMetadata>& tensor_metadata_list,
133     uint64_t input_order)
134     : tag_(TENSOR_LIST), value_(tensor_metadata_list), order_(input_order) {}
135 
ParameterMetadata(const std::vector<at::Tensor> & tensor_list,uint64_t input_order)136 ParameterMetadata::ParameterMetadata(
137     const std::vector<at::Tensor>& tensor_list,
138     uint64_t input_order)
139     : tag_(TENSOR_LIST), order_(input_order) {
140   std::vector<TensorMetadata> tensor_metadata_list;
141   tensor_metadata_list.reserve(tensor_list.size());
142   for (const auto& tensor : tensor_list) {
143     tensor_metadata_list.emplace_back(tensor);
144   }
145   value_ = tensor_metadata_list;
146 }
147 
ParameterMetadata(const c10::Scalar & scalar,uint64_t input_order)148 ParameterMetadata::ParameterMetadata(
149     const c10::Scalar& scalar,
150     uint64_t input_order)
151     : tag_(SCALAR), value_(scalar), order_(input_order) {}
152 
ParameterMetadata(const std::string & str,uint64_t input_order)153 ParameterMetadata::ParameterMetadata(
154     const std::string& str,
155     uint64_t input_order)
156     : tag_(STRING), value_(str), order_(input_order) {}
157 
ParameterMetadata(const c10::Device & device,uint64_t input_order)158 ParameterMetadata::ParameterMetadata(
159     const c10::Device& device,
160     uint64_t input_order)
161     : tag_(DEVICE), value_(device), order_(input_order) {}
162 
operator ==(const ParameterMetadata & other) const163 bool ParameterMetadata::operator==(const ParameterMetadata& other) const {
164   // Same type
165   if (tag_ != other.tag_) {
166     return false;
167   }
168 
169   // Same order of the input parameters
170   if (order_ != other.order_) {
171     return false;
172   }
173 
174   switch (tag_) {
175     case TENSOR:
176       return std::get<TensorMetadata>(value_) ==
177           std::get<TensorMetadata>(other.value_);
178     case TENSOR_LIST:
179       return std::get<std::vector<TensorMetadata>>(value_) ==
180           std::get<std::vector<TensorMetadata>>(other.value_);
181     case SCALAR:
182       TORCH_INTERNAL_ASSERT(
183           std::get<c10::Scalar>(other.value_).isFloatingPoint() ||
184           std::get<c10::Scalar>(other.value_).isIntegral(true /*includeBool*/));
185       return equal_to(std::get<c10::Scalar>(other.value_));
186     case STRING:
187       return std::get<std::string>(value_) ==
188           std::get<std::string>(other.value_);
189     case DEVICE:
190       return std::get<c10::Device>(value_) ==
191           std::get<c10::Device>(other.value_);
192     default:
193       return false;
194   }
195 }
196 
equal_to(const c10::Scalar & scalar) const197 bool ParameterMetadata::equal_to(const c10::Scalar& scalar) const {
198   TORCH_INTERNAL_ASSERT(scalar.isFloatingPoint() || scalar.isIntegral(true));
199   if (tag_ != SCALAR) {
200     return false;
201   }
202 
203   auto self_scalar = std::get<c10::Scalar>(value_);
204   if (scalar.isFloatingPoint() && self_scalar.isFloatingPoint()) {
205     return self_scalar.toDouble() == scalar.toDouble();
206   } else if (scalar.isIntegral(true) && self_scalar.isIntegral(true)) {
207     return self_scalar.toInt() == scalar.toInt();
208   }
209 
210   return false;
211 }
212 
213 } // namespace torch::inductor
214 #endif
215