xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/NamedTensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Dimname.h>
4 #include <c10/core/TensorImpl.h>
5 
6 namespace at {
7 
8 class TensorBase;
9 
10 // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen.
11 // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname,
12 // so we have a couple of workarounds.
13 //
14 // In the long term, we'll move Dimname to c10 and everything in this file
15 // can be refactored out. The main blocker for that is that "c10::Symbol"
16 // actually exists outside of c10 and needs to be moved in.
17 
18 // TensorImpl has a unique_ptr<NamedTensorMetaInterface> field.
19 // XXX: Ideally we would just put std::optional<vector<Dimname>> into TensorImpl.
20 //
21 // This class has an important invariant: there must be at least ONE
22 // non-wildcard
23 struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface {
24   // This enum is to remind people that the invariant on constructors is that
25   // the list of dimnames must have at least one non-wildcard
26   enum HAS_NON_WILDCARD {
27     HasNonWildcard
28   };
29 
NamedTensorMetafinal30   explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names)
31     : names_(names.vec()) {
32     check_invariants();
33   }
NamedTensorMetafinal34   explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names)
35     : names_(std::move(names)) {
36     check_invariants();
37   }
38 
clonefinal39   std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override {
40     return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_);
41   }
42 
namesfinal43   DimnameList names() const { return names_; }
44 
45   // Used for an assertion in TensorImpl.h
slow_dimfinal46   int64_t slow_dim() const override {
47     return static_cast<int64_t>(names_.size());
48   }
49 
check_invariantsfinal50   void check_invariants() const {
51     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
52       std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); }));
53   }
54 
set_namesfinal55   void set_names(HAS_NON_WILDCARD, DimnameList new_names) {
56     TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
57     std::copy(new_names.begin(), new_names.end(), names_.begin());
58     check_invariants();
59   }
60 
set_namesfinal61   void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) {
62     TORCH_INTERNAL_ASSERT(new_names.size() == names_.size());
63     names_ = std::move(new_names);
64     check_invariants();
65   }
66 
67   // INVARIANT: at least one Dimname is non-WILDCARD
68   std::vector<Dimname> names_;
69 };
70 
71 // When NamesMode is disabled, then all operations ignore tensors' names fields.
72 // Concretely speaking, all tensors are treated as having nullopt names.
73 struct TORCH_API NamesMode {
74   static bool is_enabled();
75   static void set_enabled(bool enabled);
76 };
77 
78 
79 // A RAII, thread local (!) guard that enables or disables names upon
80 // construction, and sets it back to the original value upon destruction.
81 struct TORCH_API NoNamesGuard {
NoNamesGuardNoNamesGuard82   NoNamesGuard() : prev_mode(NamesMode::is_enabled()) {
83     NamesMode::set_enabled(false);
84   }
~NoNamesGuardNoNamesGuard85   ~NoNamesGuard() {
86     if (initialized) {
87       reset();
88     }
89   }
resetNoNamesGuard90   void reset() {
91     TORCH_INTERNAL_ASSERT(initialized);
92     NamesMode::set_enabled(prev_mode);
93   }
94  private:
95   bool prev_mode;
96   bool initialized{true};
97 };
98 
99 void check_names_valid_for(const TensorBase& tensor, DimnameList names);
100 void check_names_valid_for(size_t tensor_dim, DimnameList names);
101 
102 // Sets the names of `tensor` to be `names`.
103 TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names);
104 TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names);
105 
106 constexpr size_t kMaxNamedTensorDim = 64;
107 
108 DimnameList default_names(size_t len);
109 
110 namespace impl {
111 
112 // Some helper functions on TensorImpl. Useful for working with names in TH.
113 // XXX: Ideally these would exist as methods on TensorImpl
114 TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names);
115 TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names);
116 
117 void check_names_valid_for(TensorImpl* impl, DimnameList names);
118 
119 // Returns true if the tensor's names exist and are not all 'None'.
120 // Returns false if the tensor's names don't exist (were not allocated),
121 // or if all names are 'None'.
122 // We treat not-allocated-names the same as allocated names that are all 'None'.
123 TORCH_API bool has_names(const TensorImpl* impl);
124 
125 // Returns the names of the tensor's dimensions.
126 // Unnamed tensors are treated as having 'None' in all dimension; this method
127 // would return a DimnameList of all 'None's for an unnamed tensor.
128 TORCH_API DimnameList get_names(const TensorImpl* impl);
129 
130 // This is more of an implementation detail; one should use impl::get_names /
131 // Tensor::names() whenever possible because it provides a cleaner API.
132 // Returns the names of the tensor if they have been allocated; returns nullopt
133 // instead if the haven't been. The names of a tensor are not allocated if a
134 // tensor is constructed with names=None.
135 TORCH_API std::optional<DimnameList> get_opt_names(const TensorImpl* impl);
136 
137 } // namespace impl
138 
139 } // namespace at
140