xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/xfeed_manager.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header declares the abstract class for the infeed manager that
17 // is used by the CPU runtime to transfer buffers into an executing
18 // CPU computation, e.g., to feed data into a while loop.
19 
20 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_
21 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_
22 
23 #include <deque>
24 
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/shape.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 
31 namespace xla {
32 namespace cpu {
33 namespace runtime {
34 
35 // Abstract class defining an infeed buffer that is passed to the
36 // runtime by a client. The client manages the storage of the buffer.
37 class XfeedBuffer {
38  public:
39   virtual ~XfeedBuffer() = default;
40 
41   virtual int32_t length() = 0;
42   virtual void* data() = 0;
43 
44   // The 'shape' parameter reflects what shape the embedded program was
45   // expecting / producing with respect to this XfeedBuffer. E.g. this will
46   // contain information about the layout of an outfed buffer.
47   virtual void Done(StatusOr<Shape> shape) = 0;
48 };
49 
50 // Reusable component for managing the infeed and outfeed queue state.
51 class XfeedQueueManager {
52  public:
XfeedQueueManager(std::string queue_name)53   XfeedQueueManager(std::string queue_name) : queue_name_(queue_name) {}
54 
55   // Calls the completion callback for any enqueued buffers that have
56   // not been dequeued by the runtime, and empties the
57   // queue. Reset may not be called while a runtime computation is
58   // processing a dequeued buffer. The only safe way to ensure this
59   // condition is to call Reset when no computation is taking place.
60   void Reset();
61 
62   // Adds a sequence of buffers to the queue atomically. buffer->Done will be
63   // called when the buffer will no longer be accessed by the XfeedManager,
64   // either as a result of a call to Reset or because the runtime has dequeued
65   // and used the buffer.
66   void EnqueueBuffersAtomically(absl::Span<XfeedBuffer* const> buffers);
67 
68   // Blocks until the queue is non-empty, then returns the buffer at the head of
69   // the queue. Sets the current buffer to be the returned buffer. It is an
70   // error to call BlockingDequeueBuffer if there is an unreleased current
71   // buffer, i.e., ReleaseCurrentBuffer must be called between calls to
72   // BlockingDequeueBuffer.
73   XfeedBuffer* BlockingDequeueBuffer();
74 
75   // Releases the current buffer, which is the last buffer returned by
76   // BlockingDequeuBuffer and not yet released. length and data must
77   // match the buffer->length() and buffer->data() for the current
78   // buffer.
79   //
80   // 'shape' communicates the shape of the buffer being released. If the program
81   // passed a value that could not be decoded as a shape, 'shape' will be an
82   // error status. In the case of outfeed, this indicates the layout of the
83   // shape that has been outfed. In the case of infeed, this can be used for
84   // sanity checking purposes.
85   void ReleaseCurrentBuffer(int32_t length, void* data, StatusOr<Shape> shape);
86 
87  private:
88   const std::string queue_name_;
89 
90   absl::Mutex mu_;
91 
92   // Condition variable that is signaled every time a buffer is
93   // enqueued to an empty queue.
94   absl::CondVar cv_;
95 
96   // XfeedBuffer* queue contents are not owned, but buffer->Done must
97   // be called when the buffer is no longer needed by the runtime.
98   std::deque<XfeedBuffer*> enqueued_buffers_;
99 
100   // If non-NULL, the buffer that is currently being processed by the
101   // runtime. Not owned.
102   XfeedBuffer* current_buffer_ = nullptr;
103 };
104 
105 // Client-side class used to enqueue infeed buffers.
106 class XfeedManager {
107  public:
108   XfeedManager() = default;
109 
110   void Reset();
111 
infeed()112   XfeedQueueManager* infeed() { return &infeed_; }
outfeed()113   XfeedQueueManager* outfeed() { return &outfeed_; }
114 
115  private:
116   XfeedQueueManager infeed_ = {"infeed"};
117   XfeedQueueManager outfeed_ = {"outfeed"};
118 };
119 
120 int64_t GetByteSizeRequirement(const Shape& shape, int64_t pointer_size);
121 
122 }  // namespace runtime
123 }  // namespace cpu
124 }  // namespace xla
125 
126 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_XFEED_MANAGER_H_
127