1 #pragma once 2 3 #include <ATen/core/Dimname.h> 4 #include <c10/core/TensorImpl.h> 5 6 namespace at { 7 8 class TensorBase; 9 10 // XXX: This file exists because TensorImpl is in c10, but Dimname is in ATen. 11 // Due to the c10/ATen library split, TensorImpl cannot depend on Dimname, 12 // so we have a couple of workarounds. 13 // 14 // In the long term, we'll move Dimname to c10 and everything in this file 15 // can be refactored out. The main blocker for that is that "c10::Symbol" 16 // actually exists outside of c10 and needs to be moved in. 17 18 // TensorImpl has a unique_ptr<NamedTensorMetaInterface> field. 19 // XXX: Ideally we would just put std::optional<vector<Dimname>> into TensorImpl. 20 // 21 // This class has an important invariant: there must be at least ONE 22 // non-wildcard 23 struct TORCH_API NamedTensorMeta final : public c10::NamedTensorMetaInterface { 24 // This enum is to remind people that the invariant on constructors is that 25 // the list of dimnames must have at least one non-wildcard 26 enum HAS_NON_WILDCARD { 27 HasNonWildcard 28 }; 29 NamedTensorMetafinal30 explicit NamedTensorMeta(HAS_NON_WILDCARD, DimnameList names) 31 : names_(names.vec()) { 32 check_invariants(); 33 } NamedTensorMetafinal34 explicit NamedTensorMeta(HAS_NON_WILDCARD, std::vector<Dimname>&& names) 35 : names_(std::move(names)) { 36 check_invariants(); 37 } 38 clonefinal39 std::unique_ptr<c10::NamedTensorMetaInterface> clone() const override { 40 return std::make_unique<NamedTensorMeta>(HasNonWildcard, names_); 41 } 42 namesfinal43 DimnameList names() const { return names_; } 44 45 // Used for an assertion in TensorImpl.h slow_dimfinal46 int64_t slow_dim() const override { 47 return static_cast<int64_t>(names_.size()); 48 } 49 check_invariantsfinal50 void check_invariants() const { 51 TORCH_INTERNAL_ASSERT_DEBUG_ONLY( 52 std::any_of(names_.begin(), names_.end(), [](const Dimname& n) { return !n.isWildcard(); })); 53 } 54 set_namesfinal55 void set_names(HAS_NON_WILDCARD, DimnameList new_names) { 56 TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); 57 std::copy(new_names.begin(), new_names.end(), names_.begin()); 58 check_invariants(); 59 } 60 set_namesfinal61 void set_names(HAS_NON_WILDCARD, std::vector<Dimname>&& new_names) { 62 TORCH_INTERNAL_ASSERT(new_names.size() == names_.size()); 63 names_ = std::move(new_names); 64 check_invariants(); 65 } 66 67 // INVARIANT: at least one Dimname is non-WILDCARD 68 std::vector<Dimname> names_; 69 }; 70 71 // When NamesMode is disabled, then all operations ignore tensors' names fields. 72 // Concretely speaking, all tensors are treated as having nullopt names. 73 struct TORCH_API NamesMode { 74 static bool is_enabled(); 75 static void set_enabled(bool enabled); 76 }; 77 78 79 // A RAII, thread local (!) guard that enables or disables names upon 80 // construction, and sets it back to the original value upon destruction. 81 struct TORCH_API NoNamesGuard { NoNamesGuardNoNamesGuard82 NoNamesGuard() : prev_mode(NamesMode::is_enabled()) { 83 NamesMode::set_enabled(false); 84 } ~NoNamesGuardNoNamesGuard85 ~NoNamesGuard() { 86 if (initialized) { 87 reset(); 88 } 89 } resetNoNamesGuard90 void reset() { 91 TORCH_INTERNAL_ASSERT(initialized); 92 NamesMode::set_enabled(prev_mode); 93 } 94 private: 95 bool prev_mode; 96 bool initialized{true}; 97 }; 98 99 void check_names_valid_for(const TensorBase& tensor, DimnameList names); 100 void check_names_valid_for(size_t tensor_dim, DimnameList names); 101 102 // Sets the names of `tensor` to be `names`. 103 TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::optional<DimnameList> names); 104 TORCH_API const TensorBase& internal_set_names_inplace(const TensorBase& tensor, std::vector<Dimname>&& names, bool validate_names); 105 106 constexpr size_t kMaxNamedTensorDim = 64; 107 108 DimnameList default_names(size_t len); 109 110 namespace impl { 111 112 // Some helper functions on TensorImpl. Useful for working with names in TH. 113 // XXX: Ideally these would exist as methods on TensorImpl 114 TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::optional<DimnameList> names, bool validate_names); 115 TORCH_API void internal_set_names_inplace(TensorImpl* impl, std::vector<Dimname>&& names, bool validate_names); 116 117 void check_names_valid_for(TensorImpl* impl, DimnameList names); 118 119 // Returns true if the tensor's names exist and are not all 'None'. 120 // Returns false if the tensor's names don't exist (were not allocated), 121 // or if all names are 'None'. 122 // We treat not-allocated-names the same as allocated names that are all 'None'. 123 TORCH_API bool has_names(const TensorImpl* impl); 124 125 // Returns the names of the tensor's dimensions. 126 // Unnamed tensors are treated as having 'None' in all dimension; this method 127 // would return a DimnameList of all 'None's for an unnamed tensor. 128 TORCH_API DimnameList get_names(const TensorImpl* impl); 129 130 // This is more of an implementation detail; one should use impl::get_names / 131 // Tensor::names() whenever possible because it provides a cleaner API. 132 // Returns the names of the tensor if they have been allocated; returns nullopt 133 // instead if the haven't been. The names of a tensor are not allocated if a 134 // tensor is constructed with names=None. 135 TORCH_API std::optional<DimnameList> get_opt_names(const TensorImpl* impl); 136 137 } // namespace impl 138 139 } // namespace at 140