1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // See docs in ../ops/array_ops.cc. 17 18 // clang-format off 19 #include "tensorflow/core/platform/bfloat16.h" 20 21 #include <math.h> // NOLINT 22 #include <algorithm> // NOLINT 23 #include <numeric> // NOLINT 24 // clang-format on 25 26 #include "tensorflow/core/framework/op_kernel.h" 27 #include "tensorflow/core/framework/register_types.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/tensor_reference.h" 30 #include "tensorflow/core/framework/types.h" 31 32 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 33 #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" 34 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 35 36 #if GOOGLE_CUDA 37 #include "tensorflow/stream_executor/cuda/cuda_activation.h" 38 #elif TENSORFLOW_USE_ROCM 39 #include "tensorflow/core/platform/rocm.h" 40 #endif 41 namespace tensorflow { 42 43 typedef Eigen::ThreadPoolDevice CPUDevice; 44 typedef Eigen::GpuDevice GPUDevice; 45 46 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 47 template <typename T> 48 struct CheckNumericsLaunch { 49 void Run(const GPUDevice& d, const T* data, int size, 50 int abnormal_detected[2]); 51 }; 52 53 extern template struct CheckNumericsLaunch<Eigen::half>; 54 extern template struct CheckNumericsLaunch<float>; 55 extern template struct CheckNumericsLaunch<double>; 56 57 template <typename T> 58 struct CheckNumericsLaunchV2 { 59 void Run(const GPUDevice& d, const T* data, int size, 60 int abnormal_detected[3]); 61 }; 62 63 extern template struct CheckNumericsLaunchV2<Eigen::half>; 64 extern template struct CheckNumericsLaunchV2<float>; 65 extern template struct CheckNumericsLaunchV2<double>; 66 #endif 67 68 namespace { 69 70 const int kInfBit = 0x01; 71 const int kNaNBit = 0x02; 72 const int kNegativeInfBit = 0x04; 73 const int kPositiveInfBit = 0x08; 74 75 template <typename Device, typename T> 76 class CheckNumericsOp; 77 78 // Partial specialization for CPU 79 // TODO(jeff,rmlarsen): We should make this variant be an AsyncOpKernel, as 80 // was done for the GPU case below. 81 template <typename T> 82 class CheckNumericsOp<CPUDevice, T> : public OpKernel { 83 public: CheckNumericsOp(OpKernelConstruction * context)84 explicit CheckNumericsOp(OpKernelConstruction* context) : OpKernel(context) { 85 // message_ is used as the prefix for the assertion error message. For 86 // instance, this can be the name of the input op that produced the tensor. 87 OP_REQUIRES_OK(context, context->GetAttr("message", &message_)); 88 } 89 Compute(OpKernelContext * context)90 void Compute(OpKernelContext* context) override { 91 // pass along the input to the output 92 context->set_output(0, context->input(0)); 93 94 auto in = context->input(0).flat<T>(); 95 const T* data = in.data(); 96 const int64_t size = in.size(); 97 // Check to see if any element of the tensor is NaN or Inf. 98 int fp_props = std::accumulate( 99 data, data + size, 0, 100 [this](const int x, const T& y) { return checkFloatingElement(x, y); }); 101 if (fp_props != 0) { 102 const string& status = getErrorString(fp_props); 103 if (!status.empty()) { 104 context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", 105 status, " values")); 106 } 107 } 108 } 109 110 protected: checkFloatingElement(const int x,const T & y)111 virtual int checkFloatingElement(const int x, const T& y) { 112 int result = x; 113 if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) { 114 // Do nothing: common case. 115 } else { 116 if (Eigen::numext::isinf(y)) { 117 result |= kInfBit; 118 } else if (Eigen::numext::isnan(y)) { 119 result |= kNaNBit; 120 } 121 } 122 return result; 123 } 124 getErrorString(const int fp_props)125 virtual const string getErrorString(const int fp_props) { 126 string status; 127 if ((fp_props & kInfBit) && (fp_props & kNaNBit)) { 128 status = "Inf and NaN"; 129 } else { 130 if (fp_props & kInfBit) { 131 status = "Inf"; 132 } 133 if (fp_props & kNaNBit) { 134 status = "NaN"; 135 } 136 } 137 return status; 138 } 139 140 private: 141 string message_; 142 }; 143 144 template <typename Device, typename T> 145 class CheckNumericsV2Op; 146 147 // Partial specialization for CPU: v2. 148 // The v2 op differs from the v1 in that it distinguishes -inf and +inf. 149 template <typename T> 150 class CheckNumericsV2Op<CPUDevice, T> : public CheckNumericsOp<CPUDevice, T> { 151 public: CheckNumericsV2Op(OpKernelConstruction * context)152 explicit CheckNumericsV2Op(OpKernelConstruction* context) 153 : CheckNumericsOp<CPUDevice, T>(context) {} 154 155 protected: checkFloatingElement(const int x,const T & y)156 int checkFloatingElement(const int x, const T& y) override { 157 int result = x; 158 if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) { 159 // Do nothing: common case. 160 } else { 161 if (Eigen::numext::isinf(y)) { 162 result |= y < static_cast<T>(0.) ? kNegativeInfBit : kPositiveInfBit; 163 } else if (Eigen::numext::isnan(y)) { 164 result |= kNaNBit; 165 } 166 } 167 return result; 168 } 169 getErrorString(const int fp_props)170 const string getErrorString(const int fp_props) override { 171 std::vector<string> anomalies; 172 if (fp_props & kNegativeInfBit) { 173 anomalies.push_back("-Inf"); 174 } 175 if (fp_props & kPositiveInfBit) { 176 anomalies.push_back("+Inf"); 177 } 178 if (fp_props & kNaNBit) { 179 anomalies.push_back("NaN"); 180 } 181 if (anomalies.size() == 3) { 182 return strings::StrCat(anomalies[0], ", ", anomalies[1], ", and ", 183 anomalies[2]); 184 } else if (anomalies.size() == 2) { 185 return strings::StrCat(anomalies[0], " and ", anomalies[1]); 186 } else { 187 return anomalies[0]; 188 } 189 } 190 }; 191 192 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 193 // Partial specialization for GPU 194 template <typename T> 195 class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel { 196 public: 197 typedef GPUDevice Device; 198 CheckNumericsOp(OpKernelConstruction * context)199 explicit CheckNumericsOp(OpKernelConstruction* context) 200 : AsyncOpKernel(context) { 201 // message_ is used as the prefix for the assertion error message. For 202 // instance, this can be the name of the input op that produced the tensor. 203 OP_REQUIRES_OK(context, context->GetAttr("message", &message_)); 204 } 205 ComputeAsync(OpKernelContext * context,DoneCallback done)206 void ComputeAsync(OpKernelContext* context, DoneCallback done) override { 207 // pass along the input to the output 208 context->set_output(0, context->input(0)); 209 if (context->input(0).NumElements() == 0) { 210 done(); 211 return; 212 } 213 auto input = context->input(0).flat<T>(); 214 215 // Allocate and initialize the elements to hold the check results 216 Tensor abnormal_detected; 217 const int abnormal_detected_size = getAnomalyIndicatorSize(); 218 OP_REQUIRES_OK(context, context->allocate_temp( 219 DT_INT32, TensorShape({abnormal_detected_size}), 220 &abnormal_detected)); 221 222 auto* stream = context->op_device_context()->stream(); 223 OP_REQUIRES_ASYNC(context, stream != nullptr, 224 errors::Internal("No GPU stream available."), done); 225 226 se::DeviceMemoryBase abnormal_detected_ptr( 227 abnormal_detected.flat<int>().data(), 228 abnormal_detected.flat<int>().size()); 229 stream->ThenMemset32(&abnormal_detected_ptr, 0, 230 abnormal_detected.flat<int>().size() * sizeof(int)); 231 232 // Call the GPU kernels for the numerical checks 233 const Device& d = context->eigen_device<Device>(); 234 RunKernel(d, input.data(), input.size(), 235 abnormal_detected.flat<int>().data()); 236 237 // Copy the results from device to host 238 AllocatorAttributes attr; 239 attr.set_on_host(true); 240 attr.set_gpu_compatible(true); 241 Tensor abnormal_detected_host; 242 OP_REQUIRES_OK_ASYNC( 243 context, 244 context->allocate_temp(DT_INT32, TensorShape({abnormal_detected_size}), 245 &abnormal_detected_host, attr), 246 done); 247 OP_REQUIRES_ASYNC( 248 context, 249 stream 250 ->ThenMemcpy(abnormal_detected_host.flat<int>().data(), 251 abnormal_detected_ptr, 252 abnormal_detected_size * sizeof(int)) 253 .ok(), 254 errors::Internal("GPU memcpy from device to host failed"), done); 255 256 // We have observed crashes on some network stacks when not holding 257 // this tensor reference. 258 TensorReference abnormal_detected_ref(abnormal_detected); 259 auto check_cb = [this, stream, abnormal_detected_ref, 260 abnormal_detected_host, context, done]() { 261 #if GOOGLE_CUDA 262 se::cuda::ScopedActivateExecutorContext scoped_activation{ 263 stream->parent()}; 264 #elif TENSORFLOW_USE_ROCM 265 se::rocm::ScopedActivateExecutorContext scoped_activation{ 266 stream->parent()}; 267 #endif 268 TTypes<const int>::Vec abnormal_detected_host_flat = 269 abnormal_detected_host.flat<int>(); 270 abnormal_detected_ref.Unref(); 271 checkForAnomalies(context, abnormal_detected_host_flat); 272 done(); 273 }; 274 context->device() 275 ->tensorflow_accelerator_device_info() 276 ->event_mgr->ThenExecute(stream, std::move(check_cb)); 277 } 278 279 protected: getAnomalyIndicatorSize()280 virtual int getAnomalyIndicatorSize() { return 2; } 281 RunKernel(const GPUDevice & d,const T * data,int size,int * abnormal_detected)282 virtual void RunKernel(const GPUDevice& d, const T* data, int size, 283 int* abnormal_detected) { 284 CheckNumericsLaunch<T>().Run(d, data, size, abnormal_detected); 285 } 286 checkForAnomalies(OpKernelContext * context,const TTypes<const int>::Vec & abnormality_indicators)287 virtual void checkForAnomalies( 288 OpKernelContext* context, 289 const TTypes<const int>::Vec& abnormality_indicators) { 290 const int is_nan = abnormality_indicators(0); 291 const int is_inf = abnormality_indicators(1); 292 if (is_nan || is_inf) { 293 LOG(ERROR) << "abnormal_detected_host @" << abnormality_indicators.data() 294 << " = {" << is_nan << ", " << is_inf << "} " << message_; 295 296 string anomalies; 297 if (is_nan && is_inf) { 298 anomalies = "Inf and NaN"; 299 } else if (is_nan) { 300 anomalies = "NaN"; 301 } else if (is_inf) { 302 anomalies = "Inf"; 303 } 304 context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", 305 anomalies, " values")); 306 } 307 } 308 309 string message_; 310 }; 311 312 template <typename T> 313 class CheckNumericsV2Op<GPUDevice, T> : public CheckNumericsOp<GPUDevice, T> { 314 public: CheckNumericsV2Op(OpKernelConstruction * context)315 CheckNumericsV2Op(OpKernelConstruction* context) 316 : CheckNumericsOp<GPUDevice, T>(context) {} 317 318 protected: getAnomalyIndicatorSize()319 int getAnomalyIndicatorSize() override { return 3; } 320 RunKernel(const GPUDevice & d,const T * data,int size,int * abnormal_detected)321 void RunKernel(const GPUDevice& d, const T* data, int size, 322 int* abnormal_detected) override { 323 CheckNumericsLaunchV2<T>().Run(d, data, size, abnormal_detected); 324 } 325 checkForAnomalies(OpKernelContext * context,const TTypes<const int>::Vec & abnormality_indicators)326 void checkForAnomalies( 327 OpKernelContext* context, 328 const TTypes<const int>::Vec& abnormality_indicators) override { 329 const int is_nan = abnormality_indicators(0); 330 const int is_negative_inf = abnormality_indicators(1); 331 const int is_positive_inf = abnormality_indicators(2); 332 if (is_negative_inf || is_positive_inf || is_nan) { 333 std::vector<string> anomalies; 334 if (is_negative_inf) { 335 anomalies.push_back("-Inf"); 336 } 337 if (is_positive_inf) { 338 anomalies.push_back("+Inf"); 339 } 340 if (is_nan) { 341 anomalies.push_back("NaN"); 342 } 343 string all_anomalies; 344 if (anomalies.size() == 3) { 345 all_anomalies = strings::StrCat(anomalies[0], ", ", anomalies[1], 346 ", and ", anomalies[2]); 347 } else if (anomalies.size() == 2) { 348 all_anomalies = strings::StrCat(anomalies[0], " and ", anomalies[1]); 349 } else { 350 all_anomalies = anomalies[0]; 351 } 352 context->SetStatus(errors::InvalidArgument( 353 this->message_, " : Tensor had ", all_anomalies, " values")); 354 } 355 } 356 357 static constexpr int abnormal_detected_size = 3; 358 }; 359 360 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 361 362 } // namespace 363 364 #define REGISTER_CPU_KERNEL(T) \ 365 REGISTER_KERNEL_BUILDER( \ 366 Name("CheckNumerics").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 367 CheckNumericsOp<CPUDevice, T>); 368 TF_CALL_half(REGISTER_CPU_KERNEL); 369 TF_CALL_bfloat16(REGISTER_CPU_KERNEL); 370 TF_CALL_float(REGISTER_CPU_KERNEL); 371 TF_CALL_double(REGISTER_CPU_KERNEL); 372 373 #define REGISTER_V2_CPU_KERNEL(T) \ 374 REGISTER_KERNEL_BUILDER( \ 375 Name("CheckNumericsV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ 376 CheckNumericsV2Op<CPUDevice, T>); 377 TF_CALL_half(REGISTER_V2_CPU_KERNEL); 378 TF_CALL_bfloat16(REGISTER_V2_CPU_KERNEL); 379 TF_CALL_float(REGISTER_V2_CPU_KERNEL); 380 TF_CALL_double(REGISTER_V2_CPU_KERNEL); 381 382 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 383 REGISTER_KERNEL_BUILDER( 384 Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), 385 CheckNumericsOp<GPUDevice, Eigen::half>); 386 REGISTER_KERNEL_BUILDER( 387 Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<float>("T"), 388 CheckNumericsOp<GPUDevice, float>); 389 REGISTER_KERNEL_BUILDER( 390 Name("CheckNumerics").Device(DEVICE_GPU).TypeConstraint<double>("T"), 391 CheckNumericsOp<GPUDevice, double>); 392 393 REGISTER_KERNEL_BUILDER( 394 Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), 395 CheckNumericsV2Op<GPUDevice, Eigen::half>); 396 REGISTER_KERNEL_BUILDER( 397 Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<float>("T"), 398 CheckNumericsV2Op<GPUDevice, float>); 399 REGISTER_KERNEL_BUILDER( 400 Name("CheckNumericsV2").Device(DEVICE_GPU).TypeConstraint<double>("T"), 401 CheckNumericsV2Op<GPUDevice, double>); 402 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 403 404 } // namespace tensorflow 405