xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/shaped_buffer_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "tensorflow/compiler/xla/service/platform_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/test.h"
24 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
25 #include "tensorflow/core/platform/test_benchmark.h"
26 #include "tensorflow/core/util/ptr_util.h"
27 #include "tensorflow/stream_executor/device_memory_allocator.h"
28 
29 namespace xla {
30 namespace {
31 
TEST(ShapedBufferTest,ScopedShapeBufferAsShapedBufferB71629047)32 TEST(ShapedBufferTest, ScopedShapeBufferAsShapedBufferB71629047) {
33   TF_ASSERT_OK_AND_ASSIGN(auto* platform,
34                           xla::PlatformUtil::GetDefaultPlatform());
35   TF_ASSERT_OK_AND_ASSIGN(auto executors,
36                           xla::PlatformUtil::GetStreamExecutors(platform));
37   xla::se::StreamExecutorMemoryAllocator allocator(platform, executors);
38   const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {});
39   const int kDeviceOrdinal = 0;
40   auto scoped_buffer = std::make_unique<xla::ScopedShapedBuffer>(
41       shape, shape, &allocator, kDeviceOrdinal);
42   std::unique_ptr<xla::ShapedBuffer> buffer = std::move(scoped_buffer);
43   buffer = nullptr;
44 }
45 
46 class TestAllocator : public se::DeviceMemoryAllocator {
47  public:
TestAllocator()48   TestAllocator()
49       : se::DeviceMemoryAllocator(
50             PlatformUtil::GetDefaultPlatform().ValueOrDie()) {}
51 
~TestAllocator()52   ~TestAllocator() override {
53     if (!allocations_.empty()) {
54       ADD_FAILURE() << "Some allocations not freed!";
55     }
56   }
57 
58   // Pull in two-arg overload of Allocate.
59   using se::DeviceMemoryAllocator::Allocate;
60 
Allocate(int device_ordinal,uint64_t size,bool,int64_t)61   StatusOr<se::OwningDeviceMemory> Allocate(int device_ordinal, uint64_t size,
62                                             bool /*retry_on_failure*/,
63                                             int64_t /*memory_space*/) override {
64     // By contract, we must return null if size == 0.
65     if (size == 0) {
66       return se::OwningDeviceMemory();
67     }
68     void* buf = malloc(size);
69     allocations_.insert({device_ordinal, buf});
70     return se::OwningDeviceMemory(se::DeviceMemoryBase(buf, size),
71                                   device_ordinal, this);
72   }
73 
Deallocate(int device_ordinal,se::DeviceMemoryBase mem)74   Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override {
75     if (mem.is_null()) {
76       return OkStatus();
77     }
78 
79     auto it = allocations_.find({device_ordinal, mem.opaque()});
80     if (it == allocations_.end()) {
81       ADD_FAILURE() << "Allocation not found (double free?)";
82     } else {
83       free(mem.opaque());
84       allocations_.erase(it);
85     }
86     return OkStatus();
87   }
88 
AllowsAsynchronousDeallocation() const89   bool AllowsAsynchronousDeallocation() const override { return false; }
90 
GetStream(int device_ordinal)91   StatusOr<se::Stream*> GetStream(int device_ordinal) override {
92     LOG(FATAL) << "Not implemented";
93   }
94 
95  private:
96   std::set<std::pair</*device_ordinal*/ int64_t, void*>> allocations_;
97 };
98 
TEST(ScopedShapedBufferTest,TestMoveAssignmentOperator)99 TEST(ScopedShapedBufferTest, TestMoveAssignmentOperator) {
100   Shape s = ShapeUtil::MakeShape(F32, {1});
101   TestAllocator allocator;
102   ScopedShapedBuffer sb1(s, &allocator, /*device_ordinal=*/0);
103   sb1.set_buffer(
104       allocator.Allocate(/*device_ordinal=*/0, /*size=*/42).ValueOrDie(),
105       /*index=*/{});
106 
107   ScopedShapedBuffer sb2(s, &allocator, /*device_ordinal=*/1);
108   sb2.set_buffer(
109       allocator.Allocate(/*device_ordinal=*/1, /*size=*/10).ValueOrDie(),
110       /*index=*/{});
111 
112   sb1 = std::move(sb2);
113 
114   // TestAllocator's destructor checks that all memory was freed.
115 }
116 
TEST(ScopedShapedBufferTest,TestTakeSubTree)117 TEST(ScopedShapedBufferTest, TestTakeSubTree) {
118   TestAllocator allocator;
119 
120   Shape s = ShapeUtil::MakeShape(F32, {1});
121   s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(2, s));
122   s = xla::ShapeUtil::MakeTupleShape(std::vector<xla::Shape>(3, s));
123 
124   ScopedShapedBuffer sb(s, &allocator, /*device_ordinal=*/0);
125   sb.buffers().ForEachMutableElement(
126       [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
127         TF_ASSERT_OK_AND_ASSIGN(
128             se::OwningDeviceMemory m,
129             allocator.Allocate(/*device_ordinal=*/0, /*size=*/77));
130         *buffer = m.Release();
131       });
132   ShapeTree<se::DeviceMemoryBase> buffers = sb.buffers();
133 
134   // Takes a subtree out of 'sb', and verifies the buffers are as expected.
135   xla::ShapeIndex subtree_index = {1};
136   ScopedShapedBuffer output = sb.TakeSubTree(subtree_index);
137 
138   output.buffers().ForEachElement([&](const xla::ShapeIndex& sub_index,
139                                       const se::DeviceMemoryBase& buffer) {
140     xla::ShapeIndex orig_index = subtree_index;
141     for (int i : sub_index) {
142       orig_index.push_back(i);
143     }
144     EXPECT_TRUE(buffers.find(orig_index)->second.IsSameAs(buffer));
145   });
146   sb.buffers().ForEachElement([&](const xla::ShapeIndex& index,
147                                   const se::DeviceMemoryBase& buffer) {
148     if ((index.size() >= subtree_index.size()) &&
149         ShapeIndexView(index).first(subtree_index.size()) == subtree_index) {
150       EXPECT_TRUE(buffer.is_null());
151     } else {
152       EXPECT_TRUE(buffers.find(index)->second.IsSameAs(buffer));
153     }
154   });
155 }
156 
TEST(ScopedShapedBufferTest,TestSubShapeTree)157 TEST(ScopedShapedBufferTest, TestSubShapeTree) {
158   Shape array_shape = ShapeUtil::MakeShape(F32, {1});
159   Shape tuple_shape =
160       xla::ShapeUtil::MakeTupleShape({array_shape, array_shape});
161   TestAllocator allocator;
162   ScopedShapedBuffer sb(tuple_shape, &allocator, /*device_ordinal=*/0);
163   sb.buffers().ForEachMutableElement(
164       [&](const xla::ShapeIndex& index, se::DeviceMemoryBase* buffer) {
165         TF_ASSERT_OK_AND_ASSIGN(
166             se::OwningDeviceMemory m,
167             allocator.Allocate(/*device_ordinal=*/0, /*size=*/32));
168         *buffer = m.Release();
169       });
170   auto ssb_statusor = sb.SubShapedBuffer({1});
171   ASSERT_TRUE(ssb_statusor.ok());
172   auto ssb = std::move(ssb_statusor).value();
173   EXPECT_EQ(ssb.on_host_shape(), array_shape);
174   EXPECT_EQ(ssb.on_device_shape(), array_shape);
175 }
176 
177 // Test TakeSubTree with different depths (depth of ShapeTree) and fan-outs
178 // (cardinality of each non-leaf node's children).
BM_TakeSubTree(::testing::benchmark::State & state)179 void BM_TakeSubTree(::testing::benchmark::State& state) {
180   const int depth = state.range(0);
181   const int fan_out = state.range(1);
182 
183   TestAllocator allocator;
184   xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {32, 64, 128});
185   for (int i = 0; i < depth; ++i) {
186     std::vector<xla::Shape> shapes(fan_out, shape);
187     shape = xla::ShapeUtil::MakeTupleShape(shapes);
188   }
189   xla::ScopedShapedBuffer shaped_buffer(shape, /*allocator=*/&allocator,
190                                         /*device_ordinal=*/0);
191   for (auto s : state) {
192     // Extract a buffer from approximately the middle of the first level of the
193     // tree.
194     (void)shaped_buffer.TakeSubTree(/*index=*/{fan_out / 2}).release();
195   }
196 }
197 
198 BENCHMARK(BM_TakeSubTree)
199     ->ArgPair(1, 4)
200     ->ArgPair(1, 8)
201     ->ArgPair(1, 32)
202     ->ArgPair(1, 64)
203     ->ArgPair(1, 128)
204     ->ArgPair(1, 256)
205     ->ArgPair(1, 512)
206     ->ArgPair(2, 4)
207     ->ArgPair(2, 8)
208     ->ArgPair(2, 32)
209     ->ArgPair(2, 64)
210     ->ArgPair(2, 128);
211 
212 }  // anonymous namespace
213 }  // namespace xla
214