xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/shaped_buffer.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
18 
19 #include <memory>
20 #include <ostream>
21 #include <string>
22 
23 #include "absl/types/span.h"
24 #include "tensorflow/compiler/xla/shape_tree.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
29 #include "tensorflow/stream_executor/device_memory_allocator.h"
30 
31 namespace xla {
32 
33 class ScopedShapedBuffer;
34 
35 // Class which encapsulates a buffer or set of buffers containing data of a
36 // particular XLA shape.
37 class ShapedBuffer {
38  public:
39   // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The
40   // shape of the data on the host and the device may differ because the device
41   // may have a different representation for different data types. Therefore,
42   // both the on-host and on-device shape are required. The on-device shape
43   // determines the number of device allocations (DeviceMemoryBase) held by the
44   // ShapedBuffer.
45   ShapedBuffer(Shape on_device_shape, int device_ordinal);
46 
47   // TODO(b/170310047): remove this overload.
48   ShapedBuffer(Shape on_host_shape, Shape on_device_shape, int device_ordinal);
49 
50   // Movable, but not copyable.
51   ShapedBuffer(ShapedBuffer&& s);
52   ShapedBuffer& operator=(ShapedBuffer&&);
53   ShapedBuffer(const ShapedBuffer&) = delete;
54   ShapedBuffer& operator=(const ShapedBuffer&) = delete;
55 
56   // Prevent (some forms of) accidental object slicing.
57   ShapedBuffer(const ScopedShapedBuffer&) = delete;
58   ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete;
59 
60   virtual ~ShapedBuffer();
61 
62   // Returns the shape of the on-host representation of the data held by this
63   // ShapedBuffer.
on_host_shape()64   const Shape& on_host_shape() const { return on_host_shape_; }
65 
66   // Returns the shape of the on-device representation of the data held by this
67   // ShapedBuffer.
on_device_shape()68   const Shape& on_device_shape() const { return on_device_shape_; }
69 
device_ordinal()70   int device_ordinal() const { return device_ordinal_; }
71 
72   // Return the root buffer of the shape (shape index {}).
root_buffer()73   const se::DeviceMemoryBase& root_buffer() const {
74     return buffer(/*index=*/{});
75   }
76 
77   // Returns the buffer at the given shape index where index is defined as in
78   // ShapeUtil::GetSubshape.
buffer(const ShapeIndex & index)79   const se::DeviceMemoryBase& buffer(const ShapeIndex& index) const {
80     return buffers_.element(index);
81   }
82 
83   // Sets the device memory buffer at the given index.
set_buffer(const se::DeviceMemoryBase & buffer,const ShapeIndex & index)84   void set_buffer(const se::DeviceMemoryBase& buffer, const ShapeIndex& index) {
85     *buffers_.mutable_element(index) = buffer;
86   }
87 
88   // Sets all buffers.
89   //
90   // Precondition: buffers.shape == on_device_shape_
set_buffers(ShapeTree<se::DeviceMemoryBase> buffers)91   void set_buffers(ShapeTree<se::DeviceMemoryBase> buffers) {
92     CHECK(ShapeUtil::Equal(buffers.shape(), on_device_shape_));
93     buffers_ = std::move(buffers);
94     buffers_.replace_shape_ptr(on_device_shape_);
95   }
96 
97   // Reset the shape of this shaped buffer and underlying buffer structure.
98   //
99   // Precondition: EqualStructure(this->on_device_shape_, on_device_shape).
set_shapes(const Shape & on_device_shape)100   void set_shapes(const Shape& on_device_shape) {
101     CHECK(ShapeUtil::EqualStructure(on_device_shape, on_device_shape_))
102         << "Structures are not the same. new: " << on_device_shape
103         << ", old: " << on_device_shape_;
104     on_host_shape_ = ShapeUtil::DeviceShapeToHostShape(on_device_shape);
105     on_device_shape_ = on_device_shape;
106     buffers_.replace_shape_ptr(on_device_shape_);
107   }
108   // TODO(b/170310047): remove this overload.
set_shapes(const Shape & on_host_shape,const Shape & on_device_shape)109   void set_shapes(const Shape& on_host_shape, const Shape& on_device_shape) {
110     set_shapes(on_device_shape);
111   }
112 
113   // Returns the underlying ShapeTree containing all the device addresses in the
114   // ShapedBuffer.
buffers()115   const ShapeTree<se::DeviceMemoryBase>& buffers() const { return buffers_; }
buffers()116   ShapeTree<se::DeviceMemoryBase>& buffers() { return buffers_; }
117 
118   StatusOr<ShapedBuffer> SubShapedBuffer(const ShapeIndex& index) const;
119 
120   // Set all device memory pointers in the object to null.
121   void clear();
122 
123   std::string ToString() const;
124 
125  protected:
126   Shape on_host_shape_;
127 
128   // The shape of the data on the device.
129   Shape on_device_shape_;
130 
131   // The device the memory is allocated on.
132   int device_ordinal_;
133 
134   // The tree of device buffers. Its shape is on_device_shape().
135   ShapeTree<se::DeviceMemoryBase> buffers_;
136 };
137 
138 std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer);
139 
140 // ScopedShapedBuffer takes allocated buffers as inputs, and deallocates on
141 // destruction. This class represents an owning wrapper around `ShapedBuffer`.
142 //
143 // TODO(timshen): Remove inheritance between ScopedShapedBuffer and
144 // ShapedBuffer.  There should never be a need to consider a ScopedShapedBuffer
145 // as a ShapedBuffer, because in that case we should just be able to pass around
146 // our ShapeTree<DeviceMemoryBase>.  Inheritance only adds complexity.  See
147 // discussion in cl/192849370.
148 class ScopedShapedBuffer : public ShapedBuffer {
149  public:
150   // Creates a ScopedShapedBuffer with null DeviceMemoryBases at each index.
151   explicit ScopedShapedBuffer(Shape on_device_shape,
152                               se::DeviceMemoryAllocator* allocator,
153                               int device_ordinal);
154   // TODO(b/170310047): remove this overload.
155   explicit ScopedShapedBuffer(Shape on_host_shape, Shape on_device_shape,
156                               se::DeviceMemoryAllocator* allocator,
157                               int device_ordinal);
158 
159   // Create a ScopedShapedBuffer by taking over the memory from the incoming
160   // ShapedBuffer.
161   explicit ScopedShapedBuffer(ShapedBuffer shaped_buffer,
162                               se::DeviceMemoryAllocator* allocator);
163 
164   // Movable, but not copyable.
165   ScopedShapedBuffer(ScopedShapedBuffer&& s);
166   ScopedShapedBuffer& operator=(ScopedShapedBuffer&&);
167   ScopedShapedBuffer(const ScopedShapedBuffer&) = delete;
168   ScopedShapedBuffer& operator=(const ScopedShapedBuffer&) = delete;
169 
170   // All buffers in the shape are deallocated on destruction.
171   ~ScopedShapedBuffer() override;
172 
173   // Return the allocator used to allocate the device memory held in this
174   // ScopedShapedBuffer.
memory_allocator()175   se::DeviceMemoryAllocator* memory_allocator() const { return allocator_; }
176 
177   // Sets the device memory buffer at the given index.
178   //
179   // If the given buffer's device memory is non-null, its device_ordinal and
180   // allocator must match those in `this`.
set_buffer(se::OwningDeviceMemory buffer,const ShapeIndex & index)181   void set_buffer(se::OwningDeviceMemory buffer, const ShapeIndex& index) {
182     if (!buffer.is_null()) {
183       CHECK_EQ(buffer.device_ordinal(), device_ordinal());
184       CHECK_EQ(buffer.allocator(), allocator_);
185       *buffers_.mutable_element(index) = buffer.Release();
186     } else {
187       *buffers_.mutable_element(index) = se::DeviceMemoryBase();
188     }
189   }
190 
191   // Like unique_ptr::release(), creates and returns a regular ShapedBuffer from
192   // this ScopedShapedBuffer, without freeing any of the associated memory.
193   //
194   // It's the caller's job to ensure that the memory contained therein is freed.
195   [[nodiscard]] ShapedBuffer release();
196 
197   // Extracts the sub-tree rooted at 'index' and returns a ScopedShapedBuffer
198   // that holds ownership of the subtree. Sets the buffers corresponding to the
199   // subtree to null in 'this'.
200   ScopedShapedBuffer TakeSubTree(ShapeIndexView index);
201 
202  protected:
203   void Deallocate();
204 
205   se::DeviceMemoryAllocator* allocator_;
206 };
207 
208 }  // namespace xla
209 
210 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_SHAPED_BUFFER_H_
211