xref: /aosp_15_r20/external/okio/okio/src/jvmMain/kotlin/okio/CipherSink.kt (revision f9742813c14b702d71392179818a9e591da8620c)
1 /*
2  * Copyright (C) 2020 Square, Inc. and others.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package okio
17 
18 import java.io.IOException
19 import javax.crypto.Cipher
20 
21 class CipherSink(
22   private val sink: BufferedSink,
23   val cipher: Cipher,
24 ) : Sink {
25   private val blockSize = cipher.blockSize
26   private var closed = false
27 
28   init {
29     // Require block cipher
<lambda>null30     require(blockSize > 0) { "Block cipher required $cipher" }
31   }
32 
33   @Throws(IOException::class)
writenull34   override fun write(source: Buffer, byteCount: Long) {
35     checkOffsetAndCount(source.size, 0, byteCount)
36     check(!closed) { "closed" }
37 
38     var remaining = byteCount
39     while (remaining > 0) {
40       val size = update(source, remaining)
41       remaining -= size
42     }
43   }
44 
updatenull45   private fun update(source: Buffer, remaining: Long): Int {
46     val head = source.head!!
47     var size = minOf(remaining, head.limit - head.pos).toInt()
48     val buffer = sink.buffer
49 
50     // Shorten input until output is guaranteed to fit within a segment
51     var outputSize = cipher.getOutputSize(size)
52     while (outputSize > Segment.SIZE) {
53       if (size <= blockSize) {
54         // Bug: For AES-GCM on Android `update` method never outputs any data
55         // As a consequence, `getOutputSize` just keeps increasing indefinitely after each update
56         // When that happens, the fallback is to perform the update operation without using a pre-allocated segment
57         sink.write(cipher.update(source.readByteArray(remaining)))
58         return remaining.toInt()
59       }
60       size -= blockSize
61       outputSize = cipher.getOutputSize(size)
62     }
63     val s = buffer.writableSegment(outputSize)
64 
65     val ciphered = cipher.update(head.data, head.pos, size, s.data, s.limit)
66 
67     s.limit += ciphered
68     buffer.size += ciphered
69 
70     // We allocated a tail segment, but didn't end up needing it. Recycle!
71     if (s.pos == s.limit) {
72       buffer.head = s.pop()
73       SegmentPool.recycle(s)
74     }
75 
76     sink.emitCompleteSegments()
77 
78     // Mark those bytes as read.
79     source.size -= size
80     head.pos += size
81     if (head.pos == head.limit) {
82       source.head = head.pop()
83       SegmentPool.recycle(head)
84     }
85 
86     return size
87   }
88 
flushnull89   override fun flush() = sink.flush()
90 
91   override fun timeout() = sink.timeout()
92 
93   @Throws(IOException::class)
94   override fun close() {
95     if (closed) return
96     closed = true
97 
98     var thrown = doFinal()
99 
100     try {
101       sink.close()
102     } catch (e: Throwable) {
103       if (thrown == null) thrown = e
104     }
105 
106     if (thrown != null) throw thrown
107   }
108 
doFinalnull109   private fun doFinal(): Throwable? {
110     val outputSize = cipher.getOutputSize(0)
111     if (outputSize == 0) return null
112 
113     if (outputSize > Segment.SIZE) {
114       // Bug: For AES-GCM on Android `update` method never outputs any data
115       // As a consequence, `doFinal` returns the fully encrypted data, which may be arbitrarily large
116       // When that happens, the fallback is to perform the `doFinal` operation without using a pre-allocated segment
117       try {
118         sink.write(cipher.doFinal())
119       } catch (t: Throwable) {
120         return t
121       }
122       return null
123     }
124 
125     var thrown: Throwable? = null
126     val buffer = sink.buffer
127 
128     // For block cipher, output size cannot exceed block size in doFinal
129     val s = buffer.writableSegment(outputSize)
130 
131     try {
132       val ciphered = cipher.doFinal(s.data, s.limit)
133 
134       s.limit += ciphered
135       buffer.size += ciphered
136     } catch (e: Throwable) {
137       thrown = e
138     }
139 
140     if (s.pos == s.limit) {
141       buffer.head = s.pop()
142       SegmentPool.recycle(s)
143     }
144 
145     return thrown
146   }
147 }
148