xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/c10d/socket.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright (c) Meta Platforms, Inc. and its affiliates.
2 // All rights reserved.
3 //
4 // This source code is licensed under the BSD-style license found in the
5 // LICENSE file in the root directory of this source tree.
6 
7 #pragma once
8 
9 #include <chrono>
10 #include <cstdint>
11 #include <memory>
12 #include <string>
13 
14 #include <c10/macros/Macros.h>
15 #include <c10/util/Exception.h>
16 #include <torch/csrc/distributed/c10d/Backoff.hpp>
17 #include <torch/csrc/distributed/c10d/exception.h>
18 
19 namespace c10d {
20 namespace detail {
21 
22 class SocketOptions {
23  public:
prefer_ipv6(bool value)24   SocketOptions& prefer_ipv6(bool value) noexcept {
25     prefer_ipv6_ = value;
26 
27     return *this;
28   }
29 
prefer_ipv6()30   bool prefer_ipv6() const noexcept {
31     return prefer_ipv6_;
32   }
33 
connect_timeout(std::chrono::milliseconds value)34   SocketOptions& connect_timeout(std::chrono::milliseconds value) noexcept {
35     connect_timeout_ = value;
36 
37     return *this;
38   }
39 
connect_timeout()40   std::chrono::milliseconds connect_timeout() const noexcept {
41     return connect_timeout_;
42   }
43 
44   // Sets the backoff policy to use for socket connect ops.
connect_backoff(std::shared_ptr<Backoff> value)45   SocketOptions& connect_backoff(std::shared_ptr<Backoff> value) noexcept {
46     connect_backoff_ = std::move(value);
47 
48     return *this;
49   }
50 
connect_backoff()51   const std::shared_ptr<Backoff>& connect_backoff() const noexcept {
52     return connect_backoff_;
53   }
54 
55  private:
56   bool prefer_ipv6_ = true;
57   std::chrono::milliseconds connect_timeout_{std::chrono::seconds{30}};
58   std::shared_ptr<Backoff> connect_backoff_{
59       std::make_shared<FixedBackoff>(std::chrono::milliseconds(1000))};
60 };
61 
62 class SocketImpl;
63 
64 class Socket {
65  public:
66   // This function initializes the underlying socket library and must be called
67   // before any other socket function.
68   static void initialize();
69 
70   static Socket listen(std::uint16_t port, const SocketOptions& opts = {});
71 
72   static Socket listenFromFd(int fd, std::uint16_t expected_port);
73 
74   static Socket connect(
75       const std::string& host,
76       std::uint16_t port,
77       const SocketOptions& opts = {});
78 
79   Socket() noexcept = default;
80 
81   Socket(const Socket& other) = delete;
82 
83   Socket& operator=(const Socket& other) = delete;
84 
85   Socket(Socket&& other) noexcept;
86 
87   Socket& operator=(Socket&& other) noexcept;
88 
89   ~Socket();
90 
91   Socket accept() const;
92 
93   int handle() const noexcept;
94 
95   std::uint16_t port() const;
96 
97   bool waitForInput(std::chrono::milliseconds timeout);
98 
99   std::string repr() const;
100 
101  private:
102   explicit Socket(std::unique_ptr<SocketImpl>&& impl) noexcept;
103 
104   std::unique_ptr<SocketImpl> impl_;
105 };
106 
107 } // namespace detail
108 
109 } // namespace c10d
110