xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/collective_param_resolver_local.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/collective.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/lib/gtl/flatmap.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 
30 namespace tensorflow {
31 class CompleteGroupRequest;
32 class CompleteGroupResponse;
33 class CompleteInstanceRequest;
34 class CompleteInstanceResponse;
35 class ConfigProto;
36 class DeviceMgr;
37 
38 // Implements ParamResolverInterface for a single-task context.
39 // It also implements the functionality necessary to serve as the
40 // group leader for param resolution in a multi-task context.
41 class CollectiveParamResolverLocal : public ParamResolverInterface {
42  public:
43   CollectiveParamResolverLocal(const ConfigProto& config,
44                                const DeviceMgr* dev_mgr,
45                                DeviceResolverInterface* dev_resolver,
46                                NcclCommunicatorInterface* nccl_communicator,
47                                const string& task_name);
48 
~CollectiveParamResolverLocal()49   ~CollectiveParamResolverLocal() override {}
50 
51   void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
52                            CancellationManager* cancel_mgr,
53                            const StatusCallback& done) override;
54 
55   void CompleteGroupAsync(const DeviceAttributes& device,
56                           CollGroupParams* group_params,
57                           CancellationManager* cancel_mgr,
58                           const StatusCallback& done) override;
59 
60   void CompleteInstanceAsync(const CompleteInstanceRequest* request,
61                              CompleteInstanceResponse* response,
62                              CancellationManager* cancel_mgr,
63                              const StatusCallback& done) override;
64 
65   Status LookupGroup(int32_t group_key, CollGroupParams* group) override;
66 
67   void StartAbort(const Status& s) override;
68 
69  protected:
70   // For access to InstanceRec and CompleteDefaultRanking.
71   friend class CollectiveParamResolverLocalTest;
72 
73   // Used to complete/verify CollGroup.
74   struct GroupRec {
75     mutable mutex mu;
76     CollGroupParams group TF_GUARDED_BY(mu);
77     Status status TF_GUARDED_BY(mu);
78     std::unordered_map<string, int64_t> incarnations_by_device_name
79         TF_GUARDED_BY(mu);
80     std::vector<CollGroupParams*> pending_params TF_GUARDED_BY(mu);
81     std::vector<StatusCallback> pending_done TF_GUARDED_BY(mu);
82   };
83 
84   // Finds the GroupRec that corresponds to group_params->group_key.
85   // Also populates group_params from that group_rec.
86   // Will wait until GroupRec is fully populated or an error arises before
87   // calling done.  Callback GroupRec* arg is only valid if status is ok.
88   // Ownership of GroupRec stays with this object and does not pass to the
89   // callback.
90   void CompleteGroupLocal(const DeviceAttributes& device,
91                           CollGroupParams* group_params,
92                           CancellationManager* cancel_mgr, StatusCallback done)
93       TF_LOCKS_EXCLUDED(group_mu_);
94 
95   // Finishes the group parameters once all members of the group are there.
96   void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu);
97 
98   // Cancels the group if it's still pending.
99   void CancelGroup(int32 group_key) TF_LOCKS_EXCLUDED(group_mu_);
100 
101   // Lookup and populate parameters from an already initialized group.
102   Status LookupAndPopulateGroupParams(CollGroupParams* group_params);
103 
104   // Used to complete/verify CollInstance.
105   struct InstanceRec;
106 
107   typedef std::function<void(InstanceRec*)> IRConsumer;
108   struct InstanceRec {
109     mutex mu;
110     // Values to be shared by all instances, constant after initialization.
111     CollectiveParams* shared;
112     // If an error occurs during initialization this structure stays in the
113     // table with a non-OK status. Purging the table and restarting needs to be
114     // done at a higher level.
115     Status status TF_GUARDED_BY(mu);
116 
117     // These fields are used to count the instances that have called
118     // in and become known while resolving broadcast source identity and
119     // communicator key.
120     int source_rank TF_GUARDED_BY(mu);
121     string communicator_key TF_GUARDED_BY(mu);
122     int known_count TF_GUARDED_BY(mu);
123     std::vector<bool> known TF_GUARDED_BY(mu);
124     std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu);
125 
InstanceRecInstanceRec126     InstanceRec()
127         : shared(new CollectiveParams()), source_rank(-1), known_count(0) {}
~InstanceRecInstanceRec128     ~InstanceRec() { shared->Unref(); }
129   };
130 
131   // Find the InstanceRec with the same instance_key as cp.  If it doesn't
132   // already exist, create and initialize from gr and cp.
133   // created is set to true if a new IRec is created, false otherwise.
134   //
135   // Precondition: *gr must be a complete GroupRec, i.e. the value set
136   // by CompleteGroupLocal. *cp must be populated with all the fields
137   // required by InitInstanceSharedParams.  Ownership of InstanceRec stays
138   // with this object and does not pass to the callback.
139   InstanceRec* GetOrCreateInstanceRec(CollectiveParams* cp, bool* created)
140       TF_LOCKS_EXCLUDED(instance_mu_, group_mu_);
141 
142   // Populate *ir with device membership from gr, then initialize to be specific
143   // to cp->instance_key, i.e. order the devices and tasks.
144   //
145   // Preconditions:
146   //  cp is populated with all DeviceLocalities
147   void InitInstanceSharedParams(const CollectiveParams* cp, InstanceRec* ir);
148 
149   // Establishes the final order of gp->device_names and gp->task_names by
150   // considering localities of all devices.
151   void CompleteDefaultRanking(CollGroupParams* gp);
152 
153   // Finish populating *cp.
154   // Precondition: *gr has been fully populated by CompleteGroupLocal.
155   void CompleteInstanceLocal(const string& device, CollectiveParams* cp,
156                              const StatusCallback& done)
157       TF_LOCKS_EXCLUDED(instance_mu_, group_mu_);
158 
159   // Finish populating *cp from fully initialized *ir.
160   // Precondition: *gr and *ir are fully populated.
161   void CompleteInstanceFromInitializedIRec(const string& device,
162                                            CollectiveParams* cp,
163                                            InstanceRec* ir,
164                                            const StatusCallback& done)
165       TF_LOCKS_EXCLUDED(ir->mu);
166 
167   // Complete instance params after waiting for group.
168   // Precondition: *cp has complete group data and default_rank.
169   void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, const IRConsumer& f)
170       TF_LOCKS_EXCLUDED(ir->mu);
171 
172   // If cp.device_names contains only devices local to this process
173   // populates *localities, else returns an error.
174   Status GetLocalDeviceLocalities(const CollectiveParams& cp,
175                                   std::vector<DeviceLocality>* localities);
176 
177   // Sets cp->instance_default_rank according to location of device in
178   // current ordering of cp->instance.device_names.
179   void SetDefaultRank(const string& device, CollectiveParams* cp);
180 
181   // Sets cp->instance.type based on collective op type, and attempts to assign
182   // best implementation.
183   void AssignCollectiveType(CollectiveParams* cp);
184 
185   void StartAbortLocal(const Status& s)
186       TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_);
187 
188   const bool nccl_;
189   const DeviceMgr* dev_mgr_;
190   DeviceResolverInterface* dev_resolver_;  // Not owned.
191   NcclCommunicatorInterface* nccl_communicator_;  // Not owned.
192   string task_name_;
193   string gpu_ring_order_;
194   mutex group_mu_;
195   gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
196       TF_GUARDED_BY(group_mu_);
197   mutex instance_mu_;
198   gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>>
199       instance_table_ TF_GUARDED_BY(instance_mu_);
200   mutex status_mu_;
201   Status status_ TF_GUARDED_BY(status_mu_);
202 };
203 
204 }  // namespace tensorflow
205 
206 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
207