#include #include #include #include namespace at::autocast { bool is_autocast_enabled(at::DeviceType device_type) { at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type); return !c10::impl::tls_is_dispatch_key_excluded(dispatch_key); } void set_autocast_enabled(at::DeviceType device_type, bool enabled) { at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type); c10::impl::tls_set_dispatch_key_excluded(dispatch_key, !enabled); } namespace { // Imitate Apex and cache some of the casts to streamline parameter reuse. // Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below). // // After discussion with @ezyang, the cache uses the following structure: // The key is the fp32 source tensor's TensorImpl*, a proxy for a Tensor uuid that's // unchanged across shallow copies. // The value is a tuple with a weakref to the source tensor's TensorImpl as the first // element and the casted tensor as the second element. // // The weakref keeps the source's TensorImpl from being deleted. We need to because we're // using the source TensorImpl* as the key. If it were deleted, another random Tensor could // be allocated whose TensorImpl* happened to have the same value. This TensorImpl* would // then mistakenly hit in cache: a rare, intermittent, unpredictable bug. // // I'm not using the weak_intrusive_ptr as the key because it's more difficult to compare // directly against incoming TensorImpl*s. using weakref_type = c10::weak_intrusive_ptr; using val_type = std::tuple; static ska::flat_hash_map& get_cached_casts() { static ska::flat_hash_map cached_casts; return cached_casts; } std::mutex cached_casts_mutex; // nesting tracks the nesting depth of the Python-side context manager. // When the autocast context manager exits to a nesting level that's outside // any instance of autocast (which should occur at the end of each forward pass) // it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region. thread_local int nesting = 0; // The order of this array MUST exactly match the definition order of DeviceType // in c10/core/DeviceType.h. static_assert( at::COMPILE_TIME_MAX_DEVICE_TYPES == 21, "The definition of the default autocast data type per device backend doesn't match with the definition of the device type."); thread_local std::array autocast_dtype = { at::kBFloat16, // CPU at::kHalf, // CUDA. at::ScalarType::Undefined, // Reserved for explicit MKLDNN at::ScalarType::Undefined, // OpenGL at::ScalarType::Undefined, // OpenCL at::ScalarType::Undefined, // IDEEP. at::kHalf, // AMD HIP at::ScalarType::Undefined, // FPGA at::ScalarType::Undefined, // ONNX Runtime / Microsoft at::kBFloat16, // XLA / TPU at::ScalarType::Undefined, // Vulkan at::ScalarType::Undefined, // Metal at::kHalf, // XPU at::kHalf, // MPS at::ScalarType::Undefined, // Meta (tensors with no data) at::kBFloat16, // HPU / HABANA at::ScalarType::Undefined, // SX-Aurora / NEC at::ScalarType::Undefined, // Lazy Tensors at::kHalf, // Graphcore IPU at::ScalarType::Undefined, // Meta training and inference devices at::kHalf, // PrivateUse1 device }; // should we enabled the cache inside autocast. thread_local bool cache_enabled = true; } // anonymous namespace void clear_cache() { const std::lock_guard lock(cached_casts_mutex); get_cached_casts().clear(); } int increment_nesting() { return ++nesting; } int decrement_nesting() { return --nesting; } at::ScalarType get_autocast_dtype(at::DeviceType device_type) { return autocast_dtype[static_cast(device_type)]; } void set_autocast_dtype(at::DeviceType device_type, at::ScalarType dtype) { autocast_dtype[static_cast(device_type)] = dtype; } bool is_autocast_cache_enabled() { return cache_enabled; } void set_autocast_cache_enabled(bool enabled) { cache_enabled = enabled; } // Overload to catch Tensor args // TODO (possible optimization): // Move cast_cache to an inline function in a header with cached_casts declared as // extern thread_local in the header. Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_type) { if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) { // Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves). // See cached_casts declaration above for detailed strategy. bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view() && cache_enabled && !at::caching::is_cached_tensor(arg)); if (can_try_cache) { const std::lock_guard lock(cached_casts_mutex); auto it = get_cached_casts().find(arg.unsafeGetTensorImpl()); if (it != get_cached_casts().end()) { return std::get<1>(it->second); } else { auto casted_arg = arg.to(to_type); get_cached_casts().emplace(arg.unsafeGetTensorImpl(), val_type{weakref_type(arg.getIntrusivePtr()), casted_arg}); return casted_arg; } } else { return arg.to(to_type); } } else { return arg; } } /******************************* Banned functions *******************************/ static Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const std::optional&, int64_t) { AT_ERROR("torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n" "Many models use a sigmoid layer right before the binary cross entropy layer.\n" "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n" "or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n" "safe to autocast."); } namespace { /***************************************** Explicit registration for out-of-place ops *****************************************/ TORCH_LIBRARY_IMPL(_, Autocast, m) { m.fallback(torch::CppFunction::makeFallthrough()); } TORCH_LIBRARY_IMPL(aten, Autocast, m) { // lower_precision_fp #define _KERNEL_CUDA_LOW_PRECISION_FP(...) \ KERNEL_CUDA(__VA_ARGS__, lower_precision_fp) AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP) KERNEL_CUDA(cudnn_convolution, lower_precision_fp) KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp) // fp32 #define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32) AT_FORALL_FP32(_KERNEL_CUDA_FP32) // fp32_set_opt_dtype #define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \ KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype) AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE) // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even // when autocasting. // KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype) // KERNEL_CUDA(norm, ScalarOpt_dim_dtype, fp32_set_opt_dtype) // KERNEL_CUDA(norm, names_ScalarOpt_dim_dtype, fp32_set_opt_dtype) // fp32_append_dtype // The fp32_append_dtype wrapper overrides implicit promotion behavior. // norm does not implicitly promote, but be aware when adding new ops to this policy. AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE( KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA) // promote #define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote) AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE) m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } TORCH_LIBRARY_IMPL(_, AutocastMPS, m) { m.fallback(torch::CppFunction::makeFallthrough()); } TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { // lower_precision_fp KERNEL_MPS2(_convolution, deprecated, lower_precision_fp) KERNEL_MPS(_convolution, lower_precision_fp) KERNEL_MPS(conv1d, lower_precision_fp) KERNEL_MPS(conv2d, lower_precision_fp) KERNEL_MPS(conv_tbc, lower_precision_fp) KERNEL_MPS(conv_transpose1d, lower_precision_fp) KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp) KERNEL_MPS(convolution, lower_precision_fp) KERNEL_MPS(_mps_convolution, lower_precision_fp) KERNEL_MPS(prelu, lower_precision_fp) KERNEL_MPS(addmm, lower_precision_fp) KERNEL_MPS(addmv, lower_precision_fp) KERNEL_MPS(addr, lower_precision_fp) KERNEL_MPS(matmul, lower_precision_fp) KERNEL_MPS(einsum, lower_precision_fp) KERNEL_MPS(mm, lower_precision_fp) KERNEL_MPS(mv, lower_precision_fp) KERNEL_MPS(linear, lower_precision_fp) KERNEL_MPS(addbmm, lower_precision_fp) KERNEL_MPS(baddbmm, lower_precision_fp) KERNEL_MPS(bmm, lower_precision_fp) KERNEL_MPS(chain_matmul, lower_precision_fp) KERNEL_MPS(linalg_multi_dot, lower_precision_fp) KERNEL_MPS(lstm_cell, lower_precision_fp) // fp32 KERNEL_MPS(acos, fp32) KERNEL_MPS(asin, fp32) KERNEL_MPS(cosh, fp32) KERNEL_MPS(erfinv, fp32) KERNEL_MPS(exp, fp32) KERNEL_MPS(expm1, fp32) KERNEL_MPS(log, fp32) KERNEL_MPS(log10, fp32) KERNEL_MPS(log2, fp32) KERNEL_MPS(log1p, fp32) KERNEL_MPS(reciprocal, fp32) KERNEL_MPS(rsqrt, fp32) KERNEL_MPS(sinh, fp32) KERNEL_MPS(tan, fp32) KERNEL_MPS2(pow, Tensor_Scalar, fp32) KERNEL_MPS2(pow, Tensor_Tensor, fp32) KERNEL_MPS2(pow, Scalar, fp32) KERNEL_MPS(softplus, fp32) KERNEL_MPS(layer_norm, fp32) KERNEL_MPS(native_layer_norm, fp32) KERNEL_MPS(group_norm, fp32) KERNEL_MPS2(frobenius_norm, dim, fp32) KERNEL_MPS(nuclear_norm, fp32) KERNEL_MPS2(nuclear_norm, dim, fp32) KERNEL_MPS(batch_norm, fp32) KERNEL_MPS(cosine_similarity, fp32) KERNEL_MPS(poisson_nll_loss, fp32) KERNEL_MPS(cosine_embedding_loss, fp32) KERNEL_MPS(nll_loss, fp32) KERNEL_MPS(nll_loss2d, fp32) KERNEL_MPS(hinge_embedding_loss, fp32) KERNEL_MPS(kl_div, fp32) KERNEL_MPS(l1_loss, fp32) KERNEL_MPS(smooth_l1_loss, fp32) KERNEL_MPS(huber_loss, fp32) KERNEL_MPS(mse_loss, fp32) KERNEL_MPS(margin_ranking_loss, fp32) KERNEL_MPS(multilabel_margin_loss, fp32) KERNEL_MPS(soft_margin_loss, fp32) KERNEL_MPS(triplet_margin_loss, fp32) KERNEL_MPS(multi_margin_loss, fp32) KERNEL_MPS(binary_cross_entropy_with_logits, fp32) KERNEL_MPS(dist, fp32) KERNEL_MPS(pdist, fp32) KERNEL_MPS(cdist, fp32) KERNEL_MPS(renorm, fp32) KERNEL_MPS(logsumexp, fp32) // fp32_set_opt_dtype KERNEL_MPS(prod, fp32) KERNEL_MPS2(prod, dim_int, fp32) KERNEL_MPS2(prod, dim_Dimname, fp32) KERNEL_MPS2(softmax, int, fp32) KERNEL_MPS2(softmax, Dimname, fp32) KERNEL_MPS2(log_softmax, int, fp32) KERNEL_MPS2(log_softmax, Dimname, fp32) KERNEL_MPS(cumprod, fp32) KERNEL_MPS2(cumprod, dimname, fp32) KERNEL_MPS(cumsum, fp32) KERNEL_MPS2(cumsum, dimname, fp32) KERNEL_MPS(linalg_vector_norm, fp32) KERNEL_MPS(linalg_matrix_norm, fp32) KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32) KERNEL_MPS(sum, fp32) KERNEL_MPS2(sum, dim_IntList, fp32) KERNEL_MPS2(sum, dim_DimnameList, fp32) // // promote KERNEL_MPS(addcdiv, promote) KERNEL_MPS(addcmul, promote) KERNEL_MPS(atan2, promote) KERNEL_MPS(bilinear, promote) KERNEL_MPS(cross, promote) KERNEL_MPS(dot, promote) KERNEL_MPS(grid_sampler, promote) KERNEL_MPS(index_put, promote) KERNEL_MPS(tensordot, promote) KERNEL_MPS(scatter_add, promote) } TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { // lower_precision_fp cast policy KERNEL_CPU(conv1d, lower_precision_fp) KERNEL_CPU(conv1d, padding, lower_precision_fp) KERNEL_CPU(conv2d, lower_precision_fp) KERNEL_CPU(conv2d, padding, lower_precision_fp) KERNEL_CPU(conv3d, lower_precision_fp) KERNEL_CPU(conv3d, padding, lower_precision_fp) KERNEL_CPU(bmm, lower_precision_fp) KERNEL_CPU(mm, lower_precision_fp) KERNEL_CPU(linalg_vecdot, lower_precision_fp) KERNEL_CPU(baddbmm, lower_precision_fp) KERNEL_CPU(addmm, lower_precision_fp) KERNEL_CPU(addbmm, lower_precision_fp) KERNEL_CPU(linear, lower_precision_fp) KERNEL_CPU(_convolution, deprecated, lower_precision_fp) KERNEL_CPU(matmul, lower_precision_fp) KERNEL_CPU(conv_tbc, lower_precision_fp) KERNEL_CPU(mkldnn_rnn_layer, lower_precision_fp) KERNEL_CPU(conv_transpose1d, lower_precision_fp) KERNEL_CPU(conv_transpose2d, input, lower_precision_fp) KERNEL_CPU(conv_transpose3d, input, lower_precision_fp) KERNEL_CPU(prelu, lower_precision_fp) KERNEL_CPU(scaled_dot_product_attention, lower_precision_fp) KERNEL_CPU(_native_multi_head_attention, lower_precision_fp) // fp32 cast policy KERNEL_CPU(avg_pool3d, fp32) KERNEL_CPU(binary_cross_entropy, fp32) KERNEL_CPU(grid_sampler, fp32) KERNEL_CPU(polar, fp32) KERNEL_CPU(prod, fp32) KERNEL_CPU(prod, dim_int, fp32) KERNEL_CPU(prod, dim_Dimname, fp32) KERNEL_CPU(quantile, fp32) KERNEL_CPU(quantile, scalar, fp32) KERNEL_CPU(nanquantile, fp32) KERNEL_CPU(nanquantile, scalar, fp32) KERNEL_CPU(stft, fp32) KERNEL_CPU(stft, center, fp32) KERNEL_CPU(cdist, fp32) KERNEL_CPU(grid_sampler_2d, fp32) KERNEL_CPU(_grid_sampler_2d_cpu_fallback, fp32) KERNEL_CPU(grid_sampler_3d, fp32) KERNEL_CPU(trace, fp32) KERNEL_CPU(view_as_complex, fp32) KERNEL_CPU(cholesky, fp32) KERNEL_CPU(cholesky_inverse, fp32) KERNEL_CPU(cholesky_solve, fp32) KERNEL_CPU(inverse, fp32) KERNEL_CPU(lu_solve, fp32) KERNEL_CPU(orgqr, fp32) KERNEL_CPU(ormqr, fp32) KERNEL_CPU(pinverse, fp32) KERNEL_CPU(max_pool3d, fp32) KERNEL_CPU(max_unpool2d, fp32) KERNEL_CPU(max_unpool3d, fp32) KERNEL_CPU(adaptive_avg_pool3d, fp32) KERNEL_CPU(reflection_pad1d, fp32) KERNEL_CPU(reflection_pad2d, fp32) KERNEL_CPU(replication_pad1d, fp32) KERNEL_CPU(replication_pad2d, fp32) KERNEL_CPU(replication_pad3d, fp32) KERNEL_CPU(mse_loss, fp32) KERNEL_CPU(cosine_embedding_loss, fp32) KERNEL_CPU(nll_loss, fp32) KERNEL_CPU(nll_loss2d, fp32) KERNEL_CPU(hinge_embedding_loss, fp32) KERNEL_CPU(poisson_nll_loss, fp32) KERNEL_CPU(smooth_l1_loss, fp32) KERNEL_CPU(cross_entropy_loss, fp32) KERNEL_CPU(l1_loss, fp32) KERNEL_CPU(huber_loss, fp32) KERNEL_CPU(margin_ranking_loss, fp32) KERNEL_CPU(soft_margin_loss, fp32) KERNEL_CPU(triplet_margin_loss, fp32) KERNEL_CPU(multi_margin_loss, fp32) KERNEL_CPU(ctc_loss, IntList, fp32) KERNEL_CPU(ctc_loss, Tensor, fp32) KERNEL_CPU(kl_div, fp32) KERNEL_CPU(multilabel_margin_loss, fp32) KERNEL_CPU(binary_cross_entropy_with_logits, fp32) KERNEL_CPU(fft_fft, fp32) KERNEL_CPU(fft_ifft, fp32) KERNEL_CPU(fft_fft2, fp32) KERNEL_CPU(fft_ifft2, fp32) KERNEL_CPU(fft_fftn, fp32) KERNEL_CPU(fft_ifftn, fp32) KERNEL_CPU(fft_rfft, fp32) KERNEL_CPU(fft_irfft, fp32) KERNEL_CPU(fft_rfft2, fp32) KERNEL_CPU(fft_irfft2, fp32) KERNEL_CPU(fft_rfftn, fp32) KERNEL_CPU(fft_irfftn, fp32) KERNEL_CPU(fft_hfft, fp32) KERNEL_CPU(fft_ihfft, fp32) KERNEL_CPU(linalg_cond, fp32) KERNEL_CPU(linalg_cond, p_str, fp32) KERNEL_CPU(linalg_matrix_rank, fp32) KERNEL_CPU(linalg_matrix_rank, tol_tensor, fp32) KERNEL_CPU(linalg_matrix_rank, atol_rtol_tensor, fp32) KERNEL_CPU(linalg_matrix_rank, atol_rtol_float, fp32) KERNEL_CPU(linalg_solve, fp32) KERNEL_CPU(linalg_cholesky, fp32) KERNEL_CPU(linalg_svdvals, fp32) KERNEL_CPU(linalg_eigvals, fp32) KERNEL_CPU(linalg_eigvalsh, fp32) KERNEL_CPU(linalg_inv, fp32) KERNEL_CPU(linalg_householder_product, fp32) KERNEL_CPU(linalg_tensorinv, fp32) KERNEL_CPU(linalg_tensorsolve, fp32) KERNEL_CPU(fake_quantize_per_tensor_affine, fp32) KERNEL_CPU(geqrf, fp32) KERNEL_CPU(_lu_with_info, fp32) KERNEL_CPU(qr, fp32) KERNEL_CPU(svd, fp32) KERNEL_CPU(triangular_solve, fp32) KERNEL_CPU(fractional_max_pool2d, fp32) KERNEL_CPU(fractional_max_pool3d, fp32) KERNEL_CPU(adaptive_max_pool3d, fp32) KERNEL_CPU(multilabel_margin_loss_forward, fp32) KERNEL_CPU(linalg_qr, fp32) KERNEL_CPU(linalg_cholesky_ex, fp32) KERNEL_CPU(linalg_svd, fp32) KERNEL_CPU(linalg_eig, fp32) KERNEL_CPU(linalg_eigh, fp32) KERNEL_CPU(linalg_lstsq, fp32) KERNEL_CPU(linalg_inv_ex, fp32) // promote KERNEL_CPU(stack, promote) KERNEL_CPU(cat, promote) KERNEL_CPU(index_copy, promote) KERNEL_CPU(index_copy, dimname, promote) } TORCH_LIBRARY_IMPL(_, AutocastXPU, m) { m.fallback(torch::CppFunction::makeFallthrough()); } TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) { // lower_precision_fp #define _KERNEL_XPU_LOW_PRECISION_FP(...) \ KERNEL_XPU(__VA_ARGS__, lower_precision_fp) AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP) // fp32 #define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32) AT_FORALL_FP32(_KERNEL_XPU_FP32) // fp32_set_opt_dtype #define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \ KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype) AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE) // fp32_append_dtype // The fp32_append_dtype wrapper overrides implicit promotion behavior. // norm does not implicitly promote, but be aware when adding new ops to this policy. AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE( KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU) // promote #define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote) AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE) m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), TORCH_FN((&at::autocast::binary_cross_entropy_banned))); } } // namespace } // namespace at::autocast