xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/small_constant_optimization.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/cc/small_constant_optimization.h"
17 
18 #include <cstdint>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/c/eager/c_api.h"
26 #include "tensorflow/c/eager/c_api_experimental.h"
27 #include "tensorflow/c/tf_status.h"
28 #include "tensorflow/c/tf_tensor_internal.h"
29 #include "tensorflow/core/framework/tensor.pb.h"
30 #include "tensorflow/core/platform/ctstring_internal.h"
31 #include "tensorflow/core/platform/protobuf.h"
32 #include "tensorflow/dtensor/cc/constants.h"
33 #include "tensorflow/dtensor/proto/layout.pb.h"
34 
35 namespace tensorflow {
36 namespace dtensor {
37 
38 namespace {
39 
40 constexpr TF_DataType kAllowedDataType[] = {TF_INT32, TF_INT64, TF_FLOAT,
41                                             TF_STRING};
42 
AppendIntValues(const int num_of_elements,const int * int_values,TensorProto * proto)43 void AppendIntValues(const int num_of_elements, const int* int_values,
44                      TensorProto* proto) {
45   for (int i = 0; i < num_of_elements; ++i) {
46     proto->add_int_val(int_values[i]);
47   }
48 }
49 
AppendInt64Values(const int num_of_elements,const int64_t * int64_values,TensorProto * proto)50 void AppendInt64Values(const int num_of_elements, const int64_t* int64_values,
51                        TensorProto* proto) {
52   for (int i = 0; i < num_of_elements; ++i) {
53     proto->add_int64_val(int64_values[i]);
54   }
55 }
56 
AppendStringValues(const int num_of_elements,const TF_TString * string_values,TensorProto * proto)57 void AppendStringValues(const int num_of_elements,
58                         const TF_TString* string_values, TensorProto* proto) {
59   for (int i = 0; i < num_of_elements; ++i) {
60     proto->add_string_val(
61         std::string(TF_TString_GetDataPointer(&string_values[i]),
62                     TF_TString_GetSize(&string_values[i])));
63   }
64 }
AppendFloatValues(const int num_of_elements,const float * float_values,TensorProto * proto)65 void AppendFloatValues(const int num_of_elements, const float* float_values,
66                        TensorProto* proto) {
67   for (int i = 0; i < num_of_elements; ++i) {
68     proto->add_float_val(float_values[i]);
69   }
70 }
71 
72 }  // namespace
73 
ExtractSmallTensorValue(TFE_Context * context,TFE_TensorHandle * tensor,const Layout & layout,TF_Status * status)74 absl::optional<NodeDef> ExtractSmallTensorValue(TFE_Context* context,
75                                                 TFE_TensorHandle* tensor,
76                                                 const Layout& layout,
77                                                 TF_Status* status) {
78   if (!layout.IsFullyReplicated()) return std::nullopt;
79   auto num_elements = TFE_TensorHandleNumElements(tensor, status);
80   if (TF_GetCode(status) != TF_OK) return absl::nullopt;
81 
82   if (num_elements >= kSmallTensorThreshold) return absl::nullopt;
83 
84   // Check the DType before attempting to resolve the tensor so we don't try to
85   // copy resource-dtype tensors off the DTensor device. Currently we only
86   // extract small int32/int64_t tensors, primarily to catch shapes and axes,
87   // and tf_string tensors that are mostly used in save/restore ops.
88   const auto& dtype = TFE_TensorHandleDataType(tensor);
89   if (absl::c_find(kAllowedDataType, dtype) == std::end(kAllowedDataType)) {
90     return absl::nullopt;
91   }
92 
93   // This is the enum from protobuf, or the following AddNodeAttr will always
94   // set the integer field.
95   const auto& datatype = static_cast<DataType>(dtype);
96   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> value_tensor(
97       TFE_TensorHandleResolve(tensor, status), TF_DeleteTensor);
98   if (TF_GetCode(status) != TF_OK) return absl::nullopt;
99 
100   NodeDef node_def;
101   node_def.set_op("Const");
102   AddNodeAttr("dtype", datatype, &node_def);
103 
104   TensorProto tensor_proto;
105   tensor_proto.set_dtype(datatype);
106   switch (dtype) {
107     case TF_INT32:
108       AppendIntValues(num_elements,
109                       static_cast<int*>(TF_TensorData(value_tensor.get())),
110                       &tensor_proto);
111       break;
112     case TF_INT64:
113       AppendInt64Values(
114           num_elements,
115           static_cast<const int64_t*>(TF_TensorData(value_tensor.get())),
116           &tensor_proto);
117       break;
118     case TF_STRING:
119       AppendStringValues(
120           num_elements,
121           static_cast<const TF_TString*>(TF_TensorData(value_tensor.get())),
122           &tensor_proto);
123       break;
124     case TF_FLOAT:
125       AppendFloatValues(
126           num_elements,
127           static_cast<const float*>(TF_TensorData(value_tensor.get())),
128           &tensor_proto);
129       break;
130     default:
131       TF_SetStatus(status, TF_INTERNAL,
132                    absl::StrCat("dtype: ", dtype,
133                                 " fell through the supported extraction list. "
134                                 "This should not happen.")
135                        .c_str());
136       return absl::nullopt;
137   }
138 
139   std::vector<int64_t> dim_list;
140   int num_dims = value_tensor->tensor->NumDims();
141   dim_list.reserve(num_dims);
142   for (int i = 0; i < num_dims; ++i) {
143     dim_list.push_back(value_tensor->tensor->Dim(i));
144   }
145 
146   TensorShape shape(std::move(dim_list));
147   shape.AsProto(tensor_proto.mutable_tensor_shape());
148   AddNodeAttr("value", tensor_proto, &node_def);
149 
150   AddNodeAttr(kLayoutAttr, {layout.ToString()}, &node_def);
151   AddNodeAttr(kMeshAttr, layout.mesh().ToString(), &node_def);
152   return node_def;
153 }
154 
ShouldFoldInputArgument(absl::string_view operation_name,int input_index)155 bool ShouldFoldInputArgument(absl::string_view operation_name,
156                              int input_index) {
157   // Fold if we are in a function or if a special eager op.
158   // TODO(xiejw,power): Think about how to generalize this so it does not depend
159   // on operation_name. For example, we can check the max abs value of the
160   // tensor value.
161   if (operation_name == absl::string_view("StatelessRandomUniform") ||
162       operation_name == absl::string_view("StatelessRandomUniformFullInt") ||
163       operation_name == absl::string_view("StatelessRandomNormal") ||
164       operation_name == absl::string_view("StatelessTruncatedNormal")) {
165     // For all stateless rng ops, we avoid fold seed (input_index==1) in graph.
166     // This is an important optimization to avoid unnecessary MLIR SPMD lowering
167     // and TPU compilation during model parameters initialization process.
168     // which typically have the same shape for rng ops but different seeds.
169     return input_index != 1;
170   }
171 
172   return true;
173 }
174 
NodeDefsHaveDifferentTensorProto(const NodeDef & a,const NodeDef & b)175 bool NodeDefsHaveDifferentTensorProto(const NodeDef& a, const NodeDef& b) {
176   const TensorProto* tensor_proto_a;
177   bool read_a_tensor_proto = TryGetNodeAttr(a, "value", &tensor_proto_a);
178   if (!read_a_tensor_proto) return true;
179 
180   const TensorProto* tensor_proto_b;
181   bool read_b_tensor_proto = TryGetNodeAttr(b, "value", &tensor_proto_b);
182   if (!read_b_tensor_proto) return true;
183   return !protobuf::util::MessageDifferencer::Equals(*tensor_proto_a,
184                                                      *tensor_proto_b);
185 }
186 
187 }  // namespace dtensor
188 }  // namespace tensorflow
189