xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/slice_indices_adjust.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/slice_indices_adjust.h>
2 
3 #include <c10/util/Exception.h>
4 #include <cstdint>
5 
6 namespace torch::jit {
7 
slice_indices_adjust(int64_t length,int64_t * start,int64_t * stop,int64_t step)8 int64_t slice_indices_adjust(
9     int64_t length,
10     int64_t* start,
11     int64_t* stop,
12     int64_t step) {
13   TORCH_CHECK(step != 0, "List slice should have non-zero step")
14   TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds")
15 
16   // Comes from PySlice_Unpack.
17   if (*start == INT64_MAX) {
18     *start = (step < 0) ? INT64_MAX : 0;
19   }
20   if (*stop == INT64_MAX) {
21     *stop = (step < 0) ? INT64_MIN : INT64_MAX;
22   }
23 
24   // Comes from PySlice_AdjustIndices.
25   if (*start < 0) {
26     *start += length;
27     if (*start < 0) {
28       *start = (step < 0) ? -1 : 0;
29     }
30   } else if (*start >= length) {
31     *start = (step < 0) ? length - 1 : length;
32   }
33 
34   if (*stop < 0) {
35     *stop += length;
36     if (*stop < 0) {
37       *stop = (step < 0) ? -1 : 0;
38     }
39   } else if (*stop >= length) {
40     *stop = (step < 0) ? length - 1 : length;
41   }
42 
43   if (step < 0) {
44     if (*stop < *start) {
45       return (*start - *stop - 1) / (-step) + 1;
46     }
47   } else {
48     if (*start < *stop) {
49       return (*stop - *start - 1) / step + 1;
50     }
51   }
52   return 0;
53 }
54 
55 } // namespace torch::jit
56