1syntax = "proto3"; 2 3package tensorflow.tpu; 4 5import "google/protobuf/wrappers.proto"; 6import "tensorflow/compiler/xla/service/hlo.proto"; 7 8message ClippingLimits { 9 google.protobuf.FloatValue lower = 1; // -inf if not set 10 google.protobuf.FloatValue upper = 2; // +inf if not set 11} 12 13// Configuration for simulated quantization; simulated quantization is used to 14// reduce training/serving skew when the serving variables are quantized. The 15// same quantization operations are executed during training to minimize 16// differences with serving. 17// 18// Simulated quantization inserts the following operations on the forward pass 19// after gathering the embedding vector from HBM. The backward pass operations 20// are unchanged. 21// 22// clipped_val = clip(input, clipping_limits) 23// quantum = clipping_limits.range() / (num_buckets - 1) 24// quantized_val = floor((clipped_val - clipping_limits.lower()) / quantum + .5) 25// return quantized_val * quantum + clipping_limits.lower(). 26message SimulatedQuantization { 27 // Whether simulated quantization is enabled. 28 bool enabled = 1; 29 30 // Minimum and maximum values of the range used for quantization. 31 ClippingLimits clipping_limits = 2; 32 33 // Number of possible quantized values. 34 int32 num_buckets = 3; 35} 36 37// Dynamic learning rate specification in the TPUEmbeddingConfiguration. The 38// actual learning rates are provided as a scalar input list to the 39// SendTPUEmbeddingGradients Op indexed by their tag specified through the 40// following proto. 41message DynamicLearningRate { 42 // For tables where learning rates are dynamically computed and communicated 43 // to the TPU embedding program, a tag must be specified for the learning 44 // rate. 45 // 46 // The tag must be a non-negative integer. The total number of unique tags 47 // must be less than or equal to the number of tables in the TPU embedding 48 // configuration (a table does not specify any tag if it uses a constant 49 // learning rate, and specifies exactly one tag if it uses dynamic learning 50 // rates). 51 // 52 // All tags in the range [0, number_of_unique_tags) must be present in the TPU 53 // embedding configuration, i.e. a tag cannot be skipped if a different tag 54 // numerically greater than it is used in the configuration. 55 // 56 // If multiple tables specify the same tag, they *MUST* have 57 // the same dynamic learning rate, for example, their dynamic learning rate 58 // could be computed by the same TensorFlow sub-graph. The partitioning of the 59 // embedding layer would be more optimal if the number_of_unique_tags is as 60 // *LOW* as possible, i.e., if many tables share the same tag. 61 // 62 // The learning_rate input of the SendTPUEmbeddingGradients op is used to 63 // communicate dynamic learning rates to the TPU embedding program. 64 // The learning_rate input is a list of scalars where the size of the list is 65 // equal to the number of unique tags. The learning rate associated with a 66 // particular tag is specified by populating its corresponding index in the 67 // list of learning_rate scalars. 68 int32 tag = 1; 69} 70 71// Source of learning rate to use. 72message LearningRate { 73 oneof learning_rate { 74 float constant = 1; 75 DynamicLearningRate dynamic = 2; 76 } 77} 78 79// Each optimizer's parameter proto has a link to its documentation and CPU 80// implementation (if available) for user reference. 81 82// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adagrad 83// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1634 84message AdagradParameters { 85 // Old initial accumulator parameter. 86 reserved "initial_accumulator"; 87 reserved 1; 88} 89 90// This optimizer combines the Adagrad and Momentum update rules. 91// accum(new) = beta2 == 1.0 ? 92// accum(old) + grad^2 : 93// beta2 * accum(old) + (1 - beta2) * grad^2 94// accum_with_exponent = (accum(new) + epsilon)^(-1.0 / exponent) 95// mom_accum(new) = momentum * mom_accum(old) + accum_with_exponent 96// update = use_nesterov ? 97// momentum * mom_accum(new) + accum_with_exponent : 98// mom_accum(new) 99// var(new) = var(old) - lr * grad * update 100// Algorithm described in https://arxiv.org/abs/2002.11803. 101message AdagradMomentumParameters { 102 // Moving average parameter for the momentum accumulator. 103 float momentum = 1; 104 105 // Whether to use the Nesterov variant of momentum. 106 bool use_nesterov = 2; 107 108 // Exponent for the gradient^2 accumulator. 109 float exponent = 3; 110 111 // Moving average parameter for the gradient^2 accumulator. 112 float beta2 = 4; 113 114 // Offset added to the Adagrad accumulator. 115 float epsilon = 5; 116} 117 118// Algorithm in http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf. 119message BoundedAdagradParameters { 120 // Whether to use the updated or the old value of the accumulator when 121 // computing the effective learning rate. When update_accumulator_first is set 122 // to True, the updated value of the accumulator is used. 123 bool update_accumulator_first = 1; 124 125 // The max_var_update value to use. Set value to 0 (default) to disable using 126 // max_var_update to clip the gradient. 127 float max_var_update = 2; 128 129 // The maximum value of the accumulator. Set max_accumulator to 0 (default) 130 // to disable using max_accumulator to clip the accumulator. 131 float max_accumulator = 3; 132} 133 134// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD 135// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L629 136message StochasticGradientDescentParameters {} 137 138// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Ftrl 139// https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf 140// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L2646 141// 142// The hyperparameters for FTRL are the same as for the Keras implementation, 143// with some additions. The "beta" parameter matches the behavior described in 144// the second link above; "beta" / (2 * learning rate) should be added to "l2" 145// to get equivalent behavior in the other TensorFlow implementations of this 146// optimizer. When the multiply_linear_by_lr field is set to true, a modified 147// formula is used for FTRL that treats the "linear" accumulator as being 148// pre-multiplied by the learning rate (i.e., the accumulator named "linear" 149// actually stores "linear * learning_rate"). Other than checkpoint 150// compatibility, this is mathematically equivalent for a static learning rate; 151// for a dynamic learning rate, it is nearly the same as long as the learning 152// rate does not change quickly. The benefit of setting multiply_linear_by_lr to 153// true is that the modified formula handles zero and near-zero learning rates 154// without producing NaNs, improving flexibility for learning rate ramp-up. 155message FtrlParameters { 156 float l1 = 1; 157 float l2 = 2; 158 float lr_power = 3; 159 float beta = 7; 160 bool multiply_linear_by_lr = 6; 161 162 // Previously, allow_zero_accumulator parameter changed some internal formulas 163 // to allow zero and near-zero accumulator values at the cost of some 164 // performance. The current implementation ignores this parameter; zero or 165 // near-zero accumulator values are now always supported. 166 bool allow_zero_accumulator = 8 [deprecated = true]; 167 168 // Old initial accumulator parameters. 169 reserved "initial_accum", "initial_linear"; 170 reserved 4, 5; 171} 172 173// The Adam optimizer does not implement hyper-parameter update due to hardware 174// limitations; use the dynamic learning rate feature instead, setting the 175// learning rate to: user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) 176// Here, t is the current timestep. 177// 178// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adam 179// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L32 180// 181// Note that the code by default implements the lazy version of Adam 182// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer) 183// unless the use_non_lazy_adam parameter is set, in which case it implements 184// the normal version of Adam that updates all parameters in the embedding 185// table, even for entries that are not used in the current minibatch 186// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If 187// use_non_lazy_adam is enabled, gradient accumulation is also required to be 188// enabled in order to get correct results; a warning will be printed otherwise 189// (which may change to an error in the future). If use_sum_inside_sqrt is set, 190// the Adam variable update formula will be changed from m / (sqrt(v) + epsilon) 191// to m / sqrt(v + epsilon**2); this option improves the performance of TPU 192// training and is not expected to harm model quality. 193message AdamParameters { 194 float beta1 = 3; 195 float beta2 = 4; 196 float epsilon = 5; 197 bool use_non_lazy_adam = 8; 198 bool use_sum_inside_sqrt = 10; 199 200 // Old initial accumulator parameters. 201 reserved "initial_m", "initial_v"; 202 reserved 6, 7; 203} 204 205// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/SGD 206// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L3068 207message MomentumParameters { 208 float momentum = 1; 209 bool use_nesterov = 2; 210 211 // Old initial accumulator parameter. 212 reserved "initial_accum"; 213 reserved 3; 214} 215 216// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop 217// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4229 218message RmsPropParameters { 219 float rho = 1; 220 float momentum = 2; 221 float epsilon = 3; 222 223 // Old initial accumulator parameters. 224 reserved "initial_ms", "initial_mom"; 225 reserved 4, 5; 226} 227 228// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/RMSprop 229// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L4358 230message CenteredRmsPropParameters { 231 float rho = 1; 232 float momentum = 2; 233 float epsilon = 3; 234 235 // Old initial accumulator parameters. 236 reserved "initial_ms", "initial_mom", "initial_mg"; 237 reserved 4, 5, 6; 238} 239 240// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf 241message MdlAdagradLightParameters { 242 float l2 = 1; 243 float lr_power = 2; 244 float min_servable_mdl_benefit = 3; 245 float mdl_mix_in_margin = 4; 246 float mdl_benefit_rampup_coeff = 5; 247 float mdl_min_weight = 6; 248 float benefit_revisit_scale = 7; 249 float max_event_benefit = 8; 250 float max_total_benefit = 9; 251 float mdl_hard_limit = 10; 252 bool hard_limit_min_benefit = 11; 253 bool mdl_regularize = 12; 254 255 // Old initial accumulator parameters. 256 reserved "initial_accumulator", "initial_weight", "initial_benefit"; 257 reserved 13, 14, 15; 258} 259 260// https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/Adadelta 261// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L933 262message AdadeltaParameters { 263 float rho = 1; 264 float epsilon = 2; 265 266 // Old initial accumulator parameters. 267 reserved "initial_accumulator", "initial_update"; 268 reserved 3, 4; 269} 270 271// https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalAdagradOptimizer 272// https://github.com/tensorflow/tensorflow/blob/6b6471f3ffb7f1fefe42d814aa5fb9ab7a535b58/tensorflow/core/kernels/training_ops.cc#L1961 273message ProximalAdagradParameters { 274 float l1 = 1; 275 float l2 = 2; 276 277 // Old initial accumulator parameter. 278 reserved "initial_accumulator"; 279 reserved 3; 280} 281 282// The online Yogi optimizer does not implement hyper-parameter update; use the 283// dynamic learning rate feature instead, setting the learning rate to: 284// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) 285// Here, t is the current timestep. 286// 287// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf 288// plus some extensions based on FTRL. 289// 290// Note that the code by default implements the lazy version of online Yogi. 291message OnlineYogiParameters { 292 // The L1 regularization parameter (used analogously to the one in FTRL). 293 float l1 = 1; 294 295 // The L2 regularization parameter (used analogously to the one in FTRL). 296 float l2 = 2; 297 298 // \beta_2 from Algorithm 2 in the paper. 299 float beta2 = 3; 300 301 // Reserved ids corresponding to removed tanh activation. 302 reserved 6; // sign 303 reserved 7; // tanh 304} 305 306// The online Yogi optimizer does not implement hyper-parameter update; use the 307// dynamic learning rate feature instead, setting the learning rate to: 308// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) 309// Here, t is the current timestep. 310// 311// https://papers.nips.cc/paper/8186-adaptive-methods-for-nonconvex-optimization.pdf 312// plus some extensions based on FTRL. 313// 314// Note that the code by default implements the lazy version of proximal Yogi. 315message ProximalYogiParameters { 316 // The L1 regularization parameter. 317 float l1 = 1; 318 319 // The L2 regularization parameter. 320 float l2 = 2; 321 322 // The exponential decay rate for the 1st moment estimates. 323 float beta1 = 3; 324 325 // The exponential decay rate for the 2nd moment estimates. 326 float beta2 = 4; 327 328 // A constant trading off adaptivity and noise. 329 float epsilon = 5; 330 331 // Reserved ids corresponding to removed tanh activation. 332 reserved 8; // sign 333 reserved 9; // tanh 334} 335 336// Estimator for the frequency of updates to a lookup table. It maintains an 337// array (tf.Variable) D, where each element records the average number of 338// global steps between two consecutive batches that hit the corresponding 339// bucket. Once an item with bucket id i is sampled, D[i] is updated by: 340// D[i] <- D[i] * (1 - tau) + delta[i] * tau, 341// 342// where tau is a learning rate between 0 and 1 (exclusive), and 343// delta[i] = current global step - last step i is sampled. 344// 345// The estimated frequency (sampling rate in a batch) is thus 1 / D[i]. 346// 347// Elements in D are initialized with a large value max_delta. delta[i] will 348// also be capped by this value. 349// 350// The exact sequence of operations used in the optimizer is shown below. 351// last_hit_step[i] is a tf.Variable that holds the last global step at which i 352// was sampled. 353// 354// delta = global_step - last_hit_step[i] 355// clipped_delta = min(delta, params.max_delta) 356// is_outlier = (delta >= params.outlier_threshold * D[i]) 357// D[i] <- is_outlier ? clipped_delta 358// : D[i] * (1 - params.tau) + clipped_delta * params.tau 359// last_hit_step[i] <- global_step 360message FrequencyEstimatorParameters { 361 // Learning rate between (0, 1) that is used to update the array D. 362 float tau = 1; 363 364 // Maximum value of delta: difference between the current global step and the 365 // last global step at which the row was sampled. 366 float max_delta = 2; 367 368 // Threshold used to determine whether the current update is an outlier. 369 float outlier_threshold = 3; 370 371 // The weight exponent used to transform the estimated delta into weights. 372 // The transformation function is: (delta / max_delta) ^ (weight_exponent) 373 float weight_exponent = 4; 374} 375 376// A user-defined optimizer. 377// The contained HLO program must take the following arguments in the following 378// order: 379// 1. gradients 380// 2. table weights 381// 3. slot variables 382// 4. an optional scalar input that is passed in via the dynamic learning 383// rate mechanism. 384// 385// It must return/end in a tuple op that contains the following values in the 386// following order: 387// 1. new table values 388// 2. new slot variable value 389// 390// The program must have shape (1,1) with dtype float32 throughout and only use 391// HLO that operate elementwise (e.g., no reduce, no variables, no control flow 392// and no broadcasting outside of the single scalar input). 393// The HLO program should be written as if it were a dense update. It will be 394// called on each row that needs an update and will applied elementwise. 395message UserDefinedProgramParameters { 396 xla.HloModuleProto program = 1; 397 reserved 2; // Was padding_values 398} 399 400// Optimizer that just sets the variable to the value of the gradient. To be 401// correct, this requires either gradient accumulation (to sum the values of a 402// computed expression across the samples) or to deduplicate IDs within a single 403// host (to assign the value from an arbitrary sample). 404message AssignParameters {} 405 406// Status of using gradient accumulation (doing two passes over the input 407// gradients: one to accumulate them into a temporary array and another to apply 408// them using the actual optimization algorithm). The extra message is to wrap 409// the enum for scoping. 410message GradientAccumulationStatus { 411 // if UNSPECIFIED (default), gradient accumulation is ENABLED. 412 enum Status { 413 UNSPECIFIED = 0; 414 ENABLED = 1; 415 DISABLED = 2; 416 } 417} 418 419// Whether to optimize the packing of low-dimensional embedding tables in HBM 420// (high bandwidth memory). TPUs access HBM at 32-byte (8-float) granularity. 421// For functional correctness, the TPU software internally pads the embedding 422// dimension to a multiple of 8. This can sometimes lead to significant memory 423// wastage due to padding. For 1-dimensional, 2-dimensional, and 4-dimensional, 424// the TPU software can remove this padding by packing multiple rows into the 425// same 8-float HBM chunk. For example, 8 rows could be packed into the same 426// 8-float chunk for a 1-dimensional embedding table. 427 428// There is one important limitation for this HBM packing though. When only a 429// subset of rows in an 8-float chunk are accessed on a particular step, the 430// adjoining rows in the same chunk are updated with zero gradients on the 431// backward pass even if they are not touched. This is an artifact of the 432// packing implementation. This operation is NOT functionally correct for 433// optimizers where zero gradients change the embeddings/slot-variable values, 434// e.g., momentum-based optimizers. Hence, this HBM packing cannot be enabled 435// for embedding tables with such optimizers. The TPU software automatically 436// recognizes that a zero gradient can modify state and turns off the low 437// dimensional embedding packing in that scenario. 438// 439// However, for optimizers where a zero gradient is a NoOp, such as SGD, 440// Adagrad, and FTRL, this packing optimization can be used. However, there are 441// some important considerations: 442// * Clipping limits: The initial values for such embeddings should fall within 443// the clipping limits specified in the optimization parameters. Otherwise, a 444// zero gradient will cause the embeddings to be clipped. This changes state 445// and hence, is not a NoOp. 446// * FTRL: The embedding vector is computed directly from the values of the 447// accumulator and linear slot variables. Hence, the initial embedding values 448// should match that computed from the initial values of the accumulator and 449// linear slot variables. Note that in nearly all cases, the linear value is 450// initialized to zero; this corresponds to an embedding value of zero. 451// 452// Performance: The TPU has to perform additional work when low dimensional 453// packing is enabled. In certain situations when the vocabulary size is small, 454// it may not make sense to turn on this packing since the total memory usage 455// due to padding is extremely low. Hence, the TPU software automatically turns 456// off the packing optimization in such scenarios. 457message LowDimensionalPackingStatus { 458 // if UNSPECIFIED (default), the low dimension packing status is DISABLED. 459 // This can change in future. 460 // 461 // if ENABLED, the low dimension packing is enabled only if the following 462 // three additional conditions are true: 463 // * The optimizer treats the zero gradient as a NoOp. 464 // * The embedding dimension is 1, 2, or 4. 465 // * The vocabulary size is large enough to avoid performance issues. 466 // 467 // if DISABLED, the low dimension packing is always disabled. 468 enum Status { 469 UNSPECIFIED = 0; 470 ENABLED = 1; 471 DISABLED = 2; 472 } 473} 474 475// Configuration proto for hot ID optimization. This is an experimental feature 476// that is currently disabled (by default). 477message HotIdReplicationConfiguration { 478 // Whether to enable or disable hot ID optimization. 479 // If UNSPECIFIED (default), hot ID optimization is DISABLED. 480 enum Status { 481 UNSPECIFIED = 0; 482 ENABLED = 1; 483 DISABLED = 2; 484 } 485 Status status = 1; 486} 487 488message OptimizationParameters { 489 // Learning rate used for updating the embedding layer parameters. 490 LearningRate learning_rate = 13; 491 reserved 1; // Old learning rate tag. 492 493 // Limits to which to clip the weight values after the backward pass; not 494 // present means no limits are applied. 495 ClippingLimits clipping_limits = 2; 496 497 // Limits to which to clip the backward pass gradient before using it for 498 // updates; not present means no limits are applied. 499 ClippingLimits gradient_clipping_limits = 7; 500 501 // Amount of weight decay to apply; see weight_decay_optimizers.py for 502 // details. All optimizers except MDL Adagrad Light are supported with this 503 // option. Although there is no check, users who want weight decay will also 504 // want to ensure that gradient accumulation is enabled so that the decay will 505 // happen once per global batch. 506 float weight_decay_factor = 16; 507 508 // If true, the weight decay factor is multiplied by the current learning rate 509 // before use; this is to match the note in DecoupledWeightDecayExtension in 510 // weight_decay_optimizers.py. 511 bool multiply_weight_decay_factor_by_learning_rate = 22; 512 513 // Configuration for simulated quantization which is used to reduce 514 // training/serving skew when the serving variables are quantized. The same 515 // quantization operations are executed during training to minimize 516 // differences with serving. 517 SimulatedQuantization simulated_quantization = 27; 518 519 // Status of using gradient accumulation (doing two passes over the input 520 // gradients: one to accumulate them into a temporary array and another to 521 // apply them using the actual optimization algorithm). 522 GradientAccumulationStatus.Status gradient_accumulation_status = 17; 523 524 // Status of the low-dimensional embedding packing optimization. This controls 525 // whether to optimize the packing of 1-dimensional, 2-dimensional, and 526 // 4-dimensional embedding tables in memory. 527 LowDimensionalPackingStatus.Status low_dimensional_packing_status = 28; 528 529 // Configuration proto for hot ID replication. This is an experimental 530 // feature that is currently disabled (by default). 531 HotIdReplicationConfiguration hot_id_replication_configuration = 18; 532 533 // Optimization algorithm parameters; which field is selected determines which 534 // algorithm to use. 535 oneof parameters { 536 AdagradParameters adagrad = 3; 537 AdagradMomentumParameters adagrad_momentum = 26; 538 BoundedAdagradParameters bounded_adagrad = 19; 539 StochasticGradientDescentParameters stochastic_gradient_descent = 4; 540 FtrlParameters ftrl = 5; 541 AdamParameters adam = 6; 542 MomentumParameters momentum = 8; 543 RmsPropParameters rms_prop = 9; 544 CenteredRmsPropParameters centered_rms_prop = 10; 545 MdlAdagradLightParameters mdl_adagrad_light = 11; 546 AdadeltaParameters adadelta = 12; 547 ProximalAdagradParameters proximal_adagrad = 14; 548 OnlineYogiParameters online_yogi = 20; 549 ProximalYogiParameters proximal_yogi = 21; 550 FrequencyEstimatorParameters frequency_estimator = 23; 551 UserDefinedProgramParameters user_defined_program = 24; 552 AssignParameters assign = 25; 553 } 554 555 reserved 15; // Old use_gradient_accumulation. 556 557 // NEXT_ID: 29 558} 559 560// Specification of an optimization algorithm's state variables (both the main 561// value vector and any extra accumulators, etc.). This proto is only used 562// internally by the TPU software and is not exposed directly to the TF model. 563message StateVariableSpecification { 564 // Parameter name for the state variable. 565 string name = 1; 566 567 // A normal state variable that should be saved and restored in checkpoints 568 // and used as an input or output to non-debug TensorFlow ops. 569 message UserDefined { 570 reserved 1; // Was padding_initial_value. 571 } 572 573 // A state variable that should be filled with a constant and normally hidden 574 // from users (used for intermediate gradients being accumulated, for 575 // example). 576 message FillWithConstant { 577 double initial_value = 1; 578 } 579 580 // Usage type of this state variable. 581 oneof usage { 582 UserDefined user_defined = 2; 583 FillWithConstant fill_with_constant = 3; 584 } 585} 586