xref: /aosp_15_r20/external/pytorch/c10/core/impl/SizesAndStrides.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/impl/SizesAndStrides.h>
2 
3 namespace c10::impl {
4 
resizeSlowPath(const size_t newSize,const size_t oldSize)5 void SizesAndStrides::resizeSlowPath(
6     const size_t newSize,
7     const size_t oldSize) {
8   if (newSize <= C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE) {
9     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
10         !isInline(),
11         "resizeSlowPath called when fast path should have been hit!");
12     int64_t* tempStorage = outOfLineStorage_;
13     memcpy(
14         &inlineStorage_[0],
15         &tempStorage[0],
16         C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0]));
17     memcpy(
18         &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE],
19         &tempStorage[oldSize],
20         C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE * sizeof(inlineStorage_[0]));
21     // CANNOT USE freeOutOfLineStorage() HERE! outOfLineStorage_
22     // HAS BEEN OVERWRITTEN!
23     // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
24     free(tempStorage);
25   } else {
26     if (isInline()) {
27       // CANNOT USE allocateOutOfLineStorage(newSize) HERE! WOULD
28       // OVERWRITE inlineStorage_!
29       int64_t* tempStorage =
30           // NOLINTNEXTLINE(cppcoreguidelines-no-malloc)
31           static_cast<int64_t*>(malloc(storageBytes(newSize)));
32       TORCH_CHECK(
33           tempStorage,
34           "Could not allocate memory to change Tensor SizesAndStrides!");
35       const auto bytesToCopy = oldSize * sizeof(inlineStorage_[0]);
36       const auto bytesToZero = (newSize > oldSize)
37           ? (newSize - oldSize) * sizeof(tempStorage[0])
38           : 0;
39       memcpy(&tempStorage[0], &inlineStorage_[0], bytesToCopy);
40       if (bytesToZero) {
41         memset(&tempStorage[oldSize], 0, bytesToZero);
42       }
43       memcpy(
44           &tempStorage[newSize],
45           &inlineStorage_[C10_SIZES_AND_STRIDES_MAX_INLINE_SIZE],
46           bytesToCopy);
47       if (bytesToZero) {
48         memset(&tempStorage[newSize + oldSize], 0, bytesToZero);
49       }
50       outOfLineStorage_ = tempStorage;
51     } else {
52       const bool isGrowing = oldSize < newSize;
53       if (isGrowing) {
54         // Resize before shifting so that we have room.
55         resizeOutOfLineStorage(newSize);
56       }
57       // Shift the old strides to their new starting point. Note
58       // that this does not occur in the inline path above because
59       // the stride starting point is not moving.
60       memmove(
61           outOfLineStorage_ + newSize,
62           outOfLineStorage_ + oldSize,
63           std::min(oldSize, newSize) * sizeof(outOfLineStorage_[0]));
64       if (!isGrowing) {
65         // Resize after shifting so that we don't lose data.
66         resizeOutOfLineStorage(newSize);
67       } else {
68         // Zero the end of the sizes portion.
69         const auto bytesToZero =
70             (newSize - oldSize) * sizeof(outOfLineStorage_[0]);
71         memset(&outOfLineStorage_[oldSize], 0, bytesToZero);
72         memset(&outOfLineStorage_[newSize + oldSize], 0, bytesToZero);
73       }
74     }
75   }
76   size_ = newSize;
77 }
78 
79 } // namespace c10::impl
80