1 /** 2 * Most of the utils in this file is adapted from PyTorch/XLA 3 * https://github.com/pytorch/xla/blob/master/third_party/xla_client/util.h 4 */ 5 6 #pragma once 7 8 #include <exception> 9 #include <functional> 10 #include <vector> 11 12 #include <c10/util/OptionalArrayRef.h> 13 #include <optional> 14 15 namespace torch { 16 namespace lazy { 17 18 // Similar to c10::scope_exit but with a status. 19 // TODO(alanwaketan): Consolidate it with c10::scope_exit. 20 template <typename T> 21 class Cleanup { 22 public: 23 using StatusType = T; 24 Cleanup(std::function<void (StatusType &&)> && func)25 explicit Cleanup(std::function<void(StatusType&&)>&& func) 26 : func_(std::move(func)) {} Cleanup(Cleanup && ref)27 Cleanup(Cleanup&& ref) noexcept 28 : func_(std::move(ref.func_)), status_(std::move(ref.status_)) {} 29 Cleanup(const Cleanup&) = delete; 30 ~Cleanup()31 ~Cleanup() { 32 if (func_ != nullptr) { 33 func_(std::move(status_)); 34 } 35 } 36 37 Cleanup& operator=(const Cleanup&) = delete; 38 39 Cleanup& operator=(Cleanup&& ref) noexcept { 40 if (this != &ref) { 41 func_ = std::move(ref.func_); 42 status_ = std::move(ref.status_); 43 } 44 return *this; 45 } 46 Release()47 void Release() { 48 func_ = nullptr; 49 } 50 SetStatus(StatusType && status)51 void SetStatus(StatusType&& status) { 52 status_ = std::move(status); 53 } 54 GetStatus()55 const StatusType& GetStatus() const { 56 return status_; 57 } 58 59 private: 60 std::function<void(StatusType&&)> func_; 61 StatusType status_; 62 }; 63 64 using ExceptionCleanup = Cleanup<std::exception_ptr>; 65 66 // Allows APIs which might return const references and values, to not be forced 67 // to return values in the signature. 68 // TODO(alanwaketan): This is clever, but is there really no std or c10 69 // supports? Needs more investigations. 70 template <typename T> 71 class MaybeRef { 72 public: MaybeRef(const T & ref)73 /* implicit */ MaybeRef(const T& ref) : ref_(ref) {} MaybeRef(T && value)74 /* implicit */ MaybeRef(T&& value) 75 : storage_(std::move(value)), ref_(*storage_) {} 76 Get()77 const T& Get() const { 78 return ref_; 79 } 80 const T& operator*() const { 81 return Get(); 82 } 83 operator const T&() const { 84 return Get(); 85 } 86 IsStored()87 bool IsStored() const { 88 return storage_.has_value(); 89 } 90 91 private: 92 std::optional<T> storage_; 93 const T& ref_; 94 }; 95 96 template <typename T> 97 std::vector<T> Iota(size_t size, T init = 0, T incr = 1) { 98 std::vector<T> result(size); 99 T value = init; 100 for (size_t i = 0; i < size; ++i, value += incr) { 101 result[i] = value; 102 } 103 return result; 104 } 105 106 template <typename T, typename S> ToVector(const S & input)107std::vector<T> ToVector(const S& input) { 108 return std::vector<T>(input.begin(), input.end()); 109 } 110 111 template <typename T> ToOptionalVector(c10::OptionalArrayRef<T> arrayRef)112std::optional<std::vector<T>> ToOptionalVector( 113 c10::OptionalArrayRef<T> arrayRef) { 114 if (arrayRef) { 115 return arrayRef->vec(); 116 } 117 return std::nullopt; 118 } 119 120 template <typename T> GetEnumValue(T value)121typename std::underlying_type<T>::type GetEnumValue(T value) { 122 return static_cast<typename std::underlying_type<T>::type>(value); 123 } 124 125 } // namespace lazy 126 } // namespace torch 127