xref: /aosp_15_r20/external/okio/okio/src/jvmMain/kotlin/okio/Throttler.kt (revision f9742813c14b702d71392179818a9e591da8620c)
1 /*
2  * Copyright (C) 2018 Square, Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package okio
17 
18 import java.io.IOException
19 import java.io.InterruptedIOException
20 import java.util.concurrent.locks.Condition
21 import java.util.concurrent.locks.ReentrantLock
22 import kotlin.concurrent.withLock
23 
24 /**
25  * Enables limiting of Source and Sink throughput. Attach to this throttler via [source] and [sink]
26  * and set the desired throughput via [bytesPerSecond]. Multiple Sources and Sinks can be
27  * attached to a single Throttler and they will be throttled as a group, where their combined
28  * throughput will not exceed the desired throughput. The same Source or Sink can be attached to
29  * multiple Throttlers and its throughput will not exceed the desired throughput of any of the
30  * Throttlers.
31  *
32  * This class has these tuning parameters:
33  *
34  *  * `bytesPerSecond`: Maximum sustained throughput. Use 0 for no limit.
35  *  * `waitByteCount`: When the requested byte count is greater than this many bytes and isn't
36  *    immediately available, only wait until we can allocate at least this many bytes. Use this to
37  *    set the ideal byte count during sustained throughput.
38  *  * `maxByteCount`: Maximum number of bytes to allocate on any call. This is also the number of
39  *    bytes that will be returned before any waiting.
40  */
41 class Throttler internal constructor(
42   /**
43    * The nanoTime that we've consumed all bytes through. This is never greater than the current
44    * nanoTime plus nanosForMaxByteCount.
45    */
46   private var allocatedUntil: Long,
47 ) {
48   private var bytesPerSecond: Long = 0L
49   private var waitByteCount: Long = 8 * 1024 // 8 KiB.
50   private var maxByteCount: Long = 256 * 1024 // 256 KiB.
51 
52   val lock: ReentrantLock = ReentrantLock()
53   val condition: Condition = lock.newCondition()
54 
55   constructor() : this(allocatedUntil = System.nanoTime())
56 
57   /** Sets the rate at which bytes will be allocated. Use 0 for no limit. */
58   @JvmOverloads
bytesPerSecondnull59   fun bytesPerSecond(
60     bytesPerSecond: Long,
61     waitByteCount: Long = this.waitByteCount,
62     maxByteCount: Long = this.maxByteCount,
63   ) {
64     lock.withLock {
65       require(bytesPerSecond >= 0)
66       require(waitByteCount > 0)
67       require(maxByteCount >= waitByteCount)
68 
69       this.bytesPerSecond = bytesPerSecond
70       this.waitByteCount = waitByteCount
71       this.maxByteCount = maxByteCount
72       condition.signalAll()
73     }
74   }
75 
76   /**
77    * Take up to `byteCount` bytes, waiting if necessary. Returns the number of bytes that were
78    * taken.
79    */
takenull80   internal fun take(byteCount: Long): Long {
81     require(byteCount > 0)
82 
83     lock.withLock {
84       while (true) {
85         val now = System.nanoTime()
86         val byteCountOrWaitNanos = byteCountOrWaitNanos(now, byteCount)
87         if (byteCountOrWaitNanos >= 0) return byteCountOrWaitNanos
88         condition.awaitNanos(-byteCountOrWaitNanos)
89       }
90     }
91   }
92 
93   /**
94    * Returns the byte count to take immediately or -1 times the number of nanos to wait until the
95    * next attempt. If the returned value is negative it should be interpreted as a duration in
96    * nanos; if it is positive it should be interpreted as a byte count.
97    */
byteCountOrWaitNanosnull98   internal fun byteCountOrWaitNanos(now: Long, byteCount: Long): Long {
99     if (bytesPerSecond == 0L) return byteCount // No limits.
100 
101     val idleInNanos = maxOf(allocatedUntil - now, 0L)
102     val immediateBytes = maxByteCount - idleInNanos.nanosToBytes()
103 
104     // Fulfill the entire request without waiting.
105     if (immediateBytes >= byteCount) {
106       allocatedUntil = now + idleInNanos + byteCount.bytesToNanos()
107       return byteCount
108     }
109 
110     // Fulfill a big-enough block without waiting.
111     if (immediateBytes >= waitByteCount) {
112       allocatedUntil = now + maxByteCount.bytesToNanos()
113       return immediateBytes
114     }
115 
116     // Looks like we'll need to wait until we can take the minimum required bytes.
117     val minByteCount = minOf(waitByteCount, byteCount)
118     val minWaitNanos = idleInNanos + (minByteCount - maxByteCount).bytesToNanos()
119 
120     // But if the wait duration truncates to zero nanos after division, don't wait.
121     if (minWaitNanos == 0L) {
122       allocatedUntil = now + maxByteCount.bytesToNanos()
123       return minByteCount
124     }
125 
126     return -minWaitNanos
127   }
128 
Longnull129   private fun Long.nanosToBytes() = this * bytesPerSecond / 1_000_000_000L
130 
131   private fun Long.bytesToNanos() = this * 1_000_000_000L / bytesPerSecond
132 
133   /** Create a Source which honors this Throttler.  */
134   fun source(source: Source): Source {
135     return object : ForwardingSource(source) {
136       override fun read(sink: Buffer, byteCount: Long): Long {
137         try {
138           val toRead = take(byteCount)
139           return super.read(sink, toRead)
140         } catch (e: InterruptedException) {
141           Thread.currentThread().interrupt()
142           throw InterruptedIOException("interrupted")
143         }
144       }
145     }
146   }
147 
148   /** Create a Sink which honors this Throttler.  */
sinknull149   fun sink(sink: Sink): Sink {
150     return object : ForwardingSink(sink) {
151       @Throws(IOException::class)
152       override fun write(source: Buffer, byteCount: Long) {
153         try {
154           var remaining = byteCount
155           while (remaining > 0L) {
156             val toWrite = take(remaining)
157             super.write(source, toWrite)
158             remaining -= toWrite
159           }
160         } catch (e: InterruptedException) {
161           Thread.currentThread().interrupt()
162           throw InterruptedIOException("interrupted")
163         }
164       }
165     }
166   }
167 }
168