1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.media.videoquality.bdrate;
18 
19 import com.google.auto.value.AutoValue;
20 import com.google.common.annotations.VisibleForTesting;
21 
22 import org.apache.commons.math3.stat.descriptive.moment.Mean;
23 
24 import java.util.ArrayList;
25 import java.util.Iterator;
26 import java.util.LinkedList;
27 
28 /** Pair of two {@link RateDistortionCurve}s used for calculating a Bjontegaard-Delta (BD) value. */
29 @AutoValue
30 public abstract class RateDistortionCurvePair {
31 
32     private static final Mean MEAN = new Mean();
33 
baseline()34     public abstract RateDistortionCurve baseline();
35 
target()36     public abstract RateDistortionCurve target();
37 
38     /**
39      * Creates a new {@link RateDistortionCurvePair} by first clustering the points to eliminate
40      * noise in the data, then validating that the remaining points are sufficient for BD
41      * calculations.
42      */
createClusteredPair( RateDistortionCurve baseline, RateDistortionCurve target)43     public static RateDistortionCurvePair createClusteredPair(
44             RateDistortionCurve baseline, RateDistortionCurve target) {
45         RateDistortionCurve clusteredBaseline = cluster(baseline);
46         RateDistortionCurve clusteredTarget = cluster(target);
47 
48         // Check for correct number of points.
49         if (clusteredBaseline.points().size() < 5) {
50             throw new BdPreconditionFailedException(
51                     "The reference curve does not have enough points.", /* isTargetCurve= */ false);
52         }
53         if (clusteredTarget.points().size() < 5) {
54             throw new BdPreconditionFailedException(
55                     "The target curve does not have enough points.", /* isTargetCurve= */ true);
56         }
57 
58         // Check for monotonicity.
59         if (!isMonotonicallyIncreasing(clusteredBaseline)) {
60             throw new BdPreconditionFailedException(
61                     "The reference curve is not monotonically increasing.",
62                     /* isTargetCurve= */ false);
63         }
64         if (!isMonotonicallyIncreasing(clusteredTarget)) {
65             throw new BdPreconditionFailedException(
66                     "The is not monotonically increasing.", /* isTargetCurve= */ true);
67         }
68 
69         return new AutoValue_RateDistortionCurvePair(clusteredBaseline, clusteredTarget);
70     }
71 
72     /** To calculate BD-RATE, the two rate-distortion curves must overlap in terms of distortion. */
canCalculateBdRate()73     public boolean canCalculateBdRate() {
74         return !(baseline().getMaxDistortion() < target().getMinDistortion())
75                 && !(target().getMaxDistortion() < baseline().getMinDistortion());
76     }
77 
78     /** To calculate BD-QUALITY, the two rate-distortion curves must overlap in terms of bitrate. */
canCalculateBdQuality()79     public boolean canCalculateBdQuality() {
80         return !(baseline().getMaxLog10Bitrate() < target().getMinLog10Bitrate())
81                 && !(target().getMaxLog10Bitrate() < baseline().getMinLog10Bitrate());
82     }
83 
84     /**
85      * Clusters provided rate-distortion points together to reduce noise when the points are close
86      * together in terms of bitrate.
87      *
88      * <p>"Clusters" are points that have a bitrate that is within 1% of the previous
89      * rate-distortion point. Such points are bucketed and then averaged to provide a single point
90      * in the same range as the cluster.
91      */
92     @VisibleForTesting
cluster(RateDistortionCurve baseCurve)93     static RateDistortionCurve cluster(RateDistortionCurve baseCurve) {
94         if (baseCurve.points().size() < 3) {
95             return baseCurve;
96         }
97 
98         RateDistortionCurve.Builder newCurve = RateDistortionCurve.builder();
99 
100         LinkedList<ArrayList<RateDistortionPoint>> buckets = new LinkedList<>();
101 
102         // Bucket the items, moving through the points pairwise.
103         buckets.add(new ArrayList<>());
104         buckets.peekLast().add(baseCurve.points().first());
105 
106         Iterator<RateDistortionPoint> pointIterator = baseCurve.points().iterator();
107         RateDistortionPoint lastPoint = pointIterator.next();
108         RateDistortionPoint currentPoint;
109 
110         double maxObservedDistortion = lastPoint.distortion();
111         while (pointIterator.hasNext()) {
112             currentPoint = pointIterator.next();
113 
114             // Cluster points that are within 10% (bitrate) of each other that would make the curve
115             // non-monotonic.
116             if (currentPoint.rate() / lastPoint.rate() > 1.1
117                     || currentPoint.distortion() > maxObservedDistortion) {
118                 buckets.add(new ArrayList<>());
119                 maxObservedDistortion = currentPoint.distortion();
120             }
121             buckets.peekLast().add(currentPoint);
122             lastPoint = currentPoint;
123         }
124 
125         for (ArrayList<RateDistortionPoint> bucket : buckets) {
126             if (bucket.size() < 2) {
127                 newCurve.addPoint(bucket.get(0));
128             }
129 
130             // For a bucket with multiple points, the new point is the average
131             // between all other points.
132             newCurve.addPoint(
133                     RateDistortionPoint.create(
134                             MEAN.evaluate(bucket.stream().mapToDouble(p -> p.rate()).toArray()),
135                             MEAN.evaluate(
136                                     bucket.stream().mapToDouble(p -> p.distortion()).toArray())));
137         }
138 
139         return newCurve.build();
140     }
141 
142     /**
143      * Returns whether a {@link RateDistortionCurve} is monotonically increasing which is required
144      * for the Cubic Spline interpolation performed during BD rate calculation.
145      */
isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve)146     private static boolean isMonotonicallyIncreasing(RateDistortionCurve rateDistortionCurve) {
147         Iterator<RateDistortionPoint> pointIterator = rateDistortionCurve.points().iterator();
148 
149         RateDistortionPoint lastPoint = pointIterator.next();
150         RateDistortionPoint currentPoint;
151         while (pointIterator.hasNext()) {
152             currentPoint = pointIterator.next();
153             if (currentPoint.distortion() <= lastPoint.distortion()) {
154                 return false;
155             }
156             lastPoint = currentPoint;
157         }
158 
159         return true;
160     }
161 }
162