Home
last modified time | relevance | path

Searched refs:tupleElements (Results 1 – 6 of 6) sorted by relevance

/aosp_15_r20/external/pytorch/torch/csrc/distributed/autograd/rpc_messages/
H A Dpropagate_gradients_req.cpp59 const auto& tupleElements = tuple.toTupleRef().elements(); in fromMessage() local
62 TORCH_INTERNAL_ASSERT(tupleElements.size() >= 3); in fromMessage()
65 bool retainGraph = tupleElements.back().toBool(); in fromMessage()
70 autogradMessageId = tupleElements[tupleElements.size() - 2].toInt(); in fromMessage()
71 autogradContextId = tupleElements[tupleElements.size() - 3].toInt(); in fromMessage()
76 std::vector<Variable> grads(tupleElements.size() - 3); in fromMessage()
77 for (const auto i : c10::irange(tupleElements.size() - 3)) { in fromMessage()
78 grads[i] = tupleElements[i].toTensor(); in fromMessage()
H A Drpc_with_profiling_resp.cpp109 auto tupleElements = rpc::readWrappedPayload(payload, message); in fromMessage() local
112 tupleElements.size() >= kProfileEventsStartIdx, in fromMessage()
117 tupleElements.size())); in fromMessage()
119 static_cast<rpc::MessageType>(tupleElements[0].toInt()); in fromMessage()
120 rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]); in fromMessage()
121 int profiledEventsSize = tupleElements[2].toInt(); in fromMessage()
127 TORCH_CHECK(static_cast<size_t>(i) < tupleElements.size()); in fromMessage()
130 torch::autograd::profiler::LegacyEvent::fromIValue(tupleElements[i]); in fromMessage()
H A Drpc_with_autograd.cpp103 auto tupleElements = rpc::readWrappedPayload(payload, message); in fromMessage() local
106 TORCH_INTERNAL_ASSERT(tupleElements.size() == 5); in fromMessage()
108 static_cast<MessageType>(tupleElements[0].toInt()); in fromMessage()
110 tupleElements[1].toInt(), tupleElements[2].toInt()); in fromMessage()
111 worker_id_t workerId = tupleElements[3].toInt(); in fromMessage()
113 tupleElements[4].to<c10::Dict<std::string, std::string>>(); in fromMessage()
H A Drpc_with_profiling_req.cpp112 auto tupleElements = rpc::readWrappedPayload(payload, message); in fromMessage() local
115 tupleElements.size() == kProfilingResponseElementExpectedSize, in fromMessage()
120 tupleElements.size())); in fromMessage()
122 static_cast<rpc::MessageType>(tupleElements[0].toInt()); in fromMessage()
126 torch::autograd::profiler::ProfilerConfig::fromIValue(tupleElements[1]); in fromMessage()
128 rpc::ProfilingId profilerId = rpc::ProfilingId::fromIValue(tupleElements[2]); in fromMessage()
H A Drref_backward_req.cpp49 const auto& tupleElements = std::move(*std::move(tuple).toTuple()).elements(); in fromMessage() local
52 TORCH_INTERNAL_ASSERT(tupleElements.size() == 3); in fromMessage()
55 bool retainGraph = tupleElements[2].toBool(); in fromMessage()
56 int64_t autogradContextId = tupleElements[1].toInt(); in fromMessage()
57 rpc::RRefId rrefId = rpc::RRefId::fromIValue(tupleElements[0]); in fromMessage()
/aosp_15_r20/external/pytorch/torch/csrc/distributed/rpc/
H A Dutils.cpp518 std::vector<at::IValue> tupleElements = tuple.toTupleRef().elements().vec(); in readWrappedPayload() local
520 return tupleElements; in readWrappedPayload()