xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/preprocess_custom_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9
10import torch
11
12from torch.library import impl, Library
13
14preprocess_op_lib = Library("preprocess", "DEF")
15
16# Register and define tile_crop and out variant.
17preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor")
18
19# Keep this in sync with model config.
20MAX_NUM_TILES = 4
21
22
23@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd")
24def tile_crop_impl(input: torch.Tensor, tile_size: int) -> torch.Tensor:
25    c = input.shape[0]
26    h = input.shape[1]
27    w = input.shape[2]
28    tiles_height = h // tile_size
29    tiles_width = w // tile_size
30    tile_cropped = input.view(c, tiles_height, tile_size, tiles_width, tile_size)
31    transposed = tile_cropped.permute(1, 3, 0, 2, 4)
32    tiles = transposed.contiguous().view(
33        tiles_height * tiles_width, c, tile_size, tile_size
34    )
35    return tiles
36
37
38preprocess_op_lib.define(
39    "tile_crop.out(Tensor input, int tile_size, *, Tensor(a!) out) -> Tensor(a!)"
40)
41
42
43@impl(preprocess_op_lib, "tile_crop.out", dispatch_key="CompositeExplicitAutograd")
44def tile_crop_out_impl(
45    input: torch.Tensor, tile_size: int, out: torch.Tensor
46) -> torch.Tensor:
47    out = input.clone()
48    c = out.shape[0]
49    h = out.shape[1]
50    w = out.shape[2]
51    tiles_height = h // tile_size
52    tiles_width = w // tile_size
53    out = out.view(c, tiles_height, tile_size, tiles_width, tile_size)
54    out = out.permute(1, 3, 0, 2, 4)
55    out = out.contiguous().view(tiles_height * tiles_width, c, tile_size, tile_size)
56    return out
57
58
59# Register meta kernel to prevent export tracing into the tile_crop impl.
60@torch.library.register_fake("preprocess::tile_crop")
61def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor:
62    # Returned tensor is of size [n, 3, 224, 224], where n = number of tiles.
63    # Use an unbacked symint to create an upper-bounded dynamic shape output.
64    # Otherwise, output is set to a static shape, and we can only output
65    # tensors of shape [MAX_NUM_TILES, 3, 224, 224].
66    ctx = torch._custom_ops.get_ctx()
67    s0 = ctx.create_unbacked_symint()
68    torch._constrain_as_size(s0, 0, MAX_NUM_TILES)
69    return torch.empty([s0, output.size(0), tile_size, tile_size])
70