xref: /aosp_15_r20/external/tensorflow/tensorflow/core/common_runtime/collective_rma_local.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/collective_rma_local.h"
16 
17 #include "tensorflow/core/common_runtime/copy_tensor.h"
18 #include "tensorflow/core/common_runtime/dma_helper.h"
19 
20 namespace tensorflow {
21 
StartAbort(const Status & s)22 void CollectiveRemoteAccessLocal::StartAbort(const Status& s) {
23   buf_rendezvous_.StartAbort(s);
24 }
25 
RecvFromPeer(const string & peer_device,const string & peer_task,bool peer_is_local,const string & key,Device * to_device,DeviceContext * to_device_ctx,const AllocatorAttributes & to_alloc_attr,Tensor * to_tensor,const DeviceLocality & client_locality,int dev_to_dev_stream_index,CancellationManager * cancellation_manager,const StatusCallback & done)26 void CollectiveRemoteAccessLocal::RecvFromPeer(
27     const string& peer_device, const string& peer_task, bool peer_is_local,
28     const string& key, Device* to_device, DeviceContext* to_device_ctx,
29     const AllocatorAttributes& to_alloc_attr, Tensor* to_tensor,
30     const DeviceLocality& client_locality, int dev_to_dev_stream_index,
31     CancellationManager* cancellation_manager, const StatusCallback& done) {
32   VLOG(1) << "RecvFromPeer " << this << " from " << peer_device << " key "
33           << key;
34   if (!peer_is_local) {
35     done(
36         errors::Internal("CollectiveRemoteAccessLocal::RecvFromPeer "
37                          "called with peer_is_local=false"));
38     return;
39   }
40 
41   Device* from_device;
42   Status status = dev_mgr_->LookupDevice(peer_device, &from_device);
43   if (!status.ok()) {
44     done(status);
45     return;
46   }
47 
48   auto consumer_callback = [to_tensor, to_device_ctx, to_device, to_alloc_attr,
49                             dev_to_dev_stream_index,
50                             done](const Status& status,
51                                   BufRendezvous::Hook* hook) {
52     Status s = status;
53     if (s.ok()) {
54       if (hook == nullptr) {
55         s = errors::Internal("Invalid null hook in ConsumeBuf callback");
56       }
57     } else {
58       if (hook != nullptr) {
59         LOG(ERROR) << "Got hook " << hook << " with status " << s
60                    << " from ConsumeBuf";
61       }
62     }
63 
64     if (s.ok()) {
65       int64_t recv_bytes = to_tensor->TotalBytes();
66       CHECK_EQ(recv_bytes, hook->prod_value->TotalBytes());
67       MemCpyAsync(hook->prod_ctx,    // src DeviceContext
68                   to_device_ctx,     // dst DeviceContext
69                   hook->prod_dev,    // src Device
70                   to_device,         // dst Device
71                   hook->prod_attr,   // src AllocatorAttributes
72                   to_alloc_attr,     // dst AllocatorAttributes
73                   hook->prod_value,  // src Tensor*
74                   to_tensor,         // dst Tensor*
75                   dev_to_dev_stream_index,
76                   [hook, done](const Status& memcpy_status) {
77                     // This callback may be executing in the GPUEventMgr
78                     // pool in which case it must be very short duration
79                     // and non-blocking (except e.g. for queue insertion).
80                     // It would be safer, though expensive, to transfer
81                     // to another thread here.
82                     done(memcpy_status);
83                     BufRendezvous::DoneWithHook(hook);
84                   });
85     } else {
86       done(s);
87       if (hook != nullptr) {
88         BufRendezvous::DoneWithHook(hook);
89       }
90     }
91   };
92   buf_rendezvous_.ConsumeBuf(key, from_device->name(),
93                              from_device->attributes().incarnation(),
94                              consumer_callback, cancellation_manager);
95 }
96 
PostToPeer(const string & peer_device,const string & peer_task,const string & key,Device * from_device,DeviceContext * from_device_ctx,const AllocatorAttributes & from_alloc_attr,const Tensor * from_tensor,const DeviceLocality & client_locality,CancellationManager * cancellation_manager,const StatusCallback & done)97 void CollectiveRemoteAccessLocal::PostToPeer(
98     const string& peer_device, const string& peer_task, const string& key,
99     Device* from_device, DeviceContext* from_device_ctx,
100     const AllocatorAttributes& from_alloc_attr, const Tensor* from_tensor,
101     const DeviceLocality& client_locality,
102     CancellationManager* cancellation_manager, const StatusCallback& done) {
103   VLOG(1) << "PostToPeer " << this << " key " << key
104           << " step_id_=" << step_id_;
105   buf_rendezvous_.ProvideBuf(key, from_device, from_device_ctx, from_tensor,
106                              from_alloc_attr, done, cancellation_manager);
107 }
108 
CheckPeerHealth(const string & peer_task,int64_t timeout_in_ms,const StatusCallback & done)109 void CollectiveRemoteAccessLocal::CheckPeerHealth(const string& peer_task,
110                                                   int64_t timeout_in_ms,
111                                                   const StatusCallback& done) {
112   // Assume local devices are always healthy.
113   done(errors::Internal(
114       "CheckPeerHealth is not supposed to be called for local collectives"));
115 }
116 
117 /*static*/
MemCpyAsync(DeviceContext * src_dev_ctx,DeviceContext * dst_dev_ctx,Device * src_dev,Device * dst_dev,const AllocatorAttributes & src_attr,const AllocatorAttributes & dst_attr,const Tensor * src,Tensor * dst,int dev_to_dev_stream_index,const StatusCallback & done)118 void CollectiveRemoteAccessLocal::MemCpyAsync(
119     DeviceContext* src_dev_ctx, DeviceContext* dst_dev_ctx, Device* src_dev,
120     Device* dst_dev, const AllocatorAttributes& src_attr,
121     const AllocatorAttributes& dst_attr, const Tensor* src, Tensor* dst,
122     int dev_to_dev_stream_index, const StatusCallback& done) {
123   // We want a real copy to happen, i.e. the bytes inside of src should be
124   // transferred to the buffer backing dst.  If src and dst are on different
125   // devices then CopyTensor::ViaDMA will do just that.  But if they're both
126   // the same CPU, then it will actually just reset dst to point to src.
127   // Since this routine is used for copying between devices and within a
128   // device, we need to detect and bypass the wrong-semantics case.
129   const DeviceType src_device_type(
130       src_attr.on_host() ? DEVICE_CPU : src_dev->attributes().device_type());
131   const DeviceType dst_device_type(
132       dst_attr.on_host() ? DEVICE_CPU : dst_dev->attributes().device_type());
133   const bool non_cpu_src = src_device_type != DeviceType(DEVICE_CPU);
134   const bool non_cpu_dst = dst_device_type != DeviceType(DEVICE_CPU);
135   // For GPU devices when only one compute stream is used (the default)
136   // the OpKernelContext does not supply a DeviceContext.  It's assumed
137   // that all nodes use the default context.
138   if (src_dev_ctx == nullptr && src_device_type == DEVICE_GPU) {
139     const DeviceBase::AcceleratorDeviceInfo* dev_info =
140         src_dev->tensorflow_accelerator_device_info();
141     CHECK(dev_info);
142     src_dev_ctx = dev_info->default_context;
143   }
144   if (dst_dev_ctx == nullptr && dst_device_type == DEVICE_GPU) {
145     const DeviceBase::AcceleratorDeviceInfo* dev_info =
146         src_dev->tensorflow_accelerator_device_info();
147     CHECK(dev_info);
148     dst_dev_ctx = dev_info->default_context;
149   }
150   if (non_cpu_src) CHECK(src_dev_ctx);
151   if (non_cpu_dst) CHECK(dst_dev_ctx);
152   if (non_cpu_src || non_cpu_dst) {
153     CopyTensor::ViaDMA("",  // edge name (non-existent)
154                        src_dev_ctx, dst_dev_ctx, src_dev, dst_dev, src_attr,
155                        dst_attr, src, dst, dev_to_dev_stream_index, done);
156   } else {
157     int64_t bytes = src->TotalBytes();
158     DCHECK_EQ(dst->TotalBytes(), bytes);
159     memcpy(DMAHelper::base(dst), DMAHelper::base(src), bytes);
160     done(OkStatus());
161   }
162 }
163 
164 }  // namespace tensorflow
165