<lambda>null1 @file:OptIn(ExperimentalSerializationApi::class)
2
3 package kotlinx.serialization.protobuf.schema
4
5 import kotlinx.serialization.*
6 import kotlinx.serialization.builtins.*
7 import kotlinx.serialization.descriptors.*
8 import kotlinx.serialization.protobuf.*
9 import kotlinx.serialization.protobuf.internal.*
10
11 /**
12 * Experimental generator of ProtoBuf schema that is compatible with [serializable][Serializable] Kotlin classes
13 * and data encoded and decoded by [ProtoBuf] format.
14 *
15 * The schema is generated based on provided [SerialDescriptor] and is compatible with proto2 schema definition.
16 * An arbitrary Kotlin class represent much wider object domain than the ProtoBuf specification, thus schema generator
17 * has the following list of restrictions:
18 *
19 * * Serial name of the class and all its fields should be a valid Proto [identifier](https://developers.google.com/protocol-buffers/docs/reference/proto2-spec)
20 * * Nullable values are allowed only for Kotlin [nullable][SerialDescriptor.isNullable] types, but not [optional][SerialDescriptor.isElementOptional]
21 * in order to properly distinguish "default" and "absent" values.
22 * * The name of the type without the package directive uniquely identifies the proto message type and two or more fields with the same serial name
23 * are considered to have the same type. Schema generator allows to specify a separate package directive for the pack of classes in order
24 * to mitigate this limitation.
25 * * Nested collections, e.g. `List<List<Int>>` are represented using the artificial wrapper message in order to distinguish
26 * repeated fields boundaries.
27 * * Default Kotlin values are not representable in proto schema. A special commentary is generated for properties with default values.
28 * * Empty nullable collections are not supported by the generated schema and will be prohibited in [ProtoBuf] as well
29 * due to their ambiguous nature.
30 *
31 * Temporary restrictions:
32 * * [Contextual] data is represented as as `bytes` type
33 * * [Polymorphic] data is represented as a artificial `KotlinxSerializationPolymorphic` message.
34 *
35 * Other types are mapped according to their specification: primitives as primitives, lists as 'repeated' fields and
36 * maps as 'repeated' map entries.
37 *
38 * The name of messages and enums is extracted from [SerialDescriptor.serialName] in [SerialDescriptor] without the package directive,
39 * as substring after the last dot character, the `'?'` character is also removed if it is present at the end of the string.
40 */
41 @ExperimentalSerializationApi
42 public object ProtoBufSchemaGenerator {
43
44 /**
45 * Generate text of protocol buffers schema version 2 for the given [rootDescriptor].
46 * The resulting schema will contain all types referred by [rootDescriptor].
47 *
48 * [packageName] define common protobuf package for all messages and enum in the schema, it may contain `'a'`..`'z'`
49 * letters in upper and lower case, decimal digits, `'.'` or `'_'` chars, but must be started only by a letter and
50 * not finished by a dot.
51 *
52 * [options] define values for protobuf options. Option value (map value) is an any string, option name (map key)
53 * should be the same format as [packageName].
54 *
55 * The method throws [IllegalArgumentException] if any of the restrictions imposed by [ProtoBufSchemaGenerator] is violated.
56 */
57 @ExperimentalSerializationApi
58 public fun generateSchemaText(
59 rootDescriptor: SerialDescriptor,
60 packageName: String? = null,
61 options: Map<String, String> = emptyMap()
62 ): String = generateSchemaText(listOf(rootDescriptor), packageName, options)
63
64 /**
65 * Generate text of protocol buffers schema version 2 for the given serializable [descriptors].
66 * [packageName] define common protobuf package for all messages and enum in the schema, it may contain `'a'`..`'z'`
67 * letters in upper and lower case, decimal digits, `'.'` or `'_'` chars, but started only from a letter and
68 * not finished by dot.
69 *
70 * [options] define values for protobuf options. Option value (map value) is an any string, option name (map key)
71 * should be the same format as [packageName].
72 *
73 * The method throws [IllegalArgumentException] if any of the restrictions imposed by [ProtoBufSchemaGenerator] is violated.
74 */
75 @ExperimentalSerializationApi
76 public fun generateSchemaText(
77 descriptors: List<SerialDescriptor>,
78 packageName: String? = null,
79 options: Map<String, String> = emptyMap()
80 ): String {
81 packageName?.let { p -> p.checkIsValidFullIdentifier { "Incorrect protobuf package name '$it'" } }
82 checkDoubles(descriptors)
83 val builder = StringBuilder()
84 builder.generateProto2SchemaText(descriptors, packageName, options)
85 return builder.toString()
86 }
87
88 private fun checkDoubles(descriptors: List<SerialDescriptor>) {
89 val rootTypesNames = mutableSetOf<String>()
90 val duplicates = mutableListOf<String>()
91
92 descriptors.map { it.messageOrEnumName }.forEach {
93 if (!rootTypesNames.add(it)) {
94 duplicates += it
95 }
96 }
97 if (duplicates.isNotEmpty()) {
98 throw IllegalArgumentException("Serial names of the following types are duplicated: $duplicates")
99 }
100 }
101
102 private fun StringBuilder.generateProto2SchemaText(
103 descriptors: List<SerialDescriptor>,
104 packageName: String?,
105 options: Map<String, String>
106 ) {
107 appendLine("""syntax = "proto2";""").appendLine()
108
109 packageName?.let { append("package ").append(it).appendLine(';') }
110
111 for ((optionName, optionValue) in options) {
112 val safeOptionName = removeLineBreaks(optionName)
113 val safeOptionValue = removeLineBreaks(optionValue)
114 safeOptionName.checkIsValidFullIdentifier { "Invalid option name '$it'" }
115 append("option ").append(safeOptionName).append(" = \"").append(safeOptionValue).appendLine("\";")
116 }
117
118 val generatedTypes = mutableSetOf<String>()
119 val queue = ArrayDeque<TypeDefinition>()
120 descriptors.map { TypeDefinition(it) }.forEach { queue.add(it) }
121
122 while (queue.isNotEmpty()) {
123 val type = queue.removeFirst()
124 val descriptor = type.descriptor
125 val name = descriptor.messageOrEnumName
126 if (!generatedTypes.add(name)) {
127 continue
128 }
129
130 appendLine()
131 when {
132 descriptor.isProtobufMessage -> queue.addAll(generateMessage(type))
133 descriptor.isProtobufEnum -> generateEnum(type)
134 else -> throw IllegalStateException(
135 "Unrecognized custom type with serial name "
136 + "'${descriptor.serialName}' and kind '${descriptor.kind}'"
137 )
138 }
139 }
140 }
141
142 private fun StringBuilder.generateMessage(messageType: TypeDefinition): List<TypeDefinition> {
143 val messageDescriptor = messageType.descriptor
144 val messageName: String
145 if (messageType.isSynthetic) {
146 append("// This message was generated to support ").append(messageType.ability)
147 .appendLine(" and does not present in Kotlin.")
148
149 messageName = messageDescriptor.serialName
150 if (messageType.containingMessageName != null) {
151 append("// Containing message '").append(messageType.containingMessageName).append("', field '")
152 .append(messageType.fieldName).appendLine('\'')
153 }
154 } else {
155 messageName = messageDescriptor.messageOrEnumName
156 messageName.checkIsValidIdentifier {
157 "Invalid name for the message in protobuf schema '$messageName'. " +
158 "Serial name of the class '${messageDescriptor.serialName}'"
159 }
160 val safeSerialName = removeLineBreaks(messageDescriptor.serialName)
161 if (safeSerialName != messageName) {
162 append("// serial name '").append(safeSerialName).appendLine('\'')
163 }
164 }
165
166 append("message ").append(messageName).appendLine(" {")
167
168 val usedNumbers: MutableSet<Int> = mutableSetOf()
169 val nestedTypes = mutableListOf<TypeDefinition>()
170 for (index in 0 until messageDescriptor.elementsCount) {
171 val fieldName = messageDescriptor.getElementName(index)
172 fieldName.checkIsValidIdentifier {
173 "Invalid name of the field '$fieldName' in message '$messageName' for class with serial " +
174 "name '${messageDescriptor.serialName}'"
175 }
176
177 val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
178
179 val isList = fieldDescriptor.isProtobufRepeated
180
181 nestedTypes += when {
182 fieldDescriptor.isProtobufNamedType -> generateNamedType(messageType, index)
183 isList -> generateListType(messageType, index)
184 fieldDescriptor.isProtobufMap -> generateMapType(messageType, index)
185 else -> throw IllegalStateException(
186 "Unprocessed message field type with serial name " +
187 "'${fieldDescriptor.serialName}' and kind '${fieldDescriptor.kind}'"
188 )
189 }
190
191
192 val annotations = messageDescriptor.getElementAnnotations(index)
193 val number = annotations.filterIsInstance<ProtoNumber>().singleOrNull()?.number ?: (index + 1)
194 if (!usedNumbers.add(number)) {
195 throw IllegalArgumentException("Field number $number is repeated in the class with serial name ${messageDescriptor.serialName}")
196 }
197
198 append(' ').append(fieldName).append(" = ").append(number)
199
200 val isPackRequested = annotations.filterIsInstance<ProtoPacked>().singleOrNull() != null
201
202 when {
203 !isPackRequested ||
204 !isList || // ignore as packed only meaningful on repeated types
205 !fieldDescriptor.getElementDescriptor(0).isPackable // Ignore if the type is not allowed to be packed
206 -> appendLine(';')
207 else -> appendLine(" [packed=true];")
208 }
209 }
210 appendLine('}')
211
212 return nestedTypes
213 }
214
215 private fun StringBuilder.generateNamedType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
216 val messageDescriptor = messageType.descriptor
217
218 val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
219 var unwrappedFieldDescriptor = fieldDescriptor
220 while (unwrappedFieldDescriptor.isInline) {
221 unwrappedFieldDescriptor = unwrappedFieldDescriptor.getElementDescriptor(0)
222 }
223
224 val nestedTypes: List<TypeDefinition>
225 val typeName: String = when {
226 messageDescriptor.isSealedPolymorphic && index == 1 -> {
227 appendLine(" // decoded as message with one of these types:")
228 nestedTypes = unwrappedFieldDescriptor.elementDescriptors.map { TypeDefinition(it) }.toList()
229 nestedTypes.forEachIndexed { _, childType ->
230 append(" // message ").append(childType.descriptor.messageOrEnumName).append(", serial name '")
231 .append(removeLineBreaks(childType.descriptor.serialName)).appendLine('\'')
232 }
233 unwrappedFieldDescriptor.scalarTypeName()
234 }
235 unwrappedFieldDescriptor.isProtobufScalar -> {
236 nestedTypes = emptyList()
237 unwrappedFieldDescriptor.scalarTypeName(messageDescriptor.getElementAnnotations(index))
238 }
239 unwrappedFieldDescriptor.isOpenPolymorphic -> {
240 nestedTypes = listOf(SyntheticPolymorphicType)
241 SyntheticPolymorphicType.descriptor.serialName
242 }
243 else -> {
244 // enum or regular message
245 nestedTypes = listOf(TypeDefinition(unwrappedFieldDescriptor))
246 unwrappedFieldDescriptor.messageOrEnumName
247 }
248 }
249
250 if (messageDescriptor.isElementOptional(index)) {
251 appendLine(" // WARNING: a default value decoded when value is missing")
252 }
253 val optional = fieldDescriptor.isNullable || messageDescriptor.isElementOptional(index)
254
255 append(" ").append(if (optional) "optional " else "required ").append(typeName)
256
257 return nestedTypes
258 }
259
260 private fun StringBuilder.generateMapType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
261 val messageDescriptor = messageType.descriptor
262 val mapDescriptor = messageDescriptor.getElementDescriptor(index)
263 val originalMapValueDescriptor = mapDescriptor.getElementDescriptor(1)
264 val valueType = if (originalMapValueDescriptor.isProtobufCollection) {
265 createNestedCollectionType(messageType, index, originalMapValueDescriptor, "nested collection in map value")
266 } else {
267 TypeDefinition(originalMapValueDescriptor)
268 }
269 val valueDescriptor = valueType.descriptor
270
271 if (originalMapValueDescriptor.isNullable) {
272 appendLine(" // WARNING: nullable map values can not be represented in protobuf")
273 }
274 generateCollectionAbsenceComment(messageDescriptor, mapDescriptor, index)
275
276 val keyTypeName = mapDescriptor.getElementDescriptor(0).scalarTypeName(mapDescriptor.getElementAnnotations(0))
277 val valueTypeName = valueDescriptor.protobufTypeName(mapDescriptor.getElementAnnotations(1))
278 append(" map<").append(keyTypeName).append(", ").append(valueTypeName).append(">")
279
280 return if (valueDescriptor.isProtobufMessageOrEnum) {
281 listOf(valueType)
282 } else {
283 emptyList()
284 }
285 }
286
287 private fun StringBuilder.generateListType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
288 val messageDescriptor = messageType.descriptor
289 val collectionDescriptor = messageDescriptor.getElementDescriptor(index)
290 val originalElementDescriptor = collectionDescriptor.getElementDescriptor(0)
291 val elementType = if (collectionDescriptor.kind == StructureKind.LIST) {
292 if (originalElementDescriptor.isProtobufCollection) {
293 createNestedCollectionType(messageType, index, originalElementDescriptor, "nested collection in list")
294 } else {
295 TypeDefinition(originalElementDescriptor)
296 }
297 } else {
298 createLegacyMapType(messageType, index, "legacy map")
299 }
300
301 val elementDescriptor = elementType.descriptor
302
303 if (elementDescriptor.isNullable) {
304 appendLine(" // WARNING: nullable elements of collections can not be represented in protobuf")
305 }
306 generateCollectionAbsenceComment(messageDescriptor, collectionDescriptor, index)
307
308 val typeName = elementDescriptor.protobufTypeName(messageDescriptor.getElementAnnotations(index))
309 append(" repeated ").append(typeName)
310
311 return if (elementDescriptor.isProtobufMessageOrEnum) {
312 listOf(elementType)
313 } else {
314 emptyList()
315 }
316 }
317
318 private fun StringBuilder.generateEnum(enumType: TypeDefinition) {
319 val enumDescriptor = enumType.descriptor
320 val enumName = enumDescriptor.messageOrEnumName
321 enumName.checkIsValidIdentifier {
322 "Invalid name for the enum in protobuf schema '$enumName'. Serial name of the enum " +
323 "class '${enumDescriptor.serialName}'"
324 }
325 val safeSerialName = removeLineBreaks(enumDescriptor.serialName)
326 if (safeSerialName != enumName) {
327 append("// serial name '").append(safeSerialName).appendLine('\'')
328 }
329
330 append("enum ").append(enumName).appendLine(" {")
331
332 val usedNumbers: MutableSet<Int> = mutableSetOf()
333 val duplicatedNumbers: MutableSet<Int> = mutableSetOf()
334 enumDescriptor.elementDescriptors.forEachIndexed { index, element ->
335 val elementName = element.protobufEnumElementName
336 elementName.checkIsValidIdentifier {
337 "The enum element name '$elementName' is invalid in the " +
338 "protobuf schema. Serial name of the enum class '${enumDescriptor.serialName}'"
339 }
340
341 val annotations = enumDescriptor.getElementAnnotations(index)
342 val number = annotations.filterIsInstance<ProtoNumber>().singleOrNull()?.number ?: index
343 if (!usedNumbers.add(number)) {
344 duplicatedNumbers.add(number)
345 }
346
347 append(" ").append(elementName).append(" = ").append(number).appendLine(';')
348 }
349 if (duplicatedNumbers.isNotEmpty()) {
350 throw IllegalArgumentException(
351 "The class with serial name ${enumDescriptor.serialName} has duplicate " +
352 "elements with numbers $duplicatedNumbers"
353 )
354 }
355
356 appendLine('}')
357 }
358
359 private val SerialDescriptor.isOpenPolymorphic: Boolean
360 get() = kind == PolymorphicKind.OPEN
361
362 private val SerialDescriptor.isSealedPolymorphic: Boolean
363 get() = kind == PolymorphicKind.SEALED
364
365 private val SerialDescriptor.isProtobufNamedType: Boolean
366 get() = isProtobufMessageOrEnum || isProtobufScalar
367
368 private val SerialDescriptor.isProtobufScalar: Boolean
369 get() = (kind is PrimitiveKind)
370 || (kind is StructureKind.LIST && getElementDescriptor(0).kind === PrimitiveKind.BYTE)
371 || kind == SerialKind.CONTEXTUAL
372
373 private val SerialDescriptor.isProtobufMessageOrEnum: Boolean
374 get() = isProtobufMessage || isProtobufEnum
375
376 private val SerialDescriptor.isProtobufMessage: Boolean
377 get() = kind == StructureKind.CLASS || kind == StructureKind.OBJECT || kind == PolymorphicKind.SEALED || kind == PolymorphicKind.OPEN
378
379 private val SerialDescriptor.isProtobufCollection: Boolean
380 get() = isProtobufRepeated || isProtobufMap
381
382 private val SerialDescriptor.isProtobufRepeated: Boolean
383 get() = (kind == StructureKind.LIST && getElementDescriptor(0).kind != PrimitiveKind.BYTE)
384 || (kind == StructureKind.MAP && !getElementDescriptor(0).isValidMapKey)
385
386 private val SerialDescriptor.isProtobufMap: Boolean
387 get() = kind == StructureKind.MAP && getElementDescriptor(0).isValidMapKey
388
389 private val SerialDescriptor.isProtobufEnum: Boolean
390 get() = kind == SerialKind.ENUM
391
392 private val SerialDescriptor.isValidMapKey: Boolean
393 get() = kind == PrimitiveKind.INT || kind == PrimitiveKind.LONG || kind == PrimitiveKind.BOOLEAN || kind == PrimitiveKind.STRING
394
395
396 private val SerialDescriptor.messageOrEnumName: String
397 get() = (serialName.substringAfterLast('.', serialName)).removeSuffix("?")
398
399 private fun SerialDescriptor.protobufTypeName(annotations: List<Annotation> = emptyList()): String {
400 return if (isProtobufScalar) {
401 scalarTypeName(annotations)
402 } else {
403 messageOrEnumName
404 }
405 }
406
407 private val SerialDescriptor.protobufEnumElementName: String
408 get() = serialName.substringAfterLast('.', serialName)
409
410 private fun SerialDescriptor.scalarTypeName(annotations: List<Annotation> = emptyList()): String {
411 val integerType = annotations.filterIsInstance<ProtoType>().firstOrNull()?.type ?: ProtoIntegerType.DEFAULT
412
413 if (kind == SerialKind.CONTEXTUAL) {
414 return "bytes"
415 }
416
417 if (kind is StructureKind.LIST && getElementDescriptor(0).kind == PrimitiveKind.BYTE) {
418 return "bytes"
419 }
420
421 return when (kind as PrimitiveKind) {
422 PrimitiveKind.BOOLEAN -> "bool"
423 PrimitiveKind.BYTE, PrimitiveKind.CHAR, PrimitiveKind.SHORT, PrimitiveKind.INT ->
424 when (integerType) {
425 ProtoIntegerType.DEFAULT -> "int32"
426 ProtoIntegerType.SIGNED -> "sint32"
427 ProtoIntegerType.FIXED -> "fixed32"
428 }
429 PrimitiveKind.LONG ->
430 when (integerType) {
431 ProtoIntegerType.DEFAULT -> "int64"
432 ProtoIntegerType.SIGNED -> "sint64"
433 ProtoIntegerType.FIXED -> "fixed64"
434 }
435 PrimitiveKind.FLOAT -> "float"
436 PrimitiveKind.DOUBLE -> "double"
437 PrimitiveKind.STRING -> "string"
438 }
439 }
440
441 @SuppressAnimalSniffer // Boolean.hashCode(boolean) in compiler-generated hashCode implementation
442 private data class TypeDefinition(
443 val descriptor: SerialDescriptor,
444 val isSynthetic: Boolean = false,
445 val ability: String? = null,
446 val containingMessageName: String? = null,
447 val fieldName: String? = null
448 )
449
450 private val SyntheticPolymorphicType = TypeDefinition(
451 buildClassSerialDescriptor("KotlinxSerializationPolymorphic") {
452 element("type", PrimitiveSerialDescriptor("typeDescriptor", PrimitiveKind.STRING))
453 element("value", buildSerialDescriptor("valueDescriptor", StructureKind.LIST) {
454 element("0", Byte.serializer().descriptor)
455 })
456 },
457 true,
458 "polymorphic types"
459 )
460
461 private class NotNullSerialDescriptor(val original: SerialDescriptor) : SerialDescriptor by original {
462 override val isNullable = false
463 }
464
465 private val SerialDescriptor.notNull get() = NotNullSerialDescriptor(this)
466
467 private fun StringBuilder.generateCollectionAbsenceComment(
468 messageDescriptor: SerialDescriptor,
469 collectionDescriptor: SerialDescriptor,
470 index: Int
471 ) {
472 if (!collectionDescriptor.isNullable && messageDescriptor.isElementOptional(index)) {
473 appendLine(" // WARNING: a default value decoded when value is missing")
474 } else if (collectionDescriptor.isNullable && !messageDescriptor.isElementOptional(index)) {
475 appendLine(" // WARNING: an empty collection decoded when a value is missing")
476 } else if (collectionDescriptor.isNullable && messageDescriptor.isElementOptional(index)) {
477 appendLine(" // WARNING: a default value decoded when value is missing")
478 }
479 }
480
481 private fun createLegacyMapType(
482 messageType: TypeDefinition,
483 index: Int,
484 description: String
485 ): TypeDefinition {
486 val messageDescriptor = messageType.descriptor
487 val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
488 val fieldName = messageDescriptor.getElementName(index)
489 val messageName = messageDescriptor.messageOrEnumName
490
491 val wrapperName = "${messageName}_${fieldName}"
492 val wrapperDescriptor = buildClassSerialDescriptor(wrapperName) {
493 element("key", fieldDescriptor.getElementDescriptor(0).notNull)
494 element("value", fieldDescriptor.getElementDescriptor(1).notNull)
495 }
496
497 return TypeDefinition(
498 wrapperDescriptor,
499 true,
500 description,
501 messageType.containingMessageName ?: messageName,
502 messageType.fieldName ?: fieldName
503 )
504 }
505
506 private fun createNestedCollectionType(
507 messageType: TypeDefinition,
508 index: Int,
509 elementDescriptor: SerialDescriptor,
510 description: String
511 ): TypeDefinition {
512 val messageDescriptor = messageType.descriptor
513 val fieldName = messageDescriptor.getElementName(index)
514 val messageName = messageDescriptor.messageOrEnumName
515
516 val wrapperName = "${messageName}_${fieldName}"
517 val wrapperDescriptor = buildClassSerialDescriptor(wrapperName) {
518 element("value", elementDescriptor.notNull)
519 }
520
521 return TypeDefinition(
522 wrapperDescriptor,
523 true,
524 description,
525 messageType.containingMessageName ?: messageName,
526 messageType.fieldName ?: fieldName
527 )
528 }
529
530 private fun removeLineBreaks(text: String): String {
531 return text.replace('\n', ' ').replace('\r', ' ')
532 }
533
534 private val IDENTIFIER_REGEX = Regex("[A-Za-z][A-Za-z0-9_]*")
535
536 private fun String.checkIsValidFullIdentifier(messageSupplier: (String) -> String) {
537 if (split('.').any { !it.matches(IDENTIFIER_REGEX) }) {
538 throw IllegalArgumentException(messageSupplier.invoke(this))
539 }
540 }
541
542 private fun String.checkIsValidIdentifier(messageSupplier: () -> String) {
543 if (!matches(IDENTIFIER_REGEX)) {
544 throw IllegalArgumentException(messageSupplier.invoke())
545 }
546 }
547 }
548