1 #include <torch/csrc/distributed/rpc/types.h>
2
3 namespace torch::distributed::rpc {
4
5 // Thread local flag to enforce rref JIT pickling to be allowed only
6 // in the scope of an rpc call. For other scopes like when model is
7 // saved by calling torch.save(), rref is not allowed to be pickled directly.
8 static thread_local bool allowJitRRefPickle = false;
9
getAllowJitRRefPickle()10 bool getAllowJitRRefPickle() {
11 return allowJitRRefPickle;
12 }
13
enableJitRRefPickle()14 void enableJitRRefPickle() {
15 allowJitRRefPickle = true;
16 }
17
disableJitRRefPickle()18 void disableJitRRefPickle() {
19 allowJitRRefPickle = false;
20 }
21
22 static_assert(
23 std::numeric_limits<local_id_t>::max() <=
24 std::numeric_limits<int64_t>::max(),
25 "The max value of local_id_t must be within the range of int64_t");
26 static_assert(
27 std::numeric_limits<worker_id_t>::max() <=
28 std::numeric_limits<int64_t>::max(),
29 "The max value of worker_id_t must be within the range of int64_t");
30
31 /////////////////////////// JitRRefPickleGuard ///////////////////////////
JitRRefPickleGuard()32 JitRRefPickleGuard::JitRRefPickleGuard() {
33 allowJitRRefPickle = true;
34 }
~JitRRefPickleGuard()35 JitRRefPickleGuard::~JitRRefPickleGuard() {
36 allowJitRRefPickle = false;
37 }
38
39 /////////////////////////// GloballyUniqueId ///////////////////////////
40
GloballyUniqueId(worker_id_t createdOn,local_id_t localId)41 GloballyUniqueId::GloballyUniqueId(worker_id_t createdOn, local_id_t localId)
42 : createdOn_(createdOn), localId_(localId) {}
43
operator ==(const GloballyUniqueId & other) const44 bool GloballyUniqueId::operator==(const GloballyUniqueId& other) const {
45 return createdOn_ == other.createdOn_ && localId_ == other.localId_;
46 }
47
operator !=(const GloballyUniqueId & other) const48 bool GloballyUniqueId::operator!=(const GloballyUniqueId& other) const {
49 return createdOn_ != other.createdOn_ || localId_ != other.localId_;
50 }
51
toIValue() const52 at::IValue GloballyUniqueId::toIValue() const {
53 return c10::ivalue::Tuple::create(
54 {static_cast<int64_t>(createdOn_), static_cast<int64_t>(localId_)});
55 }
56
fromIValue(const at::IValue & ivalue)57 GloballyUniqueId GloballyUniqueId::fromIValue(const at::IValue& ivalue) {
58 TORCH_INTERNAL_ASSERT(
59 ivalue.isTuple(),
60 "GloballyUniqueId::fromIValue expected ivalue to be a tuple.");
61 const auto& ivalues = ivalue.toTupleRef().elements();
62 TORCH_CHECK(
63 ivalues.size() == 2,
64 "Constructing GloballyUniqueId from ivalue "
65 "expects a GenericList of two elements, but got ",
66 ivalues.size());
67
68 TORCH_CHECK(
69 ivalues[0].toInt() <= std::numeric_limits<worker_id_t>::max(),
70 "GloballyUniqueId createdOn out of range, got ",
71 ivalues[0].toInt());
72 worker_id_t createdOn = ivalues[0].toInt();
73
74 TORCH_CHECK(
75 ivalues[1].toInt() <= std::numeric_limits<local_id_t>::max(),
76 "GloballyUniqueId localId out of range, got ",
77 ivalues[1].toInt());
78 local_id_t localId = ivalues[1].toInt();
79
80 return GloballyUniqueId(createdOn, localId);
81 }
82
operator <<(std::ostream & os,GloballyUniqueId const & globalId)83 std::ostream& operator<<(std::ostream& os, GloballyUniqueId const& globalId) {
84 return os << "GloballyUniqueId(created_on=" << globalId.createdOn_
85 << ", local_id=" << globalId.localId_ << ")";
86 }
87
88 /////////////////////////// SerializedPyObj ///////////////////////////
89
toIValues()90 std::vector<at::IValue> SerializedPyObj::toIValues() && {
91 std::vector<at::IValue> ivalues;
92 ivalues.reserve(tensors_.size() + 1);
93 for (auto& tensor : tensors_) {
94 ivalues.emplace_back(std::move(tensor));
95 }
96 ivalues.emplace_back(std::move(payload_));
97 return ivalues;
98 }
99
fromIValues(std::vector<at::IValue> values)100 SerializedPyObj SerializedPyObj::fromIValues(std::vector<at::IValue> values) {
101 std::string payload = values.back().toStringRef();
102 values.pop_back();
103 std::vector<at::Tensor> tensors;
104 tensors.reserve(values.size());
105 for (auto& value : values) {
106 tensors.emplace_back(value.toTensor());
107 }
108 return SerializedPyObj(std::move(payload), std::move(tensors));
109 }
110
111 } // namespace torch::distributed::rpc
112