xref: /aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/tensor_converter.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
2 #include <torch/csrc/inductor/aoti_torch/utils.h>
3 
4 namespace torch::aot_inductor {
5 
unsafe_alloc_new_handles_from_tensors(std::vector<at::Tensor> & tensors)6 std::vector<AtenTensorHandle> unsafe_alloc_new_handles_from_tensors(
7     std::vector<at::Tensor>& tensors) {
8   std::vector<AtenTensorHandle> result;
9   result.reserve(tensors.size());
10   for (auto tensor : tensors) {
11     auto allocated = new at::Tensor(std::move(tensor));
12     result.push_back(tensor_pointer_to_tensor_handle(allocated));
13   }
14   return result;
15 }
16 
alloc_tensors_by_stealing_from_handles(AtenTensorHandle * handles,size_t length)17 std::vector<at::Tensor> alloc_tensors_by_stealing_from_handles(
18     AtenTensorHandle* handles,
19     size_t length) {
20   // Find duplicates by recording the last known index for each handle.
21   std::unordered_map<AtenTensorHandle, size_t> lastKnownIdx;
22   for (size_t i = 0; i < length; i++) {
23     lastKnownIdx[handles[i]] = i;
24   }
25 
26   std::vector<at::Tensor> result;
27   result.reserve(length);
28   for (size_t i = 0; i < length; i++) {
29     if (handles[i] == nullptr) {
30       result.emplace_back();
31       continue;
32     }
33 
34     at::Tensor tensor = *tensor_handle_to_tensor_pointer(handles[i]);
35     if (lastKnownIdx[handles[i]] != i) {
36       result.emplace_back(tensor);
37     } else {
38       result.emplace_back(std::move(tensor));
39       aoti_torch_delete_tensor_object(handles[i]);
40     }
41     handles[i] = nullptr;
42   }
43 
44   return result;
45 }
46 
47 } // namespace torch::aot_inductor
48