xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/mkl/mkl_input_conversion_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifdef INTEL_MKL
17 
18 #include <algorithm>
19 #include <vector>
20 
21 #include "tensorflow/core/kernels/mkl/mkl_tfconv_op.h"
22 
23 namespace tensorflow {
24 
25 ///////////////////////////////////////////////////////////
26 //               Op kernel
27 // Checks and ensures that the 2 inputs are compatible for mkl binary ops.
28 // Here's the basic logic:
29 //
30 // if both inputs are in TF format:
31 //   pass the inputs through to the output
32 // else if both inputs are in mkl format:
33 //   if both have the same shape:
34 //     pass the inputs through to the output
35 //   else:
36 //     convert both to TF
37 // else if one is TF and one is MKL:
38 //   if broadcast is needed:
39 //     convert the MKL format input to TF format
40 //   else:
41 //     convert the TF format input to MKL format
42 ///////////////////////////////////////////////////////////
43 
44 template <typename Device, typename T>
45 class MklInputConversionOp : public OpKernel {
46  public:
MklInputConversionOp(OpKernelConstruction * context)47   explicit MklInputConversionOp(OpKernelConstruction* context)
48       : OpKernel(context) {
49     OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
50     OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type));
51     has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F);
52   }
53 
54  private:
Compute(OpKernelContext * context)55   void Compute(OpKernelContext* context) override {
56     const int kInputIndex_0 = 0, kInputIndex_1 = 1;
57     const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0);
58     MklDnnShape input_shape_0;
59     GetMklShape(context, kInputIndex_0, &input_shape_0);
60 
61     const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1);
62     MklDnnShape input_shape_1;
63     GetMklShape(context, kInputIndex_1, &input_shape_1);
64 
65     VLOG(1) << "MklInputConversionOp: Input shapes are: "
66             << context->input(kInputIndex_0).shape().DebugString() << " and "
67             << context->input(kInputIndex_1).shape().DebugString();
68 
69     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
70     // if both inputs are in TF format, just copy input tensors to output.
71     if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
72       VLOG(1) << "MklInputConversionOp: No conversion needed, "
73               << "copying TF inputs to output";
74 
75       ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0);
76       ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1);
77       return;
78     }
79 
80     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
81     // If both inputs are in MKL format
82     if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
83       // It is safer to compare the original TensorFlow shapes than to compare
84       // Mkl shapes since element wise ops are forwarded to Eigen
85       // implementation.
86       TensorShape tf_shape0 = input_shape_0.GetTfShape();
87       TensorShape tf_shape1 = input_shape_1.GetTfShape();
88       TensorShape tensor_shape0 = input_tensor_0.shape();
89       TensorShape tensor_shape1 = input_tensor_1.shape();
90       if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
91         auto input0_md = input_shape_0.GetMklLayout();
92         auto input1_md = input_shape_1.GetMklLayout();
93 
94         // If both have the same shape and same format, pass them through
95         if (input_shape_0.GetTfDataFormat() ==
96             input_shape_1.GetTfDataFormat()) {
97           VLOG(1) << "MklInputConversionOp: No conversion needed, "
98                   << "copying MKL inputs with identical shapes to output";
99 
100           ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0);
101           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
102           return;
103         } else {
104           VLOG(1) << "MklInputConversionOp: Shape is same, but format is "
105                      "different, "
106                   << "need to convert to same format";
107           // TODO(intel-tf): For now, input0 is converted and input1 is
108           // unchanged. We should choose the optimal oneDNN format to convert
109           // to.
110           Tensor* tensor_out;
111           MklDnnShape mkl_output_mkl_shape;
112           mkl_output_mkl_shape.SetMklTensor(true);
113           mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
114           mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(),
115                                            input_shape_0.GetSizesAsMklDnnDims(),
116                                            input_shape_0.GetTfDataFormat());
117 
118           // Get MKL layout from input1 as destination layout
119           mkl_output_mkl_shape.SetMklLayout(&input1_md);
120 
121           // Create output Mkl tensor for index 0
122           AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out,
123                                     input_tensor_0.shape(),
124                                     mkl_output_mkl_shape);
125 
126           // Create MklDnnData object for input0 tensor
127           auto cpu_engine = engine(engine::kind::cpu, 0);
128           MklDnnData<T> input(&cpu_engine);
129           input.SetUsrMem(input0_md, &input_tensor_0);
130           // Create reorder from input0's layout to input1's layout
131           std::vector<primitive> net;
132           std::vector<MemoryArgsMap> net_args;
133           // TODO(intel-tf): Refactor CheckReorderToOpMem() to create and
134           // execute reorder
135           OP_REQUIRES(
136               context,
137               input.CheckReorderToOpMem(input1_md, tensor_out, net, net_args,
138                                         cpu_engine),
139               errors::Internal(
140                   "MklInputConversionOp: Failed to create reorder for input0"));
141           ExecutePrimitive(net, &net_args, cpu_engine, context);
142           // Input1 will be passed through
143           ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
144           return;
145         }
146       }
147 
148       // Sanity check
149       bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
150                                   (tensor_shape0 == tensor_shape1));
151       if (mkl_shapes_are_same) {
152         CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
153                         "different but MKL shapes are same";
154       }
155 
156       // Both have different shapes, so broadcast will be necessary.
157       // Convert to TF and pass both tensors through (we can't do broadcast
158       // with MKL tensors)
159       VLOG(1) << "MklInputConversionOp: Broadcast needed, "
160               << "converted MKL inputs to TF format";
161       // TODO(intel-tf): Cleanup op_data_type and has_avx512f_ after these two
162       //     parameters are removed from ConvertMklToTf
163       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
164                                            op_data_type, has_avx512f_,
165                                            kInputIndex_0);
166       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
167                                            op_data_type, has_avx512f_,
168                                            kInputIndex_1);
169       SetDummyMklDnnShapeOutput(context, kInputIndex_0);
170       SetDummyMklDnnShapeOutput(context, kInputIndex_1);
171       return;
172     }
173 
174     // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
175     // One input is MKL and one is TF. If no broadcast is needed, convert
176     // the TF tensor to MKL, otherwise convert the MKL tensor to TF format
177     VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)";
178 
179     const Tensor* mkl_tensor;
180     const MklDnnShape* mkl_shape;
181     const Tensor* tf_tensor;
182     uint mkl_tensor_index;
183     uint tf_tensor_index;
184     if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) {
185       mkl_tensor = &input_tensor_0;
186       mkl_shape = &input_shape_0;
187       mkl_tensor_index = 0;
188       tf_tensor = &input_tensor_1;
189       tf_tensor_index = 1;
190     } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) {
191       mkl_tensor = &input_tensor_1;
192       mkl_shape = &input_shape_1;
193       mkl_tensor_index = 1;
194       tf_tensor = &input_tensor_0;
195       tf_tensor_index = 0;
196     } else {
197       CHECK(false) << "MklInputConversionOp: Unexpected combination of input "
198                       "shapes for MKL "
199                    << "element-wise op";
200     }
201 
202     // Broadcast is needed if the shapes are not the same
203     if (mkl_shape->GetTfShape().num_elements() ==
204         tf_tensor->shape().num_elements()) {
205       // Both shapes are same, convert the TF input to MKL
206       VLOG(1) << "MklInputConversionOp: No broadcast needed.";
207       VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
208               << " to MKL format";
209 
210       // Create MklDnnShape for output Mkl tensor.
211       Tensor* tensor_out;
212       MklDnnShape mkl_output_mkl_shape;
213       mkl_output_mkl_shape.SetMklTensor(true);
214       mkl_output_mkl_shape.SetElemType(MklDnnType<T>());
215       mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(),
216                                        mkl_shape->GetSizesAsMklDnnDims(),
217                                        mkl_shape->GetTfDataFormat());
218       // ** Temporarily borrow the layout from the MKL input **
219       auto output_mkl_md = mkl_shape->GetMklLayout();
220       mkl_output_mkl_shape.SetMklLayout(&output_mkl_md);
221 
222       // Create output Mkl tensor
223       AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out,
224                                 mkl_tensor->shape(), mkl_output_mkl_shape);
225 
226       // Create MklDnnData object for input tensor. Input tensor is in
227       // Tensorflow layout.
228       auto cpu_engine = engine(engine::kind::cpu, 0);
229       MklDnnData<T> tf_input(&cpu_engine);
230       auto input_tf_md = mkl_output_mkl_shape.GetTfLayout();
231       tf_input.SetUsrMem(input_tf_md, tf_tensor);
232       // Create reorder between TF layout and MKL layout if necessary
233       std::vector<primitive> net;
234       std::vector<MemoryArgsMap> net_args;
235       bool reordered = tf_input.CheckReorderToOpMem(output_mkl_md, tensor_out,
236                                                     net, net_args, cpu_engine);
237       if (!reordered) {
238         // This is the case that the TF tensor has the same shape and format of
239         // mkl tensor. However, tf_tensor can not be simply forwarded to the
240         // output tensor since mkl data tensor is always one dimensional tensor.
241         // Tensor::CopyFrom shares the buffer of the other tensor while set its
242         // shape to the other tensor.
243         OP_REQUIRES(context,
244                     tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()),
245                     errors::Internal("MklInputConversionOp: Failed to forward "
246                                      "input tensor to output"));
247       } else {
248         ExecutePrimitive(net, &net_args, cpu_engine, context);
249       }
250 
251       // -- The tensor in MKL format passes through --
252       ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
253     } else {
254       // Broadcast is needed, so convert the MKL input to TF
255       VLOG(1) << "MklInputConversionOp: Broadcast needed.";
256       VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index
257               << " to TF format";
258       MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str,
259                                            op_data_type, has_avx512f_,
260                                            mkl_tensor_index);
261       SetDummyMklDnnShapeOutput(context, mkl_tensor_index);
262 
263       // The tensor in TF format passes through
264       ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index);
265     }
266 
267     VLOG(1) << "MklInputConversionOp: Shapes (output): "
268             << context->mutable_output(kInputIndex_0)->shape().DebugString()
269             << " and "
270             << context->mutable_output(kInputIndex_1)->shape().DebugString();
271 
272     VLOG(1) << "MklInputConversion completed successfully.";
273   }
274 
275  private:
276   /// Data format of the operation
277   string data_format_str;
278 
279   /// Data type of the operation
280   DataType op_data_type;
281 
282   /// CPUIDInfo
283   bool has_avx512f_ = false;
284 };
285 
286 ///////////////////////////////////////////////////////////
287 //               Register kernel
288 ///////////////////////////////////////////////////////////
289 
290 #define REGISTER_CPU(T)                                        \
291   REGISTER_KERNEL_BUILDER(                                     \
292       Name("_MklInputConversion")                              \
293           .Device(DEVICE_CPU)                                  \
294           .TypeConstraint<T>("T")                              \
295           .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
296       MklInputConversionOp<CPUDevice, T>);
297 
298 TF_CALL_float(REGISTER_CPU);
299 TF_CALL_bfloat16(REGISTER_CPU);
300 
301 #undef REGISTER_CPU
302 
303 }  // namespace tensorflow
304 #endif  // INTEL_MKL
305