xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Config.h>
5 
6 #include <oneapi/dnnl/dnnl_graph.hpp>
7 #include <torch/csrc/jit/ir/ir.h>
8 
9 namespace torch {
10 namespace jit {
11 namespace fuser {
12 namespace onednn {
13 
14 // Engine represents a device and its context. From the device kind, the engine
15 // knows how to generate code for the target device and what kind of device
16 // object to be expected. The device id ensures that there is a unique engine
17 // being created for each device. The device handle passed from PyTorch allows
18 // oneDNN Graph implementation to work on the device specified by PyTorch, which
19 // is currently CPU, so we only have one engine.
20 // Ref: https://spec.oneapi.io/onednn-graph/latest/programming_model.html#engine
21 struct Engine {
22   // CPU engine singleton
23   static dnnl::engine& getEngine();
24   Engine(const Engine&) = delete;
25   void operator=(const Engine&) = delete;
26 };
27 
28 // Stream is the logical abstraction for execution units. It is created on top
29 // of oneDNN Graph engine. A compiled oneDNN Graph partition is submitted to a
30 // stream for execution.
31 struct Stream {
32   // CPU stream singleton
33   static dnnl::stream& getStream();
34   Stream(const Stream&) = delete;
35   void operator=(const Stream&) = delete;
36 };
37 
38 struct LlgaTensorDesc {
39   using desc = dnnl::graph::logical_tensor;
40 
LlgaTensorDescLlgaTensorDesc41   LlgaTensorDesc(
42       size_t tid,
43       std::vector<int64_t> sizes,
44       std::vector<int64_t> strides,
45       desc::data_type dtype,
46       desc::property_type property_type)
47       : tid_(tid),
48         sizes_(sizes),
49         strides_(strides),
50         dtype_(dtype),
51         property_type_(property_type),
52         layout_type_(desc::layout_type::strided),
53         layout_id_(-1) {}
54 
LlgaTensorDescLlgaTensorDesc55   LlgaTensorDesc(const desc& t)
56       : tid_(t.get_id()),
57         sizes_(t.get_dims()),
58         strides_({-1}),
59         dtype_(t.get_data_type()),
60         property_type_(t.get_property_type()),
61         layout_type_(t.get_layout_type()),
62         layout_id_(-1) {
63     if (is_opaque()) {
64       layout_id_ = t.get_layout_id();
65     }
66     if (is_strided()) {
67       strides_ = t.get_strides();
68     }
69   }
70 
LlgaTensorDescLlgaTensorDesc71   LlgaTensorDesc(const torch::jit::Value* v)
72       : LlgaTensorDesc(
73             v->unique(),
74             {},
75             {},
76             desc::data_type::f32,
77             get_property_type(v)) {
78     if (v->type()->isSubtypeOf(TensorType::get())) {
79       auto tt = v->type()->cast<TensorType>();
80 
81       if (tt->scalarType()) {
82         dtype_ = getLlgaDataType(tt->scalarType().value());
83       }
84 
85       auto sizes = tt->sizes();
86       if (sizes.sizes()) {
87         for (auto d : *sizes.sizes()) {
88           sizes_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM));
89         }
90       }
91 
92       auto strides = tt->strides();
93       if (strides.sizes()) {
94         for (auto d : *strides.sizes()) {
95           strides_.push_back(d.value_or(DNNL_GRAPH_UNKNOWN_DIM));
96         }
97       }
98     }
99   }
100 
101   LlgaTensorDesc supplementTensorInfo(const at::Tensor& t) const;
102 
103   desc::data_type getLlgaDataType(at::ScalarType dt) const;
104 
105   at::ScalarType aten_scalar_type() const;
106 
sizesLlgaTensorDesc107   const std::vector<int64_t>& sizes() const {
108     return sizes_;
109   }
110 
stridesLlgaTensorDesc111   const std::vector<int64_t>& strides() const {
112     TORCH_CHECK(!is_opaque(), "Cannot get strides on opaque layout");
113     return strides_;
114   }
115 
tidLlgaTensorDesc116   size_t tid() const {
117     return tid_;
118   }
119 
tidLlgaTensorDesc120   LlgaTensorDesc tid(uint64_t new_id) const {
121     auto ret = *this;
122     ret.tid_ = new_id;
123     return ret;
124   }
125 
dtypeLlgaTensorDesc126   desc::data_type dtype() const {
127     return dtype_;
128   }
129 
dtypeLlgaTensorDesc130   LlgaTensorDesc dtype(desc::data_type new_dtype) const {
131     return LlgaTensorDesc(tid_, sizes_, strides_, new_dtype, property_type_);
132   }
133 
layout_typeLlgaTensorDesc134   desc::layout_type layout_type() const {
135     return layout_type_;
136   }
137 
layout_typeLlgaTensorDesc138   LlgaTensorDesc layout_type(desc::layout_type new_layout_type) {
139     auto ret = *this;
140     ret.layout_type_ = new_layout_type;
141     return ret;
142   }
143 
get_property_typeLlgaTensorDesc144   desc::property_type get_property_type(const torch::jit::Value* v) {
145     switch (v->node()->kind()) {
146       case prim::Constant:
147         return desc::property_type::constant;
148       default:
149         return desc::property_type::variable;
150     }
151   }
152 
anyLlgaTensorDesc153   LlgaTensorDesc any() {
154     return layout_type(desc::layout_type::any);
155   }
156 
storage_sizeLlgaTensorDesc157   size_t storage_size() const {
158     return logical_tensor().get_mem_size();
159   }
160 
logical_tensorLlgaTensorDesc161   desc logical_tensor() const {
162     if (is_dimensionality_unknown()) {
163       return desc(
164           tid_, dtype_, DNNL_GRAPH_UNKNOWN_NDIMS, layout_type_, property_type_);
165     } else if (is_opaque()) {
166       return desc(tid_, dtype_, sizes_, layout_id_, property_type_);
167     } else if (is_any()) {
168       return desc(tid_, dtype_, sizes_, layout_type_, property_type_);
169     } else {
170       return desc(tid_, dtype_, sizes_, strides_, property_type_);
171     }
172   }
173 
is_stridedLlgaTensorDesc174   bool is_strided() const {
175     return layout_type_ == desc::layout_type::strided;
176   }
177 
is_anyLlgaTensorDesc178   bool is_any() const {
179     return layout_type_ == desc::layout_type::any;
180   }
181 
is_opaqueLlgaTensorDesc182   bool is_opaque() const {
183     return layout_type_ == desc::layout_type::opaque;
184   }
185 
186   bool operator==(const LlgaTensorDesc& desc) const {
187     return tid_ == desc.tid_ && sizes_ == desc.sizes_ &&
188         dtype_ == desc.dtype_ && layout_type_ == desc.layout_type_ &&
189         ((is_opaque() && layout_id_ == desc.layout_id_) ||
190          strides_ == desc.strides_);
191   }
192 
193   bool operator!=(const LlgaTensorDesc& desc) const {
194     return (tid_ != desc.tid_) || (sizes_ != desc.sizes_) ||
195         (dtype_ != desc.dtype_) || (layout_type_ != desc.layout_type_) ||
196         !((is_opaque() && (layout_id_ == desc.layout_id_)) ||
197           (strides_ == desc.strides_));
198   }
199 
hashLlgaTensorDesc200   static size_t hash(const LlgaTensorDesc& desc) {
201     return c10::get_hash(
202         desc.tid_,
203         desc.sizes_,
204         desc.dtype_,
205         desc.layout_type_,
206         desc.layout_id_);
207   }
208 
set_compute_inplaceLlgaTensorDesc209   void set_compute_inplace() {
210     compute_inplace_ = true;
211   }
212 
set_input_tensor_indexLlgaTensorDesc213   void set_input_tensor_index(size_t index) {
214     input_tensor_index_ = index;
215   }
216 
reuses_input_tensorLlgaTensorDesc217   bool reuses_input_tensor() {
218     return compute_inplace_;
219   }
220 
get_input_tensor_indexLlgaTensorDesc221   size_t get_input_tensor_index() {
222     return input_tensor_index_;
223   }
224 
225  private:
is_dimensionality_unknownLlgaTensorDesc226   bool is_dimensionality_unknown() const {
227     return sizes_.size() == 0;
228   }
229 
230   size_t tid_;
231   std::vector<int64_t> sizes_;
232   std::vector<int64_t> strides_;
233   desc::data_type dtype_;
234   desc::property_type property_type_;
235   desc::layout_type layout_type_;
236   size_t layout_id_;
237   // If this is an output tensor, and querying the compiled partition would
238   // determine that this tensor would reuse its input tensor, then
239   // compute_inplace would be true, and input_tensor_index would be the index of
240   // the corresponding input tensor in inputSpecs_ of the LlgaKernel object.
241   bool compute_inplace_ = false;
242   size_t input_tensor_index_;
243 };
244 
245 // Initially, oneDNN Graph also used to have blocked layout for tensors between
246 // partitions, and the LlgaTensorImpl wrapper helped us bypass guard checks.
247 // oneDNN Graph has switched over to using strided tensors between partitions,
248 // but this wrapper still helps us bypass guard checks because the strides of
249 // tensors between partitions would be different from the ones the guard is
250 // otherwise expecting.
251 struct TORCH_API LlgaTensorImpl : public c10::TensorImpl {
252   LlgaTensorImpl(
253       at::Storage&& storage,
254       const caffe2::TypeMeta& data_type,
255       const LlgaTensorDesc& desc);
256 
descLlgaTensorImpl257   const LlgaTensorDesc& desc() const {
258     return desc_;
259   }
260 
261   static at::Tensor llga_to_aten_tensor(LlgaTensorImpl* llgaImpl);
262 
263  private:
264   LlgaTensorDesc desc_;
265 };
266 
267 at::Tensor empty_llga(
268     const LlgaTensorDesc& desc,
269     const c10::TensorOptions& options);
270 
271 dnnl::graph::tensor llga_from_aten_tensor(const at::Tensor& tensor);
272 
273 } // namespace onednn
274 } // namespace fuser
275 } // namespace jit
276 } // namespace torch
277