1 #include <torch/csrc/jit/codegen/onednn/graph_helper.h>
2 #include <torch/csrc/jit/codegen/onednn/kernel.h>
3
4 #include <ATen/core/functional.h>
5 #include <torch/csrc/jit/jit_log.h>
6
7 namespace torch {
8 namespace jit {
9 namespace fuser {
10 namespace onednn {
11
12 using namespace dnnl::graph;
13 using data_type = dnnl::graph::logical_tensor::data_type;
14
LlgaKernel(const Node * fusionNode)15 LlgaKernel::LlgaKernel(const Node* fusionNode)
16 : fusionNode_(fusionNode),
17 graph_(fusionNode->g(attr::Subgraph)),
18 nGraphInputs_(graph_->inputs().size()),
19 nOutputs_(graph_->outputs().size()),
20 debugName_(genDebugName()) {
21 // TODO: This is a workaround to recreate the partitions here.
22 // The ideal way is to use the partition serialization API (not available from
23 // LLGA now) to carry a serialized string representation from graph rewrite
24 // and deserialize it here.
25 auto llgaGraphHelper = LlgaGraphHelper(graph_);
26 auto partitions = llgaGraphHelper.getPartitions();
27 tensorIdToValue_ = llgaGraphHelper.getTensorIdToValue();
28 TORCH_CHECK(
29 partitions.size() == 1,
30 "LLGA subgraph should contain only one partition");
31 partition_ = partitions[0];
32 nPartitionInputs_ = partition_.get_input_ports().size();
33 #ifdef GRAPH_DEBUG_ENABLED
34 GRAPH_DEBUG("Initialized ", debugName(), "\n", graph_->toString());
35 #endif
36 }
37
useOpaqueLayout(size_t offset) const38 bool LlgaKernel::useOpaqueLayout(size_t offset) const {
39 return LlgaNodeWrapper(fusionNode_).useOpaqueLayout(offset);
40 }
41
initializeConstantInputs()42 void LlgaKernel::initializeConstantInputs() {
43 for (auto& lt : partition_.get_input_ports()) {
44 auto inputId = lt.get_id();
45 if (initializedInputIds_.find(inputId) == initializedInputIds_.end()) {
46 TORCH_CHECK(
47 tensorIdToValue_.count(inputId) > 0,
48 "inputs with inputId ",
49 inputId,
50 " is missing");
51 auto* value = tensorIdToValue_[inputId];
52
53 TORCH_CHECK(
54 value->node()->kind() == prim::Constant &&
55 value->type()->cast<TensorType>(),
56 "inputs with inputId ",
57 inputId,
58 " should be a Constant tensor");
59 constantValues_.emplace_back(value);
60
61 auto const_tensor = toIValue(value)->toTensor();
62 constantInputs_.emplace_back(const_tensor);
63 }
64 }
65 }
66
initializeTensorIdToOccurence() const67 std::map<size_t, int64_t> LlgaKernel::initializeTensorIdToOccurence() const {
68 std::map<size_t, int64_t> tensorIdToOccurence;
69 for (auto& lt : partition_.get_input_ports()) {
70 auto inputId = lt.get_id();
71 std::map<size_t, int64_t>::iterator it(tensorIdToOccurence.find(inputId));
72 if (it != tensorIdToOccurence.end()) {
73 it->second++;
74 } else {
75 tensorIdToOccurence[inputId] = 1;
76 }
77 }
78 return tensorIdToOccurence;
79 }
80
initializeInputSpecs(const TensorArgs & inputs)81 ArgSpecs LlgaKernel::initializeInputSpecs(const TensorArgs& inputs) {
82 ArgSpecs inputSpecs;
83 inputSpecs.reserve(nPartitionInputs_);
84 GRAPH_DEBUG("Initializing graph input logical tensors");
85 std::map<size_t, int64_t> tensorIdToOccurence =
86 initializeTensorIdToOccurence();
87 for (const auto i : c10::irange(nGraphInputs_)) {
88 auto spec = ArgSpec(graph_->inputs()[i]).supplementTensorInfo(inputs[i]);
89 initializedInputIds_.insert(spec.tid());
90 int64_t occurence = tensorIdToOccurence[spec.tid()];
91 inputSpecs.insert(inputSpecs.end(), occurence, spec);
92 runArgsIdx_.insert(runArgsIdx_.end(), occurence, i);
93 }
94 GRAPH_DEBUG("Initializing constant input tensors");
95 initializeConstantInputs();
96
97 TORCH_CHECK(
98 inputSpecs.size() + constantValues_.size() ==
99 static_cast<size_t>(nPartitionInputs_),
100 "Partition inputs are missing");
101 GRAPH_DEBUG(
102 "Concatenating constant input logical tensors to graph input "
103 "logical tensors");
104 for (Value* constant_value : constantValues_) {
105 ArgSpec constantInputSpec(constant_value);
106 inputSpecs.emplace_back(constantInputSpec);
107 constantLogicalTensors_.emplace_back(constantInputSpec.logical_tensor());
108 }
109 return inputSpecs;
110 }
111
initializeOutputSpecs() const112 ArgSpecs LlgaKernel::initializeOutputSpecs() const {
113 ArgSpecs outputSpecs;
114 outputSpecs.reserve(nOutputs_);
115 for (const auto i : c10::irange(nOutputs_)) {
116 auto spec = ArgSpec(graph_->outputs()[i]);
117 if (useOpaqueLayout(i)) {
118 spec = spec.any();
119 }
120 outputSpecs.emplace_back(spec);
121 }
122 return outputSpecs;
123 }
124
prepareRunArgs(const TensorArgs & inputs,TensorArgs & outputs) const125 std::tuple<RunArgs, RunArgs> LlgaKernel::prepareRunArgs(
126 const TensorArgs& inputs,
127 TensorArgs& outputs) const {
128 RunArgs runInputs, runOutputs;
129 auto numInputs = runArgsIdx_.size();
130 for (const auto i : c10::irange(numInputs)) {
131 auto spec = inputSpecs_[i];
132 auto input = inputs[runArgsIdx_[i]];
133 runInputs.push_back(
134 {spec.logical_tensor(), Engine::getEngine(), input.data_ptr()});
135 }
136 auto numConstantInputs = constantInputs_.size();
137 for (size_t i = 0; i < numConstantInputs; i++) {
138 // constantInputSpecs are placed after graphInputSpecs
139 auto constantInputSpecIdx = nGraphInputs_ + i;
140 auto constantInputSpec = inputSpecs_[constantInputSpecIdx];
141 runInputs.push_back(
142 {constantLogicalTensors_[i],
143 Engine::getEngine(),
144 constantInputs_[i].data_ptr()});
145 }
146
147 for (const auto i : c10::irange(nOutputs_)) {
148 auto spec = outputSpecs_[i];
149 auto opt = c10::TensorOptions(spec.aten_scalar_type()).device(device_);
150
151 if (spec.reuses_input_tensor()) {
152 #ifdef GRAPH_DEBUG_ENABLED
153 GRAPH_DEBUG("inplace computation - input tensor would be reused");
154 #endif
155 auto inputTensor = inputs[spec.get_input_tensor_index()];
156 if (inputTensor.is_mkldnn()) {
157 auto dataType = spec.dtype();
158 if (C10_UNLIKELY(!useOpaqueLayout(i))) {
159 // If the input tensor was between two partitions, it would've been
160 // wrapped with LlgaTensorImpl. But if it's being reused as the output
161 // tensor, which is not between two partitions, then we'd have to
162 // re-wrap it with a sub-class of TensorImpl, as it'd be fed into a
163 // PyTorch op.
164 #ifdef GRAPH_DEBUG_ENABLED
165 GRAPH_DEBUG("rewrap tensors");
166 #endif
167 auto llgaImpl =
168 static_cast<LlgaTensorImpl*>(inputTensor.unsafeGetTensorImpl());
169 switch (dataType) {
170 case data_type::f32:
171 case data_type::bf16:
172 inputTensor = LlgaTensorImpl::llga_to_aten_tensor(llgaImpl);
173 break;
174 case data_type::s32:
175 default:
176 TORCH_CHECK(
177 false, "Invalid data type ", static_cast<size_t>(dataType));
178 }
179 }
180 outputs.push_back(inputTensor);
181 runOutputs.push_back(
182 {spec.logical_tensor(),
183 Engine::getEngine(),
184 inputTensor.data_ptr()});
185 return std::make_tuple(runInputs, runOutputs);
186 }
187 }
188 if (useOpaqueLayout(i)) {
189 // Wrap tensors between partitions with LlgaTensorImpl wrapper, so that we
190 // can bypass guard-check, as strides would be different than those
191 // expected.
192 #ifdef GRAPH_DEBUG_ENABLED
193 GRAPH_DEBUG("Between two oneDNN Graph partitions");
194 #endif
195 auto tensor = empty_llga(spec, opt);
196 outputs.push_back(tensor);
197 runOutputs.push_back(llga_from_aten_tensor(tensor));
198 } else {
199 #ifdef GRAPH_DEBUG_ENABLED
200 GRAPH_DEBUG("Neither opaque to PyTorch nor inplace-computation");
201 #endif
202 auto tensor = at::empty_strided(spec.sizes(), spec.strides(), opt);
203 outputs.push_back(tensor);
204 runOutputs.push_back(
205 {spec.logical_tensor(), Engine::getEngine(), tensor.data_ptr()});
206 }
207 }
208
209 return std::make_tuple(runInputs, runOutputs);
210 }
211
compile(const partition & partition)212 compiled_partition LlgaKernel::compile(const partition& partition) {
213 auto inputs = fmap(inputSpecs_, toLogicalTensor);
214 auto outputs = fmap(outputSpecs_, toLogicalTensor);
215 auto compilation = partition.compile(inputs, outputs, Engine::getEngine());
216
217 // Since layouts of opaque outputs would be known after compilation,
218 // we need to query them out from compilation and update outputSpecs
219 for (const auto i : c10::irange(nOutputs_)) {
220 auto tid = outputSpecs_[i].tid();
221 outputSpecs_[i] = compilation.query_logical_tensor(tid);
222 }
223
224 // Build static mapping from output id to input offset
225 // in accordance with available inplace options
226 for (auto&& option : compilation.get_inplace_ports()) {
227 size_t inputId = option.first;
228 size_t outputId = option.second;
229 auto inputSpecIter =
230 std::find_if(inputSpecs_.begin(), inputSpecs_.end(), [&](auto& spec) {
231 return spec.tid() == inputId;
232 });
233 TORCH_CHECK(inputSpecIter != inputSpecs_.end(), "In-place input not found");
234 auto inputOffset = inputSpecIter - inputSpecs_.begin();
235 auto outputSpecIter =
236 std::find_if(outputSpecs_.begin(), outputSpecs_.end(), [&](auto& spec) {
237 return spec.tid() == outputId;
238 });
239 auto outputOffset = outputSpecIter - outputSpecs_.begin();
240 outputSpecs_[outputOffset].set_compute_inplace();
241 outputSpecs_[outputOffset].set_input_tensor_index(inputOffset);
242 }
243
244 return compilation;
245 }
246
run(Stack & stack)247 void LlgaKernel::run(Stack& stack) {
248 #ifdef GRAPH_DEBUG_ENABLED
249 GRAPH_DEBUG("In ", debugName(), "\n");
250 #endif
251
252 // Grab input values from stack
253 auto stackInputs = last(stack, nGraphInputs_);
254 auto inputs = fmap(stackInputs, [&](const IValue& v) {
255 TORCH_CHECK(
256 v.isTensor(), "Stack values for LLGA partition must be Tensor type");
257 return v.toTensor();
258 });
259
260 // Even in case of concurrent threads, the kernel would be initialized once.
261 // TODO: Try not using an atomic lock
262 c10::call_once(
263 initialized_flag,
264 [&](const TensorArgs& inputs) {
265 GRAPH_DEBUG("Initializing input logical tensors");
266 inputSpecs_ = initializeInputSpecs(inputs);
267 GRAPH_DEBUG("Initializing output logical tensors");
268 outputSpecs_ = initializeOutputSpecs();
269 GRAPH_DEBUG("Compiling partition");
270 compilation_ = compile(partition_);
271 is_initialized_ = true;
272 },
273 inputs);
274 #ifdef GRAPH_DEBUG_ENABLED
275 GRAPH_DEBUG("Preparing runtime tensors");
276 #endif
277 TensorArgs outputs;
278 auto [runInputs, runOutputs] = prepareRunArgs(inputs, outputs);
279 #ifdef GRAPH_DEBUG_ENABLED
280 GRAPH_DEBUG("Executing partition");
281 #endif
282 compilation_.execute(Stream::getStream(), runInputs, runOutputs);
283 #ifdef GRAPH_DEBUG_ENABLED
284 GRAPH_DEBUG("Partition executed");
285 #endif
286
287 // Update the stack.
288 drop(stack, nGraphInputs_);
289 for (auto& o : outputs)
290 push_one(stack, std::move(o));
291 #ifdef GRAPH_DEBUG_ENABLED
292 GRAPH_DEBUG("Stack updated");
293 #endif
294 }
295
296 } // namespace onednn
297 } // namespace fuser
298 } // namespace jit
299 } // namespace torch
300