1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef XLA_RUNTIME_CONSTRAINTS_H_ 17 #define XLA_RUNTIME_CONSTRAINTS_H_ 18 19 #include "llvm/ADT/ArrayRef.h" 20 #include "llvm/ADT/StringRef.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 namespace xla { 24 namespace runtime { 25 26 // Constraints on the function argument can be specified with the function 27 // argument attributes. 28 // 29 // Example: 30 // 31 // func @compute( 32 // // Rank of the `%arg` must be known at compile time. 33 // %arg: tensor<*xf32> { rt.constraint = "rank" } 34 // ) -> tensor<?xf32> { ... } 35 // 36 // TODO(b/187114012): Add attribute verifier to `rt` dialect. 37 constexpr const char* kArgumentConstraintAttrName = "rt.constraint"; 38 39 // Constraint on what argument information must be available at compile time in 40 // order to successfully compile the executable: 41 // 42 // `rank` : argument must have statically known rank. 43 // `shape` : argument must have statically known shape. 44 // `value` : argument must have statically known value, and such arguments 45 // replaced with constants inside the compiled function body and 46 // and all value constrained argument uses replaced with the sunk 47 // constant value. 48 // 49 // For now these constraints are supported by arguments of shaped types (tensors 50 // or memrefs), but potentially can be extended to support open type hierarchy 51 // of user-defined types. 52 // 53 // XLA program example: 54 // 55 // func @main( 56 // %input0: memref<*xf32> { rt.constraint = "rank" }, 57 // %input1: memref<?x?xf32> { rt.constraint = "shape" }, 58 // %perm: memref<4xi32> { rt.constraint = "value" } 59 // ) attributes { rt.entrypoint } { ... } 60 // 61 // Entrypoint function can define constraints on its arguments, that must be 62 // resolved before the function can be compiled. If constraints can't be 63 // resolved statically from the function signature (e.g. rank is unknown), then 64 // the runtime will specialize generic function to concrete operands at runtime 65 // (concrete operands rank, shape or value). 66 // 67 // If function arguments do not have unresolved constraints, compiler can 68 // instantiate the default executable, that can take all compatible inputs 69 // without recompilation. 70 // 71 // (a) Rank constraint: 72 // 73 // %arg : tensor<*xf32> { rt.constraint = "rank" } 74 // 75 // Before compiling the function, unranked input type will be updated to the 76 // corresponding ranked input type (e.g. unranked tensor -> ranked tensor). 77 // 78 // (b) Shape constraint: 79 // 80 // %arg : tensor<?x?xf32> { rt.constraint = "shape" } 81 // 82 // Shape of the runtime argument will be used to specialize the compiled 83 // function, if this shape seen the first time, it will trigger function 84 // recompilation. 85 // 86 // (c) Value constraint: 87 // 88 // %reduction_dimension : tensor<i32> { rt.constraint = "value" } 89 // 90 // Runtime value will be sunk into the body of a function as a constant, 91 // and the function will be recompiled. For example this can be used to sink 92 // reduction dimensions to generate more efficient code. 93 // 94 // Value constraint is only supported for the integer data type, in practice 95 // it should be reduction dimension, dimension permutation, or any similar 96 // value that does not change often, and is required for generating 97 // efficient code. 98 // 99 // Shape and value specialization example: 100 // 101 // // Computes `%arg0` mean value over the axis specified by the `%arg1`. 102 // // See: https://www.tensorflow.org/api_docs/python/tf/math/reduce_mean 103 // func @mean(%arg0: tensor<?x?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> { 104 // %0 = "tf.Mean(%arg0, %arg1) 105 // : (tensor<?x?xf32>, tensor<i32>) -> tensor<?xf32> 106 // return %0: tensor<?xf32> 107 // } 108 // 109 // Shape specialization to input shapes: [tensor<4x8xf32>, tensor<f32>] 110 // 111 // func @mean(%arg0: tensor<4x8xf32>, %arg1: tensor<i32>) -> tensor<?xf32> { 112 // %0 = "tf.Mean(%arg0, %arg1) 113 // : (tensor<4x8xf32>, tensor<i32>) -> tensor<?xf32> 114 // return %0: tensor<?xf32> 115 // } 116 // 117 // Shape specialization in this particular case doesn't bring much 118 // improvement, because without knowing the reduction axis we can't infer 119 // any new information from the input shape alone. 120 // 121 // Value specialization to input values: [ <do-not-specialize>, dense<1 : i32>] 122 // 123 // func @mean(%arg0: tensor<4x8xf32>) -> tensor<4xf32> { 124 // %0 = "tf.Constant" { value = dense<1 : i32>} -> tensor<i32> 125 // %1 = "tf.Mean(%arg0, %0) 126 // : (tensor<4x8xf32>, tensor<i32>) -> tensor<4xf32> 127 // return %1 : tensor<4xf32> 128 // } 129 // 130 // By specializing function to the concrete value of the second argument, by 131 // sinking it into the function body we can infer the output shape. Also this 132 // information allows to statically choose reduction implementation optimized 133 // for reducing along the inner most dimension. 134 // 135 // Furthermore static information about reduction axis allows to lower mean 136 // operation to Linalg generic operation. Dynamic reduction axis is not 137 // representable in Linalg, and would require multi-versioning and dynamic 138 // dispatch at runtime. 139 // 140 enum class ArgumentConstraint { 141 // Constraint was resolved based on the static information in the function 142 // signature type or it was never specified by the argument attribute. 143 kResolved = 0, 144 kRank = 1, 145 kShape = 2, 146 kValue = 3 147 }; 148 149 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, 150 const ArgumentConstraint& constraint); 151 llvm::raw_ostream& operator<<(llvm::raw_ostream& os, 152 llvm::ArrayRef<ArgumentConstraint> constraints); 153 154 // Converts argument constraint string to the corresponding enum class. 155 llvm::Expected<ArgumentConstraint> ParseArgumentConstraint(llvm::StringRef str); 156 157 } // namespace runtime 158 } // namespace xla 159 160 #endif // XLA_RUNTIME_CONSTRAINTS_H_ 161