1 #pragma once 2 3 #include <ATen/ATen.h> 4 #include <ATen/Config.h> 5 6 #include <oneapi/dnnl/dnnl_graph.hpp> 7 #include <torch/csrc/jit/ir/ir.h> 8 9 namespace torch { 10 namespace jit { 11 namespace fuser { 12 namespace onednn { 13 14 // Engine represents a device and its context. From the device kind, the engine 15 // knows how to generate code for the target device and what kind of device 16 // object to be expected. The device id ensures that there is a unique engine 17 // being created for each device. The device handle passed from PyTorch allows 18 // oneDNN Graph implementation to work on the device specified by PyTorch, which 19 // is currently CPU, so we only have one engine. 20 // Ref: https://spec.oneapi.io/onednn-graph/latest/programming_model.html#engine 21 struct Engine { 22 // CPU engine singleton 23 static dnnl::engine& getEngine(); 24 Engine(const Engine&) = delete; 25 void operator=(const Engine&) = delete; 26 }; 27 28 // Stream is the logical abstraction for execution units. It is created on top 29 // of oneDNN Graph engine. A compiled oneDNN Graph partition is submitted to a 30 // stream for execution. 31 struct Stream { 32 // CPU stream singleton 33 static dnnl::stream& getStream(); 34 Stream(const Stream&) = delete; 35 void operator=(const Stream&) = delete; 36 }; 37 38 struct LlgaTensorDesc { 39 using desc = dnnl::graph::logical_tensor; 40 LlgaTensorDescLlgaTensorDesc41 LlgaTensorDesc( 42 size_t tid, 43 std::vector<int64_t> sizes, 44 std::vector<int64_t> strides, 45 desc::data_type dtype, 46 desc::property_type property_type) 47 : tid_(tid), 48 sizes_(sizes), 49 strides_(strides), 50 dtype_(dtype), 51 property_type_(property_type), 52 layout_type_(desc::layout_type::strided), 53 layout_id_(-1) {} 54 LlgaTensorDescLlgaTensorDesc55 LlgaTensorDesc(const desc& t) 56 : tid_(t.get_id()), 57 sizes_(t.get_dims()), 58 strides_({-1}), 59 dtype_(t.get_data_type()), 60 property_type_(t.get_property_type()), 61 layout_type_(t.get_layout_type()), 62 layout_id_(-1) { 63 if (is_opaque()) { 64 layout_id_ = t.get_layout_id(); 65 } 66 if (is_strided()) { 67 strides_ = t.get_strides(); 68 } 69 } 70 LlgaTensorDescLlgaTensorDesc71 LlgaTensorDesc(const torch::jit::Value* v) 72 : LlgaTensorDesc( 73 v->unique(), 74 {}, 75 {}, 76 desc::data_type::f32, 77 get_property_type(v)) { 78 if (v->type()->isSubtypeOf(TensorType::get())) { 79 auto tt = v->type()->cast<TensorType>(); 80 81 if (tt->scalarType()) { 82 dtype_ = getLlgaDataType(tt->scalarType().value()); 83 } 84 85 auto sizes = tt->sizes(); 86 if (sizes.sizes()) { 87 for (auto d : *sizes.sizes()) { 88 sizes_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM)); 89 } 90 } 91 92 auto strides = tt->strides(); 93 if (strides.sizes()) { 94 for (auto d : *strides.sizes()) { 95 strides_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM)); 96 } 97 } 98 } 99 } 100 101 LlgaTensorDesc supplementTensorInfo(const at::Tensor& t) const; 102 103 desc::data_type getLlgaDataType(at::ScalarType dt) const; 104 105 at::ScalarType aten_scalar_type() const; 106 sizesLlgaTensorDesc107 const std::vector<int64_t>& sizes() const { 108 return sizes_; 109 } 110 stridesLlgaTensorDesc111 const std::vector<int64_t>& strides() const { 112 TORCH_CHECK(!is_opaque(), "Cannot get strides on opaque layout"); 113 return strides_; 114 } 115 tidLlgaTensorDesc116 size_t tid() const { 117 return tid_; 118 } 119 tidLlgaTensorDesc120 LlgaTensorDesc tid(uint64_t new_id) const { 121 auto ret = *this; 122 ret.tid_ = new_id; 123 return ret; 124 } 125 dtypeLlgaTensorDesc126 desc::data_type dtype() const { 127 return dtype_; 128 } 129 dtypeLlgaTensorDesc130 LlgaTensorDesc dtype(desc::data_type new_dtype) const { 131 return LlgaTensorDesc(tid_, sizes_, strides_, new_dtype, property_type_); 132 } 133 layout_typeLlgaTensorDesc134 desc::layout_type layout_type() const { 135 return layout_type_; 136 } 137 layout_typeLlgaTensorDesc138 LlgaTensorDesc layout_type(desc::layout_type new_layout_type) { 139 auto ret = *this; 140 ret.layout_type_ = new_layout_type; 141 return ret; 142 } 143 get_property_typeLlgaTensorDesc144 desc::property_type get_property_type(const torch::jit::Value* v) { 145 switch (v->node()->kind()) { 146 case prim::Constant: 147 return desc::property_type::constant; 148 default: 149 return desc::property_type::variable; 150 } 151 } 152 anyLlgaTensorDesc153 LlgaTensorDesc any() { 154 return layout_type(desc::layout_type::any); 155 } 156 storage_sizeLlgaTensorDesc157 size_t storage_size() const { 158 return logical_tensor().get_mem_size(); 159 } 160 logical_tensorLlgaTensorDesc161 desc logical_tensor() const { 162 if (is_dimensionality_unknown()) { 163 return desc( 164 tid_, dtype_, DNNL_GRAPH_UNKNOWN_NDIMS, layout_type_, property_type_); 165 } else if (is_opaque()) { 166 return desc(tid_, dtype_, sizes_, layout_id_, property_type_); 167 } else if (is_any()) { 168 return desc(tid_, dtype_, sizes_, layout_type_, property_type_); 169 } else { 170 return desc(tid_, dtype_, sizes_, strides_, property_type_); 171 } 172 } 173 is_stridedLlgaTensorDesc174 bool is_strided() const { 175 return layout_type_ == desc::layout_type::strided; 176 } 177 is_anyLlgaTensorDesc178 bool is_any() const { 179 return layout_type_ == desc::layout_type::any; 180 } 181 is_opaqueLlgaTensorDesc182 bool is_opaque() const { 183 return layout_type_ == desc::layout_type::opaque; 184 } 185 186 bool operator==(const LlgaTensorDesc& desc) const { 187 return tid_ == desc.tid_ && sizes_ == desc.sizes_ && 188 dtype_ == desc.dtype_ && layout_type_ == desc.layout_type_ && 189 ((is_opaque() && layout_id_ == desc.layout_id_) || 190 strides_ == desc.strides_); 191 } 192 193 bool operator!=(const LlgaTensorDesc& desc) const { 194 return (tid_ != desc.tid_) || (sizes_ != desc.sizes_) || 195 (dtype_ != desc.dtype_) || (layout_type_ != desc.layout_type_) || 196 !((is_opaque() && (layout_id_ == desc.layout_id_)) || 197 (strides_ == desc.strides_)); 198 } 199 hashLlgaTensorDesc200 static size_t hash(const LlgaTensorDesc& desc) { 201 return c10::get_hash( 202 desc.tid_, 203 desc.sizes_, 204 desc.dtype_, 205 desc.layout_type_, 206 desc.layout_id_); 207 } 208 set_compute_inplaceLlgaTensorDesc209 void set_compute_inplace() { 210 compute_inplace_ = true; 211 } 212 set_input_tensor_indexLlgaTensorDesc213 void set_input_tensor_index(size_t index) { 214 input_tensor_index_ = index; 215 } 216 reuses_input_tensorLlgaTensorDesc217 bool reuses_input_tensor() { 218 return compute_inplace_; 219 } 220 get_input_tensor_indexLlgaTensorDesc221 size_t get_input_tensor_index() { 222 return input_tensor_index_; 223 } 224 225 private: is_dimensionality_unknownLlgaTensorDesc226 bool is_dimensionality_unknown() const { 227 return sizes_.size() == 0; 228 } 229 230 size_t tid_; 231 std::vector<int64_t> sizes_; 232 std::vector<int64_t> strides_; 233 desc::data_type dtype_; 234 desc::property_type property_type_; 235 desc::layout_type layout_type_; 236 size_t layout_id_; 237 // If this is an output tensor, and querying the compiled partition would 238 // determine that this tensor would reuse its input tensor, then 239 // compute_inplace would be true, and input_tensor_index would be the index of 240 // the corresponding input tensor in inputSpecs_ of the LlgaKernel object. 241 bool compute_inplace_ = false; 242 size_t input_tensor_index_; 243 }; 244 245 // Initially, oneDNN Graph also used to have blocked layout for tensors between 246 // partitions, and the LlgaTensorImpl wrapper helped us bypass guard checks. 247 // oneDNN Graph has switched over to using strided tensors between partitions, 248 // but this wrapper still helps us bypass guard checks because the strides of 249 // tensors between partitions would be different from the ones the guard is 250 // otherwise expecting. 251 struct TORCH_API LlgaTensorImpl : public c10::TensorImpl { 252 LlgaTensorImpl( 253 at::Storage&& storage, 254 const caffe2::TypeMeta& data_type, 255 const LlgaTensorDesc& desc); 256 descLlgaTensorImpl257 const LlgaTensorDesc& desc() const { 258 return desc_; 259 } 260 261 static at::Tensor llga_to_aten_tensor(LlgaTensorImpl* llgaImpl); 262 263 private: 264 LlgaTensorDesc desc_; 265 }; 266 267 at::Tensor empty_llga( 268 const LlgaTensorDesc& desc, 269 const c10::TensorOptions& options); 270 271 dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor); 272 273 } // namespace onednn 274 } // namespace fuser 275 } // namespace jit 276 } // namespace torch 277