xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/op_tile_crop.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
10 #include <executorch/extension/llm/custom_ops/op_tile_crop.h>
11 #include <executorch/runtime/kernel/kernel_includes.h>
12 
13 namespace torch {
14 namespace executor {
15 namespace native {
16 namespace {
17 
check_tile_crop_out_args(const Tensor & in,int64_t tile_size,Tensor & out)18 bool check_tile_crop_out_args(
19     const Tensor& in,
20     int64_t tile_size,
21     Tensor& out) {
22   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
23   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 3));
24   ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 4));
25   ET_LOG_AND_RETURN_IF_FALSE(tile_size > 0);
26   ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 1) % tile_size == 0);
27   ET_LOG_AND_RETURN_IF_FALSE(in.size(in.dim() - 2) % tile_size == 0);
28   return true;
29 }
30 
get_tile_crop_out_target_size(const Tensor & in,int64_t tile_size,exec_aten::SizesType * out_sizes,size_t * out_ndim)31 void get_tile_crop_out_target_size(
32     const Tensor& in,
33     int64_t tile_size,
34     exec_aten::SizesType* out_sizes,
35     size_t* out_ndim) {
36   *out_ndim = in.dim() + 1;
37 
38   out_sizes[0] = in.size(1) * in.size(2) / (tile_size * tile_size);
39   out_sizes[1] = in.size(0);
40   out_sizes[2] = tile_size;
41   out_sizes[3] = tile_size;
42 }
43 
44 template <typename CTYPE>
tile_crop_impl(const Tensor & in,int64_t tile_size,Tensor & out)45 void tile_crop_impl(const Tensor& in, int64_t tile_size, Tensor& out) {
46   const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
47   CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
48 
49   const auto channels = in.size(0);
50   const auto height = in.size(1);
51   const auto width = in.size(2);
52 
53   const auto HdivS = height / tile_size;
54   const auto WdivS = width / tile_size;
55 
56   size_t out_ix = 0;
57   for (size_t bH = 0; bH < HdivS; bH++) {
58     for (size_t bW = 0; bW < WdivS; bW++) {
59       for (size_t c = 0; c < channels; c++) {
60         for (size_t h = 0; h < tile_size; h++) {
61           for (size_t w = 0; w < tile_size; w++) {
62             size_t in_h = bH * tile_size + h;
63             size_t in_w = bW * tile_size + w;
64             size_t in_ix = c * height * width + in_h * width + in_w;
65 
66             out_data[out_ix++] = in_data[in_ix];
67           }
68         }
69       }
70     }
71   }
72 }
73 
74 } // namespace
75 
tile_crop_out_impl(KernelRuntimeContext & ctx,const Tensor & input,const int64_t tile_size,Tensor & out)76 Tensor& tile_crop_out_impl(
77     KernelRuntimeContext& ctx,
78     const Tensor& input, // NOLINT
79     const int64_t tile_size, // NOLINT
80     Tensor& out) {
81   ET_KERNEL_CHECK(
82       ctx,
83       check_tile_crop_out_args(input, tile_size, out),
84       InvalidArgument,
85       out);
86 
87   // @lint-ignore CLANGTIDY facebook-hte-CArray
88   Tensor::SizesType expected_out_size[kTensorDimensionLimit];
89   size_t expected_out_dim = 0;
90   get_tile_crop_out_target_size(
91       input, tile_size, expected_out_size, &expected_out_dim);
92 
93   ET_KERNEL_CHECK(
94       ctx,
95       resize_tensor(out, {expected_out_size, expected_out_dim}) == Error::Ok,
96       InvalidArgument,
97       out);
98 
99   constexpr auto name = "tile_crop.out";
100 
101   ET_SWITCH_ALL_TYPES(out.scalar_type(), ctx, name, CTYPE, [&]() {
102     tile_crop_impl<CTYPE>(input, tile_size, out);
103   });
104 
105   return out;
106 }
107 
108 } // namespace native
109 } // namespace executor
110 } // namespace torch
111 
112 EXECUTORCH_LIBRARY(
113     preprocess,
114     "tile_crop.out",
115     torch::executor::native::tile_crop_out_impl);
116