#pragma once #include #include #include #include #include #include #include namespace torch::cuda { using tensor_list2d = std::vector>; TORCH_CUDA_CU_API std::vector& broadcast_out( const at::Tensor& tensor, std::vector& out_tensors); TORCH_CUDA_CU_API std::vector broadcast( const at::Tensor& tensor, at::IntArrayRef devices); TORCH_CUDA_CU_API tensor_list2d broadcast_coalesced( at::TensorList tensors, at::IntArrayRef devices, size_t buffer_size); TORCH_CUDA_CU_API std::vector& scatter_out( const at::Tensor& tensor, std::vector& out_tensors, int64_t dim = 0, const std::optional>>& streams = std::nullopt); TORCH_CUDA_CU_API std::vector scatter( const at::Tensor& tensor, at::IntArrayRef devices, const std::optional>& chunk_sizes = std::nullopt, int64_t dim = 0, const std::optional>>& streams = std::nullopt); TORCH_CUDA_CU_API at::Tensor& gather_out( at::TensorList tensors, at::Tensor& out_tensor, int64_t dim); TORCH_CUDA_CU_API at::Tensor gather( at::TensorList tensors, int64_t dim, std::optional destination_index); } // namespace torch::cuda