1 /*
2  * Copyright 2017-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3  */
4 
5 package kotlinx.serialization.internal
6 
7 import kotlinx.serialization.*
8 import kotlinx.serialization.descriptors.SerialDescriptor
9 import kotlinx.serialization.encoding.CompositeDecoder
10 
11 @OptIn(ExperimentalSerializationApi::class)
12 @CoreFriendModuleApi
13 public class ElementMarker(
14     private val descriptor: SerialDescriptor,
15     // Instead of inheritance and virtual function in order to keep cross-module internal modifier via suppresses
16     // Can be reworked via public + internal api if necessary
17     private val readIfAbsent: (SerialDescriptor, Int) -> Boolean
18 ) {
19     /*
20      * Element decoding marks from given bytes.
21      * The element index is the same as the set bit position.
22      * Marks for the lowest 64 elements are always stored in a single Long value, higher elements stores in long array.
23      */
24     private var lowerMarks: Long
25     private val highMarksArray: LongArray
26 
27     private companion object {
28         private val EMPTY_HIGH_MARKS = LongArray(0)
29     }
30 
31     init {
32         val elementsCount = descriptor.elementsCount
33         if (elementsCount <= Long.SIZE_BITS) {
34             lowerMarks = if (elementsCount == Long.SIZE_BITS) {
35                 // number of bits in the mark is equal to the number of fields
36                 0L
37             } else {
38                 // (1 - elementsCount) bits are always 1 since there are no fields for them
39                 -1L shl elementsCount
40             }
41             highMarksArray = EMPTY_HIGH_MARKS
42         } else {
43             lowerMarks = 0L
44             highMarksArray = prepareHighMarksArray(elementsCount)
45         }
46     }
47 
marknull48     public fun mark(index: Int) {
49         if (index < Long.SIZE_BITS) {
50             lowerMarks = lowerMarks or (1L shl index)
51         } else {
52             markHigh(index)
53         }
54     }
55 
nextUnmarkedIndexnull56     public fun nextUnmarkedIndex(): Int {
57         val elementsCount = descriptor.elementsCount
58         while (lowerMarks != -1L) {
59             val index = lowerMarks.inv().countTrailingZeroBits()
60             lowerMarks = lowerMarks or (1L shl index)
61 
62             if (readIfAbsent(descriptor, index)) {
63                 return index
64             }
65         }
66 
67         if (elementsCount > Long.SIZE_BITS) {
68             return nextUnmarkedHighIndex()
69         }
70         return CompositeDecoder.DECODE_DONE
71     }
72 
prepareHighMarksArraynull73     private fun prepareHighMarksArray(elementsCount: Int): LongArray {
74         // (elementsCount - 1) / Long.SIZE_BITS
75         // (elementsCount - 1) because only one Long value is needed to store 64 fields etc
76         val slotsCount = (elementsCount - 1) ushr 6
77         // elementsCount % Long.SIZE_BITS
78         val elementsInLastSlot = elementsCount and (Long.SIZE_BITS - 1)
79         val highMarks = LongArray(slotsCount)
80         // if (elementsCount % Long.SIZE_BITS) == 0 means that the fields occupy all bits in mark
81         if (elementsInLastSlot != 0) {
82             // all marks except the higher are always 0
83             highMarks[highMarks.lastIndex] = -1L shl elementsCount
84         }
85         return highMarks
86     }
87 
markHighnull88     private fun markHigh(index: Int) {
89         // (index / Long.SIZE_BITS) - 1
90         val slot = (index ushr 6) - 1
91         // index % Long.SIZE_BITS
92         val offsetInSlot = index and (Long.SIZE_BITS - 1)
93         highMarksArray[slot] = highMarksArray[slot] or (1L shl offsetInSlot)
94     }
95 
nextUnmarkedHighIndexnull96     private fun nextUnmarkedHighIndex(): Int {
97         for (slot in highMarksArray.indices) {
98             // (slot + 1) because first element in high marks has index 64
99             val slotOffset = (slot + 1) * Long.SIZE_BITS
100             // store in a variable so as not to frequently use the array
101             var slotMarks = highMarksArray[slot]
102 
103             while (slotMarks != -1L) {
104                 val indexInSlot = slotMarks.inv().countTrailingZeroBits()
105                 slotMarks = slotMarks or (1L shl indexInSlot)
106 
107                 val index = slotOffset + indexInSlot
108                 if (readIfAbsent(descriptor, index)) {
109                     highMarksArray[slot] = slotMarks
110                     return index
111                 }
112             }
113             highMarksArray[slot] = slotMarks
114         }
115         return CompositeDecoder.DECODE_DONE
116     }
117 }
118