xref: /aosp_15_r20/external/pytorch/torch/csrc/lazy/core/util.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 /**
2  * Most of the utils in this file is adapted from PyTorch/XLA
3  * https://github.com/pytorch/xla/blob/master/third_party/xla_client/util.h
4  */
5 
6 #pragma once
7 
8 #include <exception>
9 #include <functional>
10 #include <vector>
11 
12 #include <c10/util/OptionalArrayRef.h>
13 #include <optional>
14 
15 namespace torch {
16 namespace lazy {
17 
18 // Similar to c10::scope_exit but with a status.
19 // TODO(alanwaketan): Consolidate it with c10::scope_exit.
20 template <typename T>
21 class Cleanup {
22  public:
23   using StatusType = T;
24 
Cleanup(std::function<void (StatusType &&)> && func)25   explicit Cleanup(std::function<void(StatusType&&)>&& func)
26       : func_(std::move(func)) {}
Cleanup(Cleanup && ref)27   Cleanup(Cleanup&& ref) noexcept
28       : func_(std::move(ref.func_)), status_(std::move(ref.status_)) {}
29   Cleanup(const Cleanup&) = delete;
30 
~Cleanup()31   ~Cleanup() {
32     if (func_ != nullptr) {
33       func_(std::move(status_));
34     }
35   }
36 
37   Cleanup& operator=(const Cleanup&) = delete;
38 
39   Cleanup& operator=(Cleanup&& ref) noexcept {
40     if (this != &ref) {
41       func_ = std::move(ref.func_);
42       status_ = std::move(ref.status_);
43     }
44     return *this;
45   }
46 
Release()47   void Release() {
48     func_ = nullptr;
49   }
50 
SetStatus(StatusType && status)51   void SetStatus(StatusType&& status) {
52     status_ = std::move(status);
53   }
54 
GetStatus()55   const StatusType& GetStatus() const {
56     return status_;
57   }
58 
59  private:
60   std::function<void(StatusType&&)> func_;
61   StatusType status_;
62 };
63 
64 using ExceptionCleanup = Cleanup<std::exception_ptr>;
65 
66 // Allows APIs which might return const references and values, to not be forced
67 // to return values in the signature.
68 // TODO(alanwaketan): This is clever, but is there really no std or c10
69 // supports? Needs more investigations.
70 template <typename T>
71 class MaybeRef {
72  public:
MaybeRef(const T & ref)73   /* implicit */ MaybeRef(const T& ref) : ref_(ref) {}
MaybeRef(T && value)74   /* implicit */ MaybeRef(T&& value)
75       : storage_(std::move(value)), ref_(*storage_) {}
76 
Get()77   const T& Get() const {
78     return ref_;
79   }
80   const T& operator*() const {
81     return Get();
82   }
83   operator const T&() const {
84     return Get();
85   }
86 
IsStored()87   bool IsStored() const {
88     return storage_.has_value();
89   }
90 
91  private:
92   std::optional<T> storage_;
93   const T& ref_;
94 };
95 
96 template <typename T>
97 std::vector<T> Iota(size_t size, T init = 0, T incr = 1) {
98   std::vector<T> result(size);
99   T value = init;
100   for (size_t i = 0; i < size; ++i, value += incr) {
101     result[i] = value;
102   }
103   return result;
104 }
105 
106 template <typename T, typename S>
ToVector(const S & input)107 std::vector<T> ToVector(const S& input) {
108   return std::vector<T>(input.begin(), input.end());
109 }
110 
111 template <typename T>
ToOptionalVector(c10::OptionalArrayRef<T> arrayRef)112 std::optional<std::vector<T>> ToOptionalVector(
113     c10::OptionalArrayRef<T> arrayRef) {
114   if (arrayRef) {
115     return arrayRef->vec();
116   }
117   return std::nullopt;
118 }
119 
120 template <typename T>
GetEnumValue(T value)121 typename std::underlying_type<T>::type GetEnumValue(T value) {
122   return static_cast<typename std::underlying_type<T>::type>(value);
123 }
124 
125 } // namespace lazy
126 } // namespace torch
127