xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_expression.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/compiler/tf2xla/xla_resource.h"
21 #include "tensorflow/compiler/xla/client/client.h"
22 #include "tensorflow/compiler/xla/client/value_inference.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/statusor.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/lib/core/status.h"
27 
28 namespace tensorflow {
29 
30 // A XlaExpression represents a symbolic TensorFlow value in a TF->XLA
31 // compilation.
32 // An expression is one of:
33 // * a constant tensor.
34 // * an xla::XlaOp, representing a symbolic XLA value.
35 // * a resource, e.g., a variable, represented as an XlaResource pointer.
36 // * a tensor list, represented by a tuple of tensors and the list length.
37 //
38 // Constant tensors are mostly an optimization to avoid passing large constants
39 // to XLA, but are also sometimes used to represent tensors that have no XLA
40 // representation, for example, DT_STRING tensors. A canonical use case might be
41 // an error message string.
42 //
43 // Tensor lists are very similar to xla::XlaOp, however they require some
44 // specific logic around shape management since the tuples are not supported by
45 // TensorFlow.
46 class XlaExpression {
47  public:
48   enum class Kind {
49     kInvalid,
50     kConstant,
51     kXlaOp,
52     kResource,
53     kTensorList,
54   };
55 
56   XlaExpression();
57   XlaExpression(const XlaExpression&) = default;
58   XlaExpression& operator=(const XlaExpression&) = default;
59 
60   // Builds an invalid expression. (Same as the default constructor, but makes
61   // the intent clearer.)
62   static XlaExpression Invalid();
63 
64   // Builds a constant XLA expression.
65   static XlaExpression Constant(Tensor value);
66 
67   // Builds a XlaOp expression. Since the mapping from TF data types to XLA
68   // types is not 1-1, the TF type must also be provided; in general it cannot
69   // be derived from the XLA type.
70   static XlaExpression XlaOp(xla::XlaOp value, DataType dtype);
71 
72   // Builds a tensor list expression.
73   static XlaExpression TensorList(xla::XlaOp tensor_list);
74 
75   // Builds a resource expression.
76   static XlaExpression Resource(XlaResource* resource);
77 
78   // Builds a resource whose value is known at a compile time.
79   static XlaExpression ConstantResource(Tensor value, XlaResource* resource);
80 
kind()81   Kind kind() const { return kind_; }
82 
dtype()83   DataType dtype() const { return dtype_; }
84 
85   // handle() returns the XlaOp that backs a kXlaOp expression.
handle()86   const xla::XlaOp& handle() const { return handle_; }
87 
88   // Return a constant value associated with this expression. Always set for
89   // constants, might be set for resources.
constant_value()90   std::optional<Tensor> constant_value() const {
91     if (kind_ == Kind::kResource && resource_->IsOverwritten()) {
92       // The constant is no longer available if the value was overwritten.
93       return std::nullopt;
94     }
95     return constant_value_;
96   }
97 
98   // Set the bound of the expression.
set_value_bound(Tensor tensor)99   void set_value_bound(Tensor tensor) {
100     value_bound_.emplace(std::move(tensor));
101   }
102 
103   // Return the bound of the expression, if available.
value_bound()104   std::optional<Tensor> value_bound() const { return value_bound_; }
105 
106   // Set the dynamism of the expression, indicating whether or not each value in
107   // this expression is dynamic.
set_value_dynamism(Tensor tensor)108   void set_value_dynamism(Tensor tensor) {
109     value_dynamism_.emplace(std::move(tensor));
110   }
111 
112   // Return the dynamism of the expression, if available.
value_dynamism()113   std::optional<Tensor> value_dynamism() const { return value_dynamism_; }
114 
resource()115   XlaResource* resource() const { return resource_; }
116 
117   // Returns a human-readable summary of the expression.
118   string HumanString() const;
119 
120   // Returns the value of a kValue or kXlaOp as an xla::XlaOp. Returns
121   // an erroneous XlaOp if the expression is not a constant or an expression.
122   xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
123 
124   // If a kXlaOp or kValue expression can be resolved to a compile-time
125   // constant, returns the value as a host-memory Tensor. Returns an empty
126   // optional if it cannot be resolved. Returns an error if passed a resource
127   // expression.
128   StatusOr<std::optional<Tensor>> ResolveConstant(
129       xla::Client* client, bool dynamic_dimension_is_minus_one = false,
130       xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue) const;
131 
132   // ResolveDynamism computes where a value inside this op is dynamic or can be
133   // inferred at compile time.
134   StatusOr<Tensor> ResolveDynamism(xla::Client* client) const;
135 
136   // Returns the shape of the tensor.
137   // The shape of a resource is the shape of a resource handle (i.e., a scalar),
138   // not the shape of the resource's value.
139   StatusOr<TensorShape> GetShape() const;
140 
141   // Retrieves an XlaExpression that was allocated by a previous Op.
142   static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
143 
144   // Assigns an XlaExpression to a tensor on an XLA compilation device.
145   static void AssignExpressionToTensor(const XlaExpression& value,
146                                        Tensor* tensor);
147 
148  private:
149   Kind kind_ = Kind::kInvalid;
150 
151   DataType dtype_ = DT_INVALID;
152 
153   // The XLA handle of the expression's computation, if kind_ == kXlaOp or
154   // a tuple expression if kind_ == kTensorList.
155   xla::XlaOp handle_;
156 
157   // The value of the constant, if available.
158   std::optional<Tensor> constant_value_;
159 
160   // The bound of the expression, if available.
161   std::optional<Tensor> value_bound_;
162 
163   // Indicate whether each value inside a tensor is dynamic or not.
164   std::optional<Tensor> value_dynamism_;
165 
166   // The resource, if kind_ == kResource. Not owned.
167   XlaResource* resource_ = nullptr;
168 };
169 
170 }  // namespace tensorflow
171 
172 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
173