xref: /aosp_15_r20/external/eigen/Eigen/src/Core/arch/SYCL/SyclMemoryModel.h (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 /***************************************************************************
2  *  Copyright (C) 2017 Codeplay Software Limited
3  *  This Source Code Form is subject to the terms of the Mozilla
4  *  Public License v. 2.0. If a copy of the MPL was not distributed
5  *  with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
6  *
7  *
8  *  SyclMemoryModel.h
9  *
10  *  Description:
11  *    Interface for SYCL buffers to behave as a non-dereferenceable pointer
12  *    Interface for Placeholder accessor to behave as a pointer on both host
13  *    and device
14  *
15  * Authors:
16  *
17  *    Ruyman Reyes   Codeplay Software Ltd.
18  *    Mehdi Goli     Codeplay Software Ltd.
19  *    Vanya Yaneva   Codeplay Software Ltd.
20  *
21  **************************************************************************/
22 
23 #if defined(EIGEN_USE_SYCL) && \
24     !defined(EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H)
25 #define EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
26 
27 #include <CL/sycl.hpp>
28 #ifdef EIGEN_EXCEPTIONS
29 #include <stdexcept>
30 #endif
31 #include <cstddef>
32 #include <queue>
33 #include <set>
34 #include <unordered_map>
35 
36 namespace Eigen {
37 namespace TensorSycl {
38 namespace internal {
39 
40 using sycl_acc_target = cl::sycl::access::target;
41 using sycl_acc_mode = cl::sycl::access::mode;
42 
43 /**
44  * Default values for template arguments
45  */
46 using buffer_data_type_t = uint8_t;
47 const sycl_acc_target default_acc_target = sycl_acc_target::global_buffer;
48 const sycl_acc_mode default_acc_mode = sycl_acc_mode::read_write;
49 
50 /**
51  * PointerMapper
52  *  Associates fake pointers with buffers.
53  *
54  */
55 class PointerMapper {
56  public:
57   using base_ptr_t = std::intptr_t;
58 
59   /* Structure of a virtual pointer
60    *
61    * |================================================|
62    * |               POINTER ADDRESS                  |
63    * |================================================|
64    */
65   struct virtual_pointer_t {
66     /* Type for the pointers
67      */
68     base_ptr_t m_contents;
69 
70     /** Conversions from virtual_pointer_t to
71      * void * should just reinterpret_cast the integer number
72      */
73     operator void *() const { return reinterpret_cast<void *>(m_contents); }
74 
75     /**
76      * Convert back to the integer number.
77      */
base_ptr_tvirtual_pointer_t78     operator base_ptr_t() const { return m_contents; }
79 
80     /**
81      * Add a certain value to the pointer to create a
82      * new pointer to that offset
83      */
84     virtual_pointer_t operator+(size_t off) { return m_contents + off; }
85 
86     /* Numerical order for sorting pointers in containers. */
87     bool operator<(virtual_pointer_t rhs) const {
88       return (static_cast<base_ptr_t>(m_contents) <
89               static_cast<base_ptr_t>(rhs.m_contents));
90     }
91 
92     bool operator>(virtual_pointer_t rhs) const {
93       return (static_cast<base_ptr_t>(m_contents) >
94               static_cast<base_ptr_t>(rhs.m_contents));
95     }
96 
97     /**
98      * Numerical order for sorting pointers in containers
99      */
100     bool operator==(virtual_pointer_t rhs) const {
101       return (static_cast<base_ptr_t>(m_contents) ==
102               static_cast<base_ptr_t>(rhs.m_contents));
103     }
104 
105     /**
106      * Simple forward to the equality overload.
107      */
108     bool operator!=(virtual_pointer_t rhs) const {
109       return !(this->operator==(rhs));
110     }
111 
112     /**
113      * Converts a void * into a virtual pointer structure.
114      * Note that this will only work if the void * was
115      * already a virtual_pointer_t, but we have no way of
116      * checking
117      */
virtual_pointer_tvirtual_pointer_t118     virtual_pointer_t(const void *ptr)
119         : m_contents(reinterpret_cast<base_ptr_t>(ptr)){};
120 
121     /**
122      * Creates a virtual_pointer_t from the given integer
123      * number
124      */
virtual_pointer_tvirtual_pointer_t125     virtual_pointer_t(base_ptr_t u) : m_contents(u){};
126   };
127 
128   /* Definition of a null pointer
129    */
130   const virtual_pointer_t null_virtual_ptr = nullptr;
131 
132   /**
133    * Whether if a pointer is null or not.
134    * A pointer is nullptr if the value is of null_virtual_ptr
135    */
is_nullptr(virtual_pointer_t ptr)136   static inline bool is_nullptr(virtual_pointer_t ptr) {
137     return (static_cast<void *>(ptr) == nullptr);
138   }
139 
140   /* basic type for all buffers
141    */
142   using buffer_t = cl::sycl::buffer_mem;
143 
144   /**
145    * Node that stores information about a device allocation.
146    * Nodes are sorted by size to organise a free list of nodes
147    * that can be recovered.
148    */
149   struct pMapNode_t {
150     buffer_t m_buffer;
151     size_t m_size;
152     bool m_free;
153 
pMapNode_tpMapNode_t154     pMapNode_t(buffer_t b, size_t size, bool f)
155         : m_buffer{b}, m_size{size}, m_free{f} {
156       m_buffer.set_final_data(nullptr);
157     }
158 
159     bool operator<=(const pMapNode_t &rhs) { return (m_size <= rhs.m_size); }
160   };
161 
162   /** Storage of the pointer / buffer tree
163    */
164   using pointerMap_t = std::map<virtual_pointer_t, pMapNode_t>;
165 
166   /**
167    * Obtain the insertion point in the pointer map for
168    * a pointer of the given size.
169    * \param requiredSize Size attemted to reclaim
170    */
get_insertion_point(size_t requiredSize)171   typename pointerMap_t::iterator get_insertion_point(size_t requiredSize) {
172     typename pointerMap_t::iterator retVal;
173     bool reuse = false;
174     if (!m_freeList.empty()) {
175       // try to re-use an existing block
176       for (auto freeElem : m_freeList) {
177         if (freeElem->second.m_size >= requiredSize) {
178           retVal = freeElem;
179           reuse = true;
180           // Element is not going to be free anymore
181           m_freeList.erase(freeElem);
182           break;
183         }
184       }
185     }
186     if (!reuse) {
187       retVal = std::prev(m_pointerMap.end());
188     }
189     return retVal;
190   }
191 
192   /**
193    * Returns an iterator to the node that stores the information
194    * of the given virtual pointer from the given pointer map structure.
195    * If pointer is not found, throws std::out_of_range.
196    * If the pointer map structure is empty, throws std::out_of_range
197    *
198    * \param pMap the pointerMap_t structure storing all the pointers
199    * \param virtual_pointer_ptr The virtual pointer to obtain the node of
200    * \throws std::out:of_range if the pointer is not found or pMap is empty
201    */
get_node(const virtual_pointer_t ptr)202   typename pointerMap_t::iterator get_node(const virtual_pointer_t ptr) {
203     if (this->count() == 0) {
204       m_pointerMap.clear();
205       EIGEN_THROW_X(std::out_of_range("There are no pointers allocated\n"));
206 
207     }
208     if (is_nullptr(ptr)) {
209       m_pointerMap.clear();
210       EIGEN_THROW_X(std::out_of_range("Cannot access null pointer\n"));
211     }
212     // The previous element to the lower bound is the node that
213     // holds this memory address
214     auto node = m_pointerMap.lower_bound(ptr);
215     // If the value of the pointer is not the one of the node
216     // then we return the previous one
217     if (node == std::end(m_pointerMap)) {
218       --node;
219     } else if (node->first != ptr) {
220       if (node == std::begin(m_pointerMap)) {
221         m_pointerMap.clear();
222         EIGEN_THROW_X(
223             std::out_of_range("The pointer is not registered in the map\n"));
224 
225       }
226       --node;
227     }
228 
229     return node;
230   }
231 
232   /* get_buffer.
233    * Returns a buffer from the map using the pointer address
234    */
235   template <typename buffer_data_type = buffer_data_type_t>
get_buffer(const virtual_pointer_t ptr)236   cl::sycl::buffer<buffer_data_type, 1> get_buffer(
237       const virtual_pointer_t ptr) {
238     using sycl_buffer_t = cl::sycl::buffer<buffer_data_type, 1>;
239 
240     // get_node() returns a `buffer_mem`, so we need to cast it to a `buffer<>`.
241     // We can do this without the `buffer_mem` being a pointer, as we
242     // only declare member variables in the base class (`buffer_mem`) and not in
243     // the child class (`buffer<>).
244     auto node = get_node(ptr);
245     eigen_assert(node->first == ptr || node->first < ptr);
246     eigen_assert(ptr < static_cast<virtual_pointer_t>(node->second.m_size +
247                                                       node->first));
248     return *(static_cast<sycl_buffer_t *>(&node->second.m_buffer));
249   }
250 
251   /**
252    * @brief Returns an accessor to the buffer of the given virtual pointer
253    * @param accessMode
254    * @param accessTarget
255    * @param ptr The virtual pointer
256    */
257   template <sycl_acc_mode access_mode = default_acc_mode,
258             sycl_acc_target access_target = default_acc_target,
259             typename buffer_data_type = buffer_data_type_t>
260   cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
get_access(const virtual_pointer_t ptr)261   get_access(const virtual_pointer_t ptr) {
262     auto buf = get_buffer<buffer_data_type>(ptr);
263     return buf.template get_access<access_mode, access_target>();
264   }
265 
266   /**
267    * @brief Returns an accessor to the buffer of the given virtual pointer
268    *        in the given command group scope
269    * @param accessMode
270    * @param accessTarget
271    * @param ptr The virtual pointer
272    * @param cgh Reference to the command group scope
273    */
274   template <sycl_acc_mode access_mode = default_acc_mode,
275             sycl_acc_target access_target = default_acc_target,
276             typename buffer_data_type = buffer_data_type_t>
277   cl::sycl::accessor<buffer_data_type, 1, access_mode, access_target>
get_access(const virtual_pointer_t ptr,cl::sycl::handler & cgh)278   get_access(const virtual_pointer_t ptr, cl::sycl::handler &cgh) {
279     auto buf = get_buffer<buffer_data_type>(ptr);
280     return buf.template get_access<access_mode, access_target>(cgh);
281   }
282 
283   /*
284    * Returns the offset from the base address of this pointer.
285    */
get_offset(const virtual_pointer_t ptr)286   inline std::ptrdiff_t get_offset(const virtual_pointer_t ptr) {
287     // The previous element to the lower bound is the node that
288     // holds this memory address
289     auto node = get_node(ptr);
290     auto start = node->first;
291     eigen_assert(start == ptr || start < ptr);
292     eigen_assert(ptr < start + node->second.m_size);
293     return (ptr - start);
294   }
295 
296   /*
297    * Returns the number of elements by which the given pointer is offset from
298    * the base address.
299    */
300   template <typename buffer_data_type>
get_element_offset(const virtual_pointer_t ptr)301   inline size_t get_element_offset(const virtual_pointer_t ptr) {
302     return get_offset(ptr) / sizeof(buffer_data_type);
303   }
304 
305   /**
306    * Constructs the PointerMapper structure.
307    */
308   PointerMapper(base_ptr_t baseAddress = 4096)
309       : m_pointerMap{}, m_freeList{}, m_baseAddress{baseAddress} {
310     if (m_baseAddress == 0) {
311       EIGEN_THROW_X(std::invalid_argument("Base address cannot be zero\n"));
312     }
313   };
314 
315   /**
316    * PointerMapper cannot be copied or moved
317    */
318   PointerMapper(const PointerMapper &) = delete;
319 
320   /**
321    * Empty the pointer list
322    */
clear()323   inline void clear() {
324     m_freeList.clear();
325     m_pointerMap.clear();
326   }
327 
328   /* add_pointer.
329    * Adds an existing pointer to the map and returns the virtual pointer id.
330    */
add_pointer(const buffer_t & b)331   inline virtual_pointer_t add_pointer(const buffer_t &b) {
332     return add_pointer_impl(b);
333   }
334 
335   /* add_pointer.
336    * Adds a pointer to the map and returns the virtual pointer id.
337    */
add_pointer(buffer_t && b)338   inline virtual_pointer_t add_pointer(buffer_t &&b) {
339     return add_pointer_impl(b);
340   }
341 
342   /**
343    * @brief Fuses the given node with the previous nodes in the
344    *        pointer map if they are free
345    *
346    * @param node A reference to the free node to be fused
347    */
fuse_forward(typename pointerMap_t::iterator & node)348   void fuse_forward(typename pointerMap_t::iterator &node) {
349     while (node != std::prev(m_pointerMap.end())) {
350       // if following node is free
351       // remove it and extend the current node with its size
352       auto fwd_node = std::next(node);
353       if (!fwd_node->second.m_free) {
354         break;
355       }
356       auto fwd_size = fwd_node->second.m_size;
357       m_freeList.erase(fwd_node);
358       m_pointerMap.erase(fwd_node);
359 
360       node->second.m_size += fwd_size;
361     }
362   }
363 
364   /**
365    * @brief Fuses the given node with the following nodes in the
366    *        pointer map if they are free
367    *
368    * @param node A reference to the free node to be fused
369    */
fuse_backward(typename pointerMap_t::iterator & node)370   void fuse_backward(typename pointerMap_t::iterator &node) {
371     while (node != m_pointerMap.begin()) {
372       // if previous node is free, extend it
373       // with the size of the current one
374       auto prev_node = std::prev(node);
375       if (!prev_node->second.m_free) {
376         break;
377       }
378       prev_node->second.m_size += node->second.m_size;
379 
380       // remove the current node
381       m_freeList.erase(node);
382       m_pointerMap.erase(node);
383 
384       // point to the previous node
385       node = prev_node;
386     }
387   }
388 
389   /* remove_pointer.
390    * Removes the given pointer from the map.
391    * The pointer is allowed to be reused only if ReUse if true.
392    */
393   template <bool ReUse = true>
remove_pointer(const virtual_pointer_t ptr)394   void remove_pointer(const virtual_pointer_t ptr) {
395     if (is_nullptr(ptr)) {
396       return;
397     }
398     auto node = this->get_node(ptr);
399 
400     node->second.m_free = true;
401     m_freeList.emplace(node);
402 
403     // Fuse the node
404     // with free nodes before and after it
405     fuse_forward(node);
406     fuse_backward(node);
407 
408     // If after fusing the node is the last one
409     // simply remove it (since it is free)
410     if (node == std::prev(m_pointerMap.end())) {
411       m_freeList.erase(node);
412       m_pointerMap.erase(node);
413     }
414   }
415 
416   /* count.
417    * Return the number of active pointers (i.e, pointers that
418    * have been malloc but not freed).
419    */
count()420   size_t count() const { return (m_pointerMap.size() - m_freeList.size()); }
421 
422  private:
423   /* add_pointer_impl.
424    * Adds a pointer to the map and returns the virtual pointer id.
425    * BufferT is either a const buffer_t& or a buffer_t&&.
426    */
427   template <class BufferT>
add_pointer_impl(BufferT b)428   virtual_pointer_t add_pointer_impl(BufferT b) {
429     virtual_pointer_t retVal = nullptr;
430     size_t bufSize = b.get_count();
431     pMapNode_t p{b, bufSize, false};
432     // If this is the first pointer:
433     if (m_pointerMap.empty()) {
434       virtual_pointer_t initialVal{m_baseAddress};
435       m_pointerMap.emplace(initialVal, p);
436       return initialVal;
437     }
438 
439     auto lastElemIter = get_insertion_point(bufSize);
440     // We are recovering an existing free node
441     if (lastElemIter->second.m_free) {
442       lastElemIter->second.m_buffer = b;
443       lastElemIter->second.m_free = false;
444 
445       // If the recovered node is bigger than the inserted one
446       // add a new free node with the remaining space
447       if (lastElemIter->second.m_size > bufSize) {
448         // create a new node with the remaining space
449         auto remainingSize = lastElemIter->second.m_size - bufSize;
450         pMapNode_t p2{b, remainingSize, true};
451 
452         // update size of the current node
453         lastElemIter->second.m_size = bufSize;
454 
455         // add the new free node
456         auto newFreePtr = lastElemIter->first + bufSize;
457         auto freeNode = m_pointerMap.emplace(newFreePtr, p2).first;
458         m_freeList.emplace(freeNode);
459       }
460 
461       retVal = lastElemIter->first;
462     } else {
463       size_t lastSize = lastElemIter->second.m_size;
464       retVal = lastElemIter->first + lastSize;
465       m_pointerMap.emplace(retVal, p);
466     }
467     return retVal;
468   }
469 
470   /**
471    * Compare two iterators to pointer map entries according to
472    * the size of the allocation on the device.
473    */
474   struct SortBySize {
operatorSortBySize475     bool operator()(typename pointerMap_t::iterator a,
476                     typename pointerMap_t::iterator b) const {
477       return ((a->first < b->first) && (a->second <= b->second)) ||
478              ((a->first < b->first) && (b->second <= a->second));
479     }
480   };
481 
482   /* Maps the pointer addresses to buffer and size pairs.
483    */
484   pointerMap_t m_pointerMap;
485 
486   /* List of free nodes available for re-using
487    */
488   std::set<typename pointerMap_t::iterator, SortBySize> m_freeList;
489 
490   /* Base address used when issuing the first virtual pointer, allows users
491    * to specify alignment. Cannot be zero. */
492   std::intptr_t m_baseAddress;
493 };
494 
495 /* remove_pointer.
496  * Removes the given pointer from the map.
497  * The pointer is allowed to be reused only if ReUse if true.
498  */
499 template <>
500 inline void PointerMapper::remove_pointer<false>(const virtual_pointer_t ptr) {
501   if (is_nullptr(ptr)) {
502     return;
503   }
504   m_pointerMap.erase(this->get_node(ptr));
505 }
506 
507 /**
508  * Malloc-like interface to the pointer-mapper.
509  * Given a size, creates a byte-typed buffer and returns a
510  * fake pointer to keep track of it.
511  * \param size Size in bytes of the desired allocation
512  * \throw cl::sycl::exception if error while creating the buffer
513  */
SYCLmalloc(size_t size,PointerMapper & pMap)514 inline void *SYCLmalloc(size_t size, PointerMapper &pMap) {
515   if (size == 0) {
516     return nullptr;
517   }
518   // Create a generic buffer of the given size
519   using buffer_t = cl::sycl::buffer<buffer_data_type_t, 1>;
520   auto thePointer = pMap.add_pointer(buffer_t(cl::sycl::range<1>{size}));
521   // Store the buffer on the global list
522   return static_cast<void *>(thePointer);
523 }
524 
525 /**
526  * Free-like interface to the pointer mapper.
527  * Given a fake-pointer created with the virtual-pointer malloc,
528  * destroys the buffer and remove it from the list.
529  * If ReUse is false, the pointer is not added to the freeList,
530  * it should be false only for sub-buffers.
531  */
532 template <bool ReUse = true, typename PointerMapper>
SYCLfree(void * ptr,PointerMapper & pMap)533 inline void SYCLfree(void *ptr, PointerMapper &pMap) {
534   pMap.template remove_pointer<ReUse>(ptr);
535 }
536 
537 /**
538  * Clear all the memory allocated by SYCL.
539  */
540 template <typename PointerMapper>
SYCLfreeAll(PointerMapper & pMap)541 inline void SYCLfreeAll(PointerMapper &pMap) {
542   pMap.clear();
543 }
544 
545 template <cl::sycl::access::mode AcMd, typename T>
546 struct RangeAccess {
547   static const auto global_access = cl::sycl::access::target::global_buffer;
548   static const auto is_place_holder = cl::sycl::access::placeholder::true_t;
549   typedef T scalar_t;
550   typedef scalar_t &ref_t;
551   typedef typename cl::sycl::global_ptr<scalar_t>::pointer_t ptr_t;
552 
553   // the accessor type does not necessarily the same as T
554   typedef cl::sycl::accessor<scalar_t, 1, AcMd, global_access, is_place_holder>
555       accessor;
556 
557   typedef RangeAccess<AcMd, T> self_t;
RangeAccessRangeAccess558   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RangeAccess(accessor access,
559                                                     size_t offset,
560                                                     std::intptr_t virtual_ptr)
561       : access_(access), offset_(offset), virtual_ptr_(virtual_ptr) {}
562 
563   RangeAccess(cl::sycl::buffer<scalar_t, 1> buff =
564                   cl::sycl::buffer<scalar_t, 1>(cl::sycl::range<1>(1)))
565       : access_{accessor{buff}}, offset_(0), virtual_ptr_(-1) {}
566 
567   // This should be only used for null constructor on the host side
RangeAccessRangeAccess568   RangeAccess(std::nullptr_t) : RangeAccess() {}
569   // This template parameter must be removed and scalar_t should be replaced
get_pointerRangeAccess570   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t get_pointer() const {
571     return (access_.get_pointer().get() + offset_);
572   }
573   template <typename Index>
574   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator+=(Index offset) {
575     offset_ += (offset);
576     return *this;
577   }
578   template <typename Index>
579   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator+(Index offset) const {
580     return self_t(access_, offset_ + offset, virtual_ptr_);
581   }
582   template <typename Index>
583   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator-(Index offset) const {
584     return self_t(access_, offset_ - offset, virtual_ptr_);
585   }
586   template <typename Index>
587   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator-=(Index offset) {
588     offset_ -= offset;
589     return *this;
590   }
591 
592   // THIS IS FOR NULL COMPARISON ONLY
593   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
594       const RangeAccess &lhs, std::nullptr_t) {
595     return ((lhs.virtual_ptr_ == -1));
596   }
597   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
598       const RangeAccess &lhs, std::nullptr_t i) {
599     return !(lhs == i);
600   }
601 
602   // THIS IS FOR NULL COMPARISON ONLY
603   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator==(
604       std::nullptr_t, const RangeAccess &rhs) {
605     return ((rhs.virtual_ptr_ == -1));
606   }
607   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE friend bool operator!=(
608       std::nullptr_t i, const RangeAccess &rhs) {
609     return !(i == rhs);
610   }
611   // Prefix operator (Increment and return value)
612   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t &operator++() {
613     offset_++;
614     return (*this);
615   }
616 
617   // Postfix operator (Return value and increment)
618   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE self_t operator++(int i) {
619     EIGEN_UNUSED_VARIABLE(i);
620     self_t temp_iterator(*this);
621     offset_++;
622     return temp_iterator;
623   }
624 
get_sizeRangeAccess625   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_size() const {
626     return (access_.get_count() - offset_);
627   }
628 
get_offsetRangeAccess629   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t get_offset() const {
630     return offset_;
631   }
632 
set_offsetRangeAccess633   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void set_offset(std::ptrdiff_t offset) {
634     offset_ = offset;
635   }
636 
637   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() const {
638     return *get_pointer();
639   }
640 
641   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator*() {
642     return *get_pointer();
643   }
644 
645   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ptr_t operator->() = delete;
646 
647   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) {
648     return *(get_pointer() + x);
649   }
650 
651   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE ref_t operator[](int x) const {
652     return *(get_pointer() + x);
653   }
654 
get_virtual_pointerRangeAccess655   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE scalar_t *get_virtual_pointer() const {
656     return reinterpret_cast<scalar_t *>(virtual_ptr_ +
657                                         (offset_ * sizeof(scalar_t)));
658   }
659 
660   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit operator bool() const {
661     return (virtual_ptr_ != -1);
662   }
663 
RangeAccessRangeAccess664   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE operator RangeAccess<AcMd, const T>() {
665     return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
666   }
667 
668   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
RangeAccessRangeAccess669   operator RangeAccess<AcMd, const T>() const {
670     return RangeAccess<AcMd, const T>(access_, offset_, virtual_ptr_);
671   }
672   // binding placeholder accessors to a command group handler for SYCL
bindRangeAccess673   EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void bind(
674       cl::sycl::handler &cgh) const {
675     cgh.require(access_);
676   }
677 
678  private:
679   accessor access_;
680   size_t offset_;
681   std::intptr_t virtual_ptr_;  // the location of the buffer in the map
682 };
683 
684 template <cl::sycl::access::mode AcMd, typename T>
685 struct RangeAccess<AcMd, const T> : RangeAccess<AcMd, T> {
686   typedef RangeAccess<AcMd, T> Base;
687   using Base::Base;
688 };
689 
690 }  // namespace internal
691 }  // namespace TensorSycl
692 }  // namespace Eigen
693 
694 #endif  // EIGEN_CXX11_TENSOR_TENSOR_SYCL_STORAGE_MEMORY_H
695