xref: /aosp_15_r20/external/pytorch/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <cassert>
2 #include <cmath>
3 #include <vector>
4 
5 #include "jni.h"
6 
7 namespace pytorch_vision_jni {
8 
imageYUV420CenterCropToFloatBuffer(JNIEnv * jniEnv,jclass,jobject yBuffer,jint yRowStride,jint yPixelStride,jobject uBuffer,jobject vBuffer,jint uRowStride,jint uvPixelStride,jint imageWidth,jint imageHeight,jint rotateCWDegrees,jint tensorWidth,jint tensorHeight,jfloatArray jnormMeanRGB,jfloatArray jnormStdRGB,jobject outBuffer,jint outOffset,jint memoryFormatCode)9 static void imageYUV420CenterCropToFloatBuffer(
10     JNIEnv* jniEnv,
11     jclass,
12     jobject yBuffer,
13     jint yRowStride,
14     jint yPixelStride,
15     jobject uBuffer,
16     jobject vBuffer,
17     jint uRowStride,
18     jint uvPixelStride,
19     jint imageWidth,
20     jint imageHeight,
21     jint rotateCWDegrees,
22     jint tensorWidth,
23     jint tensorHeight,
24     jfloatArray jnormMeanRGB,
25     jfloatArray jnormStdRGB,
26     jobject outBuffer,
27     jint outOffset,
28     jint memoryFormatCode) {
29   constexpr static int32_t kMemoryFormatContiguous = 1;
30   constexpr static int32_t kMemoryFormatChannelsLast = 2;
31 
32   float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer);
33 
34   jfloat normMeanRGB[3];
35   jfloat normStdRGB[3];
36   jniEnv->GetFloatArrayRegion(jnormMeanRGB, 0, 3, normMeanRGB);
37   jniEnv->GetFloatArrayRegion(jnormStdRGB, 0, 3, normStdRGB);
38   int widthAfterRtn = imageWidth;
39   int heightAfterRtn = imageHeight;
40   bool oddRotation = rotateCWDegrees == 90 || rotateCWDegrees == 270;
41   if (oddRotation) {
42     widthAfterRtn = imageHeight;
43     heightAfterRtn = imageWidth;
44   }
45 
46   int cropWidthAfterRtn = widthAfterRtn;
47   int cropHeightAfterRtn = heightAfterRtn;
48 
49   if (tensorWidth * heightAfterRtn <= tensorHeight * widthAfterRtn) {
50     cropWidthAfterRtn = tensorWidth * heightAfterRtn / tensorHeight;
51   } else {
52     cropHeightAfterRtn = tensorHeight * widthAfterRtn / tensorWidth;
53   }
54 
55   int cropWidthBeforeRtn = cropWidthAfterRtn;
56   int cropHeightBeforeRtn = cropHeightAfterRtn;
57   if (oddRotation) {
58     cropWidthBeforeRtn = cropHeightAfterRtn;
59     cropHeightBeforeRtn = cropWidthAfterRtn;
60   }
61 
62   const int offsetX = (imageWidth - cropWidthBeforeRtn) / 2.f;
63   const int offsetY = (imageHeight - cropHeightBeforeRtn) / 2.f;
64 
65   const uint8_t* yData = (uint8_t*)jniEnv->GetDirectBufferAddress(yBuffer);
66   const uint8_t* uData = (uint8_t*)jniEnv->GetDirectBufferAddress(uBuffer);
67   const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer);
68 
69   float scale = cropWidthAfterRtn / tensorWidth;
70   int uvRowStride = uRowStride;
71   int cropXMult = 1;
72   int cropYMult = 1;
73   int cropXAdd = offsetX;
74   int cropYAdd = offsetY;
75   if (rotateCWDegrees == 90) {
76     cropYMult = -1;
77     cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
78   } else if (rotateCWDegrees == 180) {
79     cropXMult = -1;
80     cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
81     cropYMult = -1;
82     cropYAdd = offsetY + (cropHeightBeforeRtn - 1);
83   } else if (rotateCWDegrees == 270) {
84     cropXMult = -1;
85     cropXAdd = offsetX + (cropWidthBeforeRtn - 1);
86   }
87 
88   float normMeanRm255 = 255 * normMeanRGB[0];
89   float normMeanGm255 = 255 * normMeanRGB[1];
90   float normMeanBm255 = 255 * normMeanRGB[2];
91   float normStdRm255 = 255 * normStdRGB[0];
92   float normStdGm255 = 255 * normStdRGB[1];
93   float normStdBm255 = 255 * normStdRGB[2];
94 
95   int xBeforeRtn, yBeforeRtn;
96   int yi, yIdx, uvIdx, ui, vi, a0, ri, gi, bi;
97   int channelSize = tensorWidth * tensorHeight;
98   // A bit of code duplication to avoid branching in the cycles
99   if (memoryFormatCode == kMemoryFormatContiguous) {
100     int wr = outOffset;
101     int wg = wr + channelSize;
102     int wb = wg + channelSize;
103     for (int y = 0; y < tensorHeight; y++) {
104       for (int x = 0; x < tensorWidth; x++) {
105         xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
106         yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
107         yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
108         uvIdx =
109             (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
110         ui = uData[uvIdx];
111         vi = vData[uvIdx];
112         yi = yData[yIdx];
113         yi = (yi - 16) < 0 ? 0 : (yi - 16);
114         ui -= 128;
115         vi -= 128;
116         a0 = 1192 * yi;
117         ri = (a0 + 1634 * vi) >> 10;
118         gi = (a0 - 833 * vi - 400 * ui) >> 10;
119         bi = (a0 + 2066 * ui) >> 10;
120         ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
121         gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
122         bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
123         outData[wr++] = (ri - normMeanRm255) / normStdRm255;
124         outData[wg++] = (gi - normMeanGm255) / normStdGm255;
125         outData[wb++] = (bi - normMeanBm255) / normStdBm255;
126       }
127     }
128   } else if (memoryFormatCode == kMemoryFormatChannelsLast) {
129     int wc = outOffset;
130     for (int y = 0; y < tensorHeight; y++) {
131       for (int x = 0; x < tensorWidth; x++) {
132         xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale);
133         yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale);
134         yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride;
135         uvIdx =
136             (yBeforeRtn >> 1) * uvRowStride + (xBeforeRtn >> 1) * uvPixelStride;
137         ui = uData[uvIdx];
138         vi = vData[uvIdx];
139         yi = yData[yIdx];
140         yi = (yi - 16) < 0 ? 0 : (yi - 16);
141         ui -= 128;
142         vi -= 128;
143         a0 = 1192 * yi;
144         ri = (a0 + 1634 * vi) >> 10;
145         gi = (a0 - 833 * vi - 400 * ui) >> 10;
146         bi = (a0 + 2066 * ui) >> 10;
147         ri = ri > 255 ? 255 : ri < 0 ? 0 : ri;
148         gi = gi > 255 ? 255 : gi < 0 ? 0 : gi;
149         bi = bi > 255 ? 255 : bi < 0 ? 0 : bi;
150         outData[wc++] = (ri - normMeanRm255) / normStdRm255;
151         outData[wc++] = (gi - normMeanGm255) / normStdGm255;
152         outData[wc++] = (bi - normMeanBm255) / normStdBm255;
153       }
154     }
155   } else {
156     jclass Exception = jniEnv->FindClass("java/lang/IllegalArgumentException");
157     jniEnv->ThrowNew(Exception, "Illegal memory format code");
158   }
159 }
160 } // namespace pytorch_vision_jni
161 
JNI_OnLoad(JavaVM * vm,void *)162 JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) {
163   JNIEnv* env;
164   if (vm->GetEnv(reinterpret_cast<void**>(&env), JNI_VERSION_1_6) != JNI_OK) {
165     return JNI_ERR;
166   }
167 
168   jclass c =
169       env->FindClass("org/pytorch/torchvision/TensorImageUtils$NativePeer");
170   if (c == nullptr) {
171     return JNI_ERR;
172   }
173 
174   static const JNINativeMethod methods[] = {
175       {"imageYUV420CenterCropToFloatBuffer",
176        "(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;II)V",
177        (void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer},
178   };
179   int rc = env->RegisterNatives(
180       c, methods, sizeof(methods) / sizeof(JNINativeMethod));
181 
182   if (rc != JNI_OK) {
183     return rc;
184   }
185 
186   return JNI_VERSION_1_6;
187 }
188