<lambda>null1 @file:OptIn(ExperimentalSerializationApi::class)
2 
3 package kotlinx.serialization.protobuf.schema
4 
5 import kotlinx.serialization.*
6 import kotlinx.serialization.builtins.*
7 import kotlinx.serialization.descriptors.*
8 import kotlinx.serialization.protobuf.*
9 import kotlinx.serialization.protobuf.internal.*
10 
11 /**
12  * Experimental generator of ProtoBuf schema that is compatible with [serializable][Serializable] Kotlin classes
13  * and data encoded and decoded by [ProtoBuf] format.
14  *
15  * The schema is generated based on provided [SerialDescriptor] and is compatible with proto2 schema definition.
16  * An arbitrary Kotlin class represent much wider object domain than the ProtoBuf specification, thus schema generator
17  * has the following list of restrictions:
18  *
19  *  * Serial name of the class and all its fields should be a valid Proto [identifier](https://developers.google.com/protocol-buffers/docs/reference/proto2-spec)
20  *  * Nullable values are allowed only for Kotlin [nullable][SerialDescriptor.isNullable] types, but not [optional][SerialDescriptor.isElementOptional]
21  *    in order to properly distinguish "default" and "absent" values.
22  *  * The name of the type without the package directive uniquely identifies the proto message type and two or more fields with the same serial name
23  *    are considered to have the same type. Schema generator allows to specify a separate package directive for the pack of classes in order
24  *    to mitigate this limitation.
25  *  * Nested collections, e.g. `List<List<Int>>` are represented using the artificial wrapper message in order to distinguish
26  *    repeated fields boundaries.
27  *  * Default Kotlin values are not representable in proto schema. A special commentary is generated for properties with default values.
28  *  * Empty nullable collections are not supported by the generated schema and will be prohibited in [ProtoBuf] as well
29  *    due to their ambiguous nature.
30  *
31  * Temporary restrictions:
32  *  * [Contextual] data is represented as as `bytes` type
33  *  * [Polymorphic] data is represented as a artificial `KotlinxSerializationPolymorphic` message.
34  *
35  * Other types are mapped according to their specification: primitives as primitives, lists as 'repeated' fields and
36  * maps as 'repeated' map entries.
37  *
38  * The name of messages and enums is extracted from [SerialDescriptor.serialName] in [SerialDescriptor] without the package directive,
39  * as substring after the last dot character, the `'?'` character is also removed if it is present at the end of the string.
40  */
41 @ExperimentalSerializationApi
42 public object ProtoBufSchemaGenerator {
43 
44     /**
45      * Generate text of protocol buffers schema version 2 for the given [rootDescriptor].
46      * The resulting schema will contain all types referred by [rootDescriptor].
47      *
48      * [packageName] define common protobuf package for all messages and enum in the schema, it may contain `'a'`..`'z'`
49      * letters in upper and lower case, decimal digits, `'.'` or `'_'` chars, but must be started only by a letter and
50      * not finished by a dot.
51      *
52      * [options] define values for protobuf options. Option value (map value) is an any string, option name (map key)
53      * should be the same format as [packageName].
54      *
55      * The method throws [IllegalArgumentException] if any of the restrictions imposed by [ProtoBufSchemaGenerator] is violated.
56      */
57     @ExperimentalSerializationApi
58     public fun generateSchemaText(
59         rootDescriptor: SerialDescriptor,
60         packageName: String? = null,
61         options: Map<String, String> = emptyMap()
62     ): String = generateSchemaText(listOf(rootDescriptor), packageName, options)
63 
64     /**
65      * Generate text of protocol buffers schema version 2 for the given serializable [descriptors].
66      * [packageName] define common protobuf package for all messages and enum in the schema, it may contain `'a'`..`'z'`
67      * letters in upper and lower case, decimal digits, `'.'` or `'_'` chars, but started only from a letter and
68      * not finished by dot.
69      *
70      * [options] define values for protobuf options. Option value (map value) is an any string, option name (map key)
71      * should be the same format as [packageName].
72      *
73      * The method throws [IllegalArgumentException] if any of the restrictions imposed by [ProtoBufSchemaGenerator] is violated.
74      */
75     @ExperimentalSerializationApi
76     public fun generateSchemaText(
77         descriptors: List<SerialDescriptor>,
78         packageName: String? = null,
79         options: Map<String, String> = emptyMap()
80     ): String {
81         packageName?.let { p -> p.checkIsValidFullIdentifier { "Incorrect protobuf package name '$it'" } }
82         checkDoubles(descriptors)
83         val builder = StringBuilder()
84         builder.generateProto2SchemaText(descriptors, packageName, options)
85         return builder.toString()
86     }
87 
88     private fun checkDoubles(descriptors: List<SerialDescriptor>) {
89         val rootTypesNames = mutableSetOf<String>()
90         val duplicates = mutableListOf<String>()
91 
92         descriptors.map { it.messageOrEnumName }.forEach {
93             if (!rootTypesNames.add(it)) {
94                 duplicates += it
95             }
96         }
97         if (duplicates.isNotEmpty()) {
98             throw IllegalArgumentException("Serial names of the following types are duplicated: $duplicates")
99         }
100     }
101 
102     private fun StringBuilder.generateProto2SchemaText(
103         descriptors: List<SerialDescriptor>,
104         packageName: String?,
105         options: Map<String, String>
106     ) {
107         appendLine("""syntax = "proto2";""").appendLine()
108 
109         packageName?.let { append("package ").append(it).appendLine(';') }
110 
111         for ((optionName, optionValue) in options) {
112             val safeOptionName = removeLineBreaks(optionName)
113             val safeOptionValue = removeLineBreaks(optionValue)
114             safeOptionName.checkIsValidFullIdentifier { "Invalid option name '$it'" }
115             append("option ").append(safeOptionName).append(" = \"").append(safeOptionValue).appendLine("\";")
116         }
117 
118         val generatedTypes = mutableSetOf<String>()
119         val queue = ArrayDeque<TypeDefinition>()
120         descriptors.map { TypeDefinition(it) }.forEach { queue.add(it) }
121 
122         while (queue.isNotEmpty()) {
123             val type = queue.removeFirst()
124             val descriptor = type.descriptor
125             val name = descriptor.messageOrEnumName
126             if (!generatedTypes.add(name)) {
127                 continue
128             }
129 
130             appendLine()
131             when {
132                 descriptor.isProtobufMessage -> queue.addAll(generateMessage(type))
133                 descriptor.isProtobufEnum -> generateEnum(type)
134                 else -> throw IllegalStateException(
135                     "Unrecognized custom type with serial name "
136                             + "'${descriptor.serialName}' and kind '${descriptor.kind}'"
137                 )
138             }
139         }
140     }
141 
142     private fun StringBuilder.generateMessage(messageType: TypeDefinition): List<TypeDefinition> {
143         val messageDescriptor = messageType.descriptor
144         val messageName: String
145         if (messageType.isSynthetic) {
146             append("// This message was generated to support ").append(messageType.ability)
147                 .appendLine(" and does not present in Kotlin.")
148 
149             messageName = messageDescriptor.serialName
150             if (messageType.containingMessageName != null) {
151                 append("// Containing message '").append(messageType.containingMessageName).append("', field '")
152                     .append(messageType.fieldName).appendLine('\'')
153             }
154         } else {
155             messageName = messageDescriptor.messageOrEnumName
156             messageName.checkIsValidIdentifier {
157                 "Invalid name for the message in protobuf schema '$messageName'. " +
158                         "Serial name of the class '${messageDescriptor.serialName}'"
159             }
160             val safeSerialName = removeLineBreaks(messageDescriptor.serialName)
161             if (safeSerialName != messageName) {
162                 append("// serial name '").append(safeSerialName).appendLine('\'')
163             }
164         }
165 
166         append("message ").append(messageName).appendLine(" {")
167 
168         val usedNumbers: MutableSet<Int> = mutableSetOf()
169         val nestedTypes = mutableListOf<TypeDefinition>()
170         for (index in 0 until messageDescriptor.elementsCount) {
171             val fieldName = messageDescriptor.getElementName(index)
172             fieldName.checkIsValidIdentifier {
173                 "Invalid name of the field '$fieldName' in message '$messageName' for class with serial " +
174                         "name '${messageDescriptor.serialName}'"
175             }
176 
177             val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
178 
179             val isList = fieldDescriptor.isProtobufRepeated
180 
181             nestedTypes += when {
182                 fieldDescriptor.isProtobufNamedType -> generateNamedType(messageType, index)
183                 isList -> generateListType(messageType, index)
184                 fieldDescriptor.isProtobufMap -> generateMapType(messageType, index)
185                 else -> throw IllegalStateException(
186                     "Unprocessed message field type with serial name " +
187                             "'${fieldDescriptor.serialName}' and kind '${fieldDescriptor.kind}'"
188                 )
189             }
190 
191 
192             val annotations = messageDescriptor.getElementAnnotations(index)
193             val number = annotations.filterIsInstance<ProtoNumber>().singleOrNull()?.number ?: (index + 1)
194             if (!usedNumbers.add(number)) {
195                 throw IllegalArgumentException("Field number $number is repeated in the class with serial name ${messageDescriptor.serialName}")
196             }
197 
198             append(' ').append(fieldName).append(" = ").append(number)
199 
200             val isPackRequested = annotations.filterIsInstance<ProtoPacked>().singleOrNull() != null
201 
202             when {
203                 !isPackRequested ||
204                 !isList || // ignore as packed only meaningful on repeated types
205                 !fieldDescriptor.getElementDescriptor(0).isPackable // Ignore if the type is not allowed to be packed
206                      -> appendLine(';')
207                 else -> appendLine(" [packed=true];")
208             }
209         }
210         appendLine('}')
211 
212         return nestedTypes
213     }
214 
215     private fun StringBuilder.generateNamedType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
216         val messageDescriptor = messageType.descriptor
217 
218         val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
219         var unwrappedFieldDescriptor = fieldDescriptor
220         while (unwrappedFieldDescriptor.isInline) {
221             unwrappedFieldDescriptor = unwrappedFieldDescriptor.getElementDescriptor(0)
222         }
223 
224         val nestedTypes: List<TypeDefinition>
225         val typeName: String = when {
226             messageDescriptor.isSealedPolymorphic && index == 1 -> {
227                 appendLine("  // decoded as message with one of these types:")
228                 nestedTypes = unwrappedFieldDescriptor.elementDescriptors.map { TypeDefinition(it) }.toList()
229                 nestedTypes.forEachIndexed { _, childType ->
230                     append("  //   message ").append(childType.descriptor.messageOrEnumName).append(", serial name '")
231                         .append(removeLineBreaks(childType.descriptor.serialName)).appendLine('\'')
232                 }
233                 unwrappedFieldDescriptor.scalarTypeName()
234             }
235             unwrappedFieldDescriptor.isProtobufScalar -> {
236                 nestedTypes = emptyList()
237                 unwrappedFieldDescriptor.scalarTypeName(messageDescriptor.getElementAnnotations(index))
238             }
239             unwrappedFieldDescriptor.isOpenPolymorphic -> {
240                 nestedTypes = listOf(SyntheticPolymorphicType)
241                 SyntheticPolymorphicType.descriptor.serialName
242             }
243             else -> {
244                 // enum or regular message
245                 nestedTypes = listOf(TypeDefinition(unwrappedFieldDescriptor))
246                 unwrappedFieldDescriptor.messageOrEnumName
247             }
248         }
249 
250         if (messageDescriptor.isElementOptional(index)) {
251             appendLine("  // WARNING: a default value decoded when value is missing")
252         }
253         val optional = fieldDescriptor.isNullable || messageDescriptor.isElementOptional(index)
254 
255         append("  ").append(if (optional) "optional " else "required ").append(typeName)
256 
257         return nestedTypes
258     }
259 
260     private fun StringBuilder.generateMapType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
261         val messageDescriptor = messageType.descriptor
262         val mapDescriptor = messageDescriptor.getElementDescriptor(index)
263         val originalMapValueDescriptor = mapDescriptor.getElementDescriptor(1)
264         val valueType = if (originalMapValueDescriptor.isProtobufCollection) {
265             createNestedCollectionType(messageType, index, originalMapValueDescriptor, "nested collection in map value")
266         } else {
267             TypeDefinition(originalMapValueDescriptor)
268         }
269         val valueDescriptor = valueType.descriptor
270 
271         if (originalMapValueDescriptor.isNullable) {
272             appendLine("  // WARNING: nullable map values can not be represented in protobuf")
273         }
274         generateCollectionAbsenceComment(messageDescriptor, mapDescriptor, index)
275 
276         val keyTypeName = mapDescriptor.getElementDescriptor(0).scalarTypeName(mapDescriptor.getElementAnnotations(0))
277         val valueTypeName = valueDescriptor.protobufTypeName(mapDescriptor.getElementAnnotations(1))
278         append("  map<").append(keyTypeName).append(", ").append(valueTypeName).append(">")
279 
280         return if (valueDescriptor.isProtobufMessageOrEnum) {
281             listOf(valueType)
282         } else {
283             emptyList()
284         }
285     }
286 
287     private fun StringBuilder.generateListType(messageType: TypeDefinition, index: Int): List<TypeDefinition> {
288         val messageDescriptor = messageType.descriptor
289         val collectionDescriptor = messageDescriptor.getElementDescriptor(index)
290         val originalElementDescriptor = collectionDescriptor.getElementDescriptor(0)
291         val elementType = if (collectionDescriptor.kind == StructureKind.LIST) {
292             if (originalElementDescriptor.isProtobufCollection) {
293                 createNestedCollectionType(messageType, index, originalElementDescriptor, "nested collection in list")
294             } else {
295                 TypeDefinition(originalElementDescriptor)
296             }
297         } else {
298             createLegacyMapType(messageType, index, "legacy map")
299         }
300 
301         val elementDescriptor = elementType.descriptor
302 
303         if (elementDescriptor.isNullable) {
304             appendLine("  // WARNING: nullable elements of collections can not be represented in protobuf")
305         }
306         generateCollectionAbsenceComment(messageDescriptor, collectionDescriptor, index)
307 
308         val typeName = elementDescriptor.protobufTypeName(messageDescriptor.getElementAnnotations(index))
309         append("  repeated ").append(typeName)
310 
311         return if (elementDescriptor.isProtobufMessageOrEnum) {
312             listOf(elementType)
313         } else {
314             emptyList()
315         }
316     }
317 
318     private fun StringBuilder.generateEnum(enumType: TypeDefinition) {
319         val enumDescriptor = enumType.descriptor
320         val enumName = enumDescriptor.messageOrEnumName
321         enumName.checkIsValidIdentifier {
322             "Invalid name for the enum in protobuf schema '$enumName'. Serial name of the enum " +
323                     "class '${enumDescriptor.serialName}'"
324         }
325         val safeSerialName = removeLineBreaks(enumDescriptor.serialName)
326         if (safeSerialName != enumName) {
327             append("// serial name '").append(safeSerialName).appendLine('\'')
328         }
329 
330         append("enum ").append(enumName).appendLine(" {")
331 
332         val usedNumbers: MutableSet<Int> = mutableSetOf()
333         val duplicatedNumbers: MutableSet<Int> = mutableSetOf()
334         enumDescriptor.elementDescriptors.forEachIndexed { index, element ->
335             val elementName = element.protobufEnumElementName
336             elementName.checkIsValidIdentifier {
337                 "The enum element name '$elementName' is invalid in the " +
338                         "protobuf schema. Serial name of the enum class '${enumDescriptor.serialName}'"
339             }
340 
341             val annotations = enumDescriptor.getElementAnnotations(index)
342             val number = annotations.filterIsInstance<ProtoNumber>().singleOrNull()?.number ?: index
343             if (!usedNumbers.add(number)) {
344                 duplicatedNumbers.add(number)
345             }
346 
347             append("  ").append(elementName).append(" = ").append(number).appendLine(';')
348         }
349         if (duplicatedNumbers.isNotEmpty()) {
350             throw IllegalArgumentException(
351                 "The class with serial name ${enumDescriptor.serialName} has duplicate " +
352                     "elements with numbers $duplicatedNumbers"
353             )
354         }
355 
356         appendLine('}')
357     }
358 
359     private val SerialDescriptor.isOpenPolymorphic: Boolean
360         get() = kind == PolymorphicKind.OPEN
361 
362     private val SerialDescriptor.isSealedPolymorphic: Boolean
363         get() = kind == PolymorphicKind.SEALED
364 
365     private val SerialDescriptor.isProtobufNamedType: Boolean
366         get() = isProtobufMessageOrEnum || isProtobufScalar
367 
368     private val SerialDescriptor.isProtobufScalar: Boolean
369         get() = (kind is PrimitiveKind)
370                 || (kind is StructureKind.LIST && getElementDescriptor(0).kind === PrimitiveKind.BYTE)
371                 || kind == SerialKind.CONTEXTUAL
372 
373     private val SerialDescriptor.isProtobufMessageOrEnum: Boolean
374         get() = isProtobufMessage || isProtobufEnum
375 
376     private val SerialDescriptor.isProtobufMessage: Boolean
377         get() = kind == StructureKind.CLASS || kind == StructureKind.OBJECT || kind == PolymorphicKind.SEALED || kind == PolymorphicKind.OPEN
378 
379     private val SerialDescriptor.isProtobufCollection: Boolean
380         get() = isProtobufRepeated || isProtobufMap
381 
382     private val SerialDescriptor.isProtobufRepeated: Boolean
383         get() = (kind == StructureKind.LIST && getElementDescriptor(0).kind != PrimitiveKind.BYTE)
384                 || (kind == StructureKind.MAP && !getElementDescriptor(0).isValidMapKey)
385 
386     private val SerialDescriptor.isProtobufMap: Boolean
387         get() = kind == StructureKind.MAP && getElementDescriptor(0).isValidMapKey
388 
389     private val SerialDescriptor.isProtobufEnum: Boolean
390         get() = kind == SerialKind.ENUM
391 
392     private val SerialDescriptor.isValidMapKey: Boolean
393         get() = kind == PrimitiveKind.INT || kind == PrimitiveKind.LONG || kind == PrimitiveKind.BOOLEAN || kind == PrimitiveKind.STRING
394 
395 
396     private val SerialDescriptor.messageOrEnumName: String
397         get() = (serialName.substringAfterLast('.', serialName)).removeSuffix("?")
398 
399     private fun SerialDescriptor.protobufTypeName(annotations: List<Annotation> = emptyList()): String {
400         return if (isProtobufScalar) {
401             scalarTypeName(annotations)
402         } else {
403             messageOrEnumName
404         }
405     }
406 
407     private val SerialDescriptor.protobufEnumElementName: String
408         get() = serialName.substringAfterLast('.', serialName)
409 
410     private fun SerialDescriptor.scalarTypeName(annotations: List<Annotation> = emptyList()): String {
411         val integerType = annotations.filterIsInstance<ProtoType>().firstOrNull()?.type ?: ProtoIntegerType.DEFAULT
412 
413         if (kind == SerialKind.CONTEXTUAL) {
414             return "bytes"
415         }
416 
417         if (kind is StructureKind.LIST && getElementDescriptor(0).kind == PrimitiveKind.BYTE) {
418             return "bytes"
419         }
420 
421         return when (kind as PrimitiveKind) {
422             PrimitiveKind.BOOLEAN -> "bool"
423             PrimitiveKind.BYTE, PrimitiveKind.CHAR, PrimitiveKind.SHORT, PrimitiveKind.INT ->
424                 when (integerType) {
425                     ProtoIntegerType.DEFAULT -> "int32"
426                     ProtoIntegerType.SIGNED -> "sint32"
427                     ProtoIntegerType.FIXED -> "fixed32"
428                 }
429             PrimitiveKind.LONG ->
430                 when (integerType) {
431                     ProtoIntegerType.DEFAULT -> "int64"
432                     ProtoIntegerType.SIGNED -> "sint64"
433                     ProtoIntegerType.FIXED -> "fixed64"
434                 }
435             PrimitiveKind.FLOAT -> "float"
436             PrimitiveKind.DOUBLE -> "double"
437             PrimitiveKind.STRING -> "string"
438         }
439     }
440 
441     @SuppressAnimalSniffer // Boolean.hashCode(boolean) in compiler-generated hashCode implementation
442     private data class TypeDefinition(
443         val descriptor: SerialDescriptor,
444         val isSynthetic: Boolean = false,
445         val ability: String? = null,
446         val containingMessageName: String? = null,
447         val fieldName: String? = null
448     )
449 
450     private val SyntheticPolymorphicType = TypeDefinition(
451         buildClassSerialDescriptor("KotlinxSerializationPolymorphic") {
452             element("type", PrimitiveSerialDescriptor("typeDescriptor", PrimitiveKind.STRING))
453             element("value", buildSerialDescriptor("valueDescriptor", StructureKind.LIST) {
454                 element("0", Byte.serializer().descriptor)
455             })
456         },
457         true,
458         "polymorphic types"
459     )
460 
461     private class NotNullSerialDescriptor(val original: SerialDescriptor) : SerialDescriptor by original {
462         override val isNullable = false
463     }
464 
465     private val SerialDescriptor.notNull get() = NotNullSerialDescriptor(this)
466 
467     private fun StringBuilder.generateCollectionAbsenceComment(
468         messageDescriptor: SerialDescriptor,
469         collectionDescriptor: SerialDescriptor,
470         index: Int
471     ) {
472         if (!collectionDescriptor.isNullable && messageDescriptor.isElementOptional(index)) {
473             appendLine("  // WARNING: a default value decoded when value is missing")
474         } else if (collectionDescriptor.isNullable && !messageDescriptor.isElementOptional(index)) {
475             appendLine("  // WARNING: an empty collection decoded when a value is missing")
476         } else if (collectionDescriptor.isNullable && messageDescriptor.isElementOptional(index)) {
477             appendLine("  // WARNING: a default value decoded when value is missing")
478         }
479     }
480 
481     private fun createLegacyMapType(
482         messageType: TypeDefinition,
483         index: Int,
484         description: String
485     ): TypeDefinition {
486         val messageDescriptor = messageType.descriptor
487         val fieldDescriptor = messageDescriptor.getElementDescriptor(index)
488         val fieldName = messageDescriptor.getElementName(index)
489         val messageName = messageDescriptor.messageOrEnumName
490 
491         val wrapperName = "${messageName}_${fieldName}"
492         val wrapperDescriptor = buildClassSerialDescriptor(wrapperName) {
493             element("key", fieldDescriptor.getElementDescriptor(0).notNull)
494             element("value", fieldDescriptor.getElementDescriptor(1).notNull)
495         }
496 
497         return TypeDefinition(
498             wrapperDescriptor,
499             true,
500             description,
501             messageType.containingMessageName ?: messageName,
502             messageType.fieldName ?: fieldName
503         )
504     }
505 
506     private fun createNestedCollectionType(
507         messageType: TypeDefinition,
508         index: Int,
509         elementDescriptor: SerialDescriptor,
510         description: String
511     ): TypeDefinition {
512         val messageDescriptor = messageType.descriptor
513         val fieldName = messageDescriptor.getElementName(index)
514         val messageName = messageDescriptor.messageOrEnumName
515 
516         val wrapperName = "${messageName}_${fieldName}"
517         val wrapperDescriptor = buildClassSerialDescriptor(wrapperName) {
518             element("value", elementDescriptor.notNull)
519         }
520 
521         return TypeDefinition(
522             wrapperDescriptor,
523             true,
524             description,
525             messageType.containingMessageName ?: messageName,
526             messageType.fieldName ?: fieldName
527         )
528     }
529 
530     private fun removeLineBreaks(text: String): String {
531         return text.replace('\n', ' ').replace('\r', ' ')
532     }
533 
534     private val IDENTIFIER_REGEX = Regex("[A-Za-z][A-Za-z0-9_]*")
535 
536     private fun String.checkIsValidFullIdentifier(messageSupplier: (String) -> String) {
537         if (split('.').any { !it.matches(IDENTIFIER_REGEX) }) {
538             throw IllegalArgumentException(messageSupplier.invoke(this))
539         }
540     }
541 
542     private fun String.checkIsValidIdentifier(messageSupplier: () -> String) {
543         if (!matches(IDENTIFIER_REGEX)) {
544             throw IllegalArgumentException(messageSupplier.invoke())
545         }
546     }
547 }
548