1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_ 17 #define TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_ 18 19 #include "absl/types/optional.h" 20 #include "tensorflow/c/eager/c_api.h" 21 #include "tensorflow/core/framework/node_def_builder.h" 22 #include "tensorflow/dtensor/cc/tensor_layout.h" 23 24 namespace tensorflow { 25 namespace dtensor { 26 27 // Attempt to convert small constant tensors into a constant NodeDef operation. 28 // This constant value will be available for constant propagation in DTensor and 29 // MLIR. 30 31 // This conversion is currently required for some DTensor operations. In 32 // particular, reductions require access to the axis argument at compilation 33 // time. While this is not strictly necessary, it greatly simplifies SPMD code 34 // generation and is generally available. 35 absl::optional<NodeDef> ExtractSmallTensorValue(TFE_Context* context, 36 TFE_TensorHandle* tensor, 37 const Layout& layout, 38 TF_Status* status); 39 40 // Returns true if the given input argument should be eligible for extracting 41 // into a graph constant. 42 bool ShouldFoldInputArgument(absl::string_view operation_name, int input_index); 43 44 // Returns true if the tensor proto of a and b are different. 45 bool NodeDefsHaveDifferentTensorProto(const NodeDef& a, const NodeDef& b); 46 } // namespace dtensor 47 } // namespace tensorflow 48 49 #endif // TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_ 50