xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_slice_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/array_ops.cc.
17 
18 #ifdef INTEL_MKL
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "dnnl.hpp"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/register_types.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/kernels/ops_util.h"
26 #include "tensorflow/core/lib/core/status.h"
27 #include "tensorflow/core/lib/gtl/array_slice.h"
28 #include "tensorflow/core/platform/prefetch.h"
29 #include "tensorflow/core/util/mkl_util.h"
30 #ifdef DNNL_AARCH64_USE_ACL
31 #include "tensorflow/core/platform/mutex.h"
32 #endif
33 
34 using dnnl::stream;
35 
36 namespace tensorflow {
37 
38 namespace {
39 
IntTensorToInt64Vec(const Tensor & tensor)40 gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
41   gtl::InlinedVector<int64, 4> out;
42   if (tensor.dtype() == DT_INT32) {
43     for (int64 i = 0; i < tensor.NumElements(); ++i) {
44       out.push_back(tensor.flat<int32>()(i));
45     }
46   } else if (tensor.dtype() == DT_INT64) {
47     for (int64 i = 0; i < tensor.NumElements(); ++i) {
48       out.push_back(tensor.flat<int64_t>()(i));
49     }
50   } else {
51     // tensor must be either int32 or int64
52     DCHECK(false);
53   }
54   return out;
55 }
56 
57 }  // namespace
58 
59 typedef Eigen::ThreadPoolDevice CPUDevice;
60 
61 // A version of SharedValidation (slice_op.h) written for input that is in
62 // either Mkl layout or Tensorflow layout. A shared code to validate input
63 // shapes and check for identity, which is not dependent on the type of T.
64 // We do this to reduce code size by not duplicating all this for all T
65 // (float, double, int32, etc.)
ValidateMklInputs(OpKernelContext * context,bool * is_identity,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size)66 static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
67                               gtl::InlinedVector<int64, 4>* begin,
68                               gtl::InlinedVector<int64, 4>* size) {
69   const int kInputTensorIndex = 0;
70   const int kInputBeginIndex = 1;
71   const int kInputSizeIndex = 2;
72   const Tensor& input = MklGetInput(context, kInputTensorIndex);
73   const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
74   const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
75 
76   MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
77   GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
78   GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
79   GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
80 
81   // Begin and size tensors cannot be in MklDnn layout.
82   DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
83   DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
84 
85   TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
86                                    ? input_mkl_shape.GetTfShape()
87                                    : input.shape();
88   const int input_dims = input_tf_shape.dims();
89 
90   OP_REQUIRES(
91       context,
92       TensorShapeUtils::IsVector(begin_tensor.shape()) &&
93           TensorShapeUtils::IsVector(size_tensor.shape()) &&
94           begin_tensor.NumElements() == input_dims &&
95           size_tensor.NumElements() == input_dims,
96       errors::InvalidArgument(
97           "Expected begin and size arguments to be 1-D tensors of size ",
98           input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
99           " and ", size_tensor.shape().DebugString(), " instead."));
100 
101   *begin = IntTensorToInt64Vec(begin_tensor);
102   *size = IntTensorToInt64Vec(size_tensor);
103   for (int i = 0; i < input_dims; ++i) {
104     if ((*size)[i] == -1) {
105       // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
106       (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
107     }
108   }
109 
110   *is_identity = true;
111   for (int i = 0; i < input_dims; ++i) {
112     int64 b = (*begin)[i];
113     int64 s = (*size)[i];
114     if (input_tf_shape.dim_size(i) == 0) {
115       OP_REQUIRES(
116           context, b == 0 && s == 0,
117           errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
118                                   ") and size[", i, "] == 0 ", "(got ", s,
119                                   ") when ", "input.dim_size(", i, ") == 0"));
120     } else {
121       OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
122                   errors::InvalidArgument("Expected begin[", i, "] in [0, ",
123                                           input_tf_shape.dim_size(i),
124                                           "], but got ", b));
125       OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
126                   errors::InvalidArgument("Expected size[", i, "] in [0, ",
127                                           input_tf_shape.dim_size(i) - b,
128                                           "], but ", "got ", s));
129     }
130     const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
131     (*is_identity) &= take_all;
132   }
133 }
134 
135 // A version of SharedSliceCommonCases function written for input tensor
136 // that may be in MklDnn layout or in Tensorflow layout.
137 template <typename T>
CheckCommonCasesForMklInputs(OpKernelContext * context,gtl::InlinedVector<int64,4> * begin,gtl::InlinedVector<int64,4> * size,bool * done)138 static void CheckCommonCasesForMklInputs(OpKernelContext* context,
139                                          gtl::InlinedVector<int64, 4>* begin,
140                                          gtl::InlinedVector<int64, 4>* size,
141                                          bool* done) {
142   bool is_identity = true;
143   *done = false;
144 
145   ValidateMklInputs(context, &is_identity, begin, size);
146   if (!context->status().ok()) return;
147 
148   const Tensor& input = MklGetInput(context, 0);
149   MklDnnShape input_mkl_shape;
150   GetMklShape(context, 0, &input_mkl_shape);
151 
152   if (is_identity) {
153     VLOG(1) << "Slice identity";
154     context->set_output(0, input);
155     // Mkl metadata tensor in this case can just be forwarded from input to
156     // output.
157     AllocateOutputSetMklShape(context, 0, input_mkl_shape);
158     *done = true;
159   }
160 }
161 
162 // This structure aggregates multiple inputs to Slice methods.
163 struct MklSliceParams {
164   // Parameters from & to represents memory pointing to reorder.
165   const memory* from;
166   const memory* to;
167 
168   // Parameters begin_dims & size_dims represents offset and length
169   // passed to view primitive.
170   memory::dims begin_dims;
171   memory::dims size_dims;
172 
MklSliceParamstensorflow::MklSliceParams173   MklSliceParams(const memory* from, const memory* to, memory::dims begin_dims,
174                  memory::dims size_dims)
175       : from(from), to(to), begin_dims(begin_dims), size_dims(size_dims) {}
176 };
177 
178 // This implements the shared interface of Slice reorders.
179 template <typename T>
180 class MklSlicePrimitive : public MklPrimitive {
181  public:
MklSlicePrimitive(const MklSliceParams & sliceParams)182   explicit MklSlicePrimitive(const MklSliceParams& sliceParams)
183       : MklPrimitive(engine(engine::kind::cpu, 0)) {
184     Setup(sliceParams);
185   }
186 
~MklSlicePrimitive()187   ~MklSlicePrimitive() {}
188 
Execute(const MklSliceParams & sliceParams,std::shared_ptr<stream> slice_stream)189   void Execute(const MklSliceParams& sliceParams,
190                std::shared_ptr<stream> slice_stream) {
191 #ifdef DNNL_AARCH64_USE_ACL
192     mutex_lock lock(primitive_execution_mu_);
193 #endif
194 #ifndef ENABLE_ONEDNN_OPENMP
195     context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
196                                       *slice_stream);
197     context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
198                                       *slice_stream);
199 #else
200     context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
201     context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
202 #endif  // !ENABLE_ONEDNN_OPENMP
203 
204     execute_primitives(context_.slice_primitives, slice_stream,
205                        context_.slice_primitives_args);
206 
207     // We should set it back to DummyData so as to make the primitive
208     // in cache pool stateless. Otherwise, if the result for previous
209     // iteration is kept, problems of current iteration won't be
210     // thrown immediately, and wrong data would be reused.
211     context_.src_mem->set_data_handle(DummyData);
212     context_.dst_mem->set_data_handle(DummyData);
213     return;
214   }
215 
GetPrimitive()216   std::shared_ptr<primitive> GetPrimitive() { return context_.reorder_prim; }
217 
218  private:
219   struct SliceContext {
220     std::shared_ptr<dnnl::memory> src_mem;
221     std::shared_ptr<dnnl::memory> dst_mem;
222     std::shared_ptr<primitive> reorder_prim;
223     std::shared_ptr<reorder::primitive_desc> reorder_pd;
224     std::shared_ptr<dnnl::stream> slice_stream;
225     std::vector<dnnl::primitive> slice_primitives;
226     std::shared_ptr<dnnl::memory> src_sub_mem;
227     std::vector<std::unordered_map<int, memory>> slice_primitives_args;
SliceContexttensorflow::MklSlicePrimitive::SliceContext228     SliceContext()
229         : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
230   } context_;
231 
Setup(const MklSliceParams & sliceParams)232   void Setup(const MklSliceParams& sliceParams) {
233     // Actually, DummyData will not be used in computation,
234     // because the real data will be filled before execution.
235     context_.src_mem.reset(
236         new memory(sliceParams.from->get_desc(), cpu_engine_, DummyData));
237     context_.dst_mem.reset(
238         new memory(sliceParams.to->get_desc(), cpu_engine_, DummyData));
239 
240     auto src_sub_desc = context_.src_mem->get_desc().submemory_desc(
241         sliceParams.size_dims, sliceParams.begin_dims);
242     context_.src_sub_mem.reset(new memory(src_sub_desc, cpu_engine_, nullptr));
243     context_.reorder_pd = std::make_shared<reorder::primitive_desc>(
244         reorder::primitive_desc(*context_.src_sub_mem, *context_.dst_mem));
245     context_.reorder_prim =
246         std::make_shared<dnnl::reorder>(reorder(*context_.reorder_pd));
247 
248     context_.slice_primitives_args.push_back(
249         {{DNNL_ARG_SRC, *context_.src_mem}, {DNNL_ARG_DST, *context_.dst_mem}});
250     context_.slice_primitives.push_back(*context_.reorder_prim);
251   }
252 
253 #ifdef DNNL_AARCH64_USE_ACL
254   mutex primitive_execution_mu_;
255 #endif
256 };
257 
258 template <typename T>
259 class MklSlicePrimitiveFactory : public MklPrimitiveFactory<T> {
260  public:
Get(const MklSliceParams & sliceParams)261   static MklSlicePrimitive<T>* Get(const MklSliceParams& sliceParams) {
262     auto reorderPrim = static_cast<MklSlicePrimitive<T>*>(
263         MklSlicePrimitiveFactory<T>::GetInstance().GetReorder(sliceParams));
264     if (reorderPrim == nullptr) {
265       reorderPrim = new MklSlicePrimitive<T>(sliceParams);
266       MklSlicePrimitiveFactory<T>::GetInstance().SetReorder(sliceParams,
267                                                             reorderPrim);
268     }
269     return reorderPrim;
270   }
271 
GetInstance()272   static MklSlicePrimitiveFactory& GetInstance() {
273     static MklSlicePrimitiveFactory instance_;
274     return instance_;
275   }
276 
277  private:
MklSlicePrimitiveFactory()278   MklSlicePrimitiveFactory() {}
~MklSlicePrimitiveFactory()279   ~MklSlicePrimitiveFactory() {}
280 
CreateKey(const MklSliceParams & sliceParams)281   static string CreateKey(const MklSliceParams& sliceParams) {
282     string prefix = "reorder";
283     FactoryKeyCreator key_creator;
284     auto const& from_desc = sliceParams.from->get_desc().data;
285     auto const& to_desc = sliceParams.to->get_desc().data;
286     memory::dims from_dims(from_desc.dims, &from_desc.dims[from_desc.ndims]);
287     memory::dims to_dims(to_desc.dims, &to_desc.dims[to_desc.ndims]);
288 
289     auto from_strides = from_desc.format_desc.blocking.strides;
290     auto to_strides = to_desc.format_desc.blocking.strides;
291     memory::dims from_strides_outer_blocks(from_strides,
292                                            &from_strides[from_desc.ndims]);
293     memory::dims to_strides_outer_blocks(to_strides,
294                                          &to_strides[to_desc.ndims]);
295 
296     key_creator.AddAsKey(prefix);
297     key_creator.AddAsKey(static_cast<int>(from_desc.data_type));
298     key_creator.AddAsKey(from_dims);
299     key_creator.AddAsKey(from_strides_outer_blocks);
300     key_creator.AddAsKey(static_cast<int>(to_desc.data_type));
301     key_creator.AddAsKey(to_dims);
302     key_creator.AddAsKey(to_strides_outer_blocks);
303     key_creator.AddAsKey(sliceParams.begin_dims);
304     key_creator.AddAsKey(sliceParams.size_dims);
305     return key_creator.GetKey();
306   }
307 
GetReorder(const MklSliceParams & sliceParams)308   MklPrimitive* GetReorder(const MklSliceParams& sliceParams) {
309     string key = CreateKey(sliceParams);
310     return this->GetOp(key);
311   }
312 
SetReorder(const MklSliceParams & sliceParams,MklPrimitive * op)313   void SetReorder(const MklSliceParams& sliceParams, MklPrimitive* op) {
314     string key = CreateKey(sliceParams);
315     this->SetOp(key, op);
316   }
317 };
318 
319 // oneDNN implementation of Slice
320 template <typename Device, typename T>
321 class MklSliceOp : public OpKernel {
322  public:
MklSliceOp(OpKernelConstruction * context)323   explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
324 
~MklSliceOp()325   ~MklSliceOp() {}
326 
Compute(OpKernelContext * context)327   void Compute(OpKernelContext* context) override {
328     gtl::InlinedVector<int64, 4> begin;
329     gtl::InlinedVector<int64, 4> size;
330     bool done = false;
331 
332     CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
333 
334     if (!context->status().ok() || done == true) return;
335 
336     // oneDNN supports more than 8 dimension and less than 12 dimension tensor.
337     // But we are mimicking functionality of Eigen Slice op for CPU.
338     if (begin.size() >= 8) {
339       OP_REQUIRES(
340           context, false,
341           errors::Unimplemented("MklSliceOp : Unhandled input dimensions"));
342     }
343 
344     ComputeMklSlice(context, begin, size);
345   }
346 
347  private:
348   // Slice op implemented using oneDNN APIs.
ComputeMklSlice(OpKernelContext * context,const gtl::InlinedVector<int64,4> & begin,const gtl::InlinedVector<int64,4> & size)349   void ComputeMklSlice(OpKernelContext* context,
350                        const gtl::InlinedVector<int64, 4>& begin,
351                        const gtl::InlinedVector<int64, 4>& size) {
352     try {
353       // oneDNN API usage below is guided by description at:
354       //  https://github.com/01org/mkl-dnn/issues/69
355       //
356       // Relevant part of the description is copied below:
357       //
358       // Let's say you want to copy a part of memory into another buffer (and
359       // probably change the format). Then your steps are:
360       //
361       // 1. create memory primitive descriptor in_mem_pd and memory primitive
362       //    in_mem_p for the entire source data. create view primitive
363       //    descriptor in_submem_pd based on in_mem_pd, initial offsets,
364       //    and sub-sizes
365       // 2. create memory primitive descriptor out_mem_pd and memory primitive
366       //    out_mem_p for the output (the logical sizes should match sub-sizes
367       //    used in step 1, but the format might be arbitrary)
368       // 3. create reorder primitive descriptor reorder_pd based on in_submem_pd
369       //    and out_mem_pd. create reorder primitive itself based on reorder_pd,
370       //    in_mem_p, and out_mem_p.
371       //
372       // Please notice that there is no view primitive. There is only view
373       // primitive descriptor. And the reorder uses source memory as input but
374       // traverses it according to a view in_submem_pd.
375 
376       auto cpu_engine = engine(engine::kind::cpu, 0);
377       MklDnnData<T> src(&cpu_engine);
378       MklDnnData<T> output(&cpu_engine);
379 
380       // Populate offsets and sizes in memory::dims format based on vector.
381       memory::dims begin_dims = {};
382       begin_dims.resize(begin.size());
383       for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
384       memory::dims size_dims = {};
385       bool empty = false;
386       size_dims.resize(size.size());
387       for (size_t i = 0; i < size.size(); ++i) {
388         size_dims[i] = size[i];
389         if (size_dims[i] == 0) empty = true;
390       }
391 
392       Tensor* output_tensor = nullptr;
393       MklDnnShape output_mkl_shape;
394 
395       // If no dimension is selected in slice, the result should be empty.
396       // Just return an empty output tensor, and a dummy Mkl-shape tensor.
397       if (empty) {  // for empty dims
398         auto shape_to = MklDnnDimsToTFShape(size_dims);
399         AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
400                                   output_mkl_shape);
401         return;
402       }
403 
404       // Step 1 (as per above description) - Create memory for user data.
405       // We use blocked format here to describe input tensor.
406       const Tensor& input_tensor = MklGetInput(context, 0);
407       memory::dims input_dims, input_strides;
408       MklDnnShape input_mkl_shape;
409       GetMklShape(context, 0, &input_mkl_shape);
410 
411       if (input_mkl_shape.IsMklTensor()) {
412         auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
413         auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
414 
415         bool is_slice2d = (input_mkl_shape.GetDimension() == 4);
416         begin_dims = is_slice2d
417                          ? MklDnnDimsInNCHW(begin_dims, input_tf_format)
418                          : MklDnnDimsInNCDHW(begin_dims, input_tf_format);
419         size_dims = is_slice2d ? MklDnnDimsInNCHW(size_dims, input_tf_format)
420                                : MklDnnDimsInNCDHW(size_dims, input_tf_format);
421         auto input_md = input_mkl_shape.GetMklLayout();
422         src.SetUsrMem(input_md, &input_tensor);
423 
424         // Handle data format safely, change them to block format.
425         // Compute parameters of reorder primitive first.
426         input_dims = input_mkl_shape.GetSizesAsMklDnnDims();
427         input_strides = CalculateTFStrides(input_dims);
428       } else {
429         // Initialize input dimensions and strides to be used when input is not
430         // in MklDnn layout.
431         input_dims = TFShapeToMklDnnDims(input_tensor.shape());
432         input_strides = CalculateTFStrides(input_dims);
433         // Create input memory descriptor.
434         auto input_md =
435             MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
436         src.SetUsrMem(input_md, &input_tensor);
437       }
438 
439       // If format not equal to block format, execute reorder.
440       // Or else do nothing for it.
441       auto op_md =
442           MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
443       src.CheckReorderToOpMem(op_md, cpu_engine, context);
444 
445       // Step 2 - Create memory for output.
446       auto output_strides = CalculateTFStrides(size_dims);
447       auto output_md =
448           MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
449       auto output_pd = output_md;
450       AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
451                            &output_tensor, &output_mkl_shape);
452       DCHECK(output_tensor);
453       DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
454       output.SetUsrMem(output_md, output_tensor);
455 
456       // Step 3 - create reorder primitive.
457       MklSliceParams sliceParams(&src.GetOpMem(), output.GetUsrMem(),
458                                  begin_dims, size_dims);
459       MklSlicePrimitive<T>* reorder_prim =
460           MklSlicePrimitiveFactory<T>::Get(sliceParams);
461       // Execute slice reorder.
462       std::shared_ptr<stream> slice_stream;
463       MklDnnThreadPool eigen_tp(context);
464       slice_stream.reset(CreateStream(&eigen_tp, reorder_prim->GetEngine()));
465       reorder_prim->Execute(sliceParams, slice_stream);
466     } catch (dnnl::error& e) {
467       string error_msg = "Status: " + std::to_string(e.status) +
468                          ", message: " + string(e.message) + ", in file " +
469                          string(__FILE__) + ":" + std::to_string(__LINE__);
470       OP_REQUIRES_OK(
471           context,
472           errors::Aborted("Operation received an exception:", error_msg));
473     }
474   }
475 
476  private:
AllocateOutputTensor(OpKernelContext * context,const MklDnnShape & input_mkl_shape,memory::desc * output_pd,const memory::dims & output_dims,Tensor ** output_tensor,MklDnnShape * output_mkl_shape)477   void AllocateOutputTensor(OpKernelContext* context,
478                             const MklDnnShape& input_mkl_shape,
479                             memory::desc* output_pd,
480                             const memory::dims& output_dims,
481                             Tensor** output_tensor,
482                             MklDnnShape* output_mkl_shape) {
483     DCHECK(output_tensor);
484     DCHECK(output_mkl_shape);
485 
486     TensorShape output_tf_shape;
487 
488     if (input_mkl_shape.IsMklTensor()) {
489       // Since input tensor is in Mkl layout, output tensor will be in Mkl
490       // layout.
491 
492       // Allocate shape of Mkl tensor.
493       output_mkl_shape->SetMklTensor(true);
494       output_mkl_shape->SetMklLayout(output_pd);
495       output_mkl_shape->SetElemType(MklDnnType<T>());
496       output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
497                                     input_mkl_shape.GetTfDataFormat());
498 
499       output_tf_shape.AddDim(output_pd->get_size() / sizeof(T));
500     } else {
501       // If input is not in Mkl layout, then output won't be in Mkl layout.
502       output_mkl_shape->SetMklTensor(false);
503       output_tf_shape = MklDnnDimsToTFShape(output_dims);
504     }
505 
506     AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
507                               *output_mkl_shape);
508   }
509 };
510 
511 // oneDNN Slice registration
512 #define REGISTER_MKL_SLICE(type)                               \
513   REGISTER_KERNEL_BUILDER(                                     \
514       Name("_MklSlice")                                        \
515           .Device(DEVICE_CPU)                                  \
516           .TypeConstraint<type>("T")                           \
517           .HostMemory("begin")                                 \
518           .HostMemory("size")                                  \
519           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
520       MklSliceOp<CPUDevice, type>);
521 
522 TF_CALL_float(REGISTER_MKL_SLICE);
523 TF_CALL_bfloat16(REGISTER_MKL_SLICE);
524 #undef REGISTER_MKL_SLICE
525 
526 }  // namespace tensorflow
527 
528 #endif  // INTEL_MKL
529