xref: /aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/types.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/distributed/rpc/types.h>
2 
3 namespace torch::distributed::rpc {
4 
5 // Thread local flag to enforce rref JIT pickling to be allowed only
6 // in the scope of an rpc call. For other scopes like when model is
7 // saved by calling torch.save(), rref is not allowed to be pickled directly.
8 static thread_local bool allowJitRRefPickle = false;
9 
getAllowJitRRefPickle()10 bool getAllowJitRRefPickle() {
11   return allowJitRRefPickle;
12 }
13 
enableJitRRefPickle()14 void enableJitRRefPickle() {
15   allowJitRRefPickle = true;
16 }
17 
disableJitRRefPickle()18 void disableJitRRefPickle() {
19   allowJitRRefPickle = false;
20 }
21 
22 static_assert(
23     std::numeric_limits<local_id_t>::max() <=
24         std::numeric_limits<int64_t>::max(),
25     "The max value of local_id_t must be within the range of int64_t");
26 static_assert(
27     std::numeric_limits<worker_id_t>::max() <=
28         std::numeric_limits<int64_t>::max(),
29     "The max value of worker_id_t must be within the range of int64_t");
30 
31 ///////////////////////////  JitRRefPickleGuard   ///////////////////////////
JitRRefPickleGuard()32 JitRRefPickleGuard::JitRRefPickleGuard() {
33   allowJitRRefPickle = true;
34 }
~JitRRefPickleGuard()35 JitRRefPickleGuard::~JitRRefPickleGuard() {
36   allowJitRRefPickle = false;
37 }
38 
39 ///////////////////////////  GloballyUniqueId   ///////////////////////////
40 
GloballyUniqueId(worker_id_t createdOn,local_id_t localId)41 GloballyUniqueId::GloballyUniqueId(worker_id_t createdOn, local_id_t localId)
42     : createdOn_(createdOn), localId_(localId) {}
43 
operator ==(const GloballyUniqueId & other) const44 bool GloballyUniqueId::operator==(const GloballyUniqueId& other) const {
45   return createdOn_ == other.createdOn_ && localId_ == other.localId_;
46 }
47 
operator !=(const GloballyUniqueId & other) const48 bool GloballyUniqueId::operator!=(const GloballyUniqueId& other) const {
49   return createdOn_ != other.createdOn_ || localId_ != other.localId_;
50 }
51 
toIValue() const52 at::IValue GloballyUniqueId::toIValue() const {
53   return c10::ivalue::Tuple::create(
54       {static_cast<int64_t>(createdOn_), static_cast<int64_t>(localId_)});
55 }
56 
fromIValue(const at::IValue & ivalue)57 GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) {
58   TORCH_INTERNAL_ASSERT(
59       ivalue.isTuple(),
60       "GloballyUniqueId::fromIValue expected ivalue to be a tuple.");
61   const auto& ivalues = ivalue.toTupleRef().elements();
62   TORCH_CHECK(
63       ivalues.size() == 2,
64       "Constructing GloballyUniqueId from ivalue "
65       "expects a GenericList of two elements, but got ",
66       ivalues.size());
67 
68   TORCH_CHECK(
69       ivalues[0].toInt() <= std::numeric_limits<worker_id_t>::max(),
70       "GloballyUniqueId createdOn out of range, got ",
71       ivalues[0].toInt());
72   worker_id_t createdOn = ivalues[0].toInt();
73 
74   TORCH_CHECK(
75       ivalues[1].toInt() <= std::numeric_limits<local_id_t>::max(),
76       "GloballyUniqueId localId out of range, got ",
77       ivalues[1].toInt());
78   local_id_t localId = ivalues[1].toInt();
79 
80   return GloballyUniqueId(createdOn, localId);
81 }
82 
operator <<(std::ostream & os,GloballyUniqueId const & globalId)83 std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) {
84   return os << "GloballyUniqueId(created_on=" << globalId.createdOn_
85             << ", local_id=" << globalId.localId_ << ")";
86 }
87 
88 ///////////////////////////  SerializedPyObj   ///////////////////////////
89 
toIValues()90 std::vector<at::IValue> SerializedPyObj::toIValues() && {
91   std::vector<at::IValue> ivalues;
92   ivalues.reserve(tensors_.size() + 1);
93   for (auto& tensor : tensors_) {
94     ivalues.emplace_back(std::move(tensor));
95   }
96   ivalues.emplace_back(std::move(payload_));
97   return ivalues;
98 }
99 
fromIValues(std::vector<at::IValue> values)100 SerializedPyObj SerializedPyObj::fromIValues(std::vector<at::IValue> values) {
101   std::string payload = values.back().toStringRef();
102   values.pop_back();
103   std::vector<at::Tensor> tensors;
104   tensors.reserve(values.size());
105   for (auto& value : values) {
106     tensors.emplace_back(value.toTensor());
107   }
108   return SerializedPyObj(std::move(payload), std::move(tensors));
109 }
110 
111 } // namespace torch::distributed::rpc
112