1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/compiler/tf2xla/xla_resource.h" 21 #include "tensorflow/compiler/xla/client/client.h" 22 #include "tensorflow/compiler/xla/client/value_inference.h" 23 #include "tensorflow/compiler/xla/client/xla_builder.h" 24 #include "tensorflow/compiler/xla/statusor.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/lib/core/status.h" 27 28 namespace tensorflow { 29 30 // A XlaExpression represents a symbolic TensorFlow value in a TF->XLA 31 // compilation. 32 // An expression is one of: 33 // * a constant tensor. 34 // * an xla::XlaOp, representing a symbolic XLA value. 35 // * a resource, e.g., a variable, represented as an XlaResource pointer. 36 // * a tensor list, represented by a tuple of tensors and the list length. 37 // 38 // Constant tensors are mostly an optimization to avoid passing large constants 39 // to XLA, but are also sometimes used to represent tensors that have no XLA 40 // representation, for example, DT_STRING tensors. A canonical use case might be 41 // an error message string. 42 // 43 // Tensor lists are very similar to xla::XlaOp, however they require some 44 // specific logic around shape management since the tuples are not supported by 45 // TensorFlow. 46 class XlaExpression { 47 public: 48 enum class Kind { 49 kInvalid, 50 kConstant, 51 kXlaOp, 52 kResource, 53 kTensorList, 54 }; 55 56 XlaExpression(); 57 XlaExpression(const XlaExpression&) = default; 58 XlaExpression& operator=(const XlaExpression&) = default; 59 60 // Builds an invalid expression. (Same as the default constructor, but makes 61 // the intent clearer.) 62 static XlaExpression Invalid(); 63 64 // Builds a constant XLA expression. 65 static XlaExpression Constant(Tensor value); 66 67 // Builds a XlaOp expression. Since the mapping from TF data types to XLA 68 // types is not 1-1, the TF type must also be provided; in general it cannot 69 // be derived from the XLA type. 70 static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); 71 72 // Builds a tensor list expression. 73 static XlaExpression TensorList(xla::XlaOp tensor_list); 74 75 // Builds a resource expression. 76 static XlaExpression Resource(XlaResource* resource); 77 78 // Builds a resource whose value is known at a compile time. 79 static XlaExpression ConstantResource(Tensor value, XlaResource* resource); 80 kind()81 Kind kind() const { return kind_; } 82 dtype()83 DataType dtype() const { return dtype_; } 84 85 // handle() returns the XlaOp that backs a kXlaOp expression. handle()86 const xla::XlaOp& handle() const { return handle_; } 87 88 // Return a constant value associated with this expression. Always set for 89 // constants, might be set for resources. constant_value()90 std::optional<Tensor> constant_value() const { 91 if (kind_ == Kind::kResource && resource_->IsOverwritten()) { 92 // The constant is no longer available if the value was overwritten. 93 return std::nullopt; 94 } 95 return constant_value_; 96 } 97 98 // Set the bound of the expression. set_value_bound(Tensor tensor)99 void set_value_bound(Tensor tensor) { 100 value_bound_.emplace(std::move(tensor)); 101 } 102 103 // Return the bound of the expression, if available. value_bound()104 std::optional<Tensor> value_bound() const { return value_bound_; } 105 106 // Set the dynamism of the expression, indicating whether or not each value in 107 // this expression is dynamic. set_value_dynamism(Tensor tensor)108 void set_value_dynamism(Tensor tensor) { 109 value_dynamism_.emplace(std::move(tensor)); 110 } 111 112 // Return the dynamism of the expression, if available. value_dynamism()113 std::optional<Tensor> value_dynamism() const { return value_dynamism_; } 114 resource()115 XlaResource* resource() const { return resource_; } 116 117 // Returns a human-readable summary of the expression. 118 string HumanString() const; 119 120 // Returns the value of a kValue or kXlaOp as an xla::XlaOp. Returns 121 // an erroneous XlaOp if the expression is not a constant or an expression. 122 xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; 123 124 // If a kXlaOp or kValue expression can be resolved to a compile-time 125 // constant, returns the value as a host-memory Tensor. Returns an empty 126 // optional if it cannot be resolved. Returns an error if passed a resource 127 // expression. 128 StatusOr<std::optional<Tensor>> ResolveConstant( 129 xla::Client* client, bool dynamic_dimension_is_minus_one = false, 130 xla::ValueInferenceMode mode = xla::ValueInferenceMode::kValue) const; 131 132 // ResolveDynamism computes where a value inside this op is dynamic or can be 133 // inferred at compile time. 134 StatusOr<Tensor> ResolveDynamism(xla::Client* client) const; 135 136 // Returns the shape of the tensor. 137 // The shape of a resource is the shape of a resource handle (i.e., a scalar), 138 // not the shape of the resource's value. 139 StatusOr<TensorShape> GetShape() const; 140 141 // Retrieves an XlaExpression that was allocated by a previous Op. 142 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); 143 144 // Assigns an XlaExpression to a tensor on an XLA compilation device. 145 static void AssignExpressionToTensor(const XlaExpression& value, 146 Tensor* tensor); 147 148 private: 149 Kind kind_ = Kind::kInvalid; 150 151 DataType dtype_ = DT_INVALID; 152 153 // The XLA handle of the expression's computation, if kind_ == kXlaOp or 154 // a tuple expression if kind_ == kTensorList. 155 xla::XlaOp handle_; 156 157 // The value of the constant, if available. 158 std::optional<Tensor> constant_value_; 159 160 // The bound of the expression, if available. 161 std::optional<Tensor> value_bound_; 162 163 // Indicate whether each value inside a tensor is dynamic or not. 164 std::optional<Tensor> value_dynamism_; 165 166 // The resource, if kind_ == kResource. Not owned. 167 XlaResource* resource_ = nullptr; 168 }; 169 170 } // namespace tensorflow 171 172 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ 173