xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/onednn/kernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/codegen/onednn/graph_helper.h>
2 #include <torch/csrc/jit/codegen/onednn/kernel.h>
3 
4 #include <ATen/core/functional.h>
5 #include <torch/csrc/jit/jit_log.h>
6 
7 namespace torch {
8 namespace jit {
9 namespace fuser {
10 namespace onednn {
11 
12 using namespace dnnl::graph;
13 using data_type = dnnl::graph::logical_tensor::data_type;
14 
LlgaKernel(const Node * fusionNode)15 LlgaKernel::LlgaKernel(const Node* fusionNode)
16     : fusionNode_(fusionNode),
17       graph_(fusionNode->g(attr::Subgraph)),
18       nGraphInputs_(graph_->inputs().size()),
19       nOutputs_(graph_->outputs().size()),
20       debugName_(genDebugName()) {
21   // TODO: This is a workaround to recreate the partitions here.
22   // The ideal way is to use the partition serialization API (not available from
23   // LLGA now) to carry a serialized string representation from graph rewrite
24   // and deserialize it here.
25   auto llgaGraphHelper = LlgaGraphHelper(graph_);
26   auto partitions = llgaGraphHelper.getPartitions();
27   tensorIdToValue_ = llgaGraphHelper.getTensorIdToValue();
28   TORCH_CHECK(
29       partitions.size() == 1,
30       "LLGA subgraph should contain only one partition");
31   partition_ = partitions[0];
32   nPartitionInputs_ = partition_.get_input_ports().size();
33 #ifdef GRAPH_DEBUG_ENABLED
34   GRAPH_DEBUG("Initialized ", debugName(), "\n", graph_->toString());
35 #endif
36 }
37 
useOpaqueLayout(size_t offset) const38 bool LlgaKernel::useOpaqueLayout(size_t offset) const {
39   return LlgaNodeWrapper(fusionNode_).useOpaqueLayout(offset);
40 }
41 
initializeConstantInputs()42 void LlgaKernel::initializeConstantInputs() {
43   for (auto& lt : partition_.get_input_ports()) {
44     auto inputId = lt.get_id();
45     if (initializedInputIds_.find(inputId) == initializedInputIds_.end()) {
46       TORCH_CHECK(
47           tensorIdToValue_.count(inputId) > 0,
48           "inputs with inputId ",
49           inputId,
50           " is missing");
51       auto* value = tensorIdToValue_[inputId];
52 
53       TORCH_CHECK(
54           value->node()->kind() == prim::Constant &&
55               value->type()->cast<TensorType>(),
56           "inputs with inputId ",
57           inputId,
58           " should be a Constant tensor");
59       constantValues_.emplace_back(value);
60 
61       auto const_tensor = toIValue(value)->toTensor();
62       constantInputs_.emplace_back(const_tensor);
63     }
64   }
65 }
66 
initializeTensorIdToOccurence() const67 std::map<size_t, int64_t> LlgaKernel::initializeTensorIdToOccurence() const {
68   std::map<size_t, int64_t> tensorIdToOccurence;
69   for (auto& lt : partition_.get_input_ports()) {
70     auto inputId = lt.get_id();
71     std::map<size_t, int64_t>::iterator it(tensorIdToOccurence.find(inputId));
72     if (it != tensorIdToOccurence.end()) {
73       it->second++;
74     } else {
75       tensorIdToOccurence[inputId] = 1;
76     }
77   }
78   return tensorIdToOccurence;
79 }
80 
initializeInputSpecs(const TensorArgs & inputs)81 ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
82   ArgSpecs inputSpecs;
83   inputSpecs.reserve(nPartitionInputs_);
84   GRAPH_DEBUG("Initializing graph input logical tensors");
85   std::map<size_t, int64_t> tensorIdToOccurence =
86       initializeTensorIdToOccurence();
87   for (const auto i : c10::irange(nGraphInputs_)) {
88     auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]);
89     initializedInputIds_.insert(spec.tid());
90     int64_t occurence = tensorIdToOccurence[spec.tid()];
91     inputSpecs.insert(inputSpecs.end(), occurence, spec);
92     runArgsIdx_.insert(runArgsIdx_.end(), occurence, i);
93   }
94   GRAPH_DEBUG("Initializing constant input tensors");
95   initializeConstantInputs();
96 
97   TORCH_CHECK(
98       inputSpecs.size() + constantValues_.size() ==
99           static_cast<size_t>(nPartitionInputs_),
100       "Partition inputs are missing");
101   GRAPH_DEBUG(
102       "Concatenating constant input logical tensors to graph input "
103       "logical tensors");
104   for (Value* constant_value : constantValues_) {
105     ArgSpec constantInputSpec(constant_value);
106     inputSpecs.emplace_back(constantInputSpec);
107     constantLogicalTensors_.emplace_back(constantInputSpec.logical_tensor());
108   }
109   return inputSpecs;
110 }
111 
initializeOutputSpecs() const112 ArgSpecs LlgaKernel::initializeOutputSpecs() const {
113   ArgSpecs outputSpecs;
114   outputSpecs.reserve(nOutputs_);
115   for (const auto i : c10::irange(nOutputs_)) {
116     auto spec = ArgSpec(graph_->outputs()[i]);
117     if (useOpaqueLayout(i)) {
118       spec = spec.any();
119     }
120     outputSpecs.emplace_back(spec);
121   }
122   return outputSpecs;
123 }
124 
prepareRunArgs(const TensorArgs & inputs,TensorArgs & outputs) const125 std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
126     const TensorArgs& inputs,
127     TensorArgs& outputs) const {
128   RunArgs runInputs, runOutputs;
129   auto numInputs = runArgsIdx_.size();
130   for (const auto i : c10::irange(numInputs)) {
131     auto spec = inputSpecs_[i];
132     auto input = inputs[runArgsIdx_[i]];
133     runInputs.push_back(
134         {spec.logical_tensor(), Engine::getEngine(), input.data_ptr()});
135   }
136   auto numConstantInputs = constantInputs_.size();
137   for (size_t i = 0; i < numConstantInputs; i++) {
138     // constantInputSpecs are placed after graphInputSpecs
139     auto constantInputSpecIdx = nGraphInputs_ + i;
140     auto constantInputSpec = inputSpecs_[constantInputSpecIdx];
141     runInputs.push_back(
142         {constantLogicalTensors_[i],
143          Engine::getEngine(),
144          constantInputs_[i].data_ptr()});
145   }
146 
147   for (const auto i : c10::irange(nOutputs_)) {
148     auto spec = outputSpecs_[i];
149     auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_);
150 
151     if (spec.reuses_input_tensor()) {
152 #ifdef GRAPH_DEBUG_ENABLED
153       GRAPH_DEBUG("inplace computation - input tensor would be reused");
154 #endif
155       auto inputTensor = inputs[spec.get_input_tensor_index()];
156       if (inputTensor.is_mkldnn()) {
157         auto dataType = spec.dtype();
158         if (C10_UNLIKELY(!useOpaqueLayout(i))) {
159           // If the input tensor was between two partitions, it would've been
160           // wrapped with LlgaTensorImpl. But if it's being reused as the output
161           // tensor, which is not between two partitions, then we'd have to
162           // re-wrap it with a sub-class of TensorImpl, as it'd be fed into a
163           // PyTorch op.
164 #ifdef GRAPH_DEBUG_ENABLED
165           GRAPH_DEBUG("rewrap tensors");
166 #endif
167           auto llgaImpl =
168               static_cast<LlgaTensorImpl*>(inputTensor.unsafeGetTensorImpl());
169           switch (dataType) {
170             case data_type::f32:
171             case data_type::bf16:
172               inputTensor = LlgaTensorImpl::llga_to_aten_tensor(llgaImpl);
173               break;
174             case data_type::s32:
175             default:
176               TORCH_CHECK(
177                   false, "Invalid data type ", static_cast<size_t>(dataType));
178           }
179         }
180         outputs.push_back(inputTensor);
181         runOutputs.push_back(
182             {spec.logical_tensor(),
183              Engine::getEngine(),
184              inputTensor.data_ptr()});
185         return std::make_tuple(runInputs, runOutputs);
186       }
187     }
188     if (useOpaqueLayout(i)) {
189       // Wrap tensors between partitions with LlgaTensorImpl wrapper, so that we
190       // can bypass guard-check, as strides would be different than those
191       // expected.
192 #ifdef GRAPH_DEBUG_ENABLED
193       GRAPH_DEBUG("Between two oneDNN Graph partitions");
194 #endif
195       auto tensor = empty_llga(spec, opt);
196       outputs.push_back(tensor);
197       runOutputs.push_back(llga_from_aten_tensor(tensor));
198     } else {
199 #ifdef GRAPH_DEBUG_ENABLED
200       GRAPH_DEBUG("Neither opaque to PyTorch nor inplace-computation");
201 #endif
202       auto tensor = at::empty_strided(spec.sizes(), spec.strides(), opt);
203       outputs.push_back(tensor);
204       runOutputs.push_back(
205           {spec.logical_tensor(), Engine::getEngine(), tensor.data_ptr()});
206     }
207   }
208 
209   return std::make_tuple(runInputs, runOutputs);
210 }
211 
compile(const partition & partition)212 compiled_partition LlgaKernel::compile(const partition& partition) {
213   auto inputs = fmap(inputSpecs_, toLogicalTensor);
214   auto outputs = fmap(outputSpecs_, toLogicalTensor);
215   auto compilation = partition.compile(inputs, outputs, Engine::getEngine());
216 
217   // Since layouts of opaque outputs would be known after compilation,
218   // we need to query them out from compilation and update outputSpecs
219   for (const auto i : c10::irange(nOutputs_)) {
220     auto tid = outputSpecs_[i].tid();
221     outputSpecs_[i] = compilation.query_logical_tensor(tid);
222   }
223 
224   // Build static mapping from output id to input offset
225   // in accordance with available inplace options
226   for (auto&& option : compilation.get_inplace_ports()) {
227     size_t inputId = option.first;
228     size_t outputId = option.second;
229     auto inputSpecIter =
230         std::find_if(inputSpecs_.begin(), inputSpecs_.end(), [&](auto& spec) {
231           return spec.tid() == inputId;
232         });
233     TORCH_CHECK(inputSpecIter != inputSpecs_.end(), "In-place input not found");
234     auto inputOffset = inputSpecIter - inputSpecs_.begin();
235     auto outputSpecIter =
236         std::find_if(outputSpecs_.begin(), outputSpecs_.end(), [&](auto& spec) {
237           return spec.tid() == outputId;
238         });
239     auto outputOffset = outputSpecIter - outputSpecs_.begin();
240     outputSpecs_[outputOffset].set_compute_inplace();
241     outputSpecs_[outputOffset].set_input_tensor_index(inputOffset);
242   }
243 
244   return compilation;
245 }
246 
run(Stack & stack)247 void LlgaKernel::run(Stack& stack) {
248 #ifdef GRAPH_DEBUG_ENABLED
249   GRAPH_DEBUG("In ", debugName(), "\n");
250 #endif
251 
252   // Grab input values from stack
253   auto stackInputs = last(stack, nGraphInputs_);
254   auto inputs = fmap(stackInputs, [&](const IValue& v) {
255     TORCH_CHECK(
256         v.isTensor(), "Stack values for LLGA partition must be Tensor type");
257     return v.toTensor();
258   });
259 
260   // Even in case of concurrent threads, the kernel would be initialized once.
261   // TODO: Try not using an atomic lock
262   c10::call_once(
263       initialized_flag,
264       [&](const TensorArgs& inputs) {
265         GRAPH_DEBUG("Initializing input logical tensors");
266         inputSpecs_ = initializeInputSpecs(inputs);
267         GRAPH_DEBUG("Initializing output logical tensors");
268         outputSpecs_ = initializeOutputSpecs();
269         GRAPH_DEBUG("Compiling partition");
270         compilation_ = compile(partition_);
271         is_initialized_ = true;
272       },
273       inputs);
274 #ifdef GRAPH_DEBUG_ENABLED
275   GRAPH_DEBUG("Preparing runtime tensors");
276 #endif
277   TensorArgs outputs;
278   auto [runInputs, runOutputs] = prepareRunArgs(inputs, outputs);
279 #ifdef GRAPH_DEBUG_ENABLED
280   GRAPH_DEBUG("Executing partition");
281 #endif
282   compilation_.execute(Stream::getStream(), runInputs, runOutputs);
283 #ifdef GRAPH_DEBUG_ENABLED
284   GRAPH_DEBUG("Partition executed");
285 #endif
286 
287   // Update the stack.
288   drop(stack, nGraphInputs_);
289   for (auto& o : outputs)
290     push_one(stack, std::move(o));
291 #ifdef GRAPH_DEBUG_ENABLED
292   GRAPH_DEBUG("Stack updated");
293 #endif
294 }
295 
296 } // namespace onednn
297 } // namespace fuser
298 } // namespace jit
299 } // namespace torch
300