1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifdef INTEL_MKL 17 18 #include <algorithm> 19 #include <vector> 20 21 #include "tensorflow/core/kernels/mkl/mkl_tfconv_op.h" 22 23 namespace tensorflow { 24 25 /////////////////////////////////////////////////////////// 26 // Op kernel 27 // Checks and ensures that the 2 inputs are compatible for mkl binary ops. 28 // Here's the basic logic: 29 // 30 // if both inputs are in TF format: 31 // pass the inputs through to the output 32 // else if both inputs are in mkl format: 33 // if both have the same shape: 34 // pass the inputs through to the output 35 // else: 36 // convert both to TF 37 // else if one is TF and one is MKL: 38 // if broadcast is needed: 39 // convert the MKL format input to TF format 40 // else: 41 // convert the TF format input to MKL format 42 /////////////////////////////////////////////////////////// 43 44 template <typename Device, typename T> 45 class MklInputConversionOp : public OpKernel { 46 public: MklInputConversionOp(OpKernelConstruction * context)47 explicit MklInputConversionOp(OpKernelConstruction* context) 48 : OpKernel(context) { 49 OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str)); 50 OP_REQUIRES_OK(context, context->GetAttr("T", &op_data_type)); 51 has_avx512f_ = port::TestCPUFeature(port::CPUFeature::AVX512F); 52 } 53 54 private: Compute(OpKernelContext * context)55 void Compute(OpKernelContext* context) override { 56 const int kInputIndex_0 = 0, kInputIndex_1 = 1; 57 const Tensor& input_tensor_0 = MklGetInput(context, kInputIndex_0); 58 MklDnnShape input_shape_0; 59 GetMklShape(context, kInputIndex_0, &input_shape_0); 60 61 const Tensor& input_tensor_1 = MklGetInput(context, kInputIndex_1); 62 MklDnnShape input_shape_1; 63 GetMklShape(context, kInputIndex_1, &input_shape_1); 64 65 VLOG(1) << "MklInputConversionOp: Input shapes are: " 66 << context->input(kInputIndex_0).shape().DebugString() << " and " 67 << context->input(kInputIndex_1).shape().DebugString(); 68 69 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 70 // if both inputs are in TF format, just copy input tensors to output. 71 if (!input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 72 VLOG(1) << "MklInputConversionOp: No conversion needed, " 73 << "copying TF inputs to output"; 74 75 ForwardTfTensorInToOut(context, kInputIndex_0, kInputIndex_0); 76 ForwardTfTensorInToOut(context, kInputIndex_1, kInputIndex_1); 77 return; 78 } 79 80 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 81 // If both inputs are in MKL format 82 if (input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 83 // It is safer to compare the original TensorFlow shapes than to compare 84 // Mkl shapes since element wise ops are forwarded to Eigen 85 // implementation. 86 TensorShape tf_shape0 = input_shape_0.GetTfShape(); 87 TensorShape tf_shape1 = input_shape_1.GetTfShape(); 88 TensorShape tensor_shape0 = input_tensor_0.shape(); 89 TensorShape tensor_shape1 = input_tensor_1.shape(); 90 if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) { 91 auto input0_md = input_shape_0.GetMklLayout(); 92 auto input1_md = input_shape_1.GetMklLayout(); 93 94 // If both have the same shape and same format, pass them through 95 if (input_shape_0.GetTfDataFormat() == 96 input_shape_1.GetTfDataFormat()) { 97 VLOG(1) << "MklInputConversionOp: No conversion needed, " 98 << "copying MKL inputs with identical shapes to output"; 99 100 ForwardMklTensorInToOut(context, kInputIndex_0, kInputIndex_0); 101 ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); 102 return; 103 } else { 104 VLOG(1) << "MklInputConversionOp: Shape is same, but format is " 105 "different, " 106 << "need to convert to same format"; 107 // TODO(intel-tf): For now, input0 is converted and input1 is 108 // unchanged. We should choose the optimal oneDNN format to convert 109 // to. 110 Tensor* tensor_out; 111 MklDnnShape mkl_output_mkl_shape; 112 mkl_output_mkl_shape.SetMklTensor(true); 113 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 114 mkl_output_mkl_shape.SetTfLayout(input_shape_0.GetDimension(), 115 input_shape_0.GetSizesAsMklDnnDims(), 116 input_shape_0.GetTfDataFormat()); 117 118 // Get MKL layout from input1 as destination layout 119 mkl_output_mkl_shape.SetMklLayout(&input1_md); 120 121 // Create output Mkl tensor for index 0 122 AllocateOutputSetMklShape(context, kInputIndex_0, &tensor_out, 123 input_tensor_0.shape(), 124 mkl_output_mkl_shape); 125 126 // Create MklDnnData object for input0 tensor 127 auto cpu_engine = engine(engine::kind::cpu, 0); 128 MklDnnData<T> input(&cpu_engine); 129 input.SetUsrMem(input0_md, &input_tensor_0); 130 // Create reorder from input0's layout to input1's layout 131 std::vector<primitive> net; 132 std::vector<MemoryArgsMap> net_args; 133 // TODO(intel-tf): Refactor CheckReorderToOpMem() to create and 134 // execute reorder 135 OP_REQUIRES( 136 context, 137 input.CheckReorderToOpMem(input1_md, tensor_out, net, net_args, 138 cpu_engine), 139 errors::Internal( 140 "MklInputConversionOp: Failed to create reorder for input0")); 141 ExecutePrimitive(net, &net_args, cpu_engine, context); 142 // Input1 will be passed through 143 ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); 144 return; 145 } 146 } 147 148 // Sanity check 149 bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) && 150 (tensor_shape0 == tensor_shape1)); 151 if (mkl_shapes_are_same) { 152 CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are " 153 "different but MKL shapes are same"; 154 } 155 156 // Both have different shapes, so broadcast will be necessary. 157 // Convert to TF and pass both tensors through (we can't do broadcast 158 // with MKL tensors) 159 VLOG(1) << "MklInputConversionOp: Broadcast needed, " 160 << "converted MKL inputs to TF format"; 161 // TODO(intel-tf): Cleanup op_data_type and has_avx512f_ after these two 162 // parameters are removed from ConvertMklToTf 163 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 164 op_data_type, has_avx512f_, 165 kInputIndex_0); 166 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 167 op_data_type, has_avx512f_, 168 kInputIndex_1); 169 SetDummyMklDnnShapeOutput(context, kInputIndex_0); 170 SetDummyMklDnnShapeOutput(context, kInputIndex_1); 171 return; 172 } 173 174 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - 175 // One input is MKL and one is TF. If no broadcast is needed, convert 176 // the TF tensor to MKL, otherwise convert the MKL tensor to TF format 177 VLOG(1) << "MklInputConversionOp: Inputs in different formats (MKL/TF)"; 178 179 const Tensor* mkl_tensor; 180 const MklDnnShape* mkl_shape; 181 const Tensor* tf_tensor; 182 uint mkl_tensor_index; 183 uint tf_tensor_index; 184 if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { 185 mkl_tensor = &input_tensor_0; 186 mkl_shape = &input_shape_0; 187 mkl_tensor_index = 0; 188 tf_tensor = &input_tensor_1; 189 tf_tensor_index = 1; 190 } else if (!input_shape_0.IsMklTensor() && input_shape_1.IsMklTensor()) { 191 mkl_tensor = &input_tensor_1; 192 mkl_shape = &input_shape_1; 193 mkl_tensor_index = 1; 194 tf_tensor = &input_tensor_0; 195 tf_tensor_index = 0; 196 } else { 197 CHECK(false) << "MklInputConversionOp: Unexpected combination of input " 198 "shapes for MKL " 199 << "element-wise op"; 200 } 201 202 // Broadcast is needed if the shapes are not the same 203 if (mkl_shape->GetTfShape().num_elements() == 204 tf_tensor->shape().num_elements()) { 205 // Both shapes are same, convert the TF input to MKL 206 VLOG(1) << "MklInputConversionOp: No broadcast needed."; 207 VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index 208 << " to MKL format"; 209 210 // Create MklDnnShape for output Mkl tensor. 211 Tensor* tensor_out; 212 MklDnnShape mkl_output_mkl_shape; 213 mkl_output_mkl_shape.SetMklTensor(true); 214 mkl_output_mkl_shape.SetElemType(MklDnnType<T>()); 215 mkl_output_mkl_shape.SetTfLayout(mkl_shape->GetDimension(), 216 mkl_shape->GetSizesAsMklDnnDims(), 217 mkl_shape->GetTfDataFormat()); 218 // ** Temporarily borrow the layout from the MKL input ** 219 auto output_mkl_md = mkl_shape->GetMklLayout(); 220 mkl_output_mkl_shape.SetMklLayout(&output_mkl_md); 221 222 // Create output Mkl tensor 223 AllocateOutputSetMklShape(context, tf_tensor_index, &tensor_out, 224 mkl_tensor->shape(), mkl_output_mkl_shape); 225 226 // Create MklDnnData object for input tensor. Input tensor is in 227 // Tensorflow layout. 228 auto cpu_engine = engine(engine::kind::cpu, 0); 229 MklDnnData<T> tf_input(&cpu_engine); 230 auto input_tf_md = mkl_output_mkl_shape.GetTfLayout(); 231 tf_input.SetUsrMem(input_tf_md, tf_tensor); 232 // Create reorder between TF layout and MKL layout if necessary 233 std::vector<primitive> net; 234 std::vector<MemoryArgsMap> net_args; 235 bool reordered = tf_input.CheckReorderToOpMem(output_mkl_md, tensor_out, 236 net, net_args, cpu_engine); 237 if (!reordered) { 238 // This is the case that the TF tensor has the same shape and format of 239 // mkl tensor. However, tf_tensor can not be simply forwarded to the 240 // output tensor since mkl data tensor is always one dimensional tensor. 241 // Tensor::CopyFrom shares the buffer of the other tensor while set its 242 // shape to the other tensor. 243 OP_REQUIRES(context, 244 tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()), 245 errors::Internal("MklInputConversionOp: Failed to forward " 246 "input tensor to output")); 247 } else { 248 ExecutePrimitive(net, &net_args, cpu_engine, context); 249 } 250 251 // -- The tensor in MKL format passes through -- 252 ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index); 253 } else { 254 // Broadcast is needed, so convert the MKL input to TF 255 VLOG(1) << "MklInputConversionOp: Broadcast needed."; 256 VLOG(1) << "MklInputConversionOp: Converting input " << mkl_tensor_index 257 << " to TF format"; 258 MklToTfOp<Device, T>::ConvertMklToTf(this, context, data_format_str, 259 op_data_type, has_avx512f_, 260 mkl_tensor_index); 261 SetDummyMklDnnShapeOutput(context, mkl_tensor_index); 262 263 // The tensor in TF format passes through 264 ForwardTfTensorInToOut(context, tf_tensor_index, tf_tensor_index); 265 } 266 267 VLOG(1) << "MklInputConversionOp: Shapes (output): " 268 << context->mutable_output(kInputIndex_0)->shape().DebugString() 269 << " and " 270 << context->mutable_output(kInputIndex_1)->shape().DebugString(); 271 272 VLOG(1) << "MklInputConversion completed successfully."; 273 } 274 275 private: 276 /// Data format of the operation 277 string data_format_str; 278 279 /// Data type of the operation 280 DataType op_data_type; 281 282 /// CPUIDInfo 283 bool has_avx512f_ = false; 284 }; 285 286 /////////////////////////////////////////////////////////// 287 // Register kernel 288 /////////////////////////////////////////////////////////// 289 290 #define REGISTER_CPU(T) \ 291 REGISTER_KERNEL_BUILDER( \ 292 Name("_MklInputConversion") \ 293 .Device(DEVICE_CPU) \ 294 .TypeConstraint<T>("T") \ 295 .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ 296 MklInputConversionOp<CPUDevice, T>); 297 298 TF_CALL_float(REGISTER_CPU); 299 TF_CALL_bfloat16(REGISTER_CPU); 300 301 #undef REGISTER_CPU 302 303 } // namespace tensorflow 304 #endif // INTEL_MKL 305