1# Copyright 2014 The Android Open Source Project. 2 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Noise model utility functions.""" 15 16import collections 17import logging 18import math 19import os.path 20import pickle 21from typing import Any, Dict, List, Tuple 22import warnings 23import capture_request_utils 24import image_processing_utils 25from matplotlib import pyplot as plt 26import noise_model_constants 27import numpy as np 28import scipy.stats 29 30 31_OUTLIER_MEDIAN_ABS_DEVS_DEFAULT = ( 32 noise_model_constants.OUTLIER_MEDIAN_ABS_DEVS_DEFAULT 33) 34 35 36def _check_auto_exposure_targets( 37 auto_exposure_ns: float, 38 sens_min: int, 39 sens_max: int, 40 bracket_factor: int, 41 min_exposure_ns: int, 42 max_exposure_ns: int, 43) -> None: 44 """Checks if AE too bright for highest gain & too dark for lowest gain. 45 46 Args: 47 auto_exposure_ns: The auto exposure value in nanoseconds. 48 sens_min: The minimum sensitivity value. 49 sens_max: The maximum sensitivity value. 50 bracket_factor: Exposure bracket factor. 51 min_exposure_ns: The minimum exposure time in nanoseconds. 52 max_exposure_ns: The maximum exposure time in nanoseconds. 53 """ 54 55 if auto_exposure_ns < min_exposure_ns * sens_max: 56 raise AssertionError( 57 'Scene is too bright to properly expose at highest ' 58 f'sensitivity: {sens_max}' 59 ) 60 if auto_exposure_ns * bracket_factor > max_exposure_ns * sens_min: 61 raise AssertionError( 62 'Scene is too dark to properly expose at lowest ' 63 f'sensitivity: {sens_min}' 64 ) 65 66 67def check_noise_model_shape(noise_model: np.ndarray) -> None: 68 """Checks if the shape of noise model is valid. 69 70 Args: 71 noise_model: A numpy array of shape (num_channels, num_parameters). 72 """ 73 num_channels, num_parameters = noise_model.shape 74 if num_channels not in noise_model_constants.VALID_NUM_CHANNELS: 75 raise AssertionError( 76 f'The number of channels {num_channels} is not in' 77 f' {noise_model_constants.VALID_NUM_CHANNELS}.' 78 ) 79 if num_parameters != 4: 80 raise AssertionError( 81 f'The number of parameters of each channel {num_parameters} != 4.' 82 ) 83 84 85def validate_noise_model( 86 noise_model: np.ndarray, 87 color_channels: List[str], 88 sens_min: int, 89) -> None: 90 """Performs validation checks on the noise model. 91 92 This function checks if read noise and intercept gradient are positive for 93 each color channel. 94 95 Args: 96 noise_model: Noise model parameters each channel, including scale_a, 97 scale_b, offset_a, offset_b. 98 color_channels: Array of color channels. 99 sens_min: Minimum sensitivity value. 100 """ 101 check_noise_model_shape(noise_model) 102 num_channels = noise_model.shape[0] 103 if len(color_channels) != num_channels: 104 raise AssertionError( 105 f'Number of color channels {num_channels} != number of noise model ' 106 f'channels {len(color_channels)}.' 107 ) 108 109 scale_a, _, offset_a, offset_b = zip(*noise_model) 110 for i, color_channel in enumerate(color_channels): 111 if scale_a[i] < 0: 112 raise AssertionError( 113 f'{color_channel} model API scale gradient < 0: {scale_a[i]:.4e}' 114 ) 115 116 if offset_a[i] <= 0: 117 raise AssertionError( 118 f'{color_channel} model API intercept gradient < 0: {offset_a[i]:.4e}' 119 ) 120 121 read_noise = offset_a[i] * sens_min * sens_min + offset_b[i] 122 if read_noise <= 0: 123 raise AssertionError( 124 f'{color_channel} model min ISO noise < 0! ' 125 f'API intercept gradient: {offset_a[i]:.4e}, ' 126 f'API intercept offset: {offset_b[i]:.4e}, ' 127 f'read_noise: {read_noise:.4e}' 128 ) 129 130 131def compute_digital_gains( 132 gains: np.ndarray, 133 sens_max_analog: np.ndarray, 134) -> np.ndarray: 135 """Computes the digital gains for the given gains and maximum analog gain. 136 137 Define digital gain as the gain divide the max analog gain sensitivity. 138 This function ensures that the digital gains are always equal to 1. If any 139 of the digital gains is not equal to 1, an AssertionError is raised. 140 141 Args: 142 gains: An array of gains. 143 sens_max_analog: The maximum analog gain sensitivity. 144 145 Returns: 146 An numpy array of digital gains. 147 """ 148 digital_gains = np.maximum(gains / sens_max_analog, 1) 149 if not np.all(digital_gains == 1): 150 raise AssertionError( 151 f'Digital gains are not all 1! gains: {gains}, ' 152 f'Max analog gain sensitivity: {sens_max_analog}.' 153 ) 154 return digital_gains 155 156 157def crop_and_save_capture( 158 cap, 159 props, 160 capture_path: str, 161 num_tiles_crop: int, 162) -> None: 163 """Crops and saves a capture image. 164 165 Args: 166 cap: The capture to be cropped and saved. 167 props: The properties to be used to convert the capture to an RGB image. 168 capture_path: The path to which the capture image should be saved. 169 num_tiles_crop: The number of tiles to crop. 170 """ 171 img = image_processing_utils.convert_capture_to_rgb_image(cap, props=props) 172 height, width, _ = img.shape 173 num_tiles_crop_max = min(height, width) // 2 174 if num_tiles_crop >= num_tiles_crop_max: 175 raise AssertionError( 176 f'Number of tiles to corp {num_tiles_crop} >= {num_tiles_crop_max}.' 177 ) 178 img = img[ 179 num_tiles_crop: height - num_tiles_crop, 180 num_tiles_crop: width - num_tiles_crop, 181 :, 182 ] 183 184 image_processing_utils.write_image(img, capture_path, True) 185 186 187def crop_and_reorder_stats_images( 188 mean_img: np.ndarray, 189 var_img: np.ndarray, 190 num_tiles_crop: int, 191 channel_indices: List[int], 192) -> Tuple[np.ndarray, np.ndarray]: 193 """Crops the stats images and sorts stats images channels in canonical order. 194 195 Args: 196 mean_img: The mean image. 197 var_img: The variance image. 198 num_tiles_crop: The number of tiles to crop from each side of the image. 199 channel_indices: The channel indices to sort stats image channels in 200 canonical order. 201 202 Returns: 203 The cropped and reordered mean image and variance image. 204 """ 205 if mean_img.shape != var_img.shape: 206 raise AssertionError( 207 'Unmatched shapes of mean and variance image: ' 208 f'shape of mean image is {mean_img.shape}, ' 209 f'shape of variance image is {var_img.shape}.' 210 ) 211 height, width, _ = mean_img.shape 212 if 2 * num_tiles_crop > min(height, width): 213 raise AssertionError( 214 f'The number of tiles to crop ({num_tiles_crop}) is so large that' 215 ' images cannot be cropped.' 216 ) 217 218 means = [] 219 vars_ = [] 220 for i in channel_indices: 221 means_i = mean_img[ 222 num_tiles_crop: height - num_tiles_crop, 223 num_tiles_crop: width - num_tiles_crop, 224 i, 225 ] 226 vars_i = var_img[ 227 num_tiles_crop: height - num_tiles_crop, 228 num_tiles_crop: width - num_tiles_crop, 229 i, 230 ] 231 means.append(means_i) 232 vars_.append(vars_i) 233 means, vars_ = np.asarray(means), np.asarray(vars_) 234 return means, vars_ 235 236 237def filter_stats( 238 means: np.ndarray, 239 vars_: np.ndarray, 240 black_levels: List[float], 241 white_level: float, 242 max_signal_value: float = 0.25, 243 is_remove_var_outliers: bool = False, 244 deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT, 245) -> Tuple[np.ndarray, np.ndarray]: 246 """Filters means outliers and variance outliers. 247 248 Args: 249 means: A numpy ndarray of pixel mean values. 250 vars_: A numpy ndarray of pixel variance values. 251 black_levels: A list of black levels for each pixel. 252 white_level: A scalar white level. 253 max_signal_value: The maximum signal (mean) value. 254 is_remove_var_outliers: A boolean value indicating whether to remove 255 variance outliers. 256 deviations: A scalar value specifying the number of standard deviations to 257 use when removing variance outliers. 258 259 Returns: 260 A tuple of (means_filtered, vars_filtered) where means_filtered and 261 vars_filtered are numpy ndarrays of filtered pixel mean and variance 262 values, respectively. 263 """ 264 if means.shape != vars_.shape: 265 raise AssertionError( 266 f'Unmatched shapes of means and vars: means.shape={means.shape},' 267 f' vars.shape={vars_.shape}.' 268 ) 269 num_planes = len(means) 270 means_filtered = [] 271 vars_filtered = [] 272 273 for pidx in range(num_planes): 274 black_level = black_levels[pidx] 275 means_i = means[pidx] 276 vars_i = vars_[pidx] 277 278 # Basic constraints: 279 # (1) means are within the range [0, 1], 280 # (2) vars are non-negative values. 281 constraints = [ 282 means_i >= black_level, 283 means_i <= white_level, 284 vars_i >= 0, 285 ] 286 if is_remove_var_outliers: 287 # Filter out variances that differ too much from the median of variances. 288 std_dev = scipy.stats.median_abs_deviation(vars_i, axis=None, scale=1) 289 med = np.median(vars_i) 290 constraints.extend([ 291 vars_i > med - deviations * std_dev, 292 vars_i < med + deviations * std_dev, 293 ]) 294 295 keep_indices = np.where(np.logical_and.reduce(constraints)) 296 if not np.any(keep_indices): 297 logging.info('After filter channel %d, stats array is empty.', pidx) 298 299 # Normalizes the range to [0, 1]. 300 means_i = (means_i[keep_indices] - black_level) / ( 301 white_level - black_level 302 ) 303 vars_i = vars_i[keep_indices] / ((white_level - black_level) ** 2) 304 # Filter out the tiles if they have samples that might be clipped. 305 mean_var_pairs = list( 306 filter( 307 lambda x: x[0] + 2 * math.sqrt(x[1]) < max_signal_value, 308 zip(means_i, vars_i), 309 ) 310 ) 311 if mean_var_pairs: 312 means_i, vars_i = zip(*mean_var_pairs) 313 else: 314 means_i, vars_i = [], [] 315 means_i = np.asarray(means_i) 316 vars_i = np.asarray(vars_i) 317 means_filtered.append(means_i) 318 vars_filtered.append(vars_i) 319 320 # After filtering, means_filtered and vars_filtered may have different shapes 321 # in each color planes. 322 means_filtered = np.asarray(means_filtered, dtype=object) 323 vars_filtered = np.asarray(vars_filtered, dtype=object) 324 return means_filtered, vars_filtered 325 326 327def get_next_iso( 328 iso: float, 329 max_iso: int, 330 iso_multiplier: float, 331) -> float: 332 """Moves to the next sensitivity. 333 334 Args: 335 iso: The current ISO sensitivity. 336 max_iso: The maximum ISO sensitivity. 337 iso_multiplier: The ISO multiplier to use. 338 339 Returns: 340 The next ISO sensitivity. 341 """ 342 if iso_multiplier <= 1: 343 raise AssertionError( 344 f'ISO multiplier is {iso_multiplier}, which should be greater than 1.' 345 ) 346 347 if round(iso) < max_iso < round(iso * iso_multiplier): 348 return max_iso 349 else: 350 return iso * iso_multiplier 351 352 353def capture_stats_images( 354 cam, 355 props, 356 stats_config: Dict[str, Any], 357 sens_min: int, 358 sens_max_meas: int, 359 zoom_ratio: float, 360 num_tiles_crop: int, 361 max_signal_value: float, 362 iso_multiplier: float, 363 max_bracket: int, 364 bracket_factor: int, 365 capture_path_prefix: str, 366 stats_file_name: str = '', 367 is_remove_var_outliers: bool = False, 368 outlier_median_abs_deviations: int = _OUTLIER_MEDIAN_ABS_DEVS_DEFAULT, 369 is_debug_mode: bool = False, 370) -> Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]]: 371 """Capture stats images and saves the stats in a dictionary. 372 373 This function captures stats images at different ISO values and exposure 374 times, and stores the stats data in a file with the specified name. 375 The stats data includes the mean and variance of each plane, as well as 376 exposure times. 377 378 Args: 379 cam: The camera session (its_session_utils.ItsSession) for capturing stats 380 images. 381 props: Camera property object. 382 stats_config: The stats format config, a dictionary that specifies the raw 383 stats image format and tile size. 384 sens_min: The minimum sensitivity. 385 sens_max_meas: The maximum sensitivity to measure. 386 zoom_ratio: The zoom ratio to use. 387 num_tiles_crop: The number of tiles to crop the images into. 388 max_signal_value: The maximum signal value to allow. 389 iso_multiplier: The ISO multiplier to use. 390 max_bracket: The maximum number of bracketed exposures to capture. 391 bracket_factor: The bracket factor with default value 2^max_bracket. 392 capture_path_prefix: The path prefix to use for captured images. 393 stats_file_name: The name of the file to save the stats images to. 394 is_remove_var_outliers: Whether to remove variance outliers. 395 outlier_median_abs_deviations: The number of median absolute deviations to 396 use for detecting outliers. 397 is_debug_mode: Whether to enable debug mode. 398 399 Returns: 400 A dictionary mapping ISO values to mean and variance image of each plane. 401 """ 402 if is_debug_mode: 403 logging.info('Capturing stats images with stats config: %s.', stats_config) 404 capture_folder = os.path.join(capture_path_prefix, 'captures') 405 if not os.path.exists(capture_folder): 406 os.makedirs(capture_folder) 407 logging.info('Capture folder: %s', capture_folder) 408 409 white_level = props['android.sensor.info.whiteLevel'] 410 min_exposure_ns, max_exposure_ns = props[ 411 'android.sensor.info.exposureTimeRange' 412 ] 413 # Focus at zero to intentionally blur the scene as much as possible. 414 f_dist = 0.0 415 # Whether the stats images are quad Bayer or standard Bayer. 416 is_quad_bayer = 'QuadBayer' in stats_config['format'] 417 if is_quad_bayer: 418 num_channels = noise_model_constants.NUM_QUAD_BAYER_CHANNELS 419 else: 420 num_channels = noise_model_constants.NUM_BAYER_CHANNELS 421 # A dict maps iso to stats images of different exposure times. 422 iso_to_stats_dict = collections.defaultdict(list) 423 # Start the sensitivity at the minimum. 424 iso = sens_min 425 # Previous iso cap. 426 pre_iso_cap = None 427 if stats_file_name: 428 stats_file_path = os.path.join(capture_path_prefix, stats_file_name) 429 if os.path.isfile(stats_file_path): 430 try: 431 with open(stats_file_path, 'rb') as f: 432 saved_iso_to_stats_dict = pickle.load(f) 433 # Filter saved stats data. 434 if saved_iso_to_stats_dict: 435 for iso, stats in saved_iso_to_stats_dict.items(): 436 if sens_min <= iso <= sens_max_meas: 437 iso_to_stats_dict[iso] = stats 438 439 # Set the starting iso to the last iso in saved stats file. 440 if iso_to_stats_dict.keys(): 441 pre_iso_cap = max(iso_to_stats_dict.keys()) 442 iso = get_next_iso(pre_iso_cap, sens_max_meas, iso_multiplier) 443 except OSError as e: 444 logging.exception( 445 'Failed to load stats file stored at %s. Error message: %s', 446 stats_file_path, 447 e, 448 ) 449 450 if round(iso) <= sens_max_meas: 451 # Wait until camera is repositioned for noise model calibration. 452 input( 453 f'\nPress <ENTER> after covering camera lense {cam.get_camera_name()} ' 454 'with frosted glass diffuser, and facing lense at evenly illuminated' 455 ' surface.\n' 456 ) 457 # Do AE to get a rough idea of where we are. 458 iso_ae, exp_ae, _, _, _ = cam.do_3a( 459 get_results=True, do_awb=False, do_af=False 460 ) 461 462 # Underexpose to get more data for low signal levels. 463 auto_exposure_ns = iso_ae * exp_ae / bracket_factor 464 _check_auto_exposure_targets( 465 auto_exposure_ns, 466 sens_min, 467 sens_max_meas, 468 bracket_factor, 469 min_exposure_ns, 470 max_exposure_ns, 471 ) 472 473 while round(iso) <= sens_max_meas: 474 req = capture_request_utils.manual_capture_request( 475 round(iso), min_exposure_ns, f_dist 476 ) 477 cap = cam.do_capture(req, stats_config) 478 # Instead of raising an error when the sensitivity readback != requested 479 # use the readback value for calculations instead. 480 iso_cap = cap['metadata']['android.sensor.sensitivity'] 481 482 # Different iso values may result in captures with the same iso_cap 483 # value, so skip this capture if it's redundant. 484 if iso_cap == pre_iso_cap: 485 logging.info( 486 'Skip current capture because of the same iso %d with the previous' 487 ' capture.', 488 iso_cap, 489 ) 490 iso = get_next_iso(iso, sens_max_meas, iso_multiplier) 491 continue 492 pre_iso_cap = iso_cap 493 494 logging.info('Request ISO: %d, Capture ISO: %d.', iso, iso_cap) 495 496 for bracket in range(max_bracket): 497 # Get the exposure for this sensitivity and exposure time. 498 exposure_ns = round(math.pow(2, bracket) * auto_exposure_ns / iso) 499 exposure_ms = round(exposure_ns * 1.0e-6, 3) 500 logging.info('ISO: %d, exposure time: %.3f ms.', iso_cap, exposure_ms) 501 req = capture_request_utils.manual_capture_request( 502 iso_cap, 503 exposure_ns, 504 f_dist, 505 ) 506 req['android.control.zoomRatio'] = zoom_ratio 507 cap = cam.do_capture(req, stats_config) 508 509 if is_debug_mode: 510 capture_path = os.path.join( 511 capture_folder, f'iso{iso_cap}_exposure{exposure_ns}ns.jpg' 512 ) 513 crop_and_save_capture(cap, props, capture_path, num_tiles_crop) 514 515 mean_img, var_img = image_processing_utils.unpack_rawstats_capture( 516 cap, num_channels=num_channels 517 ) 518 cfa_order = image_processing_utils.get_canonical_cfa_order( 519 props, is_quad_bayer 520 ) 521 522 means, vars_ = crop_and_reorder_stats_images( 523 mean_img, 524 var_img, 525 num_tiles_crop, 526 cfa_order, 527 ) 528 if is_debug_mode: 529 logging.info('Raw stats image size: %s', mean_img.shape) 530 logging.info('R plane means image size: %s', means[0].shape) 531 logging.info( 532 'means min: %.3f, median: %.3f, max: %.3f', 533 np.min(means), np.median(means), np.max(means), 534 ) 535 logging.info( 536 'vars_ min: %.4f, median: %.4f, max: %.4f', 537 np.min(vars_), np.median(vars_), np.max(vars_), 538 ) 539 540 black_levels = image_processing_utils.get_black_levels( 541 props, 542 cap['metadata'], 543 is_quad_bayer, 544 ) 545 546 means, vars_ = filter_stats( 547 means, 548 vars_, 549 black_levels, 550 white_level, 551 max_signal_value, 552 is_remove_var_outliers, 553 outlier_median_abs_deviations, 554 ) 555 556 iso_to_stats_dict[iso_cap].append((exposure_ms, means, vars_)) 557 558 if stats_file_name: 559 with open(stats_file_path, 'wb+') as f: 560 pickle.dump(iso_to_stats_dict, f) 561 iso = get_next_iso(iso, sens_max_meas, iso_multiplier) 562 563 return iso_to_stats_dict 564 565 566def measure_linear_noise_models( 567 iso_to_stats_dict: Dict[int, List[Tuple[float, np.ndarray, np.ndarray]]], 568 color_planes: List[str], 569): 570 """Measures linear noise models. 571 572 This function measures linear noise models from means and variances for each 573 color plane and ISO setting. 574 575 Args: 576 iso_to_stats_dict: A dictionary mapping ISO settings to a list of stats 577 data. 578 color_planes: A list of color planes. 579 580 Returns: 581 A tuple containing: 582 measured_models: A list of linear models, one for each color plane. 583 samples: A list of samples, one for each color plane. Each sample is a 584 tuple of (iso, mean, var). 585 """ 586 num_planes = len(color_planes) 587 # Model parameters for each color plane. 588 measured_models = [[] for _ in range(num_planes)] 589 # Samples (ISO, mean and var) of each quad Bayer color channels. 590 samples = [[] for _ in range(num_planes)] 591 592 for iso in sorted(iso_to_stats_dict.keys()): 593 logging.info('Calculating measured models for ISO %d.', iso) 594 stats_per_plane = [[] for _ in range(num_planes)] 595 for _, means, vars_ in iso_to_stats_dict[iso]: 596 for pidx in range(num_planes): 597 means_p = means[pidx] 598 vars_p = vars_[pidx] 599 if means_p.size > 0 and vars_p.size > 0: 600 stats_per_plane[pidx].extend(list(zip(means_p, vars_p))) 601 602 for pidx, mean_var_pairs in enumerate(stats_per_plane): 603 if not mean_var_pairs: 604 raise ValueError( 605 f'For ISO {iso}, samples are empty in color plane' 606 f' {color_planes[pidx]}.' 607 ) 608 slope, intercept, rvalue, _, _ = scipy.stats.linregress(mean_var_pairs) 609 610 measured_models[pidx].append((iso, slope, intercept)) 611 logging.info( 612 ( 613 'Measured model for ISO %d and color plane %s: ' 614 'y = %e * x + %e (R=%.6f).' 615 ), 616 iso, color_planes[pidx], slope, intercept, rvalue, 617 ) 618 619 # Add the samples for this sensitivity to the global samples list. 620 samples[pidx].extend([(iso, mean, var) for (mean, var) in mean_var_pairs]) 621 622 return measured_models, samples 623 624 625def compute_noise_model( 626 samples: List[List[Tuple[float, np.ndarray, np.ndarray]]], 627 sens_max_analog: int, 628 offset_a: np.ndarray, 629 offset_b: np.ndarray, 630 is_two_stage_model: bool = False, 631) -> np.ndarray: 632 """Computes noise model parameters from samples. 633 634 The noise model is defined by the following equation: 635 f(x) = scale * x + offset 636 637 where we have: 638 scale = scale_a * analog_gain * digital_gain + scale_b, 639 offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2. 640 scale is the multiplicative factor and offset is the offset term. 641 642 Assume digital_gain is 1.0 and scale_a, scale_b, offset_a, offset_b are 643 sa, sb, oa, ob respectively, so we have noise model function: 644 f(x) = (sa * analog_gain + sb) * x + (oa * analog_gain^2 + ob). 645 646 The noise model is fit to the mesuared data using the scipy.optimize 647 function, which uses an iterative Levenberg-Marquardt algorithm to 648 find the model parameters that minimize the mean squared error. 649 650 Args: 651 samples: A list of samples, each of which is a list of tuples of `(gains, 652 means, vars_)`. 653 sens_max_analog: The maximum analog gain. 654 offset_a: The gradient coefficients from the read noise calibration. 655 offset_b: The intercept coefficients from the read noise calibration. 656 is_two_stage_model: A boolean flag indicating if the noise model is 657 calibrated in the two-stage mode. 658 659 Returns: 660 A numpy array containing noise model parameters (scale_a, scale_b, 661 offset_a, offset_b) of each channel. 662 """ 663 noise_model = [] 664 for pidx, samples_p in enumerate(samples): 665 gains, means, vars_ = zip(*samples_p) 666 gains = np.asarray(gains).flatten() 667 means = np.asarray(means).flatten() 668 vars_ = np.asarray(vars_).flatten() 669 670 compute_digital_gains(gains, sens_max_analog) 671 672 # Use a global linear optimization to fit the noise model. 673 # Noise model function: 674 # f(x) = scale * x + offset 675 # Where: 676 # scale = scale_a * analog_gain * digital_gain + scale_b. 677 # offset = (offset_a * analog_gain^2 + offset_b) * digital_gain^2. 678 # Function f will be used to train the scale and offset coefficients 679 # scale_a, scale_b, offset_a, offset_b. 680 if is_two_stage_model: 681 # For the two-stage model, we want to use the line fit coefficients 682 # found from capturing read noise data (offset_a and offset_b) to 683 # train the scale coefficients. 684 oa, ob = offset_a[pidx], offset_b[pidx] 685 686 # Cannot pass oa and ob as the parameters of f since we only want 687 # curve_fit return 2 parameters. 688 def f(x, sa, sb): 689 scale = sa * x[0] + sb 690 # pylint: disable=cell-var-from-loop 691 offset = oa * x[0] ** 2 + ob 692 return (scale * x[1] + offset) / x[0] 693 694 else: 695 def f(x, sa, sb, oa, ob): 696 scale = sa * x[0] + sb 697 offset = oa * x[0] ** 2 + ob 698 return (scale * x[1] + offset) / x[0] 699 700 # Divide the whole system by gains*means. 701 coeffs, _ = scipy.optimize.curve_fit(f, (gains, means), vars_ / (gains)) 702 703 # If using two-stage model, two of the coefficients calculated above are 704 # constant, so we need to append them to the coeffs ndarray. 705 if is_two_stage_model: 706 coeffs = np.append(coeffs, offset_a[pidx]) 707 coeffs = np.append(coeffs, offset_b[pidx]) 708 709 # coeffs[0:4] = (scale_a, scale_b, offset_a, offset_b). 710 noise_model.append(coeffs[0:4]) 711 712 noise_model = np.asarray(noise_model) 713 check_noise_model_shape(noise_model) 714 return noise_model 715 716 717def create_stats_figure( 718 iso: int, 719 color_channel_names: List[str], 720): 721 """Creates a figure with subplots showing the mean and variance samples. 722 723 Args: 724 iso: The ISO setting for the images. 725 color_channel_names: A list of strings containing the names of the color 726 channels. 727 728 Returns: 729 A tuple of the figure and a list of the subplots. 730 """ 731 if len(color_channel_names) not in noise_model_constants.VALID_NUM_CHANNELS: 732 raise AssertionError( 733 'The number of channels should be in' 734 f' {noise_model_constants.VALID_NUM_CHANNELS}, but found' 735 f' {len(color_channel_names)}. ' 736 ) 737 738 is_quad_bayer = ( 739 len(color_channel_names) == noise_model_constants.NUM_QUAD_BAYER_CHANNELS 740 ) 741 if is_quad_bayer: 742 # Adds a plot of the mean and variance samples for each color plane. 743 fig, axes = plt.subplots(4, 4, figsize=(22, 22)) 744 fig.gca() 745 fig.suptitle('ISO %d' % iso, x=0.52, y=0.99) 746 747 cax = fig.add_axes([0.65, 0.995, 0.33, 0.003]) 748 cax.set_title('log(exposure_ms):', x=-0.13, y=-2.0) 749 fig.colorbar( 750 noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal' 751 ) 752 753 # Add a big axis, hide frame. 754 fig.add_subplot(111, frameon=False) 755 756 # Add a common x-axis and y-axis. 757 plt.tick_params( 758 labelcolor='none', 759 which='both', 760 top=False, 761 bottom=False, 762 left=False, 763 right=False, 764 ) 765 plt.xlabel('Mean signal level', ha='center') 766 plt.ylabel('Variance', va='center', rotation='vertical') 767 768 subplots = [] 769 for pidx in range(noise_model_constants.NUM_QUAD_BAYER_CHANNELS): 770 subplot = axes[pidx // 4, pidx % 4] 771 subplot.set_title(color_channel_names[pidx]) 772 # Set 'y' axis to scientific notation for all numbers by setting 773 # scilimits to (0, 0). 774 subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) 775 subplots.append(subplot) 776 777 else: 778 # Adds a plot of the mean and variance samples for each color plane. 779 fig, [[plt_r, plt_gr], [plt_gb, plt_b]] = plt.subplots( 780 2, 2, figsize=(11, 11) 781 ) 782 fig.gca() 783 # Add color bar to show exposure times. 784 cax = fig.add_axes([0.73, 0.99, 0.25, 0.01]) 785 cax.set_title('log(exposure_ms):', x=-0.3, y=-1.0) 786 fig.colorbar( 787 noise_model_constants.COLOR_BAR, cax=cax, orientation='horizontal' 788 ) 789 790 subplots = [plt_r, plt_gr, plt_gb, plt_b] 791 fig.suptitle('ISO %d' % iso, x=0.54, y=0.99) 792 for pidx, subplot in enumerate(subplots): 793 subplot.set_title(color_channel_names[pidx]) 794 subplot.set_xlabel('Mean signal level') 795 subplot.set_ylabel('Variance') 796 subplot.ticklabel_format(axis='y', style='sci', scilimits=(0, 0)) 797 798 with warnings.catch_warnings(): 799 warnings.simplefilter('ignore', UserWarning) 800 plt.tight_layout() 801 802 return fig, subplots 803