xref: /aosp_15_r20/cts/apps/CameraITS/utils/sensor_fusion_utils.py (revision b7c941bb3fa97aba169d73cee0bed2de8ac964bf)
1# Copyright 2020 The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Utility functions for sensor_fusion hardware rig."""
15
16
17import bisect
18import codecs
19import logging
20import math
21import os
22import struct
23import time
24
25import cv2
26import matplotlib
27from matplotlib import pyplot as plt
28import numpy as np
29import scipy.spatial
30import serial
31from serial.tools import list_ports
32
33import camera_properties_utils
34import image_processing_utils
35
36matplotlib.use('agg')  # Must be executed before any figure is created.
37
38# Constants for Rotation Rig
39ARDUINO_ANGLE_MAX = 180.0  # degrees
40ARDUINO_ANGLES_SENSOR_FUSION = (0, 90)  # degrees
41ARDUINO_ANGLES_STABILIZATION = (10, 25)  # degrees
42ARDUINO_BAUDRATE = 9600
43ARDUINO_CMD_LENGTH = 3
44ARDUINO_CMD_TIME = 2.0 * ARDUINO_CMD_LENGTH / ARDUINO_BAUDRATE  # round trip
45ARDUINO_MOVE_TIME_SENSOR_FUSION = 2  # seconds
46ARDUINO_MOVE_TIME_STABILIZATION = 0.3  # seconds
47ARDUINO_PID = 0x0043
48ARDUINO_SERVO_SPEED_MAX = 255
49ARDUINO_SERVO_SPEED_MIN = 1
50ARDUINO_SERVO_SPEED_SENSOR_FUSION = 20
51ARDUINO_SERVO_SPEED_STABILIZATION = 10
52ARDUINO_SERVO_SPEED_STABILIZATION_TABLET = 20
53ARDUINO_SPEED_START_BYTE = 253
54ARDUINO_START_BYTE = 255
55ARDUINO_START_NUM_TRYS = 5
56ARDUINO_START_TIMEOUT = 300  # seconds
57ARDUINO_STRING = 'Arduino'
58ARDUINO_TEST_CMD = (b'\x01', b'\x02', b'\x03')
59ARDUINO_VALID_CH = ('1', '2', '3', '4', '5', '6')
60ARDUINO_VIDS = (0x2341, 0x2a03)
61
62CANAKIT_BAUDRATE = 115200
63CANAKIT_CMD_TIME = 0.05  # seconds (found experimentally)
64CANAKIT_DATA_DELIMITER = '\r\n'
65CANAKIT_PID = 0xfc73
66CANAKIT_SEND_TIMEOUT = 0.02  # seconds
67CANAKIT_SET_CMD = 'REL'
68CANAKIT_SLEEP_TIME = 2  # seconds (for full 90 degree rotation)
69CANAKIT_VALID_CMD = ('ON', 'OFF')
70CANAKIT_VALID_CH = ('1', '2', '3', '4')
71CANAKIT_VID = 0x04d8
72
73HS755HB_ANGLE_MAX = 202.0  # throw for rotation motor in degrees
74
75# From test_sensor_fusion
76_FEATURE_MARGIN = 0.20  # Only take feature points from center 20% so that
77                        # rotation measured has less rolling shutter effect.
78_FEATURE_PTS_MIN = 30  # Min number of feature pts to perform rotation analysis.
79# cv2.goodFeatures to track.
80# 'POSTMASK' is the measurement method in all previous versions of Android.
81# 'POSTMASK' finds best features on entire frame and then masks the features
82# to the vertical center FEATURE_MARGIN for the measurement.
83# 'PREMASK' is a new measurement that is used when FEATURE_PTS_MIN is not
84# found in frame. This finds the best 2*FEATURE_PTS_MIN in the FEATURE_MARGIN
85# part of the frame.
86_CV2_FEATURE_PARAMS_POSTMASK = dict(maxCorners=240,
87                                    qualityLevel=0.3,
88                                    minDistance=7,
89                                    blockSize=7)
90_CV2_FEATURE_PARAMS_PREMASK = dict(maxCorners=2*_FEATURE_PTS_MIN,
91                                   qualityLevel=0.3,
92                                   minDistance=7,
93                                   blockSize=7)
94_GYRO_SAMP_RATE_MIN = 100.0  # Samples/second: min gyro sample rate.
95_CV2_LK_PARAMS = dict(winSize=(15, 15),
96                      maxLevel=2,
97                      criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
98                                10, 0.03))  # cv2.calcOpticalFlowPyrLK params.
99_ROTATION_PER_FRAME_MIN = 0.001  # rads/s
100_GYRO_ROTATION_PER_SEC_MAX = 2.0  # rads/s
101_R_SQUARED_TOLERANCE = 0.01  # tolerance for polynomial fitting r^2
102_SHIFT_DOMAIN_RADIUS = 5  # limited domain centered around best shift
103
104# unittest constants
105_COARSE_FIT_RANGE = 20  # Range area around coarse fit to do optimization.
106_CORR_TIME_OFFSET_MAX = 50  # ms max shift to try and match camera/gyro times.
107_CORR_TIME_OFFSET_STEP = 0.5  # ms step for shifts.
108
109# Unit translators
110_MSEC_TO_NSEC = 1000000
111_NSEC_TO_SEC = 1E-9
112_SEC_TO_NSEC = int(1/_NSEC_TO_SEC)
113_RADS_TO_DEGS = 180/math.pi
114
115_NUM_GYRO_PTS_TO_AVG = 20
116
117
118def polynomial_from_coefficients(coefficients):
119  """Return a polynomial function from a coefficient list, highest power first.
120
121  Args:
122    coefficients: list of coefficients (float)
123  Returns:
124    Function in the form of a*x^n + b*x^(n - 1) + ... + constant
125  """
126  def polynomial(x):
127    n = len(coefficients)
128    return sum(coefficients[i] * x ** (n - i - 1) for i in range(n))
129  return polynomial
130
131
132def smallest_absolute_minimum_of_polynomial(coefficients):
133  """Return the smallest minimum by absolute value from a coefficient list.
134
135  Args:
136    coefficients: list of coefficients (float)
137  Returns:
138    Smallest local minimum (by absolute value) on the function (float)
139  """
140  first_derivative = np.polyder(coefficients, m=1)
141  second_derivative = np.polyder(coefficients, m=2)
142  extrema = np.roots(first_derivative)
143  smallest_absolute_minimum = None
144  for extremum in extrema:
145    if np.polyval(second_derivative, extremum) > 0:
146      if smallest_absolute_minimum is None or abs(extremum) < abs(
147          smallest_absolute_minimum):
148        smallest_absolute_minimum = extremum
149  if smallest_absolute_minimum is None:
150    raise AssertionError(
151        f'No minima were found on function described by {coefficients}.')
152  return smallest_absolute_minimum
153
154
155def serial_port_def(name):
156  """Determine the serial port and open.
157
158  Args:
159    name: string of device to locate (ie. 'Arduino', 'Canakit' or 'Default')
160  Returns:
161    serial port object
162  """
163  serial_port = None
164  devices = list_ports.comports()
165  for device in devices:
166    if not (device.vid and device.pid):  # Not all comm ports have vid and pid
167      continue
168    if name.lower() == 'arduino':
169      if (device.vid in ARDUINO_VIDS and device.pid == ARDUINO_PID):
170        logging.debug('Arduino: %s', str(device))
171        serial_port = device.device
172        return serial.Serial(serial_port, ARDUINO_BAUDRATE, timeout=1)
173
174    elif name.lower() in ('canakit', 'default'):
175      if (device.vid == CANAKIT_VID and device.pid == CANAKIT_PID):
176        logging.debug('Canakit: %s', str(device))
177        serial_port = device.device
178        return serial.Serial(serial_port, CANAKIT_BAUDRATE,
179                             timeout=CANAKIT_SEND_TIMEOUT,
180                             parity=serial.PARITY_EVEN,
181                             stopbits=serial.STOPBITS_ONE,
182                             bytesize=serial.EIGHTBITS)
183  raise ValueError(f'{name} device not connected.')
184
185
186def canakit_cmd_send(canakit_serial_port, cmd_str):
187  """Wrapper for sending serial command to Canakit.
188
189  Args:
190    canakit_serial_port: port to write for canakit
191    cmd_str: str; value to send to device.
192  """
193  try:
194    logging.debug('writing port...')
195    canakit_serial_port.write(CANAKIT_DATA_DELIMITER.encode())
196    time.sleep(CANAKIT_CMD_TIME)  # This is critical for relay.
197    canakit_serial_port.write(cmd_str.encode())
198
199  except IOError as io_error:
200    raise IOError(
201        f'Port {CANAKIT_VID}:{CANAKIT_PID} is not open!') from io_error
202
203
204def canakit_set_relay_channel_state(canakit_port, ch, state):
205  """Set Canakit relay channel and state.
206
207  Waits CANAKIT_SLEEP_TIME for rotation to occur.
208
209  Args:
210    canakit_port: serial port object for the Canakit port.
211    ch: string for channel number of relay to set. '1', '2', '3', or '4'
212    state: string of either 'ON' or 'OFF'
213  """
214  logging.debug('Setting relay state %s', state)
215  if ch in CANAKIT_VALID_CH and state in CANAKIT_VALID_CMD:
216    canakit_cmd_send(canakit_port, CANAKIT_SET_CMD + ch + '.' + state + '\r\n')
217    time.sleep(CANAKIT_SLEEP_TIME)
218  else:
219    logging.debug('Invalid ch (%s) or state (%s), no command sent.', ch, state)
220
221
222def arduino_read_cmd(port):
223  """Read back Arduino command from serial port."""
224  cmd = []
225  for _ in range(ARDUINO_CMD_LENGTH):
226    cmd.append(port.read())
227  return cmd
228
229
230def arduino_send_cmd(port, cmd):
231  """Send command to serial port."""
232  for i in range(ARDUINO_CMD_LENGTH):
233    port.write(cmd[i])
234
235
236def arduino_loopback_cmd(port, cmd):
237  """Send command to serial port."""
238  arduino_send_cmd(port, cmd)
239  time.sleep(ARDUINO_CMD_TIME)
240  return arduino_read_cmd(port)
241
242
243def establish_serial_comm(port):
244  """Establish connection with serial port."""
245  logging.debug('Establishing communication with %s', port.name)
246  trys = 1
247  hex_test = convert_to_hex(ARDUINO_TEST_CMD)
248  logging.debug(' test tx: %s %s %s', hex_test[0], hex_test[1], hex_test[2])
249  start = time.time()
250  while time.time() < start + ARDUINO_START_TIMEOUT:
251    try:
252      cmd_read = arduino_loopback_cmd(port, ARDUINO_TEST_CMD)
253    except serial.serialutil.SerialException as _:
254      logging.debug('Port in use, trying again...')
255      continue
256    hex_read = convert_to_hex(cmd_read)
257    logging.debug(' test rx: %s %s %s', hex_read[0], hex_read[1], hex_read[2])
258    if cmd_read != list(ARDUINO_TEST_CMD):
259      trys += 1
260    else:
261      logging.debug(' Arduino comm established after %d try(s)', trys)
262      break
263  else:
264    raise AssertionError(f'Arduino comm not established after {trys} tries '
265                         f'and {ARDUINO_START_TIMEOUT} seconds')
266
267
268def convert_to_hex(cmd):
269  return [('%0.2x' % int(codecs.encode(x, 'hex_codec'), 16) if x else '--')
270          for x in cmd]
271
272
273def arduino_rotate_servo_to_angle(ch, angle, serial_port, move_time):
274  """Rotate servo to the specified angle.
275
276  Args:
277    ch: str; servo to rotate in ARDUINO_VALID_CH
278    angle: int; servo angle to move to
279    serial_port: object; serial port
280    move_time: int; time in seconds
281  """
282  if angle < 0 or angle > ARDUINO_ANGLE_MAX:
283    logging.debug('Angle must be between 0 and %d.', ARDUINO_ANGLE_MAX)
284    angle = 0
285    if angle > ARDUINO_ANGLE_MAX:
286      angle = ARDUINO_ANGLE_MAX
287
288  cmd = [struct.pack('B', i) for i in [ARDUINO_START_BYTE, int(ch), angle]]
289  arduino_send_cmd(serial_port, cmd)
290  time.sleep(move_time)
291
292
293def arduino_rotate_servo(ch, angles, move_time, serial_port):
294  """Rotate servo through 'angles'.
295
296  Args:
297    ch: str; servo to rotate
298    angles: list of ints; servo angles to move to
299    move_time: int; time required to allow for arduino movement
300    serial_port: object; serial port
301  """
302
303  for angle in angles:
304    angle_norm = int(round(angle*ARDUINO_ANGLE_MAX/HS755HB_ANGLE_MAX, 0))
305    arduino_rotate_servo_to_angle(ch, angle_norm, serial_port, move_time)
306
307
308def rotation_rig(rotate_cntl, rotate_ch, num_rotations, angles, servo_speed,
309                 move_time, arduino_serial_port):
310  """Rotate the phone n times using rotate_cntl and rotate_ch defined.
311
312  rotate_ch is hard wired and must be determined from physical setup.
313  If using Arduino, serial port must be initialized and communication must be
314  established before rotation.
315
316  Args:
317    rotate_cntl: str to identify 'arduino', 'canakit' or 'external' controller.
318    rotate_ch: str to identify rotation channel number.
319    num_rotations: int number of rotations.
320    angles: list of ints; servo angle to move to.
321    servo_speed: int number of move speed between [1, 255].
322    move_time: int time required to allow for arduino movement.
323    arduino_serial_port: optional initialized serial port object
324  """
325
326  logging.debug('Controller: %s, ch: %s', rotate_cntl, rotate_ch)
327  if arduino_serial_port:
328    # initialize servo at origin
329    logging.debug('Moving servo to origin')
330    arduino_rotate_servo_to_angle(rotate_ch, 0, arduino_serial_port, 1)
331
332    # set servo speed
333    set_servo_speed(rotate_ch, servo_speed, arduino_serial_port, delay=0)
334  elif rotate_cntl.lower() == 'canakit':
335    canakit_serial_port = serial_port_def('Canakit')
336  elif rotate_cntl.lower() == 'external':
337    logging.info('External rotation control.')
338  else:
339    logging.info('No rotation rig defined. Manual test: rotate phone by hand.')
340
341  # rotate phone
342  logging.debug('Rotating phone %dx', num_rotations)
343  for _ in range(num_rotations):
344    if rotate_cntl == 'arduino':
345      arduino_rotate_servo(rotate_ch, angles, move_time, arduino_serial_port)
346    elif rotate_cntl == 'canakit':
347      canakit_set_relay_channel_state(canakit_serial_port, rotate_ch, 'ON')
348      canakit_set_relay_channel_state(canakit_serial_port, rotate_ch, 'OFF')
349  logging.debug('Finished rotations')
350  if rotate_cntl == 'arduino':
351    logging.debug('Moving servo to origin')
352    arduino_rotate_servo_to_angle(rotate_ch, 0, arduino_serial_port, 1)
353
354
355def set_servo_speed(ch, servo_speed, serial_port, delay=0):
356  """Set servo to specified speed.
357
358  Args:
359    ch: str; servo to turn on in ARDUINO_VALID_CH
360    servo_speed: int; value of speed between 1 and 255
361    serial_port: object; serial port
362    delay: int; time in seconds
363  """
364  logging.debug('Servo speed: %d', servo_speed)
365  if servo_speed < ARDUINO_SERVO_SPEED_MIN:
366    logging.debug('Servo speed must be >= %d.', ARDUINO_SERVO_SPEED_MIN)
367    servo_speed = ARDUINO_SERVO_SPEED_MIN
368  elif servo_speed > ARDUINO_SERVO_SPEED_MAX:
369    logging.debug('Servo speed must be <= %d.', ARDUINO_SERVO_SPEED_MAX)
370    servo_speed = ARDUINO_SERVO_SPEED_MAX
371
372  cmd = [struct.pack('B', i) for i in [ARDUINO_SPEED_START_BYTE,
373                                       int(ch), servo_speed]]
374  arduino_send_cmd(serial_port, cmd)
375  time.sleep(delay)
376
377
378def calc_max_rotation_angle(rotations, sensor_type):
379  """Calculates the max angle of deflection from rotations.
380
381  Args:
382    rotations: numpy array of rotation per event
383    sensor_type: string 'Camera' or 'Gyro'
384
385  Returns:
386    maximum angle of rotation for the given rotations
387  """
388  rotations *= _RADS_TO_DEGS
389  rotations_sum = np.cumsum(rotations)
390  rotation_max = max(rotations_sum)
391  rotation_min = min(rotations_sum)
392  logging.debug('%s min: %.2f, max %.2f rotation (degrees)',
393                sensor_type, rotation_min, rotation_max)
394  logging.debug('%s max rotation: %.2f degrees',
395                sensor_type, (rotation_max-rotation_min))
396  return rotation_max-rotation_min
397
398
399def get_gyro_rotations(gyro_events, cam_times):
400  """Get the rotation values of the gyro.
401
402  Integrates the gyro data between each camera frame to compute an angular
403  displacement.
404
405  Args:
406    gyro_events: List of gyro event objects.
407    cam_times: Array of N camera times, one for each frame.
408
409  Returns:
410    Array of N-1 gyro rotation measurements (rads/s).
411  """
412  gyro_times = np.array([e['time'] for e in gyro_events])
413  all_gyro_rots = np.array([e['z'] for e in gyro_events])
414  gyro_rots = []
415  if gyro_times[0] > cam_times[0] or gyro_times[-1] < cam_times[-1]:
416    raise AssertionError('Gyro times do not bound camera times! '
417                         f'gyro: {gyro_times[0]:.0f} -> {gyro_times[-1]:.0f} '
418                         f'cam: {cam_times[0]} -> {cam_times[-1]} (ns).')
419
420  # Integrate the gyro data between each pair of camera frame times.
421  for i_cam in range(len(cam_times)-1):
422    # Get the window of gyro samples within the current pair of frames.
423    # Note: bisect always picks first gyro index after camera time.
424    t_cam0 = cam_times[i_cam]
425    t_cam1 = cam_times[i_cam+1]
426    i_gyro_window0 = bisect.bisect(gyro_times, t_cam0)
427    i_gyro_window1 = bisect.bisect(gyro_times, t_cam1)
428    gyro_sum = 0
429
430    # Integrate samples within the window.
431    for i_gyro in range(i_gyro_window0, i_gyro_window1):
432      gyro_val = all_gyro_rots[i_gyro+1]
433      t_gyro0 = gyro_times[i_gyro]
434      t_gyro1 = gyro_times[i_gyro+1]
435      t_gyro_delta = (t_gyro1 - t_gyro0) * _NSEC_TO_SEC
436      gyro_sum += gyro_val * t_gyro_delta
437
438    # Handle the fractional intervals at the sides of the window.
439    for side, i_gyro in enumerate([i_gyro_window0-1, i_gyro_window1]):
440      gyro_val = all_gyro_rots[i_gyro+1]
441      t_gyro0 = gyro_times[i_gyro]
442      t_gyro1 = gyro_times[i_gyro+1]
443      t_gyro_delta = (t_gyro1 - t_gyro0) * _NSEC_TO_SEC
444      if side == 0:
445        f = (t_cam0 - t_gyro0) / (t_gyro1 - t_gyro0)
446        frac_correction = gyro_val * t_gyro_delta * (1.0 - f)
447        gyro_sum += frac_correction
448      else:
449        f = (t_cam1 - t_gyro0) / (t_gyro1 - t_gyro0)
450        frac_correction = gyro_val * t_gyro_delta * f
451        gyro_sum += frac_correction
452    gyro_rots.append(gyro_sum)
453  gyro_rots = np.array(gyro_rots)
454  return gyro_rots
455
456
457def procrustes_rotation(x, y):
458  """Performs a Procrustes analysis to conform points in x to y.
459
460  Procrustes analysis determines a linear transformation (translation,
461  reflection, orthogonal rotation and scaling) of the points in y to best
462  conform them to the points in matrix x, using the sum of squared errors
463  as the metric for fit criterion.
464
465  Args:
466    x: Target coordinate matrix
467    y: Input coordinate matrix
468
469  Returns:
470    The rotation component of the transformation that maps x to y.
471  """
472  x0 = (x-x.mean(0)) / np.sqrt(((x-x.mean(0))**2.0).sum())
473  y0 = (y-y.mean(0)) / np.sqrt(((y-y.mean(0))**2.0).sum())
474  u, _, vt = np.linalg.svd(np.dot(x0.T, y0), full_matrices=False)
475  return np.dot(vt.T, u.T)
476
477
478def get_cam_rotations(frames, facing, h, file_name_stem,
479                      start_frame, stabilized_video=False):
480  """Get the rotations of the camera between each pair of frames.
481
482  Takes N frames and returns N-1 angular displacements corresponding to the
483  rotations between adjacent pairs of frames, in radians.
484  Only takes feature points from center so that rotation measured has less
485  rolling shutter effect.
486  Requires FEATURE_PTS_MIN to have enough data points for accurate measurements.
487  Uses FEATURE_PARAMS for cv2 to identify features in checkerboard images.
488  Ensures camera rotates enough if not calling with stabilized video.
489
490  Args:
491    frames: List of N images (as RGB numpy arrays).
492    facing: Direction camera is facing.
493    h: Pixel height of each frame.
494    file_name_stem: file name stem including location for data.
495    start_frame: int; index to start at
496    stabilized_video: Boolean; if called with stabilized video
497
498  Returns:
499    numpy array of N-1 camera rotation measurements (rad).
500  """
501  gframes = []
502  for frame in frames:
503    frame = (frame * 255.0).astype(np.uint8)  # cv2 uses [0, 255]
504    gframes.append(cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY))
505  num_frames = len(gframes)
506  logging.debug('num_frames: %d', num_frames)
507  # create mask
508  ymin = int(h * (1 - _FEATURE_MARGIN) / 2)
509  ymax = int(h * (1 + _FEATURE_MARGIN) / 2)
510  pre_mask = np.zeros_like(gframes[0])
511  pre_mask[ymin:ymax, :] = 255
512
513  for masking in ['post', 'pre']:  # Do post-masking (original) method 1st
514    logging.debug('Using %s masking method', masking)
515    rotations = []
516    for i in range(1, num_frames):
517      j = i - 1
518      gframe0 = gframes[j]
519      gframe1 = gframes[i]
520      if masking == 'post':
521        p0 = cv2.goodFeaturesToTrack(
522            gframe0, mask=None, **_CV2_FEATURE_PARAMS_POSTMASK)
523        post_mask = (p0[:, 0, 1] >= ymin) & (p0[:, 0, 1] <= ymax)
524        p0_filtered = p0[post_mask]
525      else:
526        p0_filtered = cv2.goodFeaturesToTrack(
527            gframe0, mask=pre_mask, **_CV2_FEATURE_PARAMS_PREMASK)
528      num_features = len(p0_filtered)
529      if num_features < _FEATURE_PTS_MIN:
530        for pt in np.rint(p0_filtered).astype(int):
531          x, y = pt[0][0], pt[0][1]
532          cv2.circle(frames[j], (x, y), 3, (100, 255, 255), -1)
533        image_processing_utils.write_image(
534            frames[j], f'{file_name_stem}_features{j+start_frame:03d}.png')
535        msg = (f'Not enough features in frame {j+start_frame}. Need at least '
536               f'{_FEATURE_PTS_MIN} features, got {num_features}.')
537        if masking == 'pre':
538          raise AssertionError(msg)
539        else:
540          logging.debug(msg)
541          break
542      else:
543        logging.debug('Number of features in frame %s is %d',
544                      str(j+start_frame).zfill(3), num_features)
545      p1, st, _ = cv2.calcOpticalFlowPyrLK(gframe0, gframe1, p0_filtered, None,
546                                           **_CV2_LK_PARAMS)
547      tform = procrustes_rotation(p0_filtered[st == 1], p1[st == 1])
548      if facing == camera_properties_utils.LENS_FACING['BACK']:
549        rotation = -math.atan2(tform[0, 1], tform[0, 0])
550      elif facing == camera_properties_utils.LENS_FACING['FRONT']:
551        rotation = math.atan2(tform[0, 1], tform[0, 0])
552      else:
553        raise AssertionError(f'Unknown lens facing: {facing}.')
554      rotations.append(rotation)
555      if i == 1:
556        # Save debug visualization of features that are being
557        # tracked in the first frame.
558        frame = frames[j]
559        for x, y in np.rint(p0_filtered[st == 1]).astype(int):
560          cv2.circle(frame, (x, y), 3, (100, 255, 255), -1)
561        image_processing_utils.write_image(
562            frame, f'{file_name_stem}_features{j+start_frame:03d}.png')
563    if i == num_frames-1:
564      logging.debug('Correct num of frames found: %d', i)
565      break  # exit if enough features in all frames
566  if i != num_frames-1:
567    raise AssertionError('Neither method found enough features in all frames')
568
569  rotations = np.array(rotations)
570  rot_per_frame_max = max(abs(rotations))
571  logging.debug('Max rotation in frame: %.2f degrees',
572                rot_per_frame_max*_RADS_TO_DEGS)
573  if stabilized_video:
574    logging.debug('Skipped camera rotation check due to stabilized video.')
575  else:
576    if rot_per_frame_max < _ROTATION_PER_FRAME_MIN:
577      raise AssertionError(f'Device not moved enough: {rot_per_frame_max:.3f} '
578                           f'movement. THRESH: {_ROTATION_PER_FRAME_MIN} rads.')
579    else:
580      logging.debug('Device movement exceeds %.2f degrees',
581                    _ROTATION_PER_FRAME_MIN*_RADS_TO_DEGS)
582  return rotations
583
584
585def get_best_alignment_offset(cam_times, cam_rots, gyro_events, degree=2):
586  """Find the best offset to align the camera and gyro motion traces.
587
588  This function integrates the shifted gyro data between camera samples
589  for a range of candidate shift values, and returns the shift that
590  result in the best correlation.
591
592  Uses a correlation distance metric between the curves, where a smaller
593  value means that the curves are better-correlated.
594
595  Fits a curve to the correlation distance data to measure the minima more
596  accurately, by looking at the correlation distances within a range of
597  +/- 10ms from the measured best score; note that this will use fewer
598  than the full +/- 10 range for the curve fit if the measured score
599  (which is used as the center of the fit) is within 10ms of the edge of
600  the +/- 50ms candidate range.
601
602  Args:
603    cam_times: Array of N camera times, one for each frame.
604    cam_rots: Array of N-1 camera rotation displacements (rad).
605    gyro_events: List of gyro event objects.
606    degree: Degree of polynomial
607
608  Returns:
609    Best alignment offset(ms), fit coefficients, candidates, and distances.
610  """
611  # Measure the correlation distance over defined shift
612  shift_candidates = np.arange(-_CORR_TIME_OFFSET_MAX,
613                               _CORR_TIME_OFFSET_MAX+_CORR_TIME_OFFSET_STEP,
614                               _CORR_TIME_OFFSET_STEP).tolist()
615  spatial_distances = []
616  for shift in shift_candidates:
617    shifted_cam_times = cam_times + shift*_MSEC_TO_NSEC
618    gyro_rots = get_gyro_rotations(gyro_events, shifted_cam_times)
619    spatial_distance = scipy.spatial.distance.correlation(cam_rots, gyro_rots)
620    logging.debug('shift %.1fms spatial distance: %.5f', shift,
621                  spatial_distance)
622    spatial_distances.append(spatial_distance)
623
624  best_corr_dist = min(spatial_distances)
625  coarse_best_shift = shift_candidates[spatial_distances.index(best_corr_dist)]
626  logging.debug('Best shift without fitting is %.4f ms', coarse_best_shift)
627
628  # Fit a polynomial around coarse_best_shift to extract best fit
629  i = spatial_distances.index(best_corr_dist)
630  i_poly_fit_min = i - _COARSE_FIT_RANGE
631  i_poly_fit_max = i + _COARSE_FIT_RANGE + 1
632  shift_candidates = shift_candidates[i_poly_fit_min:i_poly_fit_max]
633  spatial_distances = spatial_distances[i_poly_fit_min:i_poly_fit_max]
634  logging.debug('Polynomial degree: %d', degree)
635  fit_coeffs, residuals, _, _, _ = np.polyfit(
636      shift_candidates, spatial_distances, degree, full=True
637  )
638  logging.debug('Fit coefficients: %s', fit_coeffs)
639  logging.debug('Residuals: %s', residuals)
640  total_sum_of_squares = np.sum(
641      (spatial_distances - np.mean(spatial_distances)) ** 2
642  )
643  # Calculate r-squared on the entire domain for debugging
644  r_squared = 1 - residuals[0] / total_sum_of_squares
645  logging.debug('r^2 on the entire domain: %f', r_squared)
646
647  # Calculate r-squared near the best shift
648  domain_around_best_shift = [coarse_best_shift - _SHIFT_DOMAIN_RADIUS,
649                              coarse_best_shift + _SHIFT_DOMAIN_RADIUS]
650  logging.debug('Calculating r^2 on the limited domain of [%f, %f]',
651                domain_around_best_shift[0], domain_around_best_shift[1])
652  small_shifts_and_distances = [
653      (x, y)
654      for x, y in zip(shift_candidates, spatial_distances)
655      if domain_around_best_shift[0] <= x <= domain_around_best_shift[1]
656  ]
657  small_shift_candidates, small_spatial_distances = zip(
658      *small_shifts_and_distances
659  )
660  logging.debug('Shift candidates on limited domain: %s',
661                small_shift_candidates)
662  logging.debug('Spatial distances on limited domain: %s',
663                small_spatial_distances)
664  limited_residuals = np.sum(
665      (np.polyval(fit_coeffs, small_shift_candidates) - small_spatial_distances)
666      ** 2
667  )
668  logging.debug('Residuals on limited domain: %s', limited_residuals)
669  limited_total_sum_of_squares = np.sum(
670      (small_spatial_distances - np.mean(small_spatial_distances)) ** 2
671  )
672  limited_r_squared = 1 - limited_residuals / limited_total_sum_of_squares
673  logging.debug('r^2 on limited domain: %f', limited_r_squared)
674
675  # Calculate exact_best_shift (x where y is minimum of parabola)
676  exact_best_shift = smallest_absolute_minimum_of_polynomial(fit_coeffs)
677
678  if abs(coarse_best_shift - exact_best_shift) > 2.0:
679    raise AssertionError(
680        f'Test failed. Bad fit to time-shift curve. Coarse best shift: '
681        f'{coarse_best_shift}, Exact best shift: {exact_best_shift}.')
682
683  # Check fit of polynomial near the best shift
684  if not math.isclose(limited_r_squared, 1, abs_tol=_R_SQUARED_TOLERANCE):
685    logging.debug('r-squared on domain [%f, %f] was %f, expected 1.0, '
686                  'ATOL: %f',
687                  domain_around_best_shift[0], domain_around_best_shift[1],
688                  limited_r_squared, _R_SQUARED_TOLERANCE)
689    return None
690
691  return exact_best_shift, fit_coeffs, shift_candidates, spatial_distances
692
693
694def plot_camera_rotations(cam_rots, start_frame, video_quality,
695                          plot_name_stem):
696  """Plot the camera rotations.
697
698  Args:
699   cam_rots: np array of camera rotations angle per frame
700   start_frame: int value of start frame
701   video_quality: str for video quality identifier
702   plot_name_stem: str (with path) of what to call plot
703  """
704
705  plt.figure(video_quality)
706  frames = range(start_frame, len(cam_rots)+start_frame)
707  plt.title(f'Camera rotation vs frame {video_quality}')
708  plt.plot(frames, cam_rots*_RADS_TO_DEGS, '-ro', label='x')
709  plt.xlabel('frame #')
710  plt.ylabel('camera rotation (degrees)')
711  plt.savefig(f'{plot_name_stem}_cam_rots.png')
712  plt.close(video_quality)
713
714
715def plot_gyro_events(gyro_events, plot_name, log_path):
716  """Plot x, y, and z on the gyro events.
717
718  Samples are grouped into NUM_GYRO_PTS_TO_AVG groups and averaged to minimize
719  random spikes in data.
720
721  Args:
722    gyro_events: List of gyroscope events.
723    plot_name:  name of plot(s).
724    log_path: location to save data.
725  """
726
727  nevents = (len(gyro_events) // _NUM_GYRO_PTS_TO_AVG) * _NUM_GYRO_PTS_TO_AVG
728  gyro_events = gyro_events[:nevents]
729  times = np.array([(e['time'] - gyro_events[0]['time']) * _NSEC_TO_SEC
730                    for e in gyro_events])
731  x = np.array([e['x'] for e in gyro_events])
732  y = np.array([e['y'] for e in gyro_events])
733  z = np.array([e['z'] for e in gyro_events])
734
735  # Group samples into size-N groups & average each together to minimize random
736  # spikes in data.
737  times = times[_NUM_GYRO_PTS_TO_AVG//2::_NUM_GYRO_PTS_TO_AVG]
738  x = x.reshape(nevents//_NUM_GYRO_PTS_TO_AVG, _NUM_GYRO_PTS_TO_AVG).mean(1)
739  y = y.reshape(nevents//_NUM_GYRO_PTS_TO_AVG, _NUM_GYRO_PTS_TO_AVG).mean(1)
740  z = z.reshape(nevents//_NUM_GYRO_PTS_TO_AVG, _NUM_GYRO_PTS_TO_AVG).mean(1)
741
742  plt.figure(plot_name)
743  # x & y on same axes
744  plt.subplot(2, 1, 1)
745  plt.title(f'{plot_name}(mean of {_NUM_GYRO_PTS_TO_AVG} pts)')
746  plt.plot(times, x, 'r', label='x')
747  plt.plot(times, y, 'g', label='y')
748  plt.ylim([np.amin(z)/4, np.amax(z)/4])  # zoom in 4x from z axis
749  plt.ylabel('gyro x,y movement (rads/s)')
750  plt.legend()
751
752  # z on separate axes
753  plt.subplot(2, 1, 2)
754  plt.plot(times, z, 'b', label='z')
755  plt.ylim([np.amin(z), np.amax(z)])
756  plt.xlabel('time (seconds)')
757  plt.ylabel('gyro z movement (rads/s)')
758  plt.legend()
759  file_name = os.path.join(log_path, plot_name)
760  plt.savefig(f'{file_name}_gyro_events.png')
761  plt.close(plot_name)
762
763  z_max = max(abs(z))
764  logging.debug('z_max: %.3f', z_max)
765  if z_max > _GYRO_ROTATION_PER_SEC_MAX:
766    raise AssertionError(
767        f'Phone moved too rapidly! Please confirm controller firmware. '
768        f'Max: {z_max:.3f}, TOL: {_GYRO_ROTATION_PER_SEC_MAX} rads/s')
769
770
771def conv_acceleration_to_movement(gyro_events, video_delay_time):
772  """Convert gyro_events time and speed to movement during video time.
773
774  Args:
775    gyro_events: sorted dict of entries with 'time', 'x', 'y', and 'z'
776    video_delay_time: time at which video starts (and the video's duration)
777
778  Returns:
779    'z' acceleration converted to movement for times around VIDEO playing.
780  """
781  gyro_times = np.array([e['time'] for e in gyro_events])
782  gyro_speed = np.array([e['z'] for e in gyro_events])
783  gyro_time_min = gyro_times[0]
784  logging.debug('gyro start time: %dns', gyro_time_min)
785  logging.debug('gyro stop time: %dns', gyro_times[-1])
786  gyro_rotations = []
787  video_time_start = gyro_time_min + video_delay_time *_SEC_TO_NSEC
788  video_time_stop = video_time_start + video_delay_time *_SEC_TO_NSEC
789  logging.debug('video start time: %dns', video_time_start)
790  logging.debug('video stop time: %dns', video_time_stop)
791
792  for i, t in enumerate(gyro_times):
793    if video_time_start <= t <= video_time_stop:
794      gyro_rotations.append((gyro_times[i]-gyro_times[i-1])/_SEC_TO_NSEC *
795                            gyro_speed[i])
796  return np.array(gyro_rotations)
797