xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/tests/local_client_test_base.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
17 #define TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
18 
19 #include <map>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "absl/types/span.h"
26 #include "tensorflow/compiler/xla/client/client_library.h"
27 #include "tensorflow/compiler/xla/client/local_client.h"
28 #include "tensorflow/compiler/xla/client/xla_computation.h"
29 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
30 #include "tensorflow/compiler/xla/service/local_service.h"
31 #include "tensorflow/compiler/xla/service/platform_util.h"
32 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
33 #include "tensorflow/compiler/xla/service/transfer_manager.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
36 #include "tensorflow/compiler/xla/tests/manifest_checking_test.h"
37 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
40 #include "tensorflow/stream_executor/device_memory_allocator.h"
41 
42 namespace xla {
43 
44 class TestAllocator : public se::StreamExecutorMemoryAllocator {
45  public:
TestAllocator(se::Platform * platform)46   explicit TestAllocator(se::Platform* platform)
47       : se::StreamExecutorMemoryAllocator(
48             platform, PlatformUtil::GetStreamExecutors(platform).ValueOrDie()) {
49   }
50 
51   StatusOr<se::OwningDeviceMemory> Allocate(int device_ordinal, uint64_t size,
52                                             bool retry_on_failure,
53                                             int64_t memory_space) override;
54   Status Deallocate(int device_ordinal, se::DeviceMemoryBase mem) override;
55 
56   // Return the number of allocations that have been performed.
57   int64_t allocation_count() const;
58   int64_t allocation_count(int device_ordinal) const;
59 
60   // Return the number of deallocations that have been performed.
61   int64_t deallocation_count() const;
62   int64_t deallocation_count(int device_ordinal) const;
63 
64  private:
65   mutable absl::Mutex count_mutex_;
66 
67   // Global counts of allocations and deallocations.
68   int64_t allocation_count_ ABSL_GUARDED_BY(count_mutex_) = 0;
69   int64_t deallocation_count_ ABSL_GUARDED_BY(count_mutex_) = 0;
70 
71   // Per-device counts of allocations and deallocations.
72   std::map<int, int64_t> device_allocation_count_ ABSL_GUARDED_BY(count_mutex_);
73   std::map<int, int64_t> device_deallocation_count_
74       ABSL_GUARDED_BY(count_mutex_);
75 };
76 
77 // A base class for tests which exercise the LocalClient interface.
78 class LocalClientTestBase : public ManifestCheckingTest {
79  protected:
80   struct EigenThreadPoolWrapper;
81   explicit LocalClientTestBase(se::Platform* platform = nullptr);
82   virtual ~LocalClientTestBase();
83 
84   static TestAllocator* GetOrCreateAllocator(se::Platform* platform);
85 
86   // Copy the given literal onto the default device and return a
87   // ScopedShapedBuffer. Convenience wrapper around
88   // LocalClient::LiteralToShapedBuffer.
89   ScopedShapedBuffer LiteralToShapedBuffer(const Literal& literal);
90 
91   // Construct and return a literal containing the array represented by
92   // shaped_buffer.
93   Literal ShapedBufferToLiteral(const ShapedBuffer& shaped_buffer);
94 
95   // Execute the given computation on the local client. With and without
96   // options.
97   StatusOr<ScopedShapedBuffer> ExecuteLocally(
98       const XlaComputation& computation,
99       absl::Span<const ShapedBuffer* const> arguments);
100   StatusOr<ScopedShapedBuffer> ExecuteLocally(
101       const XlaComputation& computation,
102       absl::Span<const ShapedBuffer* const> arguments,
103       const ExecutableBuildOptions& build_options,
104       const ExecutableRunOptions& run_options);
105 
106   ScopedShapedBuffer ExecuteLocallyOrDie(
107       const XlaComputation& computation,
108       absl::Span<const ShapedBuffer* const> arguments);
109   ScopedShapedBuffer ExecuteLocallyOrDie(
110       const XlaComputation& computation,
111       absl::Span<const ShapedBuffer* const> arguments,
112       const ExecutableBuildOptions& build_options,
113       const ExecutableRunOptions& run_options);
114 
115   // Parses the given string and returns module as a VerifiedHloModule.
116   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
117       absl::string_view hlo_text);
118   StatusOr<std::unique_ptr<VerifiedHloModule>> ParseAndReturnVerifiedModule(
119       absl::string_view hlo_text, const HloModuleConfig& config);
120 
121   // Returns a default set of execute options.
122   ExecutableBuildOptions DefaultExecutableBuildOptions() const;
123 
124   // Returns a default set of execute options, configured to use allocator_
125   // as the allocator.
126   ExecutableRunOptions DefaultExecutableRunOptions() const;
127 
TestName()128   std::string TestName() const {
129     return ::testing::UnitTest::GetInstance()->current_test_info()->name();
130   }
131 
132   // The allocator must live as long as the service, which lives until the end
133   // of the process. So make the allocator static.
134   static TestAllocator* allocator_;
135 
136   se::StreamExecutor* stream_executor_;
137   TransferManager* transfer_manager_;
138 
139   LocalClient* local_client_;
140 
141   std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
142 };
143 
144 }  // namespace xla
145 
146 #endif  // TENSORFLOW_COMPILER_XLA_TESTS_LOCAL_CLIENT_TEST_BASE_H_
147