1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.media.videoquality.bdrate; 18 19 import com.google.auto.value.AutoValue; 20 import com.google.common.annotations.VisibleForTesting; 21 22 import org.apache.commons.math3.stat.descriptive.moment.Mean; 23 24 import java.util.ArrayList; 25 import java.util.Iterator; 26 import java.util.LinkedList; 27 28 /** Pair of two {@link RateDistortionCurve}s used for calculating a Bjontegaard-Delta (BD) value. */ 29 @AutoValue 30 public abstract class RateDistortionCurvePair { 31 32 private static final Mean MEAN = new Mean(); 33 baseline()34 public abstract RateDistortionCurve baseline(); 35 target()36 public abstract RateDistortionCurve target(); 37 38 /** 39 * Creates a new {@link RateDistortionCurvePair} by first clustering the points to eliminate 40 * noise in the data, then validating that the remaining points are sufficient for BD 41 * calculations. 42 */ createClusteredPair( RateDistortionCurve baseline, RateDistortionCurve target)43 public static RateDistortionCurvePair createClusteredPair( 44 RateDistortionCurve baseline, RateDistortionCurve target) { 45 RateDistortionCurve clusteredBaseline = cluster(baseline); 46 RateDistortionCurve clusteredTarget = cluster(target); 47 48 // Check for correct number of points. 49 if (clusteredBaseline.points().size() < 5) { 50 throw new BdPreconditionFailedException( 51 "The reference curve does not have enough points.", /* isTargetCurve= */ false); 52 } 53 if (clusteredTarget.points().size() < 5) { 54 throw new BdPreconditionFailedException( 55 "The target curve does not have enough points.", /* isTargetCurve= */ true); 56 } 57 58 // Check for monotonicity. 59 if (!isMonotonicallyIncreasing(clusteredBaseline)) { 60 throw new BdPreconditionFailedException( 61 "The reference curve is not monotonically increasing.", 62 /* isTargetCurve= */ false); 63 } 64 if (!isMonotonicallyIncreasing(clusteredTarget)) { 65 throw new BdPreconditionFailedException( 66 "The is not monotonically increasing.", /* isTargetCurve= */ true); 67 } 68 69 return new AutoValue_RateDistortionCurvePair(clusteredBaseline, clusteredTarget); 70 } 71 72 /** To calculate BD-RATE, the two rate-distortion curves must overlap in terms of distortion. */ canCalculateBdRate()73 public boolean canCalculateBdRate() { 74 return !(baseline().getMaxDistortion() < target().getMinDistortion()) 75 && !(target().getMaxDistortion() < baseline().getMinDistortion()); 76 } 77 78 /** To calculate BD-QUALITY, the two rate-distortion curves must overlap in terms of bitrate. */ canCalculateBdQuality()79 public boolean canCalculateBdQuality() { 80 return !(baseline().getMaxLog10Bitrate() < target().getMinLog10Bitrate()) 81 && !(target().getMaxLog10Bitrate() < baseline().getMinLog10Bitrate()); 82 } 83 84 /** 85 * Clusters provided rate-distortion points together to reduce noise when the points are close 86 * together in terms of bitrate. 87 * 88 * <p>"Clusters" are points that have a bitrate that is within 1% of the previous 89 * rate-distortion point. Such points are bucketed and then averaged to provide a single point 90 * in the same range as the cluster. 91 */ 92 @VisibleForTesting cluster(RateDistortionCurve baseCurve)93 static RateDistortionCurve cluster(RateDistortionCurve baseCurve) { 94 if (baseCurve.points().size() < 3) { 95 return baseCurve; 96 } 97 98 RateDistortionCurve.Builder newCurve = RateDistortionCurve.builder(); 99 100 LinkedList<ArrayList<RateDistortionPoint>> buckets = new LinkedList<>(); 101 102 // Bucket the items, moving through the points pairwise. 103 buckets.add(new ArrayList<>()); 104 buckets.peekLast().add(baseCurve.points().first()); 105 106 Iterator<RateDistortionPoint> pointIterator = baseCurve.points().iterator(); 107 RateDistortionPoint lastPoint = pointIterator.next(); 108 RateDistortionPoint currentPoint; 109 110 double maxObservedDistortion = lastPoint.distortion(); 111 while (pointIterator.hasNext()) { 112 currentPoint = pointIterator.next(); 113 114 // Cluster points that are within 10% (bitrate) of each other that would make the curve 115 // non-monotonic. 116 if (currentPoint.rate() / lastPoint.rate() > 1.1 117 || currentPoint.distortion() > maxObservedDistortion) { 118 buckets.add(new ArrayList<>()); 119 maxObservedDistortion = currentPoint.distortion(); 120 } 121 buckets.peekLast().add(currentPoint); 122 lastPoint = currentPoint; 123 } 124 125 for (ArrayList<RateDistortionPoint> bucket : buckets) { 126 if (bucket.size() < 2) { 127 newCurve.addPoint(bucket.get(0)); 128 } 129 130 // For a bucket with multiple points, the new point is the average 131 // between all other points. 132 newCurve.addPoint( 133 RateDistortionPoint.create( 134 MEAN.evaluate(bucket.stream().mapToDouble(p -> p.rate()).toArray()), 135 MEAN.evaluate( 136 bucket.stream().mapToDouble(p -> p.distortion()).toArray()))); 137 } 138 139 return newCurve.build(); 140 } 141 142 /** 143 * Returns whether a {@link RateDistortionCurve} is monotonically increasing which is required 144 * for the Cubic Spline interpolation performed during BD rate calculation. 145 */ isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve)146 private static boolean isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve) { 147 Iterator<RateDistortionPoint> pointIterator = rateDistortionCurve.points().iterator(); 148 149 RateDistortionPoint lastPoint = pointIterator.next(); 150 RateDistortionPoint currentPoint; 151 while (pointIterator.hasNext()) { 152 currentPoint = pointIterator.next(); 153 if (currentPoint.distortion() <= lastPoint.distortion()) { 154 return false; 155 } 156 lastPoint = currentPoint; 157 } 158 159 return true; 160 } 161 } 162