xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/gpu/outfeed_manager.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
18 
19 #include "tensorflow/compiler/xla/literal.h"
20 #include "tensorflow/compiler/xla/service/gpu/xfeed_queue.h"
21 #include "tensorflow/compiler/xla/shape_tree.h"
22 #include "tensorflow/core/platform/notification.h"
23 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
24 
25 namespace xla {
26 namespace gpu {
27 
28 // TODO(b/30467474) Once GPU outfeed implementation settles, consider
29 // folding back the cpu and gpu outfeed implementations into a generic
30 // one if possible.
31 
32 // Defines a buffer holding the destination for an outfeed in host memory and a
33 // notification when that triggers when the transfer is done.
34 class OutfeedBuffer {
35  public:
OutfeedBuffer(int64_t length)36   explicit OutfeedBuffer(int64_t length) : length_(length) {}
37 
38   // Waits for the device transfer to be finished.
WaitUntilAvailable()39   void WaitUntilAvailable() { done_.WaitForNotification(); }
40 
length()41   int64_t length() const { return length_; }
set_destination(std::unique_ptr<MutableBorrowingLiteral> destination)42   void set_destination(std::unique_ptr<MutableBorrowingLiteral> destination) {
43     destination_ = std::move(destination);
44   }
destination()45   MutableBorrowingLiteral* destination() { return destination_.get(); }
46 
47   // Callback to signal that this buffer is consumed.
Done()48   void Done() { done_.Notify(); }
49 
50  private:
51   std::unique_ptr<MutableBorrowingLiteral> destination_;
52   const int64_t length_;
53   tensorflow::Notification done_;
54 };
55 
56 // Manages a thread-safe queue of buffers. The buffers are supposed to be
57 // produced by the transfer manager and consumed by the device.
58 class OutfeedManager
59     : public XfeedQueue<ShapeTree<std::unique_ptr<OutfeedBuffer>>*> {
60  public:
61   Status TransferLiteralFromOutfeed(se::StreamExecutor* executor,
62                                     MutableBorrowingLiteral literal);
63 };
64 
65 // Returns the GPU outfeed manager for the given stream executor.
66 OutfeedManager* GetOrCreateOutfeedManager(se::StreamExecutor* executor);
67 
68 }  // namespace gpu
69 }  // namespace xla
70 
71 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_OUTFEED_MANAGER_H_
72