1 #region Copyright notice and license 2 // Protocol Buffers - Google's data interchange format 3 // Copyright 2008 Google Inc. All rights reserved. 4 // https://developers.google.com/protocol-buffers/ 5 // 6 // Redistribution and use in source and binary forms, with or without 7 // modification, are permitted provided that the following conditions are 8 // met: 9 // 10 // * Redistributions of source code must retain the above copyright 11 // notice, this list of conditions and the following disclaimer. 12 // * Redistributions in binary form must reproduce the above 13 // copyright notice, this list of conditions and the following disclaimer 14 // in the documentation and/or other materials provided with the 15 // distribution. 16 // * Neither the name of Google Inc. nor the names of its 17 // contributors may be used to endorse or promote products derived from 18 // this software without specific prior written permission. 19 // 20 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 21 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 22 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 23 // A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 24 // OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 25 // SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 26 // LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 27 // DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 28 // THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 #endregion 32 33 using Google.Protobuf.Collections; 34 using System; 35 using System.Collections.Generic; 36 using System.Linq; 37 using System.Reflection; 38 using System.Security; 39 40 namespace Google.Protobuf 41 { 42 /// <summary> 43 /// Methods for managing <see cref="ExtensionSet{TTarget}"/>s with null checking. 44 /// 45 /// Most users will not use this class directly and its API is experimental and subject to change. 46 /// </summary> 47 public static class ExtensionSet 48 { 49 private static bool TryGetValue<TTarget>(ref ExtensionSet<TTarget> set, Extension extension, out IExtensionValue value) where TTarget : IExtendableMessage<TTarget> 50 { 51 if (set == null) 52 { 53 value = null; 54 return false; 55 } 56 return set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value); 57 } 58 59 /// <summary> 60 /// Gets the value of the specified extension 61 /// </summary> 62 public static TValue Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget> 63 { 64 IExtensionValue value; 65 if (TryGetValue(ref set, extension, out value)) 66 { 67 // The stored ExtensionValue can be a different type to what is being requested. 68 // This happens when the same extension proto is compiled in different assemblies. 69 // To allow consuming assemblies to still get the value when the TValue type is 70 // different, this get method: 71 // 1. Attempts to cast the value to the expected ExtensionValue<TValue>. 72 // This is the usual case. It is used first because it avoids possibly boxing the value. 73 // 2. Fallback to get the value as object from IExtensionValue then casting. 74 // This allows for someone to specify a TValue of object. They can then convert 75 // the values to bytes and reparse using expected value. 76 // 3. If neither of these work, throw a user friendly error that the types aren't compatible. 77 if (value is ExtensionValue<TValue> extensionValue) 78 { 79 return extensionValue.GetValue(); 80 } 81 else if (value.GetValue() is TValue underlyingValue) 82 { 83 return underlyingValue; 84 } 85 else 86 { 87 var valueType = value.GetType().GetTypeInfo(); 88 if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(ExtensionValue<>)) 89 { 90 var storedType = valueType.GenericTypeArguments[0]; 91 throw new InvalidOperationException( 92 "The stored extension value has a type of '" + storedType.AssemblyQualifiedName + "'. " + 93 "This a different from the requested type of '" + typeof(TValue).AssemblyQualifiedName + "'."); 94 } 95 else 96 { 97 throw new InvalidOperationException("Unexpected extension value type: " + valueType.AssemblyQualifiedName); 98 } 99 } 100 } 101 else 102 { 103 return extension.DefaultValue; 104 } 105 } 106 107 /// <summary> 108 /// Gets the value of the specified repeated extension or null if it doesn't exist in this set 109 /// </summary> 110 public static RepeatedField<TValue> Get<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget> 111 { 112 IExtensionValue value; 113 if (TryGetValue(ref set, extension, out value)) 114 { 115 if (value is RepeatedExtensionValue<TValue> extensionValue) 116 { 117 return extensionValue.GetValue(); 118 } 119 else 120 { 121 var valueType = value.GetType().GetTypeInfo(); 122 if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(RepeatedExtensionValue<>)) 123 { 124 var storedType = valueType.GenericTypeArguments[0]; 125 throw new InvalidOperationException( 126 "The stored extension value has a type of '" + storedType.AssemblyQualifiedName + "'. " + 127 "This a different from the requested type of '" + typeof(TValue).AssemblyQualifiedName + "'."); 128 } 129 else 130 { 131 throw new InvalidOperationException("Unexpected extension value type: " + valueType.AssemblyQualifiedName); 132 } 133 } 134 } 135 else 136 { 137 return null; 138 } 139 } 140 141 /// <summary> 142 /// Gets the value of the specified repeated extension, registering it if it doesn't exist 143 /// </summary> 144 public static RepeatedField<TValue> GetOrInitialize<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget> 145 { 146 IExtensionValue value; 147 if (set == null) 148 { 149 value = extension.CreateValue(); 150 set = new ExtensionSet<TTarget>(); 151 set.ValuesByNumber.Add(extension.FieldNumber, value); 152 } 153 else 154 { 155 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out value)) 156 { 157 value = extension.CreateValue(); 158 set.ValuesByNumber.Add(extension.FieldNumber, value); 159 } 160 } 161 162 return ((RepeatedExtensionValue<TValue>)value).GetValue(); 163 } 164 165 /// <summary> 166 /// Sets the value of the specified extension. This will make a new instance of ExtensionSet if the set is null. 167 /// </summary> 168 public static void Set<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension, TValue value) where TTarget : IExtendableMessage<TTarget> 169 { 170 ProtoPreconditions.CheckNotNullUnconstrained(value, nameof(value)); 171 172 IExtensionValue extensionValue; 173 if (set == null) 174 { 175 extensionValue = extension.CreateValue(); 176 set = new ExtensionSet<TTarget>(); 177 set.ValuesByNumber.Add(extension.FieldNumber, extensionValue); 178 } 179 else 180 { 181 if (!set.ValuesByNumber.TryGetValue(extension.FieldNumber, out extensionValue)) 182 { 183 extensionValue = extension.CreateValue(); 184 set.ValuesByNumber.Add(extension.FieldNumber, extensionValue); 185 } 186 } 187 188 ((ExtensionValue<TValue>)extensionValue).SetValue(value); 189 } 190 191 /// <summary> 192 /// Gets whether the value of the specified extension is set 193 /// </summary> 194 public static bool Has<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget> 195 { 196 IExtensionValue value; 197 return TryGetValue(ref set, extension, out value); 198 } 199 200 /// <summary> 201 /// Clears the value of the specified extension 202 /// </summary> 203 public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, Extension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget> 204 { 205 if (set == null) 206 { 207 return; 208 } 209 set.ValuesByNumber.Remove(extension.FieldNumber); 210 if (set.ValuesByNumber.Count == 0) 211 { 212 set = null; 213 } 214 } 215 216 /// <summary> 217 /// Clears the value of the specified extension 218 /// </summary> 219 public static void Clear<TTarget, TValue>(ref ExtensionSet<TTarget> set, RepeatedExtension<TTarget, TValue> extension) where TTarget : IExtendableMessage<TTarget> 220 { 221 if (set == null) 222 { 223 return; 224 } 225 set.ValuesByNumber.Remove(extension.FieldNumber); 226 if (set.ValuesByNumber.Count == 0) 227 { 228 set = null; 229 } 230 } 231 232 /// <summary> 233 /// Tries to merge a field from the coded input, returning true if the field was merged. 234 /// If the set is null or the field was not otherwise merged, this returns false. 235 /// </summary> 236 public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, CodedInputStream stream) where TTarget : IExtendableMessage<TTarget> 237 { 238 ParseContext.Initialize(stream, out ParseContext ctx); 239 try 240 { 241 return TryMergeFieldFrom<TTarget>(ref set, ref ctx); 242 } 243 finally 244 { 245 ctx.CopyStateTo(stream); 246 } 247 } 248 249 /// <summary> 250 /// Tries to merge a field from the coded input, returning true if the field was merged. 251 /// If the set is null or the field was not otherwise merged, this returns false. 252 /// </summary> 253 public static bool TryMergeFieldFrom<TTarget>(ref ExtensionSet<TTarget> set, ref ParseContext ctx) where TTarget : IExtendableMessage<TTarget> 254 { 255 Extension extension; 256 int lastFieldNumber = WireFormat.GetTagFieldNumber(ctx.LastTag); 257 258 IExtensionValue extensionValue; 259 if (set != null && set.ValuesByNumber.TryGetValue(lastFieldNumber, out extensionValue)) 260 { 261 extensionValue.MergeFrom(ref ctx); 262 return true; 263 } 264 else if (ctx.ExtensionRegistry != null && ctx.ExtensionRegistry.ContainsInputField(ctx.LastTag, typeof(TTarget), out extension)) 265 { 266 IExtensionValue value = extension.CreateValue(); 267 value.MergeFrom(ref ctx); 268 set = (set ?? new ExtensionSet<TTarget>()); 269 set.ValuesByNumber.Add(extension.FieldNumber, value); 270 return true; 271 } 272 else 273 { 274 return false; 275 } 276 } 277 278 /// <summary> 279 /// Merges the second set into the first set, creating a new instance if first is null 280 /// </summary> 281 public static void MergeFrom<TTarget>(ref ExtensionSet<TTarget> first, ExtensionSet<TTarget> second) where TTarget : IExtendableMessage<TTarget> 282 { 283 if (second == null) 284 { 285 return; 286 } 287 if (first == null) 288 { 289 first = new ExtensionSet<TTarget>(); 290 } 291 foreach (var pair in second.ValuesByNumber) 292 { 293 IExtensionValue value; 294 if (first.ValuesByNumber.TryGetValue(pair.Key, out value)) 295 { 296 value.MergeFrom(pair.Value); 297 } 298 else 299 { 300 var cloned = pair.Value.Clone(); 301 first.ValuesByNumber[pair.Key] = cloned; 302 } 303 } 304 } 305 306 /// <summary> 307 /// Clones the set into a new set. If the set is null, this returns null 308 /// </summary> 309 public static ExtensionSet<TTarget> Clone<TTarget>(ExtensionSet<TTarget> set) where TTarget : IExtendableMessage<TTarget> 310 { 311 if (set == null) 312 { 313 return null; 314 } 315 316 var newSet = new ExtensionSet<TTarget>(); 317 foreach (var pair in set.ValuesByNumber) 318 { 319 var cloned = pair.Value.Clone(); 320 newSet.ValuesByNumber[pair.Key] = cloned; 321 } 322 return newSet; 323 } 324 } 325 326 /// <summary> 327 /// Used for keeping track of extensions in messages. 328 /// <see cref="IExtendableMessage{T}"/> methods route to this set. 329 /// 330 /// Most users will not need to use this class directly 331 /// </summary> 332 /// <typeparam name="TTarget">The message type that extensions in this set target</typeparam> 333 public sealed class ExtensionSet<TTarget> where TTarget : IExtendableMessage<TTarget> 334 { 335 internal Dictionary<int, IExtensionValue> ValuesByNumber { get; } = new Dictionary<int, IExtensionValue>(); 336 337 /// <summary> 338 /// Gets a hash code of the set 339 /// </summary> GetHashCode()340 public override int GetHashCode() 341 { 342 int ret = typeof(TTarget).GetHashCode(); 343 foreach (KeyValuePair<int, IExtensionValue> field in ValuesByNumber) 344 { 345 // Use ^ here to make the field order irrelevant. 346 int hash = field.Key.GetHashCode() ^ field.Value.GetHashCode(); 347 ret ^= hash; 348 } 349 return ret; 350 } 351 352 /// <summary> 353 /// Returns whether this set is equal to the other object 354 /// </summary> Equals(object other)355 public override bool Equals(object other) 356 { 357 if (ReferenceEquals(this, other)) 358 { 359 return true; 360 } 361 ExtensionSet<TTarget> otherSet = other as ExtensionSet<TTarget>; 362 if (ValuesByNumber.Count != otherSet.ValuesByNumber.Count) 363 { 364 return false; 365 } 366 foreach (var pair in ValuesByNumber) 367 { 368 IExtensionValue secondValue; 369 if (!otherSet.ValuesByNumber.TryGetValue(pair.Key, out secondValue)) 370 { 371 return false; 372 } 373 if (!pair.Value.Equals(secondValue)) 374 { 375 return false; 376 } 377 } 378 return true; 379 } 380 381 /// <summary> 382 /// Calculates the size of this extension set 383 /// </summary> CalculateSize()384 public int CalculateSize() 385 { 386 int size = 0; 387 foreach (var value in ValuesByNumber.Values) 388 { 389 size += value.CalculateSize(); 390 } 391 return size; 392 } 393 394 /// <summary> 395 /// Writes the extension values in this set to the output stream 396 /// </summary> WriteTo(CodedOutputStream stream)397 public void WriteTo(CodedOutputStream stream) 398 { 399 400 WriteContext.Initialize(stream, out WriteContext ctx); 401 try 402 { 403 WriteTo(ref ctx); 404 } 405 finally 406 { 407 ctx.CopyStateTo(stream); 408 } 409 } 410 411 /// <summary> 412 /// Writes the extension values in this set to the write context 413 /// </summary> 414 [SecuritySafeCritical] WriteTo(ref WriteContext ctx)415 public void WriteTo(ref WriteContext ctx) 416 { 417 foreach (var value in ValuesByNumber.Values) 418 { 419 value.WriteTo(ref ctx); 420 } 421 } 422 IsInitialized()423 internal bool IsInitialized() 424 { 425 return ValuesByNumber.Values.All(v => v.IsInitialized()); 426 } 427 } 428 } 429