1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9 10import torch 11 12from torch.library import impl, Library 13 14preprocess_op_lib = Library("preprocess", "DEF") 15 16# Register and define tile_crop and out variant. 17preprocess_op_lib.define("tile_crop(Tensor input, int tile_size) -> Tensor") 18 19# Keep this in sync with model config. 20MAX_NUM_TILES = 4 21 22 23@impl(preprocess_op_lib, "tile_crop", dispatch_key="CompositeExplicitAutograd") 24def tile_crop_impl(input: torch.Tensor, tile_size: int) -> torch.Tensor: 25 c = input.shape[0] 26 h = input.shape[1] 27 w = input.shape[2] 28 tiles_height = h // tile_size 29 tiles_width = w // tile_size 30 tile_cropped = input.view(c, tiles_height, tile_size, tiles_width, tile_size) 31 transposed = tile_cropped.permute(1, 3, 0, 2, 4) 32 tiles = transposed.contiguous().view( 33 tiles_height * tiles_width, c, tile_size, tile_size 34 ) 35 return tiles 36 37 38preprocess_op_lib.define( 39 "tile_crop.out(Tensor input, int tile_size, *, Tensor(a!) out) -> Tensor(a!)" 40) 41 42 43@impl(preprocess_op_lib, "tile_crop.out", dispatch_key="CompositeExplicitAutograd") 44def tile_crop_out_impl( 45 input: torch.Tensor, tile_size: int, out: torch.Tensor 46) -> torch.Tensor: 47 out = input.clone() 48 c = out.shape[0] 49 h = out.shape[1] 50 w = out.shape[2] 51 tiles_height = h // tile_size 52 tiles_width = w // tile_size 53 out = out.view(c, tiles_height, tile_size, tiles_width, tile_size) 54 out = out.permute(1, 3, 0, 2, 4) 55 out = out.contiguous().view(tiles_height * tiles_width, c, tile_size, tile_size) 56 return out 57 58 59# Register meta kernel to prevent export tracing into the tile_crop impl. 60@torch.library.register_fake("preprocess::tile_crop") 61def tile_crop(output: torch.Tensor, tile_size: int) -> torch.Tensor: 62 # Returned tensor is of size [n, 3, 224, 224], where n = number of tiles. 63 # Use an unbacked symint to create an upper-bounded dynamic shape output. 64 # Otherwise, output is set to a static shape, and we can only output 65 # tensors of shape [MAX_NUM_TILES, 3, 224, 224]. 66 ctx = torch._custom_ops.get_ctx() 67 s0 = ctx.create_unbacked_symint() 68 torch._constrain_as_size(s0, 0, MAX_NUM_TILES) 69 return torch.empty([s0, output.size(0), tile_size, tile_size]) 70