1 #include <torch/csrc/jit/runtime/slice_indices_adjust.h> 2 3 #include <c10/util/Exception.h> 4 #include <cstdint> 5 6 namespace torch::jit { 7 slice_indices_adjust(int64_t length,int64_t * start,int64_t * stop,int64_t step)8int64_t slice_indices_adjust( 9 int64_t length, 10 int64_t* start, 11 int64_t* stop, 12 int64_t step) { 13 TORCH_CHECK(step != 0, "List slice should have non-zero step") 14 TORCH_CHECK(step >= -INT64_MAX, "List slice step is out of bounds") 15 16 // Comes from PySlice_Unpack. 17 if (*start == INT64_MAX) { 18 *start = (step < 0) ? INT64_MAX : 0; 19 } 20 if (*stop == INT64_MAX) { 21 *stop = (step < 0) ? INT64_MIN : INT64_MAX; 22 } 23 24 // Comes from PySlice_AdjustIndices. 25 if (*start < 0) { 26 *start += length; 27 if (*start < 0) { 28 *start = (step < 0) ? -1 : 0; 29 } 30 } else if (*start >= length) { 31 *start = (step < 0) ? length - 1 : length; 32 } 33 34 if (*stop < 0) { 35 *stop += length; 36 if (*stop < 0) { 37 *stop = (step < 0) ? -1 : 0; 38 } 39 } else if (*stop >= length) { 40 *stop = (step < 0) ? length - 1 : length; 41 } 42 43 if (step < 0) { 44 if (*stop < *start) { 45 return (*start - *stop - 1) / (-step) + 1; 46 } 47 } else { 48 if (*start < *stop) { 49 return (*stop - *start - 1) / step + 1; 50 } 51 } 52 return 0; 53 } 54 55 } // namespace torch::jit 56