1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/reshape_op_builder.h"
16
17 #include <algorithm>
18 #include <memory>
19 #include <string>
20
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/c/common.h"
23 #include "tensorflow/lite/delegates/coreml/builders/op_builder.h"
24 #include "tensorflow/lite/delegates/coreml/builders/op_factory.h"
25 #include "tensorflow/lite/delegates/coreml/builders/op_validator.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/kernel_util.h"
28
29 namespace tflite {
30 namespace delegates {
31 namespace coreml {
32
DebugName()33 const std::string& ReshapeOpBuilder::DebugName() {
34 if (debug_name_.empty()) {
35 SetDebugName("ReshapeOpBuilder", node_id_);
36 }
37 return debug_name_;
38 }
39
Build()40 CoreML::Specification::NeuralNetworkLayer* ReshapeOpBuilder::Build() {
41 if (layer_ == nullptr) {
42 layer_ = std::make_unique<CoreML::Specification::NeuralNetworkLayer>();
43 }
44 layer_->set_name(DebugName());
45 for (int dim : shape_) {
46 layer_->mutable_reshape()->add_targetshape(dim);
47 }
48 if (need_transpose_)
49 layer_->mutable_reshape()->set_mode(
50 CoreML::Specification::ReshapeLayerParams::CHANNEL_LAST);
51 return layer_.release();
52 }
53
SetShapeFromTensor(const TfLiteTensor * output_shape,const TfLiteIntArray * input_shape)54 void ReshapeOpBuilder::SetShapeFromTensor(const TfLiteTensor* output_shape,
55 const TfLiteIntArray* input_shape) {
56 TfLiteIntArray* shape = TfLiteIntArrayCreate(output_shape->dims->data[0]);
57 std::memcpy(shape->data, GetTensorData<int>(output_shape),
58 shape->size * sizeof(int));
59
60 SetShapeFromIntArray(shape, input_shape);
61 TfLiteIntArrayFree(shape);
62 }
63
SetShapeFromIntArray(const TfLiteIntArray * output_shape,const TfLiteIntArray * input_shape)64 void ReshapeOpBuilder::SetShapeFromIntArray(const TfLiteIntArray* output_shape,
65 const TfLiteIntArray* input_shape) {
66 // ignore first dimension (batch)
67 std::copy(output_shape->data + 1, output_shape->data + output_shape->size,
68 std::back_inserter(shape_));
69
70 int64_t reshape_size = 1;
71 int negative_index = -1;
72 for (int i = 0; i < shape_.size(); ++i) {
73 if (shape_[i] == -1) {
74 negative_index = i;
75 } else {
76 reshape_size *= shape_[i];
77 }
78 }
79 if (negative_index >= 0) {
80 int64_t input_size = NumElements(input_shape);
81 shape_[negative_index] = input_size / reshape_size;
82 }
83
84 if (shape_.size() == 2) {
85 shape_ = {shape_[1], 1, shape_[0]};
86 } else if (shape_.size() == 3) {
87 shape_ = {shape_[2], shape_[0], shape_[1]};
88 }
89 // When channel dimension is changed, reshape should be done with HWC layout.
90 if (shape_[0] != input_shape->data[input_shape->size - 1]) {
91 need_transpose_ = true;
92 }
93 }
94
RegisterInputs(const TfLiteIntArray * inputs,TfLiteContext * context)95 TfLiteStatus ReshapeOpBuilder::RegisterInputs(const TfLiteIntArray* inputs,
96 TfLiteContext* context) {
97 AddInput(inputs->data[0]);
98
99 if (inputs->size == 2) {
100 SetShapeFromTensor(&context->tensors[inputs->data[1]],
101 context->tensors[inputs->data[0]].dims);
102 } else {
103 const auto* params = reinterpret_cast<TfLiteReshapeParams*>(builtin_data_);
104 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(params->num_dimensions);
105 std::memcpy(output_shape->data, params->shape,
106 params->num_dimensions * sizeof(int));
107
108 SetShapeFromIntArray(output_shape, context->tensors[inputs->data[0]].dims);
109 TfLiteIntArrayFree(output_shape);
110 }
111 return kTfLiteOk;
112 }
113
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)114 TfLiteStatus ReshapeOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
115 TfLiteContext* context) {
116 graph_builder_->AddTensorWithID(outputs->data[0], GetOutput(context));
117 return kTfLiteOk;
118 }
119
IsReshapeOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context,int coreml_version)120 bool IsReshapeOpSupported(const TfLiteRegistration* registration,
121 const TfLiteNode* node, TfLiteContext* context,
122 int coreml_version) {
123 if (coreml_version >= 3) {
124 return false;
125 }
126 if (node->inputs->size == 1) {
127 const auto* params =
128 reinterpret_cast<TfLiteReshapeParams*>(node->builtin_data);
129 return params->num_dimensions == 3 || params->num_dimensions == 4;
130 }
131
132 const int kShapeTensor = 1;
133 const TfLiteTensor* shape;
134 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kShapeTensor, &shape));
135 if (shape->allocation_type != kTfLiteMmapRo) {
136 TF_LITE_KERNEL_LOG(context, "Reshape has non-const shape.");
137 return false;
138 }
139 const bool is_shape_tensor =
140 shape->dims->size == 1 && shape->type == kTfLiteInt32;
141 return is_shape_tensor &&
142 (shape->dims->data[0] == 3 || shape->dims->data[0] == 4);
143 }
144
CreateReshapeOpBuilder(GraphBuilder * graph_builder)145 OpBuilder* CreateReshapeOpBuilder(GraphBuilder* graph_builder) {
146 return new ReshapeOpBuilder(graph_builder);
147 }
148
149 } // namespace coreml
150 } // namespace delegates
151 } // namespace tflite
152