1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_ 17 #define TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_ 18 19 #include "llvm/ADT/ArrayRef.h" 20 #include "llvm/ADT/SmallVector.h" 21 #include "mlir/IR/Builders.h" // from @llvm-project 22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 24 #include "mlir/IR/Location.h" // from @llvm-project 25 #include "mlir/IR/Value.h" // from @llvm-project 26 #include "tensorflow/dtensor/cc/dstatus.h" 27 28 namespace tensorflow { 29 namespace dtensor { 30 31 int ValueRank(mlir::Value operand_value); 32 33 // Creates a effective scalar type as rank 1 with a single element. 34 mlir::RankedTensorType EffectivelyScalarR1Type(mlir::Type element_type); 35 36 // Reshapes a value of size type tensor<i32> to scalar. 37 mlir::Value ReshapeSizeTypeToScalar(mlir::OpBuilder builder, mlir::Location loc, 38 mlir::Value tensor); 39 40 // Return a 1-D int32 constant array with the given values. 41 mlir::Value IntConst(mlir::OpBuilder& builder, mlir::Location loc, 42 llvm::ArrayRef<int32> values); 43 // Return a 1-D int64 constant array with the given values. 44 mlir::Value Int64Const(mlir::OpBuilder& builder, mlir::Location loc, 45 llvm::ArrayRef<int64_t> values); 46 // Return a 1-D float32 constant array with the given values. 47 mlir::Value FloatConst(mlir::OpBuilder& builder, mlir::Location loc, 48 llvm::ArrayRef<float> values); 49 // Returns a 1-D tf.string constant array with given values. 50 mlir::Value StringConst(mlir::OpBuilder& builder, mlir::Location loc, 51 llvm::ArrayRef<llvm::StringRef> values); 52 53 StatusOr<int64_t> ExtractConstIntFromValue(mlir::Value value); 54 Status ExtractConstVectorFromValue(mlir::Value value, 55 llvm::SmallVector<int64_t, 4>* out_vector); 56 57 // Returns a int64 scalar constant with `value`. 58 mlir::Value CreateIntScalarConst(const int64_t value, mlir::OpBuilder builder, 59 mlir::Location loc, bool use_int64 = true); 60 61 // Returns a scalar constant with 'value' of 'type'. 62 absl::optional<mlir::Value> CreateZeroScalarConst(mlir::OpBuilder& builder, 63 mlir::Location loc, 64 mlir::Type type); 65 66 // Selects a scalar tensor value from a 1D array in specified index. 67 StatusOr<mlir::Value> SelectScalarValueFromArray(mlir::OpBuilder& builder, 68 int index, 69 mlir::Location location, 70 mlir::Value array); 71 72 // Returns the type that value holds. If value holds a Type that has a subtype, 73 // then it returns the subtype. 74 mlir::Type GetSubtypeOrSelf(mlir::Value value); 75 76 } // namespace dtensor 77 } // namespace tensorflow 78 #endif // TENSORFLOW_DTENSOR_MLIR_VALUE_UTILS_H_ 79