xref: /aosp_15_r20/external/protobuf/csharp/src/Google.Protobuf/ExtensionSet.cs (revision 1b3f573f81763fcece89efc2b6a5209149e44ab8)
1 #region Copyright notice and license
2 // Protocol Buffers - Google's data interchange format
3 // Copyright 2008 Google Inc.  All rights reserved.
4 // https://developers.google.com/protocol-buffers/
5 //
6 // Redistribution and use in source and binary forms, with or without
7 // modification, are permitted provided that the following conditions are
8 // met:
9 //
10 //     * Redistributions of source code must retain the above copyright
11 // notice, this list of conditions and the following disclaimer.
12 //     * Redistributions in binary form must reproduce the above
13 // copyright notice, this list of conditions and the following disclaimer
14 // in the documentation and/or other materials provided with the
15 // distribution.
16 //     * Neither the name of Google Inc. nor the names of its
17 // contributors may be used to endorse or promote products derived from
18 // this software without specific prior written permission.
19 //
20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 #endregion
32 
33 using Google.Protobuf.Collections;
34 using System;
35 using System.Collections.Generic;
36 using System.Linq;
37 using System.Reflection;
38 using System.Security;
39 
40 namespace Google.Protobuf
41 {
42     /// <summary>
43     /// Methods for managing <see cref="ExtensionSet{TTarget}"/>s with null checking.
44     ///
45     /// Most users will not use this class directly and its API is experimental and subject to change.
46     /// </summary>
47     public static class ExtensionSet
48     {
49         private static bool TryGetValue<TTarget>(ref ExtensionSet<TTarget> set, Extension extension, out IExtensionValue value) where TTarget : IExtendableMessage<TTarget>
50         {
51             if (set == null)
52             {
53                 value = null;
54                 return false;
55             }
56             return set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value);
57         }
58 
59         /// <summary>
60         /// Gets the value of the specified extension
61         /// </summary>
62         public static TValue Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
63         {
64             IExtensionValue value;
65             if (TryGetValue(ref set, extension, out value))
66             {
67                 // The stored ExtensionValue can be a different type to what is being requested.
68                 // This happens when the same extension proto is compiled in different assemblies.
69                 // To allow consuming assemblies to still get the value when the TValue type is
70                 // different, this get method:
71                 // 1. Attempts to cast the value to the expected ExtensionValue<TValue>.
72                 //    This is the usual case. It is used first because it avoids possibly boxing the value.
73                 // 2. Fallback to get the value as object from IExtensionValue then casting.
74                 //    This allows for someone to specify a TValue of object. They can then convert
75                 //    the values to bytes and reparse using expected value.
76                 // 3. If neither of these work, throw a user friendly error that the types aren't compatible.
77                 if (value is ExtensionValue<TValue> extensionValue)
78                 {
79                     return extensionValue.GetValue();
80                 }
81                 else if (value.GetValue() is TValue underlyingValue)
82                 {
83                     return underlyingValue;
84                 }
85                 else
86                 {
87                     var valueType = value.GetType().GetTypeInfo();
88                     if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(ExtensionValue<>))
89                     {
90                         var storedType = valueType.GenericTypeArguments[0];
91                         throw new InvalidOperationException(
92                             "The stored extension value has a type of '" + storedType.AssemblyQualifiedName + "'. " +
93                             "This a different from the requested type of '" + typeof(TValue).AssemblyQualifiedName + "'.");
94                     }
95                     else
96                     {
97                         throw new InvalidOperationException("Unexpected extension value type: " + valueType.AssemblyQualifiedName);
98                     }
99                 }
100             }
101             else
102             {
103                 return extension.DefaultValue;
104             }
105         }
106 
107         /// <summary>
108         /// Gets the value of the specified repeated extension or null if it doesn't exist in this set
109         /// </summary>
110         public static RepeatedField<TValue> Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
111         {
112             IExtensionValue value;
113             if (TryGetValue(ref set, extension, out value))
114             {
115                 if (value is RepeatedExtensionValue<TValue> extensionValue)
116                 {
117                     return extensionValue.GetValue();
118                 }
119                 else
120                 {
121                     var valueType = value.GetType().GetTypeInfo();
122                     if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(RepeatedExtensionValue<>))
123                     {
124                         var storedType = valueType.GenericTypeArguments[0];
125                         throw new InvalidOperationException(
126                             "The stored extension value has a type of '" + storedType.AssemblyQualifiedName + "'. " +
127                             "This a different from the requested type of '" + typeof(TValue).AssemblyQualifiedName + "'.");
128                     }
129                     else
130                     {
131                         throw new InvalidOperationException("Unexpected extension value type: " + valueType.AssemblyQualifiedName);
132                     }
133                 }
134             }
135             else
136             {
137                 return null;
138             }
139         }
140 
141         /// <summary>
142         /// Gets the value of the specified repeated extension, registering it if it doesn't exist
143         /// </summary>
144         public static RepeatedField<TValue> GetOrInitialize<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
145         {
146             IExtensionValue value;
147             if (set == null)
148             {
149                 value = extension.CreateValue();
150                 set = new ExtensionSet<TTarget>();
151                 set.ValuesByNumber.Add(extension.FieldNumber, value);
152             }
153             else
154             {
155                 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value))
156                 {
157                     value = extension.CreateValue();
158                     set.ValuesByNumber.Add(extension.FieldNumber, value);
159                 }
160             }
161 
162             return ((RepeatedExtensionValue<TValue>)value).GetValue();
163         }
164 
165         /// <summary>
166         /// Sets the value of the specified extension. This will make a new instance of ExtensionSet if the set is null.
167         /// </summary>
168         public static void Set<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension, TValue value) where TTarget : IExtendableMessage<TTarget>
169         {
170             ProtoPreconditions.CheckNotNullUnconstrained(value, nameof(value));
171 
172             IExtensionValue extensionValue;
173             if (set == null)
174             {
175                 extensionValue = extension.CreateValue();
176                 set = new ExtensionSet<TTarget>();
177                 set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
178             }
179             else
180             {
181                 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out extensionValue))
182                 {
183                     extensionValue = extension.CreateValue();
184                     set.ValuesByNumber.Add(extension.FieldNumber, extensionValue);
185                 }
186             }
187 
188             ((ExtensionValue<TValue>)extensionValue).SetValue(value);
189         }
190 
191         /// <summary>
192         /// Gets whether the value of the specified extension is set
193         /// </summary>
194         public static bool Has<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
195         {
196             IExtensionValue value;
197             return TryGetValue(ref set, extension, out value);
198         }
199 
200         /// <summary>
201         /// Clears the value of the specified extension
202         /// </summary>
203         public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
204         {
205             if (set == null)
206             {
207                 return;
208             }
209             set.ValuesByNumber.Remove(extension.FieldNumber);
210             if (set.ValuesByNumber.Count == 0)
211             {
212                 set = null;
213             }
214         }
215 
216         /// <summary>
217         /// Clears the value of the specified extension
218         /// </summary>
219         public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget>
220         {
221             if (set == null)
222             {
223                 return;
224             }
225             set.ValuesByNumber.Remove(extension.FieldNumber);
226             if (set.ValuesByNumber.Count == 0)
227             {
228                 set = null;
229             }
230         }
231 
232         /// <summary>
233         /// Tries to merge a field from the coded input, returning true if the field was merged.
234         /// If the set is null or the field was not otherwise merged, this returns false.
235         /// </summary>
236         public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, CodedInputStream stream) where TTarget : IExtendableMessage<TTarget>
237         {
238             ParseContext.Initialize(stream, out ParseContext ctx);
239             try
240             {
241                 return TryMergeFieldFrom<TTarget>(ref set, ref ctx);
242             }
243             finally
244             {
245                 ctx.CopyStateTo(stream);
246             }
247         }
248 
249         /// <summary>
250         /// Tries to merge a field from the coded input, returning true if the field was merged.
251         /// If the set is null or the field was not otherwise merged, this returns false.
252         /// </summary>
253         public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, ref ParseContext ctx) where TTarget : IExtendableMessage<TTarget>
254         {
255             Extension extension;
256             int lastFieldNumber = WireFormat.GetTagFieldNumber(ctx.LastTag);
257 
258             IExtensionValue extensionValue;
259             if (set != null && set.ValuesByNumber.TryGetValue(lastFieldNumber, out extensionValue))
260             {
261                 extensionValue.MergeFrom(ref ctx);
262                 return true;
263             }
264             else if (ctx.ExtensionRegistry != null && ctx.ExtensionRegistry.ContainsInputField(ctx.LastTag, typeof(TTarget), out extension))
265             {
266                 IExtensionValue value = extension.CreateValue();
267                 value.MergeFrom(ref ctx);
268                 set = (set ?? new ExtensionSet<TTarget>());
269                 set.ValuesByNumber.Add(extension.FieldNumber, value);
270                 return true;
271             }
272             else
273             {
274                 return false;
275             }
276         }
277 
278         /// <summary>
279         /// Merges the second set into the first set, creating a new instance if first is null
280         /// </summary>
281         public static void MergeFrom<TTarget>(ref ExtensionSet<TTarget> first, ExtensionSet<TTarget> second) where TTarget : IExtendableMessage<TTarget>
282         {
283             if (second == null)
284             {
285                 return;
286             }
287             if (first == null)
288             {
289                 first = new ExtensionSet<TTarget>();
290             }
291             foreach (var pair in second.ValuesByNumber)
292             {
293                 IExtensionValue value;
294                 if (first.ValuesByNumber.TryGetValue(pair.Key, out value))
295                 {
296                     value.MergeFrom(pair.Value);
297                 }
298                 else
299                 {
300                     var cloned = pair.Value.Clone();
301                     first.ValuesByNumber[pair.Key] = cloned;
302                 }
303             }
304         }
305 
306         /// <summary>
307         /// Clones the set into a new set. If the set is null, this returns null
308         /// </summary>
309         public static ExtensionSet<TTarget> Clone<TTarget>(ExtensionSet<TTarget> set) where TTarget : IExtendableMessage<TTarget>
310         {
311             if (set == null)
312             {
313                 return null;
314             }
315 
316             var newSet = new ExtensionSet<TTarget>();
317             foreach (var pair in set.ValuesByNumber)
318             {
319                 var cloned = pair.Value.Clone();
320                 newSet.ValuesByNumber[pair.Key] = cloned;
321             }
322             return newSet;
323         }
324     }
325 
326     /// <summary>
327     /// Used for keeping track of extensions in messages.
328     /// <see cref="IExtendableMessage{T}"/> methods route to this set.
329     ///
330     /// Most users will not need to use this class directly
331     /// </summary>
332     /// <typeparam name="TTarget">The message type that extensions in this set target</typeparam>
333     public sealed class ExtensionSet<TTarget> where TTarget : IExtendableMessage<TTarget>
334     {
335         internal Dictionary<int, IExtensionValue> ValuesByNumber { get; } = new Dictionary<int, IExtensionValue>();
336 
337         /// <summary>
338         /// Gets a hash code of the set
339         /// </summary>
GetHashCode()340         public override int GetHashCode()
341         {
342             int ret = typeof(TTarget).GetHashCode();
343             foreach (KeyValuePair<int, IExtensionValue> field in ValuesByNumber)
344             {
345                 // Use ^ here to make the field order irrelevant.
346                 int hash = field.Key.GetHashCode() ^ field.Value.GetHashCode();
347                 ret ^= hash;
348             }
349             return ret;
350         }
351 
352         /// <summary>
353         /// Returns whether this set is equal to the other object
354         /// </summary>
Equals(object other)355         public override bool Equals(object other)
356         {
357             if (ReferenceEquals(this, other))
358             {
359                 return true;
360             }
361             ExtensionSet<TTarget> otherSet = other as ExtensionSet<TTarget>;
362             if (ValuesByNumber.Count != otherSet.ValuesByNumber.Count)
363             {
364                 return false;
365             }
366             foreach (var pair in ValuesByNumber)
367             {
368                 IExtensionValue secondValue;
369                 if (!otherSet.ValuesByNumber.TryGetValue(pair.Key, out secondValue))
370                 {
371                     return false;
372                 }
373                 if (!pair.Value.Equals(secondValue))
374                 {
375                     return false;
376                 }
377             }
378             return true;
379         }
380 
381         /// <summary>
382         /// Calculates the size of this extension set
383         /// </summary>
CalculateSize()384         public int CalculateSize()
385         {
386             int size = 0;
387             foreach (var value in ValuesByNumber.Values)
388             {
389                 size += value.CalculateSize();
390             }
391             return size;
392         }
393 
394         /// <summary>
395         /// Writes the extension values in this set to the output stream
396         /// </summary>
WriteTo(CodedOutputStream stream)397         public void WriteTo(CodedOutputStream stream)
398         {
399 
400             WriteContext.Initialize(stream, out WriteContext ctx);
401             try
402             {
403                 WriteTo(ref ctx);
404             }
405             finally
406             {
407                 ctx.CopyStateTo(stream);
408             }
409         }
410 
411         /// <summary>
412         /// Writes the extension values in this set to the write context
413         /// </summary>
414         [SecuritySafeCritical]
WriteTo(ref WriteContext ctx)415         public void WriteTo(ref WriteContext ctx)
416         {
417             foreach (var value in ValuesByNumber.Values)
418             {
419                 value.WriteTo(ref ctx);
420             }
421         }
422 
IsInitialized()423         internal bool IsInitialized()
424         {
425             return ValuesByNumber.Values.All(v => v.IsInitialized());
426         }
427     }
428 }
429