xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/small_constant_optimization.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_
17 #define TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_
18 
19 #include "absl/types/optional.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/core/framework/node_def_builder.h"
22 #include "tensorflow/dtensor/cc/tensor_layout.h"
23 
24 namespace tensorflow {
25 namespace dtensor {
26 
27 // Attempt to convert small constant tensors into a constant NodeDef operation.
28 // This constant value will be available for constant propagation in DTensor and
29 // MLIR.
30 
31 // This conversion is currently required for some DTensor operations. In
32 // particular, reductions require access to the axis argument at compilation
33 // time. While this is not strictly necessary, it greatly simplifies SPMD code
34 // generation and is generally available.
35 absl::optional<NodeDef> ExtractSmallTensorValue(TFE_Context* context,
36                                                 TFE_TensorHandle* tensor,
37                                                 const Layout& layout,
38                                                 TF_Status* status);
39 
40 // Returns true if the given input argument should be eligible for extracting
41 // into a graph constant.
42 bool ShouldFoldInputArgument(absl::string_view operation_name, int input_index);
43 
44 // Returns true if the tensor proto of a and b are different.
45 bool NodeDefsHaveDifferentTensorProto(const NodeDef& a, const NodeDef& b);
46 }  // namespace dtensor
47 }  // namespace tensorflow
48 
49 #endif  // TENSORFLOW_DTENSOR_CC_SMALL_CONSTANT_OPTIMIZATION_H_
50