1 /***************************************************************************
2 * Copyright (C) 2017 Codeplay Software Limited
3 * This Source Code Form is subject to the terms of the Mozilla
4 * Public License v. 2.0. If a copy of the MPL was not distributed
5 * with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
6 *
7 *
8 * SyclMemoryModel.h
9 *
10 * Description:
11 * Interface for SYCL buffers to behave as a non-dereferenceable pointer
12 * Interface for Placeholder accessor to behave as a pointer on both host
13 * and device
14 *
15 * Authors:
16 *
17 * Ruyman Reyes Codeplay Software Ltd.
18 * Mehdi Goli Codeplay Software Ltd.
19 * Vanya Yaneva Codeplay Software Ltd.
20 *
21 **************************************************************************/
22
23 #if defined(EIGEN_USE_SYCL) && \
24 !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
25 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
26
27 #include <CL/sycl.hpp>
28 #ifdef EIGEN_EXCEPTIONS
29 #include <stdexcept>
30 #endif
31 #include <cstddef>
32 #include <queue>
33 #include <set>
34 #include <unordered_map>
35
36 namespace Eigen {
37 namespace TensorSycl {
38 namespace internal {
39
40 using sycl_acc_target = cl::sycl::access::target;
41 using sycl_acc_mode = cl::sycl::access::mode;
42
43 /**
44 * Default values for template arguments
45 */
46 using buffer_data_type_t = uint8_t;
47 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
49
50 /**
51 * PointerMapper
52 * Associates fake pointers with buffers.
53 *
54 */
55 class PointerMapper {
56 public:
57 using base_ptr_t = std::intptr_t;
58
59 /* Structure of a virtual pointer
60 *
61 * |================================================|
62 * | POINTER ADDRESS |
63 * |================================================|
64 */
65 struct virtual_pointer_t {
66 /* Type for the pointers
67 */
68 base_ptr_t m_contents;
69
70 /** Conversions from virtual_pointer_t to
71 * void * should just reinterpret_cast the integer number
72 */
73 operator void *() const { return reinterpret_cast<void *>(m_contents); }
74
75 /**
76 * Convert back to the integer number.
77 */
base_ptr_tvirtual_pointer_t78 operator base_ptr_t() const { return m_contents; }
79
80 /**
81 * Add a certain value to the pointer to create a
82 * new pointer to that offset
83 */
84 virtual_pointer_t operator+(size_t off) { return m_contents + off; }
85
86 /* Numerical order for sorting pointers in containers. */
87 bool operator<(virtual_pointer_t rhs) const {
88 return (static_cast<base_ptr_t>(m_contents) <
89 static_cast<base_ptr_t>(rhs.m_contents));
90 }
91
92 bool operator>(virtual_pointer_t rhs) const {
93 return (static_cast<base_ptr_t>(m_contents) >
94 static_cast<base_ptr_t>(rhs.m_contents));
95 }
96
97 /**
98 * Numerical order for sorting pointers in containers
99 */
100 bool operator==(virtual_pointer_t rhs) const {
101 return (static_cast<base_ptr_t>(m_contents) ==
102 static_cast<base_ptr_t>(rhs.m_contents));
103 }
104
105 /**
106 * Simple forward to the equality overload.
107 */
108 bool operator!=(virtual_pointer_t rhs) const {
109 return !(this->operator==(rhs));
110 }
111
112 /**
113 * Converts a void * into a virtual pointer structure.
114 * Note that this will only work if the void * was
115 * already a virtual_pointer_t, but we have no way of
116 * checking
117 */
virtual_pointer_tvirtual_pointer_t118 virtual_pointer_t(const void *ptr)
119 : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
120
121 /**
122 * Creates a virtual_pointer_t from the given integer
123 * number
124 */
virtual_pointer_tvirtual_pointer_t125 virtual_pointer_t(base_ptr_t u) : m_contents(u){};
126 };
127
128 /* Definition of a null pointer
129 */
130 const virtual_pointer_t null_virtual_ptr = nullptr;
131
132 /**
133 * Whether if a pointer is null or not.
134 * A pointer is nullptr if the value is of null_virtual_ptr
135 */
is_nullptr(virtual_pointer_t ptr)136 static inline bool is_nullptr(virtual_pointer_t ptr) {
137 return (static_cast<void *>(ptr) == nullptr);
138 }
139
140 /* basic type for all buffers
141 */
142 using buffer_t = cl::sycl::buffer_mem;
143
144 /**
145 * Node that stores information about a device allocation.
146 * Nodes are sorted by size to organise a free list of nodes
147 * that can be recovered.
148 */
149 struct pMapNode_t {
150 buffer_t m_buffer;
151 size_t m_size;
152 bool m_free;
153
pMapNode_tpMapNode_t154 pMapNode_t(buffer_t b, size_t size, bool f)
155 : m_buffer{b}, m_size{size}, m_free{f} {
156 m_buffer.set_final_data(nullptr);
157 }
158
159 bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); }
160 };
161
162 /** Storage of the pointer / buffer tree
163 */
164 using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
165
166 /**
167 * Obtain the insertion point in the pointer map for
168 * a pointer of the given size.
169 * \param requiredSize Size attemted to reclaim
170 */
get_insertion_point(size_t requiredSize)171 typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) {
172 typename pointerMap_t::iterator retVal;
173 bool reuse = false;
174 if (!m_freeList.empty()) {
175 // try to re-use an existing block
176 for (auto freeElem : m_freeList) {
177 if (freeElem->second.m_size >= requiredSize) {
178 retVal = freeElem;
179 reuse = true;
180 // Element is not going to be free anymore
181 m_freeList.erase(freeElem);
182 break;
183 }
184 }
185 }
186 if (!reuse) {
187 retVal = std::prev(m_pointerMap.end());
188 }
189 return retVal;
190 }
191
192 /**
193 * Returns an iterator to the node that stores the information
194 * of the given virtual pointer from the given pointer map structure.
195 * If pointer is not found, throws std::out_of_range.
196 * If the pointer map structure is empty, throws std::out_of_range
197 *
198 * \param pMap the pointerMap_t structure storing all the pointers
199 * \param virtual_pointer_ptr The virtual pointer to obtain the node of
200 * \throws std::out:of_range if the pointer is not found or pMap is empty
201 */
get_node(const virtual_pointer_t ptr)202 typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) {
203 if (this->count() == 0) {
204 m_pointerMap.clear();
205 EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n"));
206
207 }
208 if (is_nullptr(ptr)) {
209 m_pointerMap.clear();
210 EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n"));
211 }
212 // The previous element to the lower bound is the node that
213 // holds this memory address
214 auto node = m_pointerMap.lower_bound(ptr);
215 // If the value of the pointer is not the one of the node
216 // then we return the previous one
217 if (node == std::end(m_pointerMap)) {
218 --node;
219 } else if (node->first != ptr) {
220 if (node == std::begin(m_pointerMap)) {
221 m_pointerMap.clear();
222 EIGEN_THROW_X(
223 std::out_of_range("The pointer is not registered in the map\n"));
224
225 }
226 --node;
227 }
228
229 return node;
230 }
231
232 /* get_buffer.
233 * Returns a buffer from the map using the pointer address
234 */
235 template <typename buffer_data_type = buffer_data_type_t>
get_buffer(const virtual_pointer_t ptr)236 cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237 const virtual_pointer_t ptr) {
238 using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
239
240 // get_node() returns a `buffer_mem`, so we need to cast it to a `buffer<>`.
241 // We can do this without the `buffer_mem` being a pointer, as we
242 // only declare member variables in the base class (`buffer_mem`) and not in
243 // the child class (`buffer<>).
244 auto node = get_node(ptr);
245 eigen_assert(node->first == ptr || node->first < ptr);
246 eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
247 node->first));
248 return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
249 }
250
251 /**
252 * @brief Returns an accessor to the buffer of the given virtual pointer
253 * @param accessMode
254 * @param accessTarget
255 * @param ptr The virtual pointer
256 */
257 template <sycl_acc_mode access_mode = default_acc_mode,
258 sycl_acc_target access_target = default_acc_target,
259 typename buffer_data_type = buffer_data_type_t>
260 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
get_access(const virtual_pointer_t ptr)261 get_access(const virtual_pointer_t ptr) {
262 auto buf = get_buffer<buffer_data_type>(ptr);
263 return buf.template get_access<access_mode, access_target>();
264 }
265
266 /**
267 * @brief Returns an accessor to the buffer of the given virtual pointer
268 * in the given command group scope
269 * @param accessMode
270 * @param accessTarget
271 * @param ptr The virtual pointer
272 * @param cgh Reference to the command group scope
273 */
274 template <sycl_acc_mode access_mode = default_acc_mode,
275 sycl_acc_target access_target = default_acc_target,
276 typename buffer_data_type = buffer_data_type_t>
277 cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
get_access(const virtual_pointer_t ptr,cl::sycl::handler & cgh)278 get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279 auto buf = get_buffer<buffer_data_type>(ptr);
280 return buf.template get_access<access_mode, access_target>(cgh);
281 }
282
283 /*
284 * Returns the offset from the base address of this pointer.
285 */
get_offset(const virtual_pointer_t ptr)286 inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) {
287 // The previous element to the lower bound is the node that
288 // holds this memory address
289 auto node = get_node(ptr);
290 auto start = node->first;
291 eigen_assert(start == ptr || start < ptr);
292 eigen_assert(ptr < start + node->second.m_size);
293 return (ptr - start);
294 }
295
296 /*
297 * Returns the number of elements by which the given pointer is offset from
298 * the base address.
299 */
300 template <typename buffer_data_type>
get_element_offset(const virtual_pointer_t ptr)301 inline size_t get_element_offset(const virtual_pointer_t ptr) {
302 return get_offset(ptr) / sizeof(buffer_data_type);
303 }
304
305 /**
306 * Constructs the PointerMapper structure.
307 */
308 PointerMapper(base_ptr_t baseAddress = 4096)
309 : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310 if (m_baseAddress == 0) {
311 EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n"));
312 }
313 };
314
315 /**
316 * PointerMapper cannot be copied or moved
317 */
318 PointerMapper(const PointerMapper &) = delete;
319
320 /**
321 * Empty the pointer list
322 */
clear()323 inline void clear() {
324 m_freeList.clear();
325 m_pointerMap.clear();
326 }
327
328 /* add_pointer.
329 * Adds an existing pointer to the map and returns the virtual pointer id.
330 */
add_pointer(const buffer_t & b)331 inline virtual_pointer_t add_pointer(const buffer_t &b) {
332 return add_pointer_impl(b);
333 }
334
335 /* add_pointer.
336 * Adds a pointer to the map and returns the virtual pointer id.
337 */
add_pointer(buffer_t && b)338 inline virtual_pointer_t add_pointer(buffer_t &&b) {
339 return add_pointer_impl(b);
340 }
341
342 /**
343 * @brief Fuses the given node with the previous nodes in the
344 * pointer map if they are free
345 *
346 * @param node A reference to the free node to be fused
347 */
fuse_forward(typename pointerMap_t::iterator & node)348 void fuse_forward(typename pointerMap_t::iterator &node) {
349 while (node != std::prev(m_pointerMap.end())) {
350 // if following node is free
351 // remove it and extend the current node with its size
352 auto fwd_node = std::next(node);
353 if (!fwd_node->second.m_free) {
354 break;
355 }
356 auto fwd_size = fwd_node->second.m_size;
357 m_freeList.erase(fwd_node);
358 m_pointerMap.erase(fwd_node);
359
360 node->second.m_size += fwd_size;
361 }
362 }
363
364 /**
365 * @brief Fuses the given node with the following nodes in the
366 * pointer map if they are free
367 *
368 * @param node A reference to the free node to be fused
369 */
fuse_backward(typename pointerMap_t::iterator & node)370 void fuse_backward(typename pointerMap_t::iterator &node) {
371 while (node != m_pointerMap.begin()) {
372 // if previous node is free, extend it
373 // with the size of the current one
374 auto prev_node = std::prev(node);
375 if (!prev_node->second.m_free) {
376 break;
377 }
378 prev_node->second.m_size += node->second.m_size;
379
380 // remove the current node
381 m_freeList.erase(node);
382 m_pointerMap.erase(node);
383
384 // point to the previous node
385 node = prev_node;
386 }
387 }
388
389 /* remove_pointer.
390 * Removes the given pointer from the map.
391 * The pointer is allowed to be reused only if ReUse if true.
392 */
393 template <bool ReUse = true>
remove_pointer(const virtual_pointer_t ptr)394 void remove_pointer(const virtual_pointer_t ptr) {
395 if (is_nullptr(ptr)) {
396 return;
397 }
398 auto node = this->get_node(ptr);
399
400 node->second.m_free = true;
401 m_freeList.emplace(node);
402
403 // Fuse the node
404 // with free nodes before and after it
405 fuse_forward(node);
406 fuse_backward(node);
407
408 // If after fusing the node is the last one
409 // simply remove it (since it is free)
410 if (node == std::prev(m_pointerMap.end())) {
411 m_freeList.erase(node);
412 m_pointerMap.erase(node);
413 }
414 }
415
416 /* count.
417 * Return the number of active pointers (i.e, pointers that
418 * have been malloc but not freed).
419 */
count()420 size_t count() const { return (m_pointerMap.size() - m_freeList.size()); }
421
422 private:
423 /* add_pointer_impl.
424 * Adds a pointer to the map and returns the virtual pointer id.
425 * BufferT is either a const buffer_t& or a buffer_t&&.
426 */
427 template <class BufferT>
add_pointer_impl(BufferT b)428 virtual_pointer_t add_pointer_impl(BufferT b) {
429 virtual_pointer_t retVal = nullptr;
430 size_t bufSize = b.get_count();
431 pMapNode_t p{b, bufSize, false};
432 // If this is the first pointer:
433 if (m_pointerMap.empty()) {
434 virtual_pointer_t initialVal{m_baseAddress};
435 m_pointerMap.emplace(initialVal, p);
436 return initialVal;
437 }
438
439 auto lastElemIter = get_insertion_point(bufSize);
440 // We are recovering an existing free node
441 if (lastElemIter->second.m_free) {
442 lastElemIter->second.m_buffer = b;
443 lastElemIter->second.m_free = false;
444
445 // If the recovered node is bigger than the inserted one
446 // add a new free node with the remaining space
447 if (lastElemIter->second.m_size > bufSize) {
448 // create a new node with the remaining space
449 auto remainingSize = lastElemIter->second.m_size - bufSize;
450 pMapNode_t p2{b, remainingSize, true};
451
452 // update size of the current node
453 lastElemIter->second.m_size = bufSize;
454
455 // add the new free node
456 auto newFreePtr = lastElemIter->first + bufSize;
457 auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
458 m_freeList.emplace(freeNode);
459 }
460
461 retVal = lastElemIter->first;
462 } else {
463 size_t lastSize = lastElemIter->second.m_size;
464 retVal = lastElemIter->first + lastSize;
465 m_pointerMap.emplace(retVal, p);
466 }
467 return retVal;
468 }
469
470 /**
471 * Compare two iterators to pointer map entries according to
472 * the size of the allocation on the device.
473 */
474 struct SortBySize {
operatorSortBySize475 bool operator()(typename pointerMap_t::iterator a,
476 typename pointerMap_t::iterator b) const {
477 return ((a->first < b->first) && (a->second <= b->second)) ||
478 ((a->first < b->first) && (b->second <= a->second));
479 }
480 };
481
482 /* Maps the pointer addresses to buffer and size pairs.
483 */
484 pointerMap_t m_pointerMap;
485
486 /* List of free nodes available for re-using
487 */
488 std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
489
490 /* Base address used when issuing the first virtual pointer, allows users
491 * to specify alignment. Cannot be zero. */
492 std::intptr_t m_baseAddress;
493 };
494
495 /* remove_pointer.
496 * Removes the given pointer from the map.
497 * The pointer is allowed to be reused only if ReUse if true.
498 */
499 template <>
500 inline void PointerMapper::remove_pointer<false>(const virtual_pointer_t ptr) {
501 if (is_nullptr(ptr)) {
502 return;
503 }
504 m_pointerMap.erase(this->get_node(ptr));
505 }
506
507 /**
508 * Malloc-like interface to the pointer-mapper.
509 * Given a size, creates a byte-typed buffer and returns a
510 * fake pointer to keep track of it.
511 * \param size Size in bytes of the desired allocation
512 * \throw cl::sycl::exception if error while creating the buffer
513 */
SYCLmalloc(size_t size,PointerMapper & pMap)514 inline void *SYCLmalloc(size_t size, PointerMapper &pMap) {
515 if (size == 0) {
516 return nullptr;
517 }
518 // Create a generic buffer of the given size
519 using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520 auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
521 // Store the buffer on the global list
522 return static_cast<void *>(thePointer);
523 }
524
525 /**
526 * Free-like interface to the pointer mapper.
527 * Given a fake-pointer created with the virtual-pointer malloc,
528 * destroys the buffer and remove it from the list.
529 * If ReUse is false, the pointer is not added to the freeList,
530 * it should be false only for sub-buffers.
531 */
532 template <bool ReUse = true, typename PointerMapper>
SYCLfree(void * ptr,PointerMapper & pMap)533 inline void SYCLfree(void *ptr, PointerMapper &pMap) {
534 pMap.template remove_pointer<ReUse>(ptr);
535 }
536
537 /**
538 * Clear all the memory allocated by SYCL.
539 */
540 template <typename PointerMapper>
SYCLfreeAll(PointerMapper & pMap)541 inline void SYCLfreeAll(PointerMapper &pMap) {
542 pMap.clear();
543 }
544
545 template <cl::sycl::access::mode AcMd, typename T>
546 struct RangeAccess {
547 static const auto global_access = cl::sycl::access::target::global_buffer;
548 static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
549 typedef T scalar_t;
550 typedef scalar_t &ref_t;
551 typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
552
553 // the accessor type does not necessarily the same as T
554 typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
555 accessor;
556
557 typedef RangeAccess<AcMd, T> self_t;
RangeAccessRangeAccess558 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
559 size_t offset,
560 std::intptr_t virtual_ptr)
561 : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
562
563 RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
564 cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565 : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
566
567 // This should be only used for null constructor on the host side
RangeAccessRangeAccess568 RangeAccess(std::nullptr_t) : RangeAccess() {}
569 // This template parameter must be removed and scalar_t should be replaced
get_pointerRangeAccess570 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const {
571 return (access_.get_pointer().get() + offset_);
572 }
573 template <typename Index>
574 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
575 offset_ += (offset);
576 return *this;
577 }
578 template <typename Index>
579 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const {
580 return self_t(access_, offset_ + offset, virtual_ptr_);
581 }
582 template <typename Index>
583 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const {
584 return self_t(access_, offset_ - offset, virtual_ptr_);
585 }
586 template <typename Index>
587 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
588 offset_ -= offset;
589 return *this;
590 }
591
592 // THIS IS FOR NULL COMPARISON ONLY
593 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
594 const RangeAccess &lhs, std::nullptr_t) {
595 return ((lhs.virtual_ptr_ == -1));
596 }
597 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
598 const RangeAccess &lhs, std::nullptr_t i) {
599 return !(lhs == i);
600 }
601
602 // THIS IS FOR NULL COMPARISON ONLY
603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
604 std::nullptr_t, const RangeAccess &rhs) {
605 return ((rhs.virtual_ptr_ == -1));
606 }
607 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
608 std::nullptr_t i, const RangeAccess &rhs) {
609 return !(i == rhs);
610 }
611 // Prefix operator (Increment and return value)
612 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
613 offset_++;
614 return (*this);
615 }
616
617 // Postfix operator (Return value and increment)
618 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) {
619 EIGEN_UNUSED_VARIABLE(i);
620 self_t temp_iterator(*this);
621 offset_++;
622 return temp_iterator;
623 }
624
get_sizeRangeAccess625 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const {
626 return (access_.get_count() - offset_);
627 }
628
get_offsetRangeAccess629 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const {
630 return offset_;
631 }
632
set_offsetRangeAccess633 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) {
634 offset_ = offset;
635 }
636
637 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const {
638 return *get_pointer();
639 }
640
641 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
642 return *get_pointer();
643 }
644
645 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete;
646
647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) {
648 return *(get_pointer() + x);
649 }
650
651 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const {
652 return *(get_pointer() + x);
653 }
654
get_virtual_pointerRangeAccess655 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const {
656 return reinterpret_cast<scalar_t *>(virtual_ptr_ +
657 (offset_ * sizeof(scalar_t)));
658 }
659
660 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const {
661 return (virtual_ptr_ != -1);
662 }
663
RangeAccessRangeAccess664 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess<AcMd, const T>() {
665 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
666 }
667
668 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
RangeAccessRangeAccess669 operator RangeAccess<AcMd, const T>() const {
670 return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
671 }
672 // binding placeholder accessors to a command group handler for SYCL
bindRangeAccess673 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
674 cl::sycl::handler &cgh) const {
675 cgh.require(access_);
676 }
677
678 private:
679 accessor access_;
680 size_t offset_;
681 std::intptr_t virtual_ptr_; // the location of the buffer in the map
682 };
683
684 template <cl::sycl::access::mode AcMd, typename T>
685 struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
686 typedef RangeAccess<AcMd, T> Base;
687 using Base::Base;
688 };
689
690 } // namespace internal
691 } // namespace TensorSycl
692 } // namespace Eigen
693
694 #endif // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
695