1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10 #include <executorch/extension/llm/custom_ops/op_tile_crop.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12
13 namespace torch {
14 namespace executor {
15 namespace native {
16 namespace {
17
check_tile_crop_out_args(const Tensor & in,int64_t tile_size,Tensor & out)18 bool check_tile_crop_out_args(
19 const Tensor& in,
20 int64_t tile_size,
21 Tensor& out) {
22 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
23 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 3));
24 ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 4));
25 ET_LOG_AND_RETURN_IF_FALSE(tile_size > 0);
26 ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 1) % tile_size == 0);
27 ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 2) % tile_size == 0);
28 return true;
29 }
30
get_tile_crop_out_target_size(const Tensor & in,int64_t tile_size,exec_aten::SizesType * out_sizes,size_t * out_ndim)31 void get_tile_crop_out_target_size(
32 const Tensor& in,
33 int64_t tile_size,
34 exec_aten::SizesType* out_sizes,
35 size_t* out_ndim) {
36 *out_ndim = in.dim() + 1;
37
38 out_sizes[0] = in.size(1) * in.size(2) / (tile_size * tile_size);
39 out_sizes[1] = in.size(0);
40 out_sizes[2] = tile_size;
41 out_sizes[3] = tile_size;
42 }
43
44 template <typename CTYPE>
tile_crop_impl(const Tensor & in,int64_t tile_size,Tensor & out)45 void tile_crop_impl(const Tensor& in, int64_t tile_size, Tensor& out) {
46 const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47 CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
48
49 const auto channels = in.size(0);
50 const auto height = in.size(1);
51 const auto width = in.size(2);
52
53 const auto HdivS = height / tile_size;
54 const auto WdivS = width / tile_size;
55
56 size_t out_ix = 0;
57 for (size_t bH = 0; bH < HdivS; bH++) {
58 for (size_t bW = 0; bW < WdivS; bW++) {
59 for (size_t c = 0; c < channels; c++) {
60 for (size_t h = 0; h < tile_size; h++) {
61 for (size_t w = 0; w < tile_size; w++) {
62 size_t in_h = bH * tile_size + h;
63 size_t in_w = bW * tile_size + w;
64 size_t in_ix = c * height * width + in_h * width + in_w;
65
66 out_data[out_ix++] = in_data[in_ix];
67 }
68 }
69 }
70 }
71 }
72 }
73
74 } // namespace
75
tile_crop_out_impl(KernelRuntimeContext & ctx,const Tensor & input,const int64_t tile_size,Tensor & out)76 Tensor& tile_crop_out_impl(
77 KernelRuntimeContext& ctx,
78 const Tensor& input, // NOLINT
79 const int64_t tile_size, // NOLINT
80 Tensor& out) {
81 ET_KERNEL_CHECK(
82 ctx,
83 check_tile_crop_out_args(input, tile_size, out),
84 InvalidArgument,
85 out);
86
87 // @lint-ignore CLANGTIDY facebook-hte-CArray
88 Tensor::SizesType expected_out_size[kTensorDimensionLimit];
89 size_t expected_out_dim = 0;
90 get_tile_crop_out_target_size(
91 input, tile_size, expected_out_size, &expected_out_dim);
92
93 ET_KERNEL_CHECK(
94 ctx,
95 resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
96 InvalidArgument,
97 out);
98
99 constexpr auto name = "tile_crop.out";
100
101 ET_SWITCH_ALL_TYPES(out.scalar_type(), ctx, name, CTYPE, [&]() {
102 tile_crop_impl<CTYPE>(input, tile_size, out);
103 });
104
105 return out;
106 }
107
108 } // namespace native
109 } // namespace executor
110 } // namespace torch
111
112 EXECUTORCH_LIBRARY(
113 preprocess,
114 "tile_crop.out",
115 torch::executor::native::tile_crop_out_impl);
116