xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/context/context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <cstdint>
4 #include <functional>
5 
6 #include <ATen/core/Dict.h>
7 #include <torch/csrc/autograd/engine.h>
8 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
9 #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
10 #include <torch/csrc/distributed/rpc/rpc_agent.h>
11 
12 namespace torch {
13 namespace distributed {
14 namespace autograd {
15 
16 class RecvRpcBackward;
17 
18 // DistAutogradContext which stores information for a single distributed
19 // autograd pass on a worker.
20 class TORCH_API DistAutogradContext {
21  public:
22   using GradCallback = std::function<bool(torch::Tensor&)>;
23 
24   explicit DistAutogradContext(int64_t contextId);
25 
26   // Retrieves the autograd context id for this context.
27   int64_t contextId() const;
28 
29   // Records a 'send' autograd function for this context with the provided
30   // message id.
31   void addSendFunction(
32       const std::shared_ptr<SendRpcBackward>& func,
33       int64_t autograd_message_id);
34 
35   // Records a 'recv' autograd function for this context with the provided
36   // message id.
37   void addRecvFunction(
38       std::shared_ptr<RecvRpcBackward>& func,
39       int64_t autograd_message_id);
40 
41   // Given an autograd_message_id, retrieve the appropriate send function.
42   std::shared_ptr<SendRpcBackward> retrieveSendFunction(
43       int64_t autograd_message_id);
44 
45   // Return all send functions for this context.
46   std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>> sendFunctions()
47       const;
48 
49   // Return all recv functions for this context.
50   std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>> recvFunctions()
51       const;
52 
53   // Adds a future message recording an outstanding RPC.
54   void addOutstandingRpc(const c10::intrusive_ptr<rpc::JitFuture>& jitFuture);
55 
56   // Returns all gradients.
57   const c10::Dict<torch::Tensor, torch::Tensor> getGradients() const;
58 
59   // This function gives a mutable grad reference to the callback.
60   // If the callback returns true, it means the grad in the context
61   // needs to be updated.
62   void runGradCallbackForVariable(
63       const torch::autograd::Variable& variable,
64       GradCallback&& cb);
65 
66   DistAutogradContext(const DistAutogradContext&) = delete;
67   DistAutogradContext& operator=(const DistAutogradContext&) = delete;
68   DistAutogradContext(DistAutogradContext&&) = delete;
69   DistAutogradContext& operator=(DistAutogradContext&&) = delete;
70 
71   // records the workerID of a node that we sent an RPC to.
72   // workerIDs are added here when we attach a send function to this autograd
73   // context
74   void addKnownWorkerId(const rpc::worker_id_t workerId);
75 
76   // Retrieves a set containing the known workerIds for this context
77   // These are the different workers that this context has sent RPCs to.
78   std::unordered_set<rpc::worker_id_t> getKnownWorkerIds() const;
79 
80  private:
81   friend class BackwardPassCleanupGuard;
82   friend class DistEngine;
83   friend class RecvRpcBackward;
84   friend class DistAccumulateGradCaptureHook;
85 
86   // Record that we would like to accumulate the provided gradient on the given
87   // variable.
88   void accumulateGrad(
89       const torch::autograd::Variable& variable,
90       const torch::Tensor& grad,
91       size_t num_expected_refs);
92 
93   // Retrieve the GraphTask.
94   std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTask();
95 
96   // Set the appropriate graph task for the backward pass. Can be called only
97   // once.
98   void setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask);
99 
100   // Resets the graph task to ensure we can run another distributed backward
101   // pass for the same autograd context.
102   void resetGraphTask();
103 
104   // Waits for all outstanding RPCs for this context to finish and clears all
105   // outstanding rpcs held in this context. This should be called only once.
106   c10::intrusive_ptr<c10::ivalue::Future> clearAndWaitForOutstandingRpcsAsync();
107 
108   void clearOutstandingRpcs();
109 
110   // Record an event to mark the completion of gradient computation. These
111   // events will later help to properly synchronize gradients consumptions
112   // in getGradients(). We need these events because backward and
113   // optimizer.step are separate RPC calls, and will occur on different CUDA
114   // streams. Without synchronization, it is possible that gradients are
115   // consumed before they are ready.
116   void recordGradEvent(c10::Device device);
117 
118   const int64_t contextId_;
119 
120   // Set containing known worker IDs, used in cleaning up autograd context.
121   // Whenever a sendRpcBackward is attached to the autograd graph for this
122   // context, the destination is added here.
123   std::unordered_set<rpc::worker_id_t> knownWorkerIds_;
124 
125   // Map from autograd_message_id to appropriate 'send' autograd function.
126   std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
127       sendAutogradFunctions_;
128 
129   // Map from autograd_message_id to appropriate 'recv' autograd function.
130   std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
131       recvAutogradFunctions_;
132 
133   // Gradients accumulated in this context so far. The key is the variable on
134   // which the gradient needs to be accumulated and the value is the gradient
135   // that needs to be accumulated on that variable..
136   c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;
137 
138   // See comments for recordGradEvent(c10::Device device);
139   std::unordered_map<c10::Device, c10::Event> gradReadyEvents_;
140   const c10::impl::VirtualGuardImpl impl_;
141 
142   // The autograd GraphTask for the backward pass on this node for this context.
143   std::shared_ptr<torch::autograd::GraphTask> graphTask_;
144 
145   // List of futures for RPCs initiated by this node to propagate gradients to
146   // other nodes. The distributed autograd engine on this node can return
147   // successfully only if all these futures are done and are successful.
148   std::vector<c10::intrusive_ptr<rpc::JitFuture>> outStandingRpcs_;
149 
150   // Lock to protect concurrent modification of the context.
151   mutable std::mutex lock_;
152 };
153 
154 using ContextPtr = std::shared_ptr<DistAutogradContext>;
155 
156 // This class stores a shared_ptr to a DistAutogradContext instance in a
157 // thread local variable. The instance is given by the call site. The class
158 // doesn't know the current context. It's just a util class.
159 class TORCH_API ThreadLocalDistAutogradContext {
160  public:
161   // Store 'new_context' to the thread local variable maintained by this class.
162   explicit ThreadLocalDistAutogradContext(ContextPtr&& new_context);
163   ~ThreadLocalDistAutogradContext();
164 
165   // Retrieve the stored DistAutogradContext instance.
166   static ContextPtr getContextPtr();
167 
168  private:
169   ContextPtr prev_context_ptr_;
170 };
171 
172 } // namespace autograd
173 } // namespace distributed
174 } // namespace torch
175