xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/tensor_bundle/byte_swap.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7 http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
17 
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/lib/core/status.h"
24 
25 namespace tensorflow {
26 
27 namespace {
28 
29 // Byte-swap a buffer in place.
30 //
31 // Args:
32 //  buff: pointer to the buffer to be modified IN PLACE.
33 //  size: size of bytes in this buffer.
34 //  dtype: type of data in this buffer.
35 //  num_of_elem: number of data in this buffer, set to -1 if it
36 //               could not be obtained directly from tensor data.
37 //               If num_of_elem is -1, this function will calculate
38 //               the number of data based on size and dtype.
39 // Returns: Status::OK() on success, -1 otherwise
ByteSwapBuffer(char * buff,size_t size,DataType dtype,int num_of_elem)40 Status ByteSwapBuffer(char* buff, size_t size, DataType dtype,
41                       int num_of_elem) {
42   int array_len = num_of_elem;
43   size_t bytes_per_elem = 0;
44 
45   switch (dtype) {
46     // Types that don't need byte-swapping
47     case DT_STRING:
48     case DT_QINT8:
49     case DT_QUINT8:
50     case DT_BOOL:
51     case DT_UINT8:
52     case DT_INT8:
53       return OkStatus();
54 
55     // 16-bit types
56     case DT_BFLOAT16:
57     case DT_HALF:
58     case DT_QINT16:
59     case DT_QUINT16:
60     case DT_UINT16:
61     case DT_INT16:
62       bytes_per_elem = 2;
63       array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
64       break;
65 
66     // 32-bit types
67     case DT_FLOAT:
68     case DT_INT32:
69     case DT_QINT32:
70     case DT_UINT32:
71       bytes_per_elem = 4;
72       array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
73       break;
74 
75     // 64-bit types
76     case DT_INT64:
77     case DT_DOUBLE:
78     case DT_UINT64:
79       bytes_per_elem = 8;
80       array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
81       break;
82 
83     // Complex types need special handling
84     case DT_COMPLEX64:
85       bytes_per_elem = 4;
86       array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
87       array_len *= 2;
88       break;
89 
90     case DT_COMPLEX128:
91       bytes_per_elem = 8;
92       array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
93       array_len *= 2;
94       break;
95 
96     // Types that ought to be supported in the future
97     case DT_RESOURCE:
98     case DT_VARIANT:
99       return errors::Unimplemented(
100           "Byte-swapping not yet implemented for tensors with dtype ", dtype);
101 
102     // Byte-swapping shouldn't make sense for other dtypes.
103     default:
104       return errors::Unimplemented(
105           "Byte-swapping not supported for tensors with dtype ", dtype);
106   }
107 
108   TF_RETURN_IF_ERROR(ByteSwapArray(buff, bytes_per_elem, array_len));
109   return OkStatus();
110 }
111 
112 }  // namespace
113 
ByteSwapArray(char * array,size_t bytes_per_elem,int array_len)114 Status ByteSwapArray(char* array, size_t bytes_per_elem, int array_len) {
115   if (bytes_per_elem == 1) {
116     // No-op
117     return OkStatus();
118   } else if (bytes_per_elem == 2) {
119     auto array_16 = reinterpret_cast<uint16_t*>(array);
120     for (int i = 0; i < array_len; i++) {
121       array_16[i] = BYTE_SWAP_16(array_16[i]);
122     }
123     return OkStatus();
124   } else if (bytes_per_elem == 4) {
125     auto array_32 = reinterpret_cast<uint32_t*>(array);
126     for (int i = 0; i < array_len; i++) {
127       array_32[i] = BYTE_SWAP_32(array_32[i]);
128     }
129     return OkStatus();
130   } else if (bytes_per_elem == 8) {
131     auto array_64 = reinterpret_cast<uint64_t*>(array);
132     for (int i = 0; i < array_len; i++) {
133       array_64[i] = BYTE_SWAP_64(array_64[i]);
134     }
135     return OkStatus();
136   } else {
137     return errors::Unimplemented("Byte-swapping of ", bytes_per_elem,
138                                  "-byte values not supported.");
139   }
140 }
141 
ByteSwapTensor(Tensor * t)142 Status ByteSwapTensor(Tensor* t) {
143   char* buff = const_cast<char*>((t->tensor_data().data()));
144   return ByteSwapBuffer(buff, t->tensor_data().size(), t->dtype(),
145                         t->NumElements());
146 }
147 
ByteSwapTensorContent(MetaGraphDef * meta_graph_def)148 Status ByteSwapTensorContent(MetaGraphDef* meta_graph_def) {
149   for (auto& function : *meta_graph_def->mutable_graph_def()
150                              ->mutable_library()
151                              ->mutable_function()) {
152     for (auto& node : (*function.mutable_node_def())) {
153       if (node.op() == "Const") {
154         auto node_iterator = node.mutable_attr()->find("value");
155         if (node_iterator != node.mutable_attr()->end()) {
156           AttrValue node_value = node_iterator->second;
157           if (node_value.has_tensor()) {
158             auto tsize = node_value.mutable_tensor()->tensor_content().size();
159             auto p_type = node_value.mutable_tensor()->dtype();
160             // Swap only when there is something in tensor_content field
161             if (tsize != 0 && DataTypeCanUseMemcpy(p_type)) {
162               Tensor parsed(p_type);
163               DCHECK(parsed.FromProto(*node_value.mutable_tensor()));
164               if (!parsed.tensor_data().empty()) {
165                 TF_RETURN_IF_ERROR(ByteSwapTensor(&parsed));
166                 (*node.mutable_attr())["value"]
167                     .mutable_tensor()
168                     ->set_tensor_content(
169                         string(reinterpret_cast<const char*>(
170                                    parsed.tensor_data().data()),
171                                parsed.tensor_data().size()));
172               } else {
173                 void* copy = tensorflow::port::Malloc(tsize);
174                 memcpy(copy,
175                        string(node_value.mutable_tensor()->tensor_content())
176                            .data(),
177                        tsize);
178                 TF_RETURN_IF_ERROR(
179                     ByteSwapBuffer((char*)copy, tsize, p_type, -1));
180                 (*node.mutable_attr())["value"]
181                     .mutable_tensor()
182                     ->set_tensor_content(
183                         string(reinterpret_cast<const char*>(copy), tsize));
184                 tensorflow::port::Free(copy);
185               }
186             }
187           }
188         }
189       }
190     }
191   }
192   return OkStatus();
193 }
194 
195 }  // namespace tensorflow
196