1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/stream_executor/stream_executor_test_util.h"
16 
17 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
18 
19 namespace stream_executor {
20 namespace test_util {
21 
22 /*** Functions for creating SP_StreamExecutor ***/
Allocate(const SP_Device * const device,uint64_t size,int64_t memory_space,SP_DeviceMemoryBase * const mem)23 void Allocate(const SP_Device* const device, uint64_t size,
24               int64_t memory_space, SP_DeviceMemoryBase* const mem) {}
Deallocate(const SP_Device * const device,SP_DeviceMemoryBase * const mem)25 void Deallocate(const SP_Device* const device, SP_DeviceMemoryBase* const mem) {
26 }
HostMemoryAllocate(const SP_Device * const device,uint64_t size)27 void* HostMemoryAllocate(const SP_Device* const device, uint64_t size) {
28   return nullptr;
29 }
HostMemoryDeallocate(const SP_Device * const device,void * mem)30 void HostMemoryDeallocate(const SP_Device* const device, void* mem) {}
GetAllocatorStats(const SP_Device * const device,SP_AllocatorStats * const stats)31 TF_Bool GetAllocatorStats(const SP_Device* const device,
32                           SP_AllocatorStats* const stats) {
33   return true;
34 }
DeviceMemoryUsage(const SP_Device * const device,int64_t * const free,int64_t * const total)35 TF_Bool DeviceMemoryUsage(const SP_Device* const device, int64_t* const free,
36                           int64_t* const total) {
37   return true;
38 }
CreateStream(const SP_Device * const device,SP_Stream * stream,TF_Status * const status)39 void CreateStream(const SP_Device* const device, SP_Stream* stream,
40                   TF_Status* const status) {
41   *stream = nullptr;
42 }
DestroyStream(const SP_Device * const device,SP_Stream stream)43 void DestroyStream(const SP_Device* const device, SP_Stream stream) {}
CreateStreamDependency(const SP_Device * const device,SP_Stream dependent,SP_Stream other,TF_Status * const status)44 void CreateStreamDependency(const SP_Device* const device, SP_Stream dependent,
45                             SP_Stream other, TF_Status* const status) {}
GetStreamStatus(const SP_Device * const device,SP_Stream stream,TF_Status * const status)46 void GetStreamStatus(const SP_Device* const device, SP_Stream stream,
47                      TF_Status* const status) {}
CreateEvent(const SP_Device * const device,SP_Event * event,TF_Status * const status)48 void CreateEvent(const SP_Device* const device, SP_Event* event,
49                  TF_Status* const status) {
50   *event = nullptr;
51 }
DestroyEvent(const SP_Device * const device,SP_Event event)52 void DestroyEvent(const SP_Device* const device, SP_Event event) {}
GetEventStatus(const SP_Device * const device,SP_Event event)53 SE_EventStatus GetEventStatus(const SP_Device* const device, SP_Event event) {
54   return SE_EVENT_UNKNOWN;
55 }
RecordEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)56 void RecordEvent(const SP_Device* const device, SP_Stream stream,
57                  SP_Event event, TF_Status* const status) {}
WaitForEvent(const SP_Device * const device,SP_Stream stream,SP_Event event,TF_Status * const status)58 void WaitForEvent(const SP_Device* const device, SP_Stream stream,
59                   SP_Event event, TF_Status* const status) {}
CreateTimer(const SP_Device * const device,SP_Timer * timer,TF_Status * const status)60 void CreateTimer(const SP_Device* const device, SP_Timer* timer,
61                  TF_Status* const status) {}
DestroyTimer(const SP_Device * const device,SP_Timer timer)62 void DestroyTimer(const SP_Device* const device, SP_Timer timer) {}
StartTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)63 void StartTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
64                 TF_Status* const status) {}
StopTimer(const SP_Device * const device,SP_Stream stream,SP_Timer timer,TF_Status * const status)65 void StopTimer(const SP_Device* const device, SP_Stream stream, SP_Timer timer,
66                TF_Status* const status) {}
MemcpyDToH(const SP_Device * const device,SP_Stream stream,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)67 void MemcpyDToH(const SP_Device* const device, SP_Stream stream, void* host_dst,
68                 const SP_DeviceMemoryBase* const device_src, uint64_t size,
69                 TF_Status* const status) {}
MemcpyHToD(const SP_Device * const device,SP_Stream stream,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)70 void MemcpyHToD(const SP_Device* const device, SP_Stream stream,
71                 SP_DeviceMemoryBase* const device_dst, const void* host_src,
72                 uint64_t size, TF_Status* const status) {}
SyncMemcpyDToH(const SP_Device * const device,void * host_dst,const SP_DeviceMemoryBase * const device_src,uint64_t size,TF_Status * const status)73 void SyncMemcpyDToH(const SP_Device* const device, void* host_dst,
74                     const SP_DeviceMemoryBase* const device_src, uint64_t size,
75                     TF_Status* const status) {}
SyncMemcpyHToD(const SP_Device * const device,SP_DeviceMemoryBase * const device_dst,const void * host_src,uint64_t size,TF_Status * const status)76 void SyncMemcpyHToD(const SP_Device* const device,
77                     SP_DeviceMemoryBase* const device_dst, const void* host_src,
78                     uint64_t size, TF_Status* const status) {}
BlockHostForEvent(const SP_Device * const device,SP_Event event,TF_Status * const status)79 void BlockHostForEvent(const SP_Device* const device, SP_Event event,
80                        TF_Status* const status) {}
SynchronizeAllActivity(const SP_Device * const device,TF_Status * const status)81 void SynchronizeAllActivity(const SP_Device* const device,
82                             TF_Status* const status) {}
HostCallback(const SP_Device * const device,SP_Stream stream,SE_StatusCallbackFn const callback_fn,void * const callback_arg)83 TF_Bool HostCallback(const SP_Device* const device, SP_Stream stream,
84                      SE_StatusCallbackFn const callback_fn,
85                      void* const callback_arg) {
86   return true;
87 }
88 
MemZero(const SP_Device * device,SP_Stream stream,SP_DeviceMemoryBase * location,uint64_t size,TF_Status * status)89 void MemZero(const SP_Device* device, SP_Stream stream,
90              SP_DeviceMemoryBase* location, uint64_t size, TF_Status* status) {}
91 
Memset(const SP_Device * device,SP_Stream stream,SP_DeviceMemoryBase * location,uint8_t pattern,uint64_t size,TF_Status * status)92 void Memset(const SP_Device* device, SP_Stream stream,
93             SP_DeviceMemoryBase* location, uint8_t pattern, uint64_t size,
94             TF_Status* status) {}
95 
Memset32(const SP_Device * device,SP_Stream stream,SP_DeviceMemoryBase * location,uint32_t pattern,uint64_t size,TF_Status * status)96 void Memset32(const SP_Device* device, SP_Stream stream,
97               SP_DeviceMemoryBase* location, uint32_t pattern, uint64_t size,
98               TF_Status* status) {}
99 
PopulateDefaultStreamExecutor(SP_StreamExecutor * se)100 void PopulateDefaultStreamExecutor(SP_StreamExecutor* se) {
101   *se = {SP_STREAMEXECUTOR_STRUCT_SIZE};
102   se->allocate = Allocate;
103   se->deallocate = Deallocate;
104   se->host_memory_allocate = HostMemoryAllocate;
105   se->host_memory_deallocate = HostMemoryDeallocate;
106   se->get_allocator_stats = GetAllocatorStats;
107   se->device_memory_usage = DeviceMemoryUsage;
108   se->create_stream = CreateStream;
109   se->destroy_stream = DestroyStream;
110   se->create_stream_dependency = CreateStreamDependency;
111   se->get_stream_status = GetStreamStatus;
112   se->create_event = CreateEvent;
113   se->destroy_event = DestroyEvent;
114   se->get_event_status = GetEventStatus;
115   se->record_event = RecordEvent;
116   se->wait_for_event = WaitForEvent;
117   se->create_timer = CreateTimer;
118   se->destroy_timer = DestroyTimer;
119   se->start_timer = StartTimer;
120   se->stop_timer = StopTimer;
121   se->memcpy_dtoh = MemcpyDToH;
122   se->memcpy_htod = MemcpyHToD;
123   se->sync_memcpy_dtoh = SyncMemcpyDToH;
124   se->sync_memcpy_htod = SyncMemcpyHToD;
125   se->block_host_for_event = BlockHostForEvent;
126   se->synchronize_all_activity = SynchronizeAllActivity;
127   se->host_callback = HostCallback;
128   se->mem_zero = MemZero;
129   se->memset = Memset;
130   se->memset32 = Memset32;
131 }
132 
PopulateDefaultDeviceFns(SP_DeviceFns * device_fns)133 void PopulateDefaultDeviceFns(SP_DeviceFns* device_fns) {
134   *device_fns = {SP_DEVICE_FNS_STRUCT_SIZE};
135 }
136 
137 /*** Functions for creating SP_TimerFns ***/
Nanoseconds(SP_Timer timer)138 uint64_t Nanoseconds(SP_Timer timer) { return timer->timer_id; }
139 
PopulateDefaultTimerFns(SP_TimerFns * timer_fns)140 void PopulateDefaultTimerFns(SP_TimerFns* timer_fns) {
141   timer_fns->nanoseconds = Nanoseconds;
142 }
143 
144 /*** Functions for creating SP_Platform ***/
CreateTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns,TF_Status * status)145 void CreateTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns,
146                     TF_Status* status) {
147   TF_SetStatus(status, TF_OK, "");
148   PopulateDefaultTimerFns(timer_fns);
149 }
DestroyTimerFns(const SP_Platform * platform,SP_TimerFns * timer_fns)150 void DestroyTimerFns(const SP_Platform* platform, SP_TimerFns* timer_fns) {}
151 
CreateStreamExecutor(const SP_Platform * platform,SE_CreateStreamExecutorParams * params,TF_Status * status)152 void CreateStreamExecutor(const SP_Platform* platform,
153                           SE_CreateStreamExecutorParams* params,
154                           TF_Status* status) {
155   TF_SetStatus(status, TF_OK, "");
156   PopulateDefaultStreamExecutor(params->stream_executor);
157 }
DestroyStreamExecutor(const SP_Platform * platform,SP_StreamExecutor * se)158 void DestroyStreamExecutor(const SP_Platform* platform, SP_StreamExecutor* se) {
159 }
GetDeviceCount(const SP_Platform * platform,int * device_count,TF_Status * status)160 void GetDeviceCount(const SP_Platform* platform, int* device_count,
161                     TF_Status* status) {
162   TF_SetStatus(status, TF_OK, "");
163   *device_count = kDeviceCount;
164 }
CreateDevice(const SP_Platform * platform,SE_CreateDeviceParams * params,TF_Status * status)165 void CreateDevice(const SP_Platform* platform, SE_CreateDeviceParams* params,
166                   TF_Status* status) {
167   TF_SetStatus(status, TF_OK, "");
168   params->device->struct_size = {SP_DEVICE_STRUCT_SIZE};
169 }
DestroyDevice(const SP_Platform * platform,SP_Device * device)170 void DestroyDevice(const SP_Platform* platform, SP_Device* device) {}
171 
CreateDeviceFns(const SP_Platform * platform,SE_CreateDeviceFnsParams * params,TF_Status * status)172 void CreateDeviceFns(const SP_Platform* platform,
173                      SE_CreateDeviceFnsParams* params, TF_Status* status) {
174   TF_SetStatus(status, TF_OK, "");
175   params->device_fns->struct_size = {SP_DEVICE_FNS_STRUCT_SIZE};
176 }
DestroyDeviceFns(const SP_Platform * platform,SP_DeviceFns * device_fns)177 void DestroyDeviceFns(const SP_Platform* platform, SP_DeviceFns* device_fns) {}
178 
PopulateDefaultPlatform(SP_Platform * platform,SP_PlatformFns * platform_fns)179 void PopulateDefaultPlatform(SP_Platform* platform,
180                              SP_PlatformFns* platform_fns) {
181   *platform = {SP_PLATFORM_STRUCT_SIZE};
182   platform->name = kDeviceName;
183   platform->type = kDeviceType;
184   platform_fns->get_device_count = GetDeviceCount;
185   platform_fns->create_device = CreateDevice;
186   platform_fns->destroy_device = DestroyDevice;
187   platform_fns->create_device_fns = CreateDeviceFns;
188   platform_fns->destroy_device_fns = DestroyDeviceFns;
189   platform_fns->create_stream_executor = CreateStreamExecutor;
190   platform_fns->destroy_stream_executor = DestroyStreamExecutor;
191   platform_fns->create_timer_fns = CreateTimerFns;
192   platform_fns->destroy_timer_fns = DestroyTimerFns;
193 }
194 
195 /*** Functions for creating SE_PlatformRegistrationParams ***/
DestroyPlatform(SP_Platform * platform)196 void DestroyPlatform(SP_Platform* platform) {}
DestroyPlatformFns(SP_PlatformFns * platform_fns)197 void DestroyPlatformFns(SP_PlatformFns* platform_fns) {}
198 
PopulateDefaultPlatformRegistrationParams(SE_PlatformRegistrationParams * const params)199 void PopulateDefaultPlatformRegistrationParams(
200     SE_PlatformRegistrationParams* const params) {
201   PopulateDefaultPlatform(params->platform, params->platform_fns);
202   params->destroy_platform = DestroyPlatform;
203   params->destroy_platform_fns = DestroyPlatformFns;
204 }
205 
206 }  // namespace test_util
207 }  // namespace stream_executor
208