1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // See docs in ../ops/array_ops.cc.
17
18 #ifdef INTEL_MKL
19
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "dnnl.hpp"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/prefetch.h"
29 #include "tensorflow/core/util/mkl_util.h"
30 #ifdef DNNL_AARCH64_USE_ACL
31 #include "tensorflow/core/platform/mutex.h"
32 #endif
33
34 using dnnl::stream;
35
36 namespace tensorflow {
37
38 namespace {
39
IntTensorToInt64Vec(const Tensor & tensor)40 gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
41 gtl::InlinedVector<int64, 4> out;
42 if (tensor.dtype() == DT_INT32) {
43 for (int64 i = 0; i < tensor.NumElements(); ++i) {
44 out.push_back(tensor.flat<int32>()(i));
45 }
46 } else if (tensor.dtype() == DT_INT64) {
47 for (int64 i = 0; i < tensor.NumElements(); ++i) {
48 out.push_back(tensor.flat<int64_t>()(i));
49 }
50 } else {
51 // tensor must be either int32 or int64
52 DCHECK(false);
53 }
54 return out;
55 }
56
57 } // namespace
58
59 typedef Eigen::ThreadPoolDevice CPUDevice;
60
61 // A version of SharedValidation (slice_op.h) written for input that is in
62 // either Mkl layout or Tensorflow layout. A shared code to validate input
63 // shapes and check for identity, which is not dependent on the type of T.
64 // We do this to reduce code size by not duplicating all this for all T
65 // (float, double, int32, etc.)
ValidateMklInputs(OpKernelContext * context,bool * is_identity,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size)66 static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
67 gtl::InlinedVector<int64, 4>* begin,
68 gtl::InlinedVector<int64, 4>* size) {
69 const int kInputTensorIndex = 0;
70 const int kInputBeginIndex = 1;
71 const int kInputSizeIndex = 2;
72 const Tensor& input = MklGetInput(context, kInputTensorIndex);
73 const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
74 const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
75
76 MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
77 GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
78 GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
79 GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
80
81 // Begin and size tensors cannot be in MklDnn layout.
82 DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
83 DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
84
85 TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
86 ? input_mkl_shape.GetTfShape()
87 : input.shape();
88 const int input_dims = input_tf_shape.dims();
89
90 OP_REQUIRES(
91 context,
92 TensorShapeUtils::IsVector(begin_tensor.shape()) &&
93 TensorShapeUtils::IsVector(size_tensor.shape()) &&
94 begin_tensor.NumElements() == input_dims &&
95 size_tensor.NumElements() == input_dims,
96 errors::InvalidArgument(
97 "Expected begin and size arguments to be 1-D tensors of size ",
98 input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
99 " and ", size_tensor.shape().DebugString(), " instead."));
100
101 *begin = IntTensorToInt64Vec(begin_tensor);
102 *size = IntTensorToInt64Vec(size_tensor);
103 for (int i = 0; i < input_dims; ++i) {
104 if ((*size)[i] == -1) {
105 // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
106 (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
107 }
108 }
109
110 *is_identity = true;
111 for (int i = 0; i < input_dims; ++i) {
112 int64 b = (*begin)[i];
113 int64 s = (*size)[i];
114 if (input_tf_shape.dim_size(i) == 0) {
115 OP_REQUIRES(
116 context, b == 0 && s == 0,
117 errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
118 ") and size[", i, "] == 0 ", "(got ", s,
119 ") when ", "input.dim_size(", i, ") == 0"));
120 } else {
121 OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
122 errors::InvalidArgument("Expected begin[", i, "] in [0, ",
123 input_tf_shape.dim_size(i),
124 "], but got ", b));
125 OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
126 errors::InvalidArgument("Expected size[", i, "] in [0, ",
127 input_tf_shape.dim_size(i) - b,
128 "], but ", "got ", s));
129 }
130 const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
131 (*is_identity) &= take_all;
132 }
133 }
134
135 // A version of SharedSliceCommonCases function written for input tensor
136 // that may be in MklDnn layout or in Tensorflow layout.
137 template <typename T>
CheckCommonCasesForMklInputs(OpKernelContext * context,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size,bool * done)138 static void CheckCommonCasesForMklInputs(OpKernelContext* context,
139 gtl::InlinedVector<int64, 4>* begin,
140 gtl::InlinedVector<int64, 4>* size,
141 bool* done) {
142 bool is_identity = true;
143 *done = false;
144
145 ValidateMklInputs(context, &is_identity, begin, size);
146 if (!context->status().ok()) return;
147
148 const Tensor& input = MklGetInput(context, 0);
149 MklDnnShape input_mkl_shape;
150 GetMklShape(context, 0, &input_mkl_shape);
151
152 if (is_identity) {
153 VLOG(1) << "Slice identity";
154 context->set_output(0, input);
155 // Mkl metadata tensor in this case can just be forwarded from input to
156 // output.
157 AllocateOutputSetMklShape(context, 0, input_mkl_shape);
158 *done = true;
159 }
160 }
161
162 // This structure aggregates multiple inputs to Slice methods.
163 struct MklSliceParams {
164 // Parameters from & to represents memory pointing to reorder.
165 const memory* from;
166 const memory* to;
167
168 // Parameters begin_dims & size_dims represents offset and length
169 // passed to view primitive.
170 memory::dims begin_dims;
171 memory::dims size_dims;
172
MklSliceParamstensorflow::MklSliceParams173 MklSliceParams(const memory* from, const memory* to, memory::dims begin_dims,
174 memory::dims size_dims)
175 : from(from), to(to), begin_dims(begin_dims), size_dims(size_dims) {}
176 };
177
178 // This implements the shared interface of Slice reorders.
179 template <typename T>
180 class MklSlicePrimitive : public MklPrimitive {
181 public:
MklSlicePrimitive(const MklSliceParams & sliceParams)182 explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
183 : MklPrimitive(engine(engine::kind::cpu, 0)) {
184 Setup(sliceParams);
185 }
186
~MklSlicePrimitive()187 ~MklSlicePrimitive() {}
188
Execute(const MklSliceParams & sliceParams,std::shared_ptr<stream> slice_stream)189 void Execute(const MklSliceParams& sliceParams,
190 std::shared_ptr<stream> slice_stream) {
191 #ifdef DNNL_AARCH64_USE_ACL
192 mutex_lock lock(primitive_execution_mu_);
193 #endif
194 #ifndef ENABLE_ONEDNN_OPENMP
195 context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
196 *slice_stream);
197 context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
198 *slice_stream);
199 #else
200 context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
201 context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
202 #endif // !ENABLE_ONEDNN_OPENMP
203
204 execute_primitives(context_.slice_primitives, slice_stream,
205 context_.slice_primitives_args);
206
207 // We should set it back to DummyData so as to make the primitive
208 // in cache pool stateless. Otherwise, if the result for previous
209 // iteration is kept, problems of current iteration won't be
210 // thrown immediately, and wrong data would be reused.
211 context_.src_mem->set_data_handle(DummyData);
212 context_.dst_mem->set_data_handle(DummyData);
213 return;
214 }
215
GetPrimitive()216 std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
217
218 private:
219 struct SliceContext {
220 std::shared_ptr<dnnl::memory> src_mem;
221 std::shared_ptr<dnnl::memory> dst_mem;
222 std::shared_ptr<primitive> reorder_prim;
223 std::shared_ptr<reorder::primitive_desc> reorder_pd;
224 std::shared_ptr<dnnl::stream> slice_stream;
225 std::vector<dnnl::primitive> slice_primitives;
226 std::shared_ptr<dnnl::memory> src_sub_mem;
227 std::vector<std::unordered_map<int, memory>> slice_primitives_args;
SliceContexttensorflow::MklSlicePrimitive::SliceContext228 SliceContext()
229 : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
230 } context_;
231
Setup(const MklSliceParams & sliceParams)232 void Setup(const MklSliceParams& sliceParams) {
233 // Actually, DummyData will not be used in computation,
234 // because the real data will be filled before execution.
235 context_.src_mem.reset(
236 new memory(sliceParams.from->get_desc(), cpu_engine_, DummyData));
237 context_.dst_mem.reset(
238 new memory(sliceParams.to->get_desc(), cpu_engine_, DummyData));
239
240 auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
241 sliceParams.size_dims, sliceParams.begin_dims);
242 context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
243 context_.reorder_pd = std::make_shared<reorder::primitive_desc>(
244 reorder::primitive_desc(*context_.src_sub_mem, *context_.dst_mem));
245 context_.reorder_prim =
246 std::make_shared<dnnl::reorder>(reorder(*context_.reorder_pd));
247
248 context_.slice_primitives_args.push_back(
249 {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}});
250 context_.slice_primitives.push_back(*context_.reorder_prim);
251 }
252
253 #ifdef DNNL_AARCH64_USE_ACL
254 mutex primitive_execution_mu_;
255 #endif
256 };
257
258 template <typename T>
259 class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
260 public:
Get(const MklSliceParams & sliceParams)261 static MklSlicePrimitive<T>* Get(const MklSliceParams& sliceParams) {
262 auto reorderPrim = static_cast<MklSlicePrimitive<T>*>(
263 MklSlicePrimitiveFactory<T>::GetInstance().GetReorder(sliceParams));
264 if (reorderPrim == nullptr) {
265 reorderPrim = new MklSlicePrimitive<T>(sliceParams);
266 MklSlicePrimitiveFactory<T>::GetInstance().SetReorder(sliceParams,
267 reorderPrim);
268 }
269 return reorderPrim;
270 }
271
GetInstance()272 static MklSlicePrimitiveFactory& GetInstance() {
273 static MklSlicePrimitiveFactory instance_;
274 return instance_;
275 }
276
277 private:
MklSlicePrimitiveFactory()278 MklSlicePrimitiveFactory() {}
~MklSlicePrimitiveFactory()279 ~MklSlicePrimitiveFactory() {}
280
CreateKey(const MklSliceParams & sliceParams)281 static string CreateKey(const MklSliceParams& sliceParams) {
282 string prefix = "reorder";
283 FactoryKeyCreator key_creator;
284 auto const& from_desc = sliceParams.from->get_desc().data;
285 auto const& to_desc = sliceParams.to->get_desc().data;
286 memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
287 memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
288
289 auto from_strides = from_desc.format_desc.blocking.strides;
290 auto to_strides = to_desc.format_desc.blocking.strides;
291 memory::dims from_strides_outer_blocks(from_strides,
292 &from_strides[from_desc.ndims]);
293 memory::dims to_strides_outer_blocks(to_strides,
294 &to_strides[to_desc.ndims]);
295
296 key_creator.AddAsKey(prefix);
297 key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
298 key_creator.AddAsKey(from_dims);
299 key_creator.AddAsKey(from_strides_outer_blocks);
300 key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
301 key_creator.AddAsKey(to_dims);
302 key_creator.AddAsKey(to_strides_outer_blocks);
303 key_creator.AddAsKey(sliceParams.begin_dims);
304 key_creator.AddAsKey(sliceParams.size_dims);
305 return key_creator.GetKey();
306 }
307
GetReorder(const MklSliceParams & sliceParams)308 MklPrimitive* GetReorder(const MklSliceParams& sliceParams) {
309 string key = CreateKey(sliceParams);
310 return this->GetOp(key);
311 }
312
SetReorder(const MklSliceParams & sliceParams,MklPrimitive * op)313 void SetReorder(const MklSliceParams& sliceParams, MklPrimitive* op) {
314 string key = CreateKey(sliceParams);
315 this->SetOp(key, op);
316 }
317 };
318
319 // oneDNN implementation of Slice
320 template <typename Device, typename T>
321 class MklSliceOp : public OpKernel {
322 public:
MklSliceOp(OpKernelConstruction * context)323 explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
324
~MklSliceOp()325 ~MklSliceOp() {}
326
Compute(OpKernelContext * context)327 void Compute(OpKernelContext* context) override {
328 gtl::InlinedVector<int64, 4> begin;
329 gtl::InlinedVector<int64, 4> size;
330 bool done = false;
331
332 CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
333
334 if (!context->status().ok() || done == true) return;
335
336 // oneDNN supports more than 8 dimension and less than 12 dimension tensor.
337 // But we are mimicking functionality of Eigen Slice op for CPU.
338 if (begin.size() >= 8) {
339 OP_REQUIRES(
340 context, false,
341 errors::Unimplemented("MklSliceOp : Unhandled input dimensions"));
342 }
343
344 ComputeMklSlice(context, begin, size);
345 }
346
347 private:
348 // Slice op implemented using oneDNN APIs.
ComputeMklSlice(OpKernelContext * context,const gtl::InlinedVector<int64,4> & begin,const gtl::InlinedVector<int64,4> & size)349 void ComputeMklSlice(OpKernelContext* context,
350 const gtl::InlinedVector<int64, 4>& begin,
351 const gtl::InlinedVector<int64, 4>& size) {
352 try {
353 // oneDNN API usage below is guided by description at:
354 // https://github.com/01org/mkl-dnn/issues/69
355 //
356 // Relevant part of the description is copied below:
357 //
358 // Let's say you want to copy a part of memory into another buffer (and
359 // probably change the format). Then your steps are:
360 //
361 // 1. create memory primitive descriptor in_mem_pd and memory primitive
362 // in_mem_p for the entire source data. create view primitive
363 // descriptor in_submem_pd based on in_mem_pd, initial offsets,
364 // and sub-sizes
365 // 2. create memory primitive descriptor out_mem_pd and memory primitive
366 // out_mem_p for the output (the logical sizes should match sub-sizes
367 // used in step 1, but the format might be arbitrary)
368 // 3. create reorder primitive descriptor reorder_pd based on in_submem_pd
369 // and out_mem_pd. create reorder primitive itself based on reorder_pd,
370 // in_mem_p, and out_mem_p.
371 //
372 // Please notice that there is no view primitive. There is only view
373 // primitive descriptor. And the reorder uses source memory as input but
374 // traverses it according to a view in_submem_pd.
375
376 auto cpu_engine = engine(engine::kind::cpu, 0);
377 MklDnnData<T> src(&cpu_engine);
378 MklDnnData<T> output(&cpu_engine);
379
380 // Populate offsets and sizes in memory::dims format based on vector.
381 memory::dims begin_dims = {};
382 begin_dims.resize(begin.size());
383 for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
384 memory::dims size_dims = {};
385 bool empty = false;
386 size_dims.resize(size.size());
387 for (size_t i = 0; i < size.size(); ++i) {
388 size_dims[i] = size[i];
389 if (size_dims[i] == 0) empty = true;
390 }
391
392 Tensor* output_tensor = nullptr;
393 MklDnnShape output_mkl_shape;
394
395 // If no dimension is selected in slice, the result should be empty.
396 // Just return an empty output tensor, and a dummy Mkl-shape tensor.
397 if (empty) { // for empty dims
398 auto shape_to = MklDnnDimsToTFShape(size_dims);
399 AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
400 output_mkl_shape);
401 return;
402 }
403
404 // Step 1 (as per above description) - Create memory for user data.
405 // We use blocked format here to describe input tensor.
406 const Tensor& input_tensor = MklGetInput(context, 0);
407 memory::dims input_dims, input_strides;
408 MklDnnShape input_mkl_shape;
409 GetMklShape(context, 0, &input_mkl_shape);
410
411 if (input_mkl_shape.IsMklTensor()) {
412 auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
413 auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
414
415 bool is_slice2d = (input_mkl_shape.GetDimension() == 4);
416 begin_dims = is_slice2d
417 ? MklDnnDimsInNCHW(begin_dims, input_tf_format)
418 : MklDnnDimsInNCDHW(begin_dims, input_tf_format);
419 size_dims = is_slice2d ? MklDnnDimsInNCHW(size_dims, input_tf_format)
420 : MklDnnDimsInNCDHW(size_dims, input_tf_format);
421 auto input_md = input_mkl_shape.GetMklLayout();
422 src.SetUsrMem(input_md, &input_tensor);
423
424 // Handle data format safely, change them to block format.
425 // Compute parameters of reorder primitive first.
426 input_dims = input_mkl_shape.GetSizesAsMklDnnDims();
427 input_strides = CalculateTFStrides(input_dims);
428 } else {
429 // Initialize input dimensions and strides to be used when input is not
430 // in MklDnn layout.
431 input_dims = TFShapeToMklDnnDims(input_tensor.shape());
432 input_strides = CalculateTFStrides(input_dims);
433 // Create input memory descriptor.
434 auto input_md =
435 MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
436 src.SetUsrMem(input_md, &input_tensor);
437 }
438
439 // If format not equal to block format, execute reorder.
440 // Or else do nothing for it.
441 auto op_md =
442 MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
443 src.CheckReorderToOpMem(op_md, cpu_engine, context);
444
445 // Step 2 - Create memory for output.
446 auto output_strides = CalculateTFStrides(size_dims);
447 auto output_md =
448 MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
449 auto output_pd = output_md;
450 AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
451 &output_tensor, &output_mkl_shape);
452 DCHECK(output_tensor);
453 DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
454 output.SetUsrMem(output_md, output_tensor);
455
456 // Step 3 - create reorder primitive.
457 MklSliceParams sliceParams(&src.GetOpMem(), output.GetUsrMem(),
458 begin_dims, size_dims);
459 MklSlicePrimitive<T>* reorder_prim =
460 MklSlicePrimitiveFactory<T>::Get(sliceParams);
461 // Execute slice reorder.
462 std::shared_ptr<stream> slice_stream;
463 MklDnnThreadPool eigen_tp(context);
464 slice_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine()));
465 reorder_prim->Execute(sliceParams, slice_stream);
466 } catch (dnnl::error& e) {
467 string error_msg = "Status: " + std::to_string(e.status) +
468 ", message: " + string(e.message) + ", in file " +
469 string(__FILE__) + ":" + std::to_string(__LINE__);
470 OP_REQUIRES_OK(
471 context,
472 errors::Aborted("Operation received an exception:", error_msg));
473 }
474 }
475
476 private:
AllocateOutputTensor(OpKernelContext * context,const MklDnnShape & input_mkl_shape,memory::desc * output_pd,const memory::dims & output_dims,Tensor ** output_tensor,MklDnnShape * output_mkl_shape)477 void AllocateOutputTensor(OpKernelContext* context,
478 const MklDnnShape& input_mkl_shape,
479 memory::desc* output_pd,
480 const memory::dims& output_dims,
481 Tensor** output_tensor,
482 MklDnnShape* output_mkl_shape) {
483 DCHECK(output_tensor);
484 DCHECK(output_mkl_shape);
485
486 TensorShape output_tf_shape;
487
488 if (input_mkl_shape.IsMklTensor()) {
489 // Since input tensor is in Mkl layout, output tensor will be in Mkl
490 // layout.
491
492 // Allocate shape of Mkl tensor.
493 output_mkl_shape->SetMklTensor(true);
494 output_mkl_shape->SetMklLayout(output_pd);
495 output_mkl_shape->SetElemType(MklDnnType<T>());
496 output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
497 input_mkl_shape.GetTfDataFormat());
498
499 output_tf_shape.AddDim(output_pd->get_size() / sizeof(T));
500 } else {
501 // If input is not in Mkl layout, then output won't be in Mkl layout.
502 output_mkl_shape->SetMklTensor(false);
503 output_tf_shape = MklDnnDimsToTFShape(output_dims);
504 }
505
506 AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
507 *output_mkl_shape);
508 }
509 };
510
511 // oneDNN Slice registration
512 #define REGISTER_MKL_SLICE(type) \
513 REGISTER_KERNEL_BUILDER( \
514 Name("_MklSlice") \
515 .Device(DEVICE_CPU) \
516 .TypeConstraint<type>("T") \
517 .HostMemory("begin") \
518 .HostMemory("size") \
519 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
520 MklSliceOp<CPUDevice, type>);
521
522 TF_CALL_float(REGISTER_MKL_SLICE);
523 TF_CALL_bfloat16(REGISTER_MKL_SLICE);
524 #undef REGISTER_MKL_SLICE
525
526 } // namespace tensorflow
527
528 #endif // INTEL_MKL
529