1#!/usr/bin/env python3 2# Owner(s): ["oncall: mobile"] 3# mypy: allow-untyped-defs 4 5import io 6 7import cv2 8 9import torch 10import torch.utils.bundled_inputs 11from torch.testing._internal.common_utils import TestCase 12 13 14torch.ops.load_library("//caffe2/torch/fb/operators:decode_bundled_image") 15 16 17def model_size(sm): 18 buffer = io.BytesIO() 19 torch.jit.save(sm, buffer) 20 return len(buffer.getvalue()) 21 22 23def save_and_load(sm): 24 buffer = io.BytesIO() 25 torch.jit.save(sm, buffer) 26 buffer.seek(0) 27 return torch.jit.load(buffer) 28 29 30"""Return an InflatableArg that contains a tensor of the compressed image and the way to decode it 31 32 keyword arguments: 33 img_tensor -- the raw image tensor in HWC or NCHW with pixel value of type unsigned int 34 if in NCHW format, N should be 1 35 quality -- the quality needed to compress the image 36""" 37 38 39def bundle_jpeg_image(img_tensor, quality): 40 # turn NCHW to HWC 41 if img_tensor.dim() == 4: 42 assert img_tensor.size(0) == 1 43 img_tensor = img_tensor[0].permute(1, 2, 0) 44 pixels = img_tensor.numpy() 45 encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] 46 _, enc_img = cv2.imencode(".JPEG", pixels, encode_param) 47 enc_img_tensor = torch.from_numpy(enc_img) 48 enc_img_tensor = torch.flatten(enc_img_tensor).byte() 49 obj = torch.utils.bundled_inputs.InflatableArg( 50 enc_img_tensor, "torch.ops.fb.decode_bundled_image({})" 51 ) 52 return obj 53 54 55def get_tensor_from_raw_BGR(im) -> torch.Tensor: 56 raw_data = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 57 raw_data = torch.from_numpy(raw_data).float() 58 raw_data = raw_data.permute(2, 0, 1) 59 raw_data = torch.div(raw_data, 255).unsqueeze(0) 60 return raw_data 61 62 63class TestBundledImages(TestCase): 64 def test_single_tensors(self): 65 class SingleTensorModel(torch.nn.Module): 66 def forward(self, arg): 67 return arg 68 69 im = cv2.imread("caffe2/test/test_img/p1.jpg") 70 tensor = torch.from_numpy(im) 71 inflatable_arg = bundle_jpeg_image(tensor, 90) 72 input = [(inflatable_arg,)] 73 sm = torch.jit.script(SingleTensorModel()) 74 torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, input) 75 loaded = save_and_load(sm) 76 inflated = loaded.get_all_bundled_inputs() 77 decoded_data = inflated[0][0] 78 79 # raw image 80 raw_data = get_tensor_from_raw_BGR(im) 81 82 self.assertEqual(len(inflated), 1) 83 self.assertEqual(len(inflated[0]), 1) 84 self.assertEqual(raw_data.shape, decoded_data.shape) 85 self.assertEqual(raw_data, decoded_data, atol=0.1, rtol=1e-01) 86 87 # Check if fb::image_decode_to_NCHW works as expected 88 with open("caffe2/test/test_img/p1.jpg", "rb") as fp: 89 weight = torch.full((3,), 1.0 / 255.0).diag() 90 bias = torch.zeros(3) 91 byte_tensor = torch.tensor(list(fp.read())).byte() 92 im2_tensor = torch.ops.fb.image_decode_to_NCHW(byte_tensor, weight, bias) 93 self.assertEqual(raw_data.shape, im2_tensor.shape) 94 self.assertEqual(raw_data, im2_tensor, atol=0.1, rtol=1e-01) 95