1 // Copyright (c) Meta Platforms, Inc. and its affiliates. 2 // All rights reserved. 3 // 4 // This source code is licensed under the BSD-style license found in the 5 // LICENSE file in the root directory of this source tree. 6 7 #pragma once 8 9 #include <chrono> 10 #include <cstdint> 11 #include <memory> 12 #include <string> 13 14 #include <c10/macros/Macros.h> 15 #include <c10/util/Exception.h> 16 #include <torch/csrc/distributed/c10d/Backoff.hpp> 17 #include <torch/csrc/distributed/c10d/exception.h> 18 19 namespace c10d { 20 namespace detail { 21 22 class SocketOptions { 23 public: prefer_ipv6(bool value)24 SocketOptions& prefer_ipv6(bool value) noexcept { 25 prefer_ipv6_ = value; 26 27 return *this; 28 } 29 prefer_ipv6()30 bool prefer_ipv6() const noexcept { 31 return prefer_ipv6_; 32 } 33 connect_timeout(std::chrono::milliseconds value)34 SocketOptions& connect_timeout(std::chrono::milliseconds value) noexcept { 35 connect_timeout_ = value; 36 37 return *this; 38 } 39 connect_timeout()40 std::chrono::milliseconds connect_timeout() const noexcept { 41 return connect_timeout_; 42 } 43 44 // Sets the backoff policy to use for socket connect ops. connect_backoff(std::shared_ptr<Backoff> value)45 SocketOptions& connect_backoff(std::shared_ptr<Backoff> value) noexcept { 46 connect_backoff_ = std::move(value); 47 48 return *this; 49 } 50 connect_backoff()51 const std::shared_ptr<Backoff>& connect_backoff() const noexcept { 52 return connect_backoff_; 53 } 54 55 private: 56 bool prefer_ipv6_ = true; 57 std::chrono::milliseconds connect_timeout_{std::chrono::seconds{30}}; 58 std::shared_ptr<Backoff> connect_backoff_{ 59 std::make_shared<FixedBackoff>(std::chrono::milliseconds(1000))}; 60 }; 61 62 class SocketImpl; 63 64 class Socket { 65 public: 66 // This function initializes the underlying socket library and must be called 67 // before any other socket function. 68 static void initialize(); 69 70 static Socket listen(std::uint16_t port, const SocketOptions& opts = {}); 71 72 static Socket listenFromFd(int fd, std::uint16_t expected_port); 73 74 static Socket connect( 75 const std::string& host, 76 std::uint16_t port, 77 const SocketOptions& opts = {}); 78 79 Socket() noexcept = default; 80 81 Socket(const Socket& other) = delete; 82 83 Socket& operator=(const Socket& other) = delete; 84 85 Socket(Socket&& other) noexcept; 86 87 Socket& operator=(Socket&& other) noexcept; 88 89 ~Socket(); 90 91 Socket accept() const; 92 93 int handle() const noexcept; 94 95 std::uint16_t port() const; 96 97 bool waitForInput(std::chrono::milliseconds timeout); 98 99 std::string repr() const; 100 101 private: 102 explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept; 103 104 std::unique_ptr<SocketImpl> impl_; 105 }; 106 107 } // namespace detail 108 109 } // namespace c10d 110