xref: /aosp_15_r20/cts/apps/CameraITS/utils/noise_model_utils.py (revision b7c941bb3fa97aba169d73cee0bed2de8ac964bf)
1# Copyright 2014 The Android Open Source Project.
2
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#      http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Noise model utility functions."""
15
16import collections
17import logging
18import math
19import os.path
20import pickle
21from typing import Any, Dict, List, Tuple
22import warnings
23import capture_request_utils
24import image_processing_utils
25from matplotlib import pyplot as plt
26import noise_model_constants
27import numpy as np
28import scipy.stats
29
30
31_OUTLIER_MEDIAN_ABS_DEVS_DEFAULT = (
32    noise_model_constants.OUTLIER_MEDIAN_ABS_DEVS_DEFAULT
33)
34
35
36def _check_auto_exposure_targets(
37    auto_exposure_ns: float,
38    sens_min: int,
39    sens_max: int,
40    bracket_factor: int,
41    min_exposure_ns: int,
42    max_exposure_ns: int,
43) -> None:
44  """Checks if AE too bright for highest gain & too dark for lowest gain.
45
46  Args:
47    auto_exposure_ns: The auto exposure value in nanoseconds.
48    sens_min: The minimum sensitivity value.
49    sens_max: The maximum sensitivity value.
50    bracket_factor: Exposure bracket factor.
51    min_exposure_ns: The minimum exposure time in nanoseconds.
52    max_exposure_ns: The maximum exposure time in nanoseconds.
53  """
54
55  if auto_exposure_ns < min_exposure_ns * sens_max:
56    raise AssertionError(
57        'Scene is too bright to properly expose at highest '
58        f'sensitivity: {sens_max}'
59    )
60  if auto_exposure_ns * bracket_factor > max_exposure_ns * sens_min:
61    raise AssertionError(
62        'Scene is too dark to properly expose at lowest '
63        f'sensitivity: {sens_min}'
64    )
65
66
67def check_noise_model_shape(noise_model: np.ndarray) -> None:
68  """Checks if the shape of noise model is valid.
69
70  Args:
71    noise_model: A numpy array of shape (num_channels, num_parameters).
72  """
73  num_channels, num_parameters = noise_model.shape
74  if num_channels not in noise_model_constants.VALID_NUM_CHANNELS:
75    raise AssertionError(
76        f'The number of channels {num_channels} is not in'
77        f' {noise_model_constants.VALID_NUM_CHANNELS}.'
78    )
79  if num_parameters != 4:
80    raise AssertionError(
81        f'The number of parameters of each channel {num_parameters} != 4.'
82    )
83
84
85def validate_noise_model(
86    noise_model: np.ndarray,
87    color_channels: List[str],
88    sens_min: int,
89) -> None:
90  """Performs validation checks on the noise model.
91
92  This function checks if read noise and intercept gradient are positive for
93  each color channel.
94
95  Args:
96      noise_model: Noise model parameters each channel, including scale_a,
97        scale_b, offset_a, offset_b.
98      color_channels: Array of color channels.
99      sens_min: Minimum sensitivity value.
100  """
101  check_noise_model_shape(noise_model)
102  num_channels = noise_model.shape[0]
103  if len(color_channels) != num_channels:
104    raise AssertionError(
105        f'Number of color channels {num_channels} != number of noise model '
106        f'channels {len(color_channels)}.'
107    )
108
109  scale_a, _, offset_a, offset_b = zip(*noise_model)
110  for i, color_channel in enumerate(color_channels):
111    if scale_a[i] < 0:
112      raise AssertionError(
113          f'{color_channel} model API scale gradient < 0: {scale_a[i]:.4e}'
114      )
115
116    if offset_a[i] <= 0:
117      raise AssertionError(
118          f'{color_channel} model API intercept gradient < 0: {offset_a[i]:.4e}'
119      )
120
121    read_noise = offset_a[i] * sens_min * sens_min + offset_b[i]
122    if read_noise <= 0:
123      raise AssertionError(
124          f'{color_channel} model min ISO noise < 0! '
125          f'API intercept gradient: {offset_a[i]:.4e}, '
126          f'API intercept offset: {offset_b[i]:.4e}, '
127          f'read_noise: {read_noise:.4e}'
128      )
129
130
131def compute_digital_gains(
132    gains: np.ndarray,
133    sens_max_analog: np.ndarray,
134) -> np.ndarray:
135  """Computes the digital gains for the given gains and maximum analog gain.
136
137  Define digital gain as the gain divide the max analog gain sensitivity.
138  This function ensures that the digital gains are always equal to 1. If any
139  of the digital gains is not equal to 1, an AssertionError is raised.
140
141  Args:
142    gains: An array of gains.
143    sens_max_analog: The maximum analog gain sensitivity.
144
145  Returns:
146    An numpy array of digital gains.
147  """
148  digital_gains = np.maximum(gains / sens_max_analog, 1)
149  if not np.all(digital_gains == 1):
150    raise AssertionError(
151        f'Digital gains are not all 1! gains: {gains}, '
152        f'Max analog gain sensitivity: {sens_max_analog}.'
153    )
154  return digital_gains
155
156
157def crop_and_save_capture(
158    cap,
159    props,
160    capture_path: str,
161    num_tiles_crop: int,
162) -> None:
163  """Crops and saves a capture image.
164
165  Args:
166    cap: The capture to be cropped and saved.
167    props: The properties to be used to convert the capture to an RGB image.
168    capture_path: The path to which the capture image should be saved.
169    num_tiles_crop: The number of tiles to crop.
170  """
171  img = image_processing_utils.convert_capture_to_rgb_image(cap, props=props)
172  height, width, _ = img.shape
173  num_tiles_crop_max = min(height, width) // 2
174  if num_tiles_crop >= num_tiles_crop_max:
175    raise AssertionError(
176        f'Number of tiles to corp {num_tiles_crop} >= {num_tiles_crop_max}.'
177    )
178  img = img[
179      num_tiles_crop: height - num_tiles_crop,
180      num_tiles_crop: width - num_tiles_crop,
181      :,
182  ]
183
184  image_processing_utils.write_image(img, capture_path, True)
185
186
187def crop_and_reorder_stats_images(
188    mean_img: np.ndarray,
189    var_img: np.ndarray,
190    num_tiles_crop: int,
191    channel_indices: List[int],
192) -> Tuple[np.ndarray, np.ndarray]:
193  """Crops the stats images and sorts stats images channels in canonical order.
194
195  Args:
196      mean_img: The mean image.
197      var_img: The variance image.
198      num_tiles_crop: The number of tiles to crop from each side of the image.
199      channel_indices: The channel indices to sort stats image channels in
200        canonical order.
201
202  Returns:
203      The cropped and reordered mean image and variance image.
204  """
205  if mean_img.shape != var_img.shape:
206    raise AssertionError(
207        'Unmatched shapes of mean and variance image: '
208        f'shape of mean image is {mean_img.shape}, '
209        f'shape of variance image is {var_img.shape}.'
210    )
211  height, width, _ = mean_img.shape
212  if 2 * num_tiles_crop > min(height, width):
213    raise AssertionError(
214        f'The number of tiles to crop ({num_tiles_crop}) is so large that'
215        ' images cannot be cropped.'
216    )
217
218  means = []
219  vars_ = []
220  for i in channel_indices:
221    means_i = mean_img[
222        num_tiles_crop: height - num_tiles_crop,
223        num_tiles_crop: width - num_tiles_crop,
224        i,
225    ]
226    vars_i = var_img[
227        num_tiles_crop: height - num_tiles_crop,
228        num_tiles_crop: width - num_tiles_crop,
229        i,
230    ]
231    means.append(means_i)
232    vars_.append(vars_i)
233  means, vars_ = np.asarray(means), np.asarray(vars_)
234  return means, vars_
235
236
237def filter_stats(
238    means: np.ndarray,
239    vars_: np.ndarray,
240    black_levels: List[float],
241    white_level: float,
242    max_signal_value: float = 0.25,
243    is_remove_var_outliers: bool = False,
244    deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT,
245) -> Tuple[np.ndarray, np.ndarray]:
246  """Filters means outliers and variance outliers.
247
248  Args:
249      means: A numpy ndarray of pixel mean values.
250      vars_: A numpy ndarray of pixel variance values.
251      black_levels: A list of black levels for each pixel.
252      white_level: A scalar white level.
253      max_signal_value: The maximum signal (mean) value.
254      is_remove_var_outliers: A boolean value indicating whether to remove
255        variance outliers.
256      deviations: A scalar value specifying the number of standard deviations to
257        use when removing variance outliers.
258
259  Returns:
260      A tuple of (means_filtered, vars_filtered) where means_filtered and
261      vars_filtered are numpy ndarrays of filtered pixel mean and variance
262      values, respectively.
263  """
264  if means.shape != vars_.shape:
265    raise AssertionError(
266        f'Unmatched shapes of means and vars: means.shape={means.shape},'
267        f' vars.shape={vars_.shape}.'
268    )
269  num_planes = len(means)
270  means_filtered = []
271  vars_filtered = []
272
273  for pidx in range(num_planes):
274    black_level = black_levels[pidx]
275    means_i = means[pidx]
276    vars_i = vars_[pidx]
277
278    # Basic constraints:
279    # (1) means are within the range [0, 1],
280    # (2) vars are non-negative values.
281    constraints = [
282        means_i >= black_level,
283        means_i <= white_level,
284        vars_i >= 0,
285    ]
286    if is_remove_var_outliers:
287      # Filter out variances that differ too much from the median of variances.
288      std_dev = scipy.stats.median_abs_deviation(vars_i, axis=None, scale=1)
289      med = np.median(vars_i)
290      constraints.extend([
291          vars_i > med - deviations * std_dev,
292          vars_i < med + deviations * std_dev,
293      ])
294
295    keep_indices = np.where(np.logical_and.reduce(constraints))
296    if not np.any(keep_indices):
297      logging.info('After filter channel %d, stats array is empty.', pidx)
298
299    # Normalizes the range to [0, 1].
300    means_i = (means_i[keep_indices] - black_level) / (
301        white_level - black_level
302    )
303    vars_i = vars_i[keep_indices] / ((white_level - black_level) ** 2)
304    # Filter out the tiles if they have samples that might be clipped.
305    mean_var_pairs = list(
306        filter(
307            lambda x: x[0] + 2 * math.sqrt(x[1]) < max_signal_value,
308            zip(means_i, vars_i),
309        )
310    )
311    if mean_var_pairs:
312      means_i, vars_i = zip(*mean_var_pairs)
313    else:
314      means_i, vars_i = [], []
315    means_i = np.asarray(means_i)
316    vars_i = np.asarray(vars_i)
317    means_filtered.append(means_i)
318    vars_filtered.append(vars_i)
319
320  # After filtering, means_filtered and vars_filtered may have different shapes
321  # in each color planes.
322  means_filtered = np.asarray(means_filtered, dtype=object)
323  vars_filtered = np.asarray(vars_filtered, dtype=object)
324  return means_filtered, vars_filtered
325
326
327def get_next_iso(
328    iso: float,
329    max_iso: int,
330    iso_multiplier: float,
331) -> float:
332  """Moves to the next sensitivity.
333
334  Args:
335    iso: The current ISO sensitivity.
336    max_iso: The maximum ISO sensitivity.
337    iso_multiplier: The ISO multiplier to use.
338
339  Returns:
340    The next ISO sensitivity.
341  """
342  if iso_multiplier <= 1:
343    raise AssertionError(
344        f'ISO multiplier is {iso_multiplier}, which should be greater than 1.'
345    )
346
347  if round(iso) < max_iso < round(iso * iso_multiplier):
348    return max_iso
349  else:
350    return iso * iso_multiplier
351
352
353def capture_stats_images(
354    cam,
355    props,
356    stats_config: Dict[str, Any],
357    sens_min: int,
358    sens_max_meas: int,
359    zoom_ratio: float,
360    num_tiles_crop: int,
361    max_signal_value: float,
362    iso_multiplier: float,
363    max_bracket: int,
364    bracket_factor: int,
365    capture_path_prefix: str,
366    stats_file_name: str = '',
367    is_remove_var_outliers: bool = False,
368    outlier_median_abs_deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT,
369    is_debug_mode: bool = False,
370) -> Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]]:
371  """Capture stats images and saves the stats in a dictionary.
372
373  This function captures stats images at different ISO values and exposure
374  times, and stores the stats data in a file with the specified name.
375  The stats data includes the mean and variance of each plane, as well as
376  exposure times.
377
378  Args:
379    cam: The camera session (its_session_utils.ItsSession) for capturing stats
380      images.
381    props: Camera property object.
382    stats_config: The stats format config, a dictionary that specifies the raw
383      stats image format and tile size.
384    sens_min: The minimum sensitivity.
385    sens_max_meas: The maximum sensitivity to measure.
386    zoom_ratio: The zoom ratio to use.
387    num_tiles_crop: The number of tiles to crop the images into.
388    max_signal_value: The maximum signal value to allow.
389    iso_multiplier: The ISO multiplier to use.
390    max_bracket: The maximum number of bracketed exposures to capture.
391    bracket_factor: The bracket factor with default value 2^max_bracket.
392    capture_path_prefix: The path prefix to use for captured images.
393    stats_file_name: The name of the file to save the stats images to.
394    is_remove_var_outliers: Whether to remove variance outliers.
395    outlier_median_abs_deviations: The number of median absolute deviations to
396      use for detecting outliers.
397    is_debug_mode: Whether to enable debug mode.
398
399  Returns:
400    A dictionary mapping ISO values to mean and variance image of each plane.
401  """
402  if is_debug_mode:
403    logging.info('Capturing stats images with stats config: %s.', stats_config)
404    capture_folder = os.path.join(capture_path_prefix, 'captures')
405    if not os.path.exists(capture_folder):
406      os.makedirs(capture_folder)
407    logging.info('Capture folder: %s', capture_folder)
408
409  white_level = props['android.sensor.info.whiteLevel']
410  min_exposure_ns, max_exposure_ns = props[
411      'android.sensor.info.exposureTimeRange'
412  ]
413  # Focus at zero to intentionally blur the scene as much as possible.
414  f_dist = 0.0
415  # Whether the stats images are quad Bayer or standard Bayer.
416  is_quad_bayer = 'QuadBayer' in stats_config['format']
417  if is_quad_bayer:
418    num_channels = noise_model_constants.NUM_QUAD_BAYER_CHANNELS
419  else:
420    num_channels = noise_model_constants.NUM_BAYER_CHANNELS
421  # A dict maps iso to stats images of different exposure times.
422  iso_to_stats_dict = collections.defaultdict(list)
423  # Start the sensitivity at the minimum.
424  iso = sens_min
425  # Previous iso cap.
426  pre_iso_cap = None
427  if stats_file_name:
428    stats_file_path = os.path.join(capture_path_prefix, stats_file_name)
429    if os.path.isfile(stats_file_path):
430      try:
431        with open(stats_file_path, 'rb') as f:
432          saved_iso_to_stats_dict = pickle.load(f)
433          # Filter saved stats data.
434          if saved_iso_to_stats_dict:
435            for iso, stats in saved_iso_to_stats_dict.items():
436              if sens_min <= iso <= sens_max_meas:
437                iso_to_stats_dict[iso] = stats
438
439        # Set the starting iso to the last iso in saved stats file.
440        if iso_to_stats_dict.keys():
441          pre_iso_cap = max(iso_to_stats_dict.keys())
442          iso = get_next_iso(pre_iso_cap, sens_max_meas, iso_multiplier)
443      except OSError as e:
444        logging.exception(
445            'Failed to load stats file stored at %s. Error message: %s',
446            stats_file_path,
447            e,
448        )
449
450  if round(iso) <= sens_max_meas:
451    # Wait until camera is repositioned for noise model calibration.
452    input(
453        f'\nPress <ENTER> after covering camera lense {cam.get_camera_name()} '
454        'with frosted glass diffuser, and facing lense at evenly illuminated'
455        ' surface.\n'
456    )
457    # Do AE to get a rough idea of where we are.
458    iso_ae, exp_ae, _, _, _ = cam.do_3a(
459        get_results=True, do_awb=False, do_af=False
460    )
461
462    # Underexpose to get more data for low signal levels.
463    auto_exposure_ns = iso_ae * exp_ae / bracket_factor
464    _check_auto_exposure_targets(
465        auto_exposure_ns,
466        sens_min,
467        sens_max_meas,
468        bracket_factor,
469        min_exposure_ns,
470        max_exposure_ns,
471    )
472
473  while round(iso) <= sens_max_meas:
474    req = capture_request_utils.manual_capture_request(
475        round(iso), min_exposure_ns, f_dist
476    )
477    cap = cam.do_capture(req, stats_config)
478    # Instead of raising an error when the sensitivity readback != requested
479    # use the readback value for calculations instead.
480    iso_cap = cap['metadata']['android.sensor.sensitivity']
481
482    # Different iso values may result in captures with the same iso_cap
483    # value, so skip this capture if it's redundant.
484    if iso_cap == pre_iso_cap:
485      logging.info(
486          'Skip current capture because of the same iso %d with the previous'
487          ' capture.',
488          iso_cap,
489      )
490      iso = get_next_iso(iso, sens_max_meas, iso_multiplier)
491      continue
492    pre_iso_cap = iso_cap
493
494    logging.info('Request ISO: %d, Capture ISO: %d.', iso, iso_cap)
495
496    for bracket in range(max_bracket):
497      # Get the exposure for this sensitivity and exposure time.
498      exposure_ns = round(math.pow(2, bracket) * auto_exposure_ns / iso)
499      exposure_ms = round(exposure_ns * 1.0e-6, 3)
500      logging.info('ISO: %d, exposure time: %.3f ms.', iso_cap, exposure_ms)
501      req = capture_request_utils.manual_capture_request(
502          iso_cap,
503          exposure_ns,
504          f_dist,
505      )
506      req['android.control.zoomRatio'] = zoom_ratio
507      cap = cam.do_capture(req, stats_config)
508
509      if is_debug_mode:
510        capture_path = os.path.join(
511            capture_folder, f'iso{iso_cap}_exposure{exposure_ns}ns.jpg'
512        )
513        crop_and_save_capture(cap, props, capture_path, num_tiles_crop)
514
515      mean_img, var_img = image_processing_utils.unpack_rawstats_capture(
516          cap, num_channels=num_channels
517      )
518      cfa_order = image_processing_utils.get_canonical_cfa_order(
519          props, is_quad_bayer
520      )
521
522      means, vars_ = crop_and_reorder_stats_images(
523          mean_img,
524          var_img,
525          num_tiles_crop,
526          cfa_order,
527      )
528      if is_debug_mode:
529        logging.info('Raw stats image size: %s', mean_img.shape)
530        logging.info('R plane means image size: %s', means[0].shape)
531        logging.info(
532            'means min: %.3f, median: %.3f, max: %.3f',
533            np.min(means), np.median(means), np.max(means),
534        )
535        logging.info(
536            'vars_ min: %.4f, median: %.4f, max: %.4f',
537            np.min(vars_), np.median(vars_), np.max(vars_),
538        )
539
540      black_levels = image_processing_utils.get_black_levels(
541          props,
542          cap['metadata'],
543          is_quad_bayer,
544      )
545
546      means, vars_ = filter_stats(
547          means,
548          vars_,
549          black_levels,
550          white_level,
551          max_signal_value,
552          is_remove_var_outliers,
553          outlier_median_abs_deviations,
554      )
555
556      iso_to_stats_dict[iso_cap].append((exposure_ms, means, vars_))
557
558    if stats_file_name:
559      with open(stats_file_path, 'wb+') as f:
560        pickle.dump(iso_to_stats_dict, f)
561    iso = get_next_iso(iso, sens_max_meas, iso_multiplier)
562
563  return iso_to_stats_dict
564
565
566def measure_linear_noise_models(
567    iso_to_stats_dict: Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]],
568    color_planes: List[str],
569):
570  """Measures linear noise models.
571
572  This function measures linear noise models from means and variances for each
573  color plane and ISO setting.
574
575  Args:
576      iso_to_stats_dict: A dictionary mapping ISO settings to a list of stats
577        data.
578      color_planes: A list of color planes.
579
580  Returns:
581      A tuple containing:
582          measured_models: A list of linear models, one for each color plane.
583          samples: A list of samples, one for each color plane. Each sample is a
584              tuple of (iso, mean, var).
585  """
586  num_planes = len(color_planes)
587  # Model parameters for each color plane.
588  measured_models = [[] for _ in range(num_planes)]
589  # Samples (ISO, mean and var) of each quad Bayer color channels.
590  samples = [[] for _ in range(num_planes)]
591
592  for iso in sorted(iso_to_stats_dict.keys()):
593    logging.info('Calculating measured models for ISO %d.', iso)
594    stats_per_plane = [[] for _ in range(num_planes)]
595    for _, means, vars_ in iso_to_stats_dict[iso]:
596      for pidx in range(num_planes):
597        means_p = means[pidx]
598        vars_p = vars_[pidx]
599        if means_p.size > 0 and vars_p.size > 0:
600          stats_per_plane[pidx].extend(list(zip(means_p, vars_p)))
601
602    for pidx, mean_var_pairs in enumerate(stats_per_plane):
603      if not mean_var_pairs:
604        raise ValueError(
605            f'For ISO {iso}, samples are empty in color plane'
606            f' {color_planes[pidx]}.'
607        )
608      slope, intercept, rvalue, _, _ = scipy.stats.linregress(mean_var_pairs)
609
610      measured_models[pidx].append((iso, slope, intercept))
611      logging.info(
612          (
613              'Measured model for ISO %d and color plane %s: '
614              'y = %e * x + %e (R=%.6f).'
615          ),
616          iso, color_planes[pidx], slope, intercept, rvalue,
617      )
618
619      # Add the samples for this sensitivity to the global samples list.
620      samples[pidx].extend([(iso, mean, var) for (mean, var) in mean_var_pairs])
621
622  return measured_models, samples
623
624
625def compute_noise_model(
626    samples: List[List[Tuple[float, np.ndarray, np.ndarray]]],
627    sens_max_analog: int,
628    offset_a: np.ndarray,
629    offset_b: np.ndarray,
630    is_two_stage_model: bool = False,
631) -> np.ndarray:
632  """Computes noise model parameters from samples.
633
634  The noise model is defined by the following equation:
635    f(x) = scale * x + offset
636
637  where we have:
638    scale = scale_a * analog_gain * digital_gain + scale_b,
639    offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2.
640    scale is the multiplicative factor and offset is the offset term.
641
642  Assume digital_gain is 1.0 and scale_a, scale_b, offset_a, offset_b are
643  sa, sb, oa, ob respectively, so we have noise model function:
644  f(x) = (sa * analog_gain + sb) * x + (oa * analog_gain^2 + ob).
645
646  The noise model is fit to the mesuared data using the scipy.optimize
647  function, which uses an iterative Levenberg-Marquardt algorithm to
648  find the model parameters that minimize the mean squared error.
649
650  Args:
651    samples: A list of samples, each of which is a list of tuples of `(gains,
652      means, vars_)`.
653    sens_max_analog: The maximum analog gain.
654    offset_a: The gradient coefficients from the read noise calibration.
655    offset_b: The intercept coefficients from the read noise calibration.
656    is_two_stage_model: A boolean flag indicating if the noise model is
657      calibrated in the two-stage mode.
658
659  Returns:
660    A numpy array containing noise model parameters (scale_a, scale_b,
661    offset_a, offset_b) of each channel.
662  """
663  noise_model = []
664  for pidx, samples_p in enumerate(samples):
665    gains, means, vars_ = zip(*samples_p)
666    gains = np.asarray(gains).flatten()
667    means = np.asarray(means).flatten()
668    vars_ = np.asarray(vars_).flatten()
669
670    compute_digital_gains(gains, sens_max_analog)
671
672    # Use a global linear optimization to fit the noise model.
673    # Noise model function:
674    # f(x) = scale * x + offset
675    # Where:
676    # scale = scale_a * analog_gain * digital_gain + scale_b.
677    # offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2.
678    # Function f will be used to train the scale and offset coefficients
679    # scale_a, scale_b, offset_a, offset_b.
680    if is_two_stage_model:
681      # For the two-stage model, we want to use the line fit coefficients
682      # found from capturing read noise data (offset_a and offset_b) to
683      # train the scale coefficients.
684      oa, ob = offset_a[pidx], offset_b[pidx]
685
686      # Cannot pass oa and ob as the parameters of f since we only want
687      # curve_fit return 2 parameters.
688      def f(x, sa, sb):
689        scale = sa * x[0] + sb
690        # pylint: disable=cell-var-from-loop
691        offset = oa * x[0] ** 2 + ob
692        return (scale * x[1] + offset) / x[0]
693
694    else:
695      def f(x, sa, sb, oa, ob):
696        scale = sa * x[0] + sb
697        offset = oa * x[0] ** 2 + ob
698        return (scale * x[1] + offset) / x[0]
699
700    # Divide the whole system by gains*means.
701    coeffs, _ = scipy.optimize.curve_fit(f, (gains, means), vars_ / (gains))
702
703    # If using two-stage model, two of the coefficients calculated above are
704    # constant, so we need to append them to the coeffs ndarray.
705    if is_two_stage_model:
706      coeffs = np.append(coeffs, offset_a[pidx])
707      coeffs = np.append(coeffs, offset_b[pidx])
708
709    # coeffs[0:4] = (scale_a, scale_b, offset_a, offset_b).
710    noise_model.append(coeffs[0:4])
711
712  noise_model = np.asarray(noise_model)
713  check_noise_model_shape(noise_model)
714  return noise_model
715
716
717def create_stats_figure(
718    iso: int,
719    color_channel_names: List[str],
720):
721  """Creates a figure with subplots showing the mean and variance samples.
722
723  Args:
724    iso: The ISO setting for the images.
725    color_channel_names: A list of strings containing the names of the color
726      channels.
727
728  Returns:
729    A tuple of the figure and a list of the subplots.
730  """
731  if len(color_channel_names) not in noise_model_constants.VALID_NUM_CHANNELS:
732    raise AssertionError(
733        'The number of channels should be in'
734        f' {noise_model_constants.VALID_NUM_CHANNELS}, but found'
735        f' {len(color_channel_names)}. '
736    )
737
738  is_quad_bayer = (
739      len(color_channel_names) == noise_model_constants.NUM_QUAD_BAYER_CHANNELS
740  )
741  if is_quad_bayer:
742    # Adds a plot of the mean and variance samples for each color plane.
743    fig, axes = plt.subplots(4, 4, figsize=(22, 22))
744    fig.gca()
745    fig.suptitle('ISO %d' % iso, x=0.52, y=0.99)
746
747    cax = fig.add_axes([0.65, 0.995, 0.33, 0.003])
748    cax.set_title('log(exposure_ms):', x=-0.13, y=-2.0)
749    fig.colorbar(
750        noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal'
751    )
752
753    # Add a big axis, hide frame.
754    fig.add_subplot(111, frameon=False)
755
756    # Add a common x-axis and y-axis.
757    plt.tick_params(
758        labelcolor='none',
759        which='both',
760        top=False,
761        bottom=False,
762        left=False,
763        right=False,
764    )
765    plt.xlabel('Mean signal level', ha='center')
766    plt.ylabel('Variance', va='center', rotation='vertical')
767
768    subplots = []
769    for pidx in range(noise_model_constants.NUM_QUAD_BAYER_CHANNELS):
770      subplot = axes[pidx // 4, pidx % 4]
771      subplot.set_title(color_channel_names[pidx])
772      # Set 'y' axis to scientific notation for all numbers by setting
773      # scilimits to (0, 0).
774      subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
775      subplots.append(subplot)
776
777  else:
778    # Adds a plot of the mean and variance samples for each color plane.
779    fig, [[plt_r, plt_gr], [plt_gb, plt_b]] = plt.subplots(
780        2, 2, figsize=(11, 11)
781    )
782    fig.gca()
783    # Add color bar to show exposure times.
784    cax = fig.add_axes([0.73, 0.99, 0.25, 0.01])
785    cax.set_title('log(exposure_ms):', x=-0.3, y=-1.0)
786    fig.colorbar(
787        noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal'
788    )
789
790    subplots = [plt_r, plt_gr, plt_gb, plt_b]
791    fig.suptitle('ISO %d' % iso, x=0.54, y=0.99)
792    for pidx, subplot in enumerate(subplots):
793      subplot.set_title(color_channel_names[pidx])
794      subplot.set_xlabel('Mean signal level')
795      subplot.set_ylabel('Variance')
796      subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0))
797
798  with warnings.catch_warnings():
799    warnings.simplefilter('ignore', UserWarning)
800    plt.tight_layout()
801
802  return fig, subplots
803