1*77c1e3ccSAndroid Build Coastguard Worker#!/usr/bin/env python3 2*77c1e3ccSAndroid Build Coastguard Worker## 3*77c1e3ccSAndroid Build Coastguard Worker## Copyright (c) 2016, Alliance for Open Media. All rights reserved. 4*77c1e3ccSAndroid Build Coastguard Worker## 5*77c1e3ccSAndroid Build Coastguard Worker## This source code is subject to the terms of the BSD 2 Clause License and 6*77c1e3ccSAndroid Build Coastguard Worker## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 7*77c1e3ccSAndroid Build Coastguard Worker## was not distributed with this source code in the LICENSE file, you can 8*77c1e3ccSAndroid Build Coastguard Worker## obtain it at www.aomedia.org/license/software. If the Alliance for Open 9*77c1e3ccSAndroid Build Coastguard Worker## Media Patent License 1.0 was not distributed with this source code in the 10*77c1e3ccSAndroid Build Coastguard Worker## PATENTS file, you can obtain it at www.aomedia.org/license/patent. 11*77c1e3ccSAndroid Build Coastguard Worker## 12*77c1e3ccSAndroid Build Coastguard Worker"""Generate the probability model for the constrained token set. 13*77c1e3ccSAndroid Build Coastguard Worker 14*77c1e3ccSAndroid Build Coastguard WorkerModel obtained from a 2-sided zero-centered distribution derived 15*77c1e3ccSAndroid Build Coastguard Workerfrom a Pareto distribution. The cdf of the distribution is: 16*77c1e3ccSAndroid Build Coastguard Workercdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta] 17*77c1e3ccSAndroid Build Coastguard Worker 18*77c1e3ccSAndroid Build Coastguard WorkerFor a given beta and a given probability of the 1-node, the alpha 19*77c1e3ccSAndroid Build Coastguard Workeris first solved, and then the {alpha, beta} pair is used to generate 20*77c1e3ccSAndroid Build Coastguard Workerthe probabilities for the rest of the nodes. 21*77c1e3ccSAndroid Build Coastguard Worker""" 22*77c1e3ccSAndroid Build Coastguard Worker 23*77c1e3ccSAndroid Build Coastguard Workerimport heapq 24*77c1e3ccSAndroid Build Coastguard Workerimport sys 25*77c1e3ccSAndroid Build Coastguard Workerimport numpy as np 26*77c1e3ccSAndroid Build Coastguard Workerimport scipy.optimize 27*77c1e3ccSAndroid Build Coastguard Workerimport scipy.stats 28*77c1e3ccSAndroid Build Coastguard Worker 29*77c1e3ccSAndroid Build Coastguard Worker 30*77c1e3ccSAndroid Build Coastguard Workerdef cdf_spareto(x, xm, beta): 31*77c1e3ccSAndroid Build Coastguard Worker p = 1 - (xm / (np.abs(x) + xm))**beta 32*77c1e3ccSAndroid Build Coastguard Worker p = 0.5 + 0.5 * np.sign(x) * p 33*77c1e3ccSAndroid Build Coastguard Worker return p 34*77c1e3ccSAndroid Build Coastguard Worker 35*77c1e3ccSAndroid Build Coastguard Worker 36*77c1e3ccSAndroid Build Coastguard Workerdef get_spareto(p, beta): 37*77c1e3ccSAndroid Build Coastguard Worker cdf = cdf_spareto 38*77c1e3ccSAndroid Build Coastguard Worker 39*77c1e3ccSAndroid Build Coastguard Worker def func(x): 40*77c1e3ccSAndroid Build Coastguard Worker return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) / 41*77c1e3ccSAndroid Build Coastguard Worker (1 - cdf(0.5, x, beta)) - p)**2 42*77c1e3ccSAndroid Build Coastguard Worker 43*77c1e3ccSAndroid Build Coastguard Worker alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12) 44*77c1e3ccSAndroid Build Coastguard Worker parray = np.zeros(11) 45*77c1e3ccSAndroid Build Coastguard Worker parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5) 46*77c1e3ccSAndroid Build Coastguard Worker parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta))) 47*77c1e3ccSAndroid Build Coastguard Worker parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta))) 48*77c1e3ccSAndroid Build Coastguard Worker parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta))) 49*77c1e3ccSAndroid Build Coastguard Worker parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta))) 50*77c1e3ccSAndroid Build Coastguard Worker parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta))) 51*77c1e3ccSAndroid Build Coastguard Worker parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta))) 52*77c1e3ccSAndroid Build Coastguard Worker parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta))) 53*77c1e3ccSAndroid Build Coastguard Worker parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta))) 54*77c1e3ccSAndroid Build Coastguard Worker parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta))) 55*77c1e3ccSAndroid Build Coastguard Worker parray[10] = 2 * (1. - cdf(66.5, alpha, beta)) 56*77c1e3ccSAndroid Build Coastguard Worker return parray 57*77c1e3ccSAndroid Build Coastguard Worker 58*77c1e3ccSAndroid Build Coastguard Worker 59*77c1e3ccSAndroid Build Coastguard Workerdef quantize_probs(p, save_first_bin, bits): 60*77c1e3ccSAndroid Build Coastguard Worker """Quantize probability precisely. 61*77c1e3ccSAndroid Build Coastguard Worker 62*77c1e3ccSAndroid Build Coastguard Worker Quantize probabilities minimizing dH (Kullback-Leibler divergence) 63*77c1e3ccSAndroid Build Coastguard Worker approximated by: sum (p_i-q_i)^2/p_i. 64*77c1e3ccSAndroid Build Coastguard Worker References: 65*77c1e3ccSAndroid Build Coastguard Worker https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 66*77c1e3ccSAndroid Build Coastguard Worker https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit 67*77c1e3ccSAndroid Build Coastguard Worker """ 68*77c1e3ccSAndroid Build Coastguard Worker num_sym = p.size 69*77c1e3ccSAndroid Build Coastguard Worker p = np.clip(p, 1e-16, 1) 70*77c1e3ccSAndroid Build Coastguard Worker L = 2**bits 71*77c1e3ccSAndroid Build Coastguard Worker pL = p * L 72*77c1e3ccSAndroid Build Coastguard Worker ip = 1. / p # inverse probability 73*77c1e3ccSAndroid Build Coastguard Worker q = np.clip(np.round(pL), 1, L + 1 - num_sym) 74*77c1e3ccSAndroid Build Coastguard Worker quant_err = (pL - q)**2 * ip 75*77c1e3ccSAndroid Build Coastguard Worker sgn = np.sign(L - q.sum()) # direction of correction 76*77c1e3ccSAndroid Build Coastguard Worker if sgn != 0: # correction is needed 77*77c1e3ccSAndroid Build Coastguard Worker v = [] # heap of adjustment results (adjustment err, index) of each symbol 78*77c1e3ccSAndroid Build Coastguard Worker for i in range(1 if save_first_bin else 0, num_sym): 79*77c1e3ccSAndroid Build Coastguard Worker q_adj = q[i] + sgn 80*77c1e3ccSAndroid Build Coastguard Worker if q_adj > 0 and q_adj < L: 81*77c1e3ccSAndroid Build Coastguard Worker adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] 82*77c1e3ccSAndroid Build Coastguard Worker heapq.heappush(v, (adj_err, i)) 83*77c1e3ccSAndroid Build Coastguard Worker while q.sum() != L: 84*77c1e3ccSAndroid Build Coastguard Worker # apply lowest error adjustment 85*77c1e3ccSAndroid Build Coastguard Worker (adj_err, i) = heapq.heappop(v) 86*77c1e3ccSAndroid Build Coastguard Worker quant_err[i] += adj_err 87*77c1e3ccSAndroid Build Coastguard Worker q[i] += sgn 88*77c1e3ccSAndroid Build Coastguard Worker # calculate the cost of adjusting this symbol again 89*77c1e3ccSAndroid Build Coastguard Worker q_adj = q[i] + sgn 90*77c1e3ccSAndroid Build Coastguard Worker if q_adj > 0 and q_adj < L: 91*77c1e3ccSAndroid Build Coastguard Worker adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] 92*77c1e3ccSAndroid Build Coastguard Worker heapq.heappush(v, (adj_err, i)) 93*77c1e3ccSAndroid Build Coastguard Worker return q 94*77c1e3ccSAndroid Build Coastguard Worker 95*77c1e3ccSAndroid Build Coastguard Worker 96*77c1e3ccSAndroid Build Coastguard Workerdef get_quantized_spareto(p, beta, bits, first_token): 97*77c1e3ccSAndroid Build Coastguard Worker parray = get_spareto(p, beta) 98*77c1e3ccSAndroid Build Coastguard Worker parray = parray[1:] / (1 - parray[0]) 99*77c1e3ccSAndroid Build Coastguard Worker # CONFIG_NEW_TOKENSET 100*77c1e3ccSAndroid Build Coastguard Worker if first_token > 1: 101*77c1e3ccSAndroid Build Coastguard Worker parray = parray[1:] / (1 - parray[0]) 102*77c1e3ccSAndroid Build Coastguard Worker qarray = quantize_probs(parray, first_token == 1, bits) 103*77c1e3ccSAndroid Build Coastguard Worker return qarray.astype(np.int) 104*77c1e3ccSAndroid Build Coastguard Worker 105*77c1e3ccSAndroid Build Coastguard Worker 106*77c1e3ccSAndroid Build Coastguard Workerdef main(bits=15, first_token=1): 107*77c1e3ccSAndroid Build Coastguard Worker beta = 8 108*77c1e3ccSAndroid Build Coastguard Worker for q in range(1, 256): 109*77c1e3ccSAndroid Build Coastguard Worker parray = get_quantized_spareto(q / 256., beta, bits, first_token) 110*77c1e3ccSAndroid Build Coastguard Worker assert parray.sum() == 2**bits 111*77c1e3ccSAndroid Build Coastguard Worker print('{', ', '.join('%d' % i for i in parray), '},') 112*77c1e3ccSAndroid Build Coastguard Worker 113*77c1e3ccSAndroid Build Coastguard Worker 114*77c1e3ccSAndroid Build Coastguard Workerif __name__ == '__main__': 115*77c1e3ccSAndroid Build Coastguard Worker if len(sys.argv) > 2: 116*77c1e3ccSAndroid Build Coastguard Worker main(int(sys.argv[1]), int(sys.argv[2])) 117*77c1e3ccSAndroid Build Coastguard Worker elif len(sys.argv) > 1: 118*77c1e3ccSAndroid Build Coastguard Worker main(int(sys.argv[1])) 119*77c1e3ccSAndroid Build Coastguard Worker else: 120*77c1e3ccSAndroid Build Coastguard Worker main() 121