1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
16
17 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
18
19 namespace stream_executor {
20 namespace test_util {
21
22 /*** Functions for creating SP_StreamExecutor ***/
Allocate(const SP_Device * const device,uint64_t size,int64_t memory_space,SP_DeviceMemoryBase * const mem)23 void Allocate(const SP_Device* const device, uint64_t size,
24 int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
Deallocate(const SP_Device * const device,SP_DeviceMemoryBase * const mem)25 void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
26 }
HostMemoryAllocate(const SP_Device * const device,uint64_t size)27 void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) {
28 return nullptr;
29 }
HostMemoryDeallocate(const SP_Device * const device,void * mem)30 void HostMemoryDeallocate(const SP_Device* const device, void* mem) {}
GetAllocatorStats(const SP_Device * const device,SP_AllocatorStats * const stats)31 TF_Bool GetAllocatorStats(const SP_Device* const device,
32 SP_AllocatorStats* const stats) {
33 return true;
34 }
DeviceMemoryUsage(const SP_Device * const device,int64_t * const free,int64_t * const total)35 TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free,
36 int64_t* const total) {
37 return true;
38 }
CreateStream(const SP_Device * const device,SP_Stream * stream,TF_Status * const status)39 void CreateStream(const SP_Device* const device, SP_Stream* stream,
40 TF_Status* const status) {
41 *stream = nullptr;
42 }
DestroyStream(const SP_Device * const device,SP_Stream stream)43 void DestroyStream(const SP_Device* const device, SP_Stream stream) {}
CreateStreamDependency(const SP_Device * const device,SP_Stream dependent,SP_Stream other,TF_Status * const status)44 void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent,
45 SP_Stream other, TF_Status* const status) {}
GetStreamStatus(const SP_Device * const device,SP_Stream stream,TF_Status * const status)46 void GetStreamStatus(const SP_Device* const device, SP_Stream stream,
47 TF_Status* const status) {}
CreateEvent(const SP_Device * const device,SP_Event * event,TF_Status * const status)48 void CreateEvent(const SP_Device* const device, SP_Event* event,
49 TF_Status* const status) {
50 *event = nullptr;
51 }
DestroyEvent(const SP_Device * const device,SP_Event event)52 void DestroyEvent(const SP_Device* const device, SP_Event event) {}
GetEventStatus(const SP_Device * const device,SP_Event event)53 SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) {
54 return SE_EVENT_UNKNOWN;
55 }
RecordEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)56 void RecordEvent(const SP_Device* const device, SP_Stream stream,
57 SP_Event event, TF_Status* const status) {}
WaitForEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)58 void WaitForEvent(const SP_Device* const device, SP_Stream stream,
59 SP_Event event, TF_Status* const status) {}
CreateTimer(const SP_Device * const device,SP_Timer * timer,TF_Status * const status)60 void CreateTimer(const SP_Device* const device, SP_Timer* timer,
61 TF_Status* const status) {}
DestroyTimer(const SP_Device * const device,SP_Timer timer)62 void DestroyTimer(const SP_Device* const device, SP_Timer timer) {}
StartTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)63 void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
64 TF_Status* const status) {}
StopTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)65 void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
66 TF_Status* const status) {}
MemcpyDToH(const SP_Device * const device,SP_Stream stream,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)67 void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst,
68 const SP_DeviceMemoryBase* const device_src, uint64_t size,
69 TF_Status* const status) {}
MemcpyHToD(const SP_Device * const device,SP_Stream stream,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)70 void MemcpyHToD(const SP_Device* const device, SP_Stream stream,
71 SP_DeviceMemoryBase* const device_dst, const void* host_src,
72 uint64_t size, TF_Status* const status) {}
SyncMemcpyDToH(const SP_Device * const device,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)73 void SyncMemcpyDToH(const SP_Device* const device, void* host_dst,
74 const SP_DeviceMemoryBase* const device_src, uint64_t size,
75 TF_Status* const status) {}
SyncMemcpyHToD(const SP_Device * const device,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)76 void SyncMemcpyHToD(const SP_Device* const device,
77 SP_DeviceMemoryBase* const device_dst, const void* host_src,
78 uint64_t size, TF_Status* const status) {}
BlockHostForEvent(const SP_Device * const device,SP_Event event,TF_Status * const status)79 void BlockHostForEvent(const SP_Device* const device, SP_Event event,
80 TF_Status* const status) {}
SynchronizeAllActivity(const SP_Device * const device,TF_Status * const status)81 void SynchronizeAllActivity(const SP_Device* const device,
82 TF_Status* const status) {}
HostCallback(const SP_Device * const device,SP_Stream stream,SE_StatusCallbackFn const callback_fn,void * const callback_arg)83 TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream,
84 SE_StatusCallbackFn const callback_fn,
85 void* const callback_arg) {
86 return true;
87 }
88
MemZero(const SP_Device * device,SP_Stream stream,SP_DeviceMemoryBase * location,uint64_t size,TF_Status * status)89 void MemZero(const SP_Device* device, SP_Stream stream,
90 SP_DeviceMemoryBase* location, uint64_t size, TF_Status* status) {}
91
Memset(const SP_Device * device,SP_Stream stream,SP_DeviceMemoryBase * location,uint8_t pattern,uint64_t size,TF_Status * status)92 void Memset(const SP_Device* device, SP_Stream stream,
93 SP_DeviceMemoryBase* location, uint8_t pattern, uint64_t size,
94 TF_Status* status) {}
95
Memset32(const SP_Device * device,SP_Stream stream,SP_DeviceMemoryBase * location,uint32_t pattern,uint64_t size,TF_Status * status)96 void Memset32(const SP_Device* device, SP_Stream stream,
97 SP_DeviceMemoryBase* location, uint32_t pattern, uint64_t size,
98 TF_Status* status) {}
99
PopulateDefaultStreamExecutor(SP_StreamExecutor * se)100 void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
101 *se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
102 se->allocate = Allocate;
103 se->deallocate = Deallocate;
104 se->host_memory_allocate = HostMemoryAllocate;
105 se->host_memory_deallocate = HostMemoryDeallocate;
106 se->get_allocator_stats = GetAllocatorStats;
107 se->device_memory_usage = DeviceMemoryUsage;
108 se->create_stream = CreateStream;
109 se->destroy_stream = DestroyStream;
110 se->create_stream_dependency = CreateStreamDependency;
111 se->get_stream_status = GetStreamStatus;
112 se->create_event = CreateEvent;
113 se->destroy_event = DestroyEvent;
114 se->get_event_status = GetEventStatus;
115 se->record_event = RecordEvent;
116 se->wait_for_event = WaitForEvent;
117 se->create_timer = CreateTimer;
118 se->destroy_timer = DestroyTimer;
119 se->start_timer = StartTimer;
120 se->stop_timer = StopTimer;
121 se->memcpy_dtoh = MemcpyDToH;
122 se->memcpy_htod = MemcpyHToD;
123 se->sync_memcpy_dtoh = SyncMemcpyDToH;
124 se->sync_memcpy_htod = SyncMemcpyHToD;
125 se->block_host_for_event = BlockHostForEvent;
126 se->synchronize_all_activity = SynchronizeAllActivity;
127 se->host_callback = HostCallback;
128 se->mem_zero = MemZero;
129 se->memset = Memset;
130 se->memset32 = Memset32;
131 }
132
PopulateDefaultDeviceFns(SP_DeviceFns * device_fns)133 void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
134 *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
135 }
136
137 /*** Functions for creating SP_TimerFns ***/
Nanoseconds(SP_Timer timer)138 uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; }
139
PopulateDefaultTimerFns(SP_TimerFns * timer_fns)140 void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
141 timer_fns->nanoseconds = Nanoseconds;
142 }
143
144 /*** Functions for creating SP_Platform ***/
CreateTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns,TF_Status * status)145 void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns,
146 TF_Status* status) {
147 TF_SetStatus(status, TF_OK, "");
148 PopulateDefaultTimerFns(timer_fns);
149 }
DestroyTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns)150 void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
151
CreateStreamExecutor(const SP_Platform * platform,SE_CreateStreamExecutorParams * params,TF_Status * status)152 void CreateStreamExecutor(const SP_Platform* platform,
153 SE_CreateStreamExecutorParams* params,
154 TF_Status* status) {
155 TF_SetStatus(status, TF_OK, "");
156 PopulateDefaultStreamExecutor(params->stream_executor);
157 }
DestroyStreamExecutor(const SP_Platform * platform,SP_StreamExecutor * se)158 void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) {
159 }
GetDeviceCount(const SP_Platform * platform,int * device_count,TF_Status * status)160 void GetDeviceCount(const SP_Platform* platform, int* device_count,
161 TF_Status* status) {
162 TF_SetStatus(status, TF_OK, "");
163 *device_count = kDeviceCount;
164 }
CreateDevice(const SP_Platform * platform,SE_CreateDeviceParams * params,TF_Status * status)165 void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params,
166 TF_Status* status) {
167 TF_SetStatus(status, TF_OK, "");
168 params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
169 }
DestroyDevice(const SP_Platform * platform,SP_Device * device)170 void DestroyDevice(const SP_Platform* platform, SP_Device* device) {}
171
CreateDeviceFns(const SP_Platform * platform,SE_CreateDeviceFnsParams * params,TF_Status * status)172 void CreateDeviceFns(const SP_Platform* platform,
173 SE_CreateDeviceFnsParams* params, TF_Status* status) {
174 TF_SetStatus(status, TF_OK, "");
175 params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
176 }
DestroyDeviceFns(const SP_Platform * platform,SP_DeviceFns * device_fns)177 void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {}
178
PopulateDefaultPlatform(SP_Platform * platform,SP_PlatformFns * platform_fns)179 void PopulateDefaultPlatform(SP_Platform* platform,
180 SP_PlatformFns* platform_fns) {
181 *platform = {SP_PLATFORM_STRUCT_SIZE};
182 platform->name = kDeviceName;
183 platform->type = kDeviceType;
184 platform_fns->get_device_count = GetDeviceCount;
185 platform_fns->create_device = CreateDevice;
186 platform_fns->destroy_device = DestroyDevice;
187 platform_fns->create_device_fns = CreateDeviceFns;
188 platform_fns->destroy_device_fns = DestroyDeviceFns;
189 platform_fns->create_stream_executor = CreateStreamExecutor;
190 platform_fns->destroy_stream_executor = DestroyStreamExecutor;
191 platform_fns->create_timer_fns = CreateTimerFns;
192 platform_fns->destroy_timer_fns = DestroyTimerFns;
193 }
194
195 /*** Functions for creating SE_PlatformRegistrationParams ***/
DestroyPlatform(SP_Platform * platform)196 void DestroyPlatform(SP_Platform* platform) {}
DestroyPlatformFns(SP_PlatformFns * platform_fns)197 void DestroyPlatformFns(SP_PlatformFns* platform_fns) {}
198
PopulateDefaultPlatformRegistrationParams(SE_PlatformRegistrationParams * const params)199 void PopulateDefaultPlatformRegistrationParams(
200 SE_PlatformRegistrationParams* const params) {
201 PopulateDefaultPlatform(params->platform, params->platform_fns);
202 params->destroy_platform = DestroyPlatform;
203 params->destroy_platform_fns = DestroyPlatformFns;
204 }
205
206 } // namespace test_util
207 } // namespace stream_executor
208