xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/runtime/constraints.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef XLA_RUNTIME_CONSTRAINTS_H_
17 #define XLA_RUNTIME_CONSTRAINTS_H_
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 namespace xla {
24 namespace runtime {
25 
26 // Constraints on the function argument can be specified with the function
27 // argument attributes.
28 //
29 // Example:
30 //
31 //   func @compute(
32 //     // Rank of the `%arg` must be known at compile time.
33 //     %arg: tensor<*xf32> { rt.constraint = "rank" }
34 //   ) -> tensor<?xf32> { ... }
35 //
36 // TODO(b/187114012): Add attribute verifier to `rt` dialect.
37 constexpr const char* kArgumentConstraintAttrName = "rt.constraint";
38 
39 // Constraint on what argument information must be available at compile time in
40 // order to successfully compile the executable:
41 //
42 //   `rank`  : argument must have statically known rank.
43 //   `shape` : argument must have statically known shape.
44 //   `value` : argument must have statically known value, and such arguments
45 //             replaced with constants inside the compiled function body and
46 //             and all value constrained argument uses replaced with the sunk
47 //             constant value.
48 //
49 // For now these constraints are supported by arguments of shaped types (tensors
50 // or memrefs), but potentially can be extended to support open type hierarchy
51 // of user-defined types.
52 //
53 // XLA program example:
54 //
55 //   func @main(
56 //     %input0: memref<*xf32>   { rt.constraint = "rank"  },
57 //     %input1: memref<?x?xf32> { rt.constraint = "shape" },
58 //     %perm: memref<4xi32>     { rt.constraint = "value" }
59 //   ) attributes { rt.entrypoint } { ... }
60 //
61 // Entrypoint function can define constraints on its arguments, that must be
62 // resolved before the function can be compiled. If constraints can't be
63 // resolved statically from the function signature (e.g. rank is unknown), then
64 // the runtime will specialize generic function to concrete operands at runtime
65 // (concrete operands rank, shape or value).
66 //
67 // If function arguments do not have unresolved constraints, compiler can
68 // instantiate the default executable, that can take all compatible inputs
69 // without recompilation.
70 //
71 // (a) Rank constraint:
72 //
73 //     %arg : tensor<*xf32> { rt.constraint = "rank" }
74 //
75 //     Before compiling the function, unranked input type will be updated to the
76 //     corresponding ranked input type (e.g. unranked tensor -> ranked tensor).
77 //
78 // (b) Shape constraint:
79 //
80 //     %arg : tensor<?x?xf32> { rt.constraint = "shape" }
81 //
82 //     Shape of the runtime argument will be used to specialize the compiled
83 //     function, if this shape seen the first time, it will trigger function
84 //     recompilation.
85 //
86 // (c) Value constraint:
87 //
88 //     %reduction_dimension : tensor<i32> { rt.constraint = "value" }
89 //
90 //     Runtime value will be sunk into the body of a function as a constant,
91 //     and the function will be recompiled. For example this can be used to sink
92 //     reduction dimensions to generate more efficient code.
93 //
94 //     Value constraint is only supported for the integer data type, in practice
95 //     it should be reduction dimension, dimension permutation, or any similar
96 //     value that does not change often, and is required for generating
97 //     efficient code.
98 //
99 //  Shape and value specialization example:
100 //
101 //    // Computes `%arg0` mean value over the axis specified by the `%arg1`.
102 //    // See: https://www.tensorflow.org/api_docs/python/tf/math/reduce_mean
103 //    func @mean(%arg0: tensor<?x?xf32>, %arg1: tensor<i32>) -> tensor<?xf32> {
104 //      %0 = "tf.Mean(%arg0, %arg1)
105 //             : (tensor<?x?xf32>, tensor<i32>) -> tensor<?xf32>
106 //      return %0: tensor<?xf32>
107 //    }
108 //
109 //  Shape specialization to input shapes: [tensor<4x8xf32>, tensor<f32>]
110 //
111 //    func @mean(%arg0: tensor<4x8xf32>, %arg1: tensor<i32>) -> tensor<?xf32> {
112 //      %0 = "tf.Mean(%arg0, %arg1)
113 //             : (tensor<4x8xf32>, tensor<i32>) -> tensor<?xf32>
114 //      return %0: tensor<?xf32>
115 //    }
116 //
117 //    Shape specialization in this particular case doesn't bring much
118 //    improvement, because without knowing the reduction axis we can't infer
119 //    any new information from the input shape alone.
120 //
121 //  Value specialization to input values: [ <do-not-specialize>, dense<1 : i32>]
122 //
123 //    func @mean(%arg0: tensor<4x8xf32>) -> tensor<4xf32> {
124 //      %0 = "tf.Constant" { value = dense<1 : i32>} -> tensor<i32>
125 //      %1 = "tf.Mean(%arg0, %0)
126 //             : (tensor<4x8xf32>, tensor<i32>) -> tensor<4xf32>
127 //      return %1 : tensor<4xf32>
128 //    }
129 //
130 //    By specializing function to the concrete value of the second argument, by
131 //    sinking it into the function body we can infer the output shape. Also this
132 //    information allows to statically choose reduction implementation optimized
133 //    for reducing along the inner most dimension.
134 //
135 //    Furthermore static information about reduction axis allows to lower mean
136 //    operation to Linalg generic operation. Dynamic reduction axis is not
137 //    representable in Linalg, and would require multi-versioning and dynamic
138 //    dispatch at runtime.
139 //
140 enum class ArgumentConstraint {
141   // Constraint was resolved based on the static information in the function
142   // signature type or it was never specified by the argument attribute.
143   kResolved = 0,
144   kRank = 1,
145   kShape = 2,
146   kValue = 3
147 };
148 
149 llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
150                               const ArgumentConstraint& constraint);
151 llvm::raw_ostream& operator<<(llvm::raw_ostream& os,
152                               llvm::ArrayRef<ArgumentConstraint> constraints);
153 
154 // Converts argument constraint string to the corresponding enum class.
155 llvm::Expected<ArgumentConstraint> ParseArgumentConstraint(llvm::StringRef str);
156 
157 }  // namespace runtime
158 }  // namespace xla
159 
160 #endif  // XLA_RUNTIME_CONSTRAINTS_H_
161