1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_bundle/byte_swap.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/function.pb.h"
20 #include "tensorflow/core/framework/graph.pb.h"
21 #include "tensorflow/core/framework/node_def.pb.h"
22 #include "tensorflow/core/framework/tensor.pb.h"
23 #include "tensorflow/core/lib/core/status.h"
24
25 namespace tensorflow {
26
27 namespace {
28
29 // Byte-swap a buffer in place.
30 //
31 // Args:
32 // buff: pointer to the buffer to be modified IN PLACE.
33 // size: size of bytes in this buffer.
34 // dtype: type of data in this buffer.
35 // num_of_elem: number of data in this buffer, set to -1 if it
36 // could not be obtained directly from tensor data.
37 // If num_of_elem is -1, this function will calculate
38 // the number of data based on size and dtype.
39 // Returns: Status::OK() on success, -1 otherwise
ByteSwapBuffer(char * buff,size_t size,DataType dtype,int num_of_elem)40 Status ByteSwapBuffer(char* buff, size_t size, DataType dtype,
41 int num_of_elem) {
42 int array_len = num_of_elem;
43 size_t bytes_per_elem = 0;
44
45 switch (dtype) {
46 // Types that don't need byte-swapping
47 case DT_STRING:
48 case DT_QINT8:
49 case DT_QUINT8:
50 case DT_BOOL:
51 case DT_UINT8:
52 case DT_INT8:
53 return OkStatus();
54
55 // 16-bit types
56 case DT_BFLOAT16:
57 case DT_HALF:
58 case DT_QINT16:
59 case DT_QUINT16:
60 case DT_UINT16:
61 case DT_INT16:
62 bytes_per_elem = 2;
63 array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
64 break;
65
66 // 32-bit types
67 case DT_FLOAT:
68 case DT_INT32:
69 case DT_QINT32:
70 case DT_UINT32:
71 bytes_per_elem = 4;
72 array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
73 break;
74
75 // 64-bit types
76 case DT_INT64:
77 case DT_DOUBLE:
78 case DT_UINT64:
79 bytes_per_elem = 8;
80 array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
81 break;
82
83 // Complex types need special handling
84 case DT_COMPLEX64:
85 bytes_per_elem = 4;
86 array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
87 array_len *= 2;
88 break;
89
90 case DT_COMPLEX128:
91 bytes_per_elem = 8;
92 array_len = (array_len == -1) ? size / bytes_per_elem : array_len;
93 array_len *= 2;
94 break;
95
96 // Types that ought to be supported in the future
97 case DT_RESOURCE:
98 case DT_VARIANT:
99 return errors::Unimplemented(
100 "Byte-swapping not yet implemented for tensors with dtype ", dtype);
101
102 // Byte-swapping shouldn't make sense for other dtypes.
103 default:
104 return errors::Unimplemented(
105 "Byte-swapping not supported for tensors with dtype ", dtype);
106 }
107
108 TF_RETURN_IF_ERROR(ByteSwapArray(buff, bytes_per_elem, array_len));
109 return OkStatus();
110 }
111
112 } // namespace
113
ByteSwapArray(char * array,size_t bytes_per_elem,int array_len)114 Status ByteSwapArray(char* array, size_t bytes_per_elem, int array_len) {
115 if (bytes_per_elem == 1) {
116 // No-op
117 return OkStatus();
118 } else if (bytes_per_elem == 2) {
119 auto array_16 = reinterpret_cast<uint16_t*>(array);
120 for (int i = 0; i < array_len; i++) {
121 array_16[i] = BYTE_SWAP_16(array_16[i]);
122 }
123 return OkStatus();
124 } else if (bytes_per_elem == 4) {
125 auto array_32 = reinterpret_cast<uint32_t*>(array);
126 for (int i = 0; i < array_len; i++) {
127 array_32[i] = BYTE_SWAP_32(array_32[i]);
128 }
129 return OkStatus();
130 } else if (bytes_per_elem == 8) {
131 auto array_64 = reinterpret_cast<uint64_t*>(array);
132 for (int i = 0; i < array_len; i++) {
133 array_64[i] = BYTE_SWAP_64(array_64[i]);
134 }
135 return OkStatus();
136 } else {
137 return errors::Unimplemented("Byte-swapping of ", bytes_per_elem,
138 "-byte values not supported.");
139 }
140 }
141
ByteSwapTensor(Tensor * t)142 Status ByteSwapTensor(Tensor* t) {
143 char* buff = const_cast<char*>((t->tensor_data().data()));
144 return ByteSwapBuffer(buff, t->tensor_data().size(), t->dtype(),
145 t->NumElements());
146 }
147
ByteSwapTensorContent(MetaGraphDef * meta_graph_def)148 Status ByteSwapTensorContent(MetaGraphDef* meta_graph_def) {
149 for (auto& function : *meta_graph_def->mutable_graph_def()
150 ->mutable_library()
151 ->mutable_function()) {
152 for (auto& node : (*function.mutable_node_def())) {
153 if (node.op() == "Const") {
154 auto node_iterator = node.mutable_attr()->find("value");
155 if (node_iterator != node.mutable_attr()->end()) {
156 AttrValue node_value = node_iterator->second;
157 if (node_value.has_tensor()) {
158 auto tsize = node_value.mutable_tensor()->tensor_content().size();
159 auto p_type = node_value.mutable_tensor()->dtype();
160 // Swap only when there is something in tensor_content field
161 if (tsize != 0 && DataTypeCanUseMemcpy(p_type)) {
162 Tensor parsed(p_type);
163 DCHECK(parsed.FromProto(*node_value.mutable_tensor()));
164 if (!parsed.tensor_data().empty()) {
165 TF_RETURN_IF_ERROR(ByteSwapTensor(&parsed));
166 (*node.mutable_attr())["value"]
167 .mutable_tensor()
168 ->set_tensor_content(
169 string(reinterpret_cast<const char*>(
170 parsed.tensor_data().data()),
171 parsed.tensor_data().size()));
172 } else {
173 void* copy = tensorflow::port::Malloc(tsize);
174 memcpy(copy,
175 string(node_value.mutable_tensor()->tensor_content())
176 .data(),
177 tsize);
178 TF_RETURN_IF_ERROR(
179 ByteSwapBuffer((char*)copy, tsize, p_type, -1));
180 (*node.mutable_attr())["value"]
181 .mutable_tensor()
182 ->set_tensor_content(
183 string(reinterpret_cast<const char*>(copy), tsize));
184 tensorflow::port::Free(copy);
185 }
186 }
187 }
188 }
189 }
190 }
191 }
192 return OkStatus();
193 }
194
195 } // namespace tensorflow
196