1 /**
2  * This file has no copyright assigned and is placed in the Public Domain.
3  * This file is part of the mingw-w64 runtime package.
4  * No warranty is given; refer to the file DISCLAIMER.PD within this package.
5  */
6 
7 #ifndef _WRL_CLIENT_H_
8 #define _WRL_CLIENT_H_
9 
10 #include <stddef.h>
11 #include <unknwn.h>
12 /* #include <weakreference.h> */
13 #include <roapi.h>
14 
15 /* #include <wrl/def.h> */
16 #include <wrl/internal.h>
17 
18 namespace Microsoft {
19     namespace WRL {
20         namespace Details {
21             template <typename T> class ComPtrRefBase {
22             protected:
23                 T* ptr_;
24 
25             public:
26                 typedef typename T::InterfaceType InterfaceType;
27 
28 #ifndef __WRL_CLASSIC_COM__
throw()29                 operator IInspectable**() const throw()  {
30                     static_assert(__is_base_of(IInspectable, InterfaceType), "Invalid cast");
31                     return reinterpret_cast<IInspectable**>(ptr_->ReleaseAndGetAddressOf());
32                 }
33 #endif
34 
throw()35                 operator IUnknown**() const throw() {
36                     static_assert(__is_base_of(IUnknown, InterfaceType), "Invalid cast");
37                     return reinterpret_cast<IUnknown**>(ptr_->ReleaseAndGetAddressOf());
38                 }
39             };
40 
41             template <typename T> class ComPtrRef : public Details::ComPtrRefBase<T> {
42             public:
ComPtrRef(T * ptr)43                 ComPtrRef(T *ptr) throw() {
44                     ComPtrRefBase<T>::ptr_ = ptr;
45                 }
46 
throw()47                 operator void**() const throw() {
48                     return reinterpret_cast<void**>(ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf());
49                 }
50 
throw()51                 operator T*() throw() {
52                     *ComPtrRefBase<T>::ptr_ = nullptr;
53                     return ComPtrRefBase<T>::ptr_;
54                 }
55 
throw()56                 operator typename ComPtrRefBase<T>::InterfaceType**() throw() {
57                     return ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf();
58                 }
59 
throw()60                 typename ComPtrRefBase<T>::InterfaceType *operator*() throw() {
61                     return ComPtrRefBase<T>::ptr_->Get();
62                 }
63 
GetAddressOf()64                 typename ComPtrRefBase<T>::InterfaceType *const *GetAddressOf() const throw() {
65                     return ComPtrRefBase<T>::ptr_->GetAddressOf();
66                 }
67 
ReleaseAndGetAddressOf()68                 typename ComPtrRefBase<T>::InterfaceType **ReleaseAndGetAddressOf() throw() {
69                     return ComPtrRefBase<T>::ptr_->ReleaseAndGetAddressOf();
70                 }
71             };
72 
73         }
74 
75         template<typename T> class ComPtr {
76         public:
77             typedef T InterfaceType;
78 
throw()79             ComPtr() throw() : ptr_(nullptr) {}
throw()80             ComPtr(decltype(nullptr)) throw() : ptr_(nullptr) {}
81 
ComPtr(U * other)82             template<class U> ComPtr(U *other) throw() : ptr_(other) {
83                 InternalAddRef();
84             }
85 
throw()86             ComPtr(const ComPtr &other) throw() : ptr_(other.ptr_) {
87                 InternalAddRef();
88             }
89 
90             template<class U>
ComPtr(const ComPtr<U> & other)91             ComPtr(const ComPtr<U> &other) throw() : ptr_(other.Get()) {
92                 InternalAddRef();
93             }
94 
throw()95             ComPtr(ComPtr &&other) throw() : ptr_(nullptr) {
96                 if(this != reinterpret_cast<ComPtr*>(&reinterpret_cast<unsigned char&>(other)))
97                     Swap(other);
98             }
99 
100             template<class U>
ComPtr(ComPtr<U> && other)101             ComPtr(ComPtr<U>&& other) throw() : ptr_(other.Detach()) {}
102 
throw()103             ~ComPtr() throw() {
104                 InternalRelease();
105             }
106 
decltype(nullptr)107             ComPtr &operator=(decltype(nullptr)) throw() {
108                 InternalRelease();
109                 return *this;
110             }
111 
throw()112             ComPtr &operator=(InterfaceType *other) throw() {
113                 if (ptr_ != other) {
114                     InternalRelease();
115                     ptr_ = other;
116                     InternalAddRef();
117                 }
118                 return *this;
119             }
120 
121             template<typename U>
throw()122             ComPtr &operator=(U *other) throw()  {
123                 if (ptr_ != other) {
124                     InternalRelease();
125                     ptr_ = other;
126                     InternalAddRef();
127                 }
128                 return *this;
129             }
130 
throw()131             ComPtr& operator=(const ComPtr &other) throw() {
132                 if (ptr_ != other.ptr_)
133                     ComPtr(other).Swap(*this);
134                 return *this;
135             }
136 
137             template<class U>
throw()138             ComPtr &operator=(const ComPtr<U> &other) throw() {
139                 ComPtr(other).Swap(*this);
140                 return *this;
141             }
142 
throw()143             ComPtr& operator=(ComPtr &&other) throw() {
144                 ComPtr(other).Swap(*this);
145                 return *this;
146             }
147 
148             template<class U>
throw()149             ComPtr& operator=(ComPtr<U> &&other) throw() {
150                 ComPtr(other).Swap(*this);
151                 return *this;
152             }
153 
Swap(ComPtr && r)154             void Swap(ComPtr &&r) throw() {
155                 InterfaceType *tmp = ptr_;
156                 ptr_ = r.ptr_;
157                 r.ptr_ = tmp;
158             }
159 
Swap(ComPtr & r)160             void Swap(ComPtr &r) throw() {
161                 InterfaceType *tmp = ptr_;
162                 ptr_ = r.ptr_;
163                 r.ptr_ = tmp;
164             }
165 
BoolType()166             operator Details::BoolType() const throw() {
167                 return Get() != nullptr ? &Details::BoolStruct::Member : nullptr;
168             }
169 
Get()170             InterfaceType *Get() const throw()  {
171                 return ptr_;
172             }
173 
174             InterfaceType *operator->() const throw() {
175                 return ptr_;
176             }
177 
throw()178             Details::ComPtrRef<ComPtr<T>> operator&() throw()  {
179                 return Details::ComPtrRef<ComPtr<T>>(this);
180             }
181 
throw()182             const Details::ComPtrRef<const ComPtr<T>> operator&() const throw() {
183                 return Details::ComPtrRef<const ComPtr<T>>(this);
184             }
185 
GetAddressOf()186             InterfaceType *const *GetAddressOf() const throw() {
187                 return &ptr_;
188             }
189 
GetAddressOf()190             InterfaceType **GetAddressOf() throw() {
191                 return &ptr_;
192             }
193 
ReleaseAndGetAddressOf()194             InterfaceType **ReleaseAndGetAddressOf() throw() {
195                 InternalRelease();
196                 return &ptr_;
197             }
198 
Detach()199             InterfaceType *Detach() throw() {
200                 T* ptr = ptr_;
201                 ptr_ = nullptr;
202                 return ptr;
203             }
204 
Attach(InterfaceType * other)205             void Attach(InterfaceType *other) throw() {
206                 if (ptr_ != other) {
207                     InternalRelease();
208                     ptr_ = other;
209                     InternalAddRef();
210                 }
211             }
212 
Reset()213             unsigned long Reset() {
214                 return InternalRelease();
215             }
216 
CopyTo(InterfaceType ** ptr)217             HRESULT CopyTo(InterfaceType **ptr) const throw() {
218                 InternalAddRef();
219                 *ptr = ptr_;
220                 return S_OK;
221             }
222 
CopyTo(REFIID riid,void ** ptr)223             HRESULT CopyTo(REFIID riid, void **ptr) const throw() {
224                 return ptr_->QueryInterface(riid, ptr);
225             }
226 
227             template<typename U>
CopyTo(U ** ptr)228             HRESULT CopyTo(U **ptr) const throw() {
229                 return ptr_->QueryInterface(__uuidof(U), reinterpret_cast<void**>(ptr));
230             }
231 
232             template<typename U>
As(Details::ComPtrRef<ComPtr<U>> p)233             HRESULT As(Details::ComPtrRef<ComPtr<U>> p) const throw() {
234                 return ptr_->QueryInterface(__uuidof(U), p);
235             }
236 
237             template<typename U>
As(ComPtr<U> * p)238             HRESULT As(ComPtr<U> *p) const throw() {
239                 return ptr_->QueryInterface(__uuidof(U), reinterpret_cast<void**>(p->ReleaseAndGetAddressOf()));
240             }
241 
AsIID(REFIID riid,ComPtr<IUnknown> * p)242             HRESULT AsIID(REFIID riid, ComPtr<IUnknown> *p) const throw() {
243                 return ptr_->QueryInterface(riid, reinterpret_cast<void**>(p->ReleaseAndGetAddressOf()));
244             }
245 
246             /*
247             HRESULT AsWeak(WeakRef *pWeakRef) const throw() {
248                 return ::Microsoft::WRL::AsWeak(ptr_, pWeakRef);
249             }
250             */
251         protected:
252             InterfaceType *ptr_;
253 
InternalAddRef()254             void InternalAddRef() const throw() {
255                 if(ptr_)
256                     ptr_->AddRef();
257             }
258 
InternalRelease()259             unsigned long InternalRelease() throw() {
260                 InterfaceType *tmp = ptr_;
261                 if(!tmp)
262                     return 0;
263                 ptr_ = nullptr;
264                 return tmp->Release();
265             }
266         };
267     }
268 }
269 
270 template<typename T>
IID_PPV_ARGS_Helper(::Microsoft::WRL::Details::ComPtrRef<T> pp)271 void **IID_PPV_ARGS_Helper(::Microsoft::WRL::Details::ComPtrRef<T> pp) throw() {
272     static_assert(__is_base_of(IUnknown, typename T::InterfaceType), "Expected COM interface");
273     return pp;
274 }
275 
276 namespace Windows {
277     namespace Foundation {
278         template<typename T>
ActivateInstance(HSTRING classid,::Microsoft::WRL::Details::ComPtrRef<T> instance)279         inline HRESULT ActivateInstance(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> instance) throw() {
280             return ActivateInstance(classid, instance.ReleaseAndGetAddressOf());
281         }
282 
283         template<typename T>
GetActivationFactory(HSTRING classid,::Microsoft::WRL::Details::ComPtrRef<T> factory)284         inline HRESULT GetActivationFactory(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> factory) throw() {
285             return RoGetActivationFactory(classid, IID_INS_ARGS(factory.ReleaseAndGetAddressOf()));
286         }
287     }
288 }
289 
290 namespace ABI {
291     namespace Windows {
292         namespace Foundation {
293             template<typename T>
ActivateInstance(HSTRING classid,::Microsoft::WRL::Details::ComPtrRef<T> instance)294             inline HRESULT ActivateInstance(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> instance) throw() {
295                 return ActivateInstance(classid, instance.ReleaseAndGetAddressOf());
296             }
297 
298             template<typename T>
GetActivationFactory(HSTRING classid,::Microsoft::WRL::Details::ComPtrRef<T> factory)299             inline HRESULT GetActivationFactory(HSTRING classid, ::Microsoft::WRL::Details::ComPtrRef<T> factory) throw() {
300                 return RoGetActivationFactory(classid, IID_INS_ARGS(factory.ReleaseAndGetAddressOf()));
301             }
302         }
303     }
304 }
305 
306 #endif
307