xref: /aosp_15_r20/external/libaom/tools/gen_constrained_tokenset.py (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker#!/usr/bin/env python3
2*77c1e3ccSAndroid Build Coastguard Worker##
3*77c1e3ccSAndroid Build Coastguard Worker## Copyright (c) 2016, Alliance for Open Media. All rights reserved.
4*77c1e3ccSAndroid Build Coastguard Worker##
5*77c1e3ccSAndroid Build Coastguard Worker## This source code is subject to the terms of the BSD 2 Clause License and
6*77c1e3ccSAndroid Build Coastguard Worker## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7*77c1e3ccSAndroid Build Coastguard Worker## was not distributed with this source code in the LICENSE file, you can
8*77c1e3ccSAndroid Build Coastguard Worker## obtain it at www.aomedia.org/license/software. If the Alliance for Open
9*77c1e3ccSAndroid Build Coastguard Worker## Media Patent License 1.0 was not distributed with this source code in the
10*77c1e3ccSAndroid Build Coastguard Worker## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11*77c1e3ccSAndroid Build Coastguard Worker##
12*77c1e3ccSAndroid Build Coastguard Worker"""Generate the probability model for the constrained token set.
13*77c1e3ccSAndroid Build Coastguard Worker
14*77c1e3ccSAndroid Build Coastguard WorkerModel obtained from a 2-sided zero-centered distribution derived
15*77c1e3ccSAndroid Build Coastguard Workerfrom a Pareto distribution. The cdf of the distribution is:
16*77c1e3ccSAndroid Build Coastguard Workercdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta]
17*77c1e3ccSAndroid Build Coastguard Worker
18*77c1e3ccSAndroid Build Coastguard WorkerFor a given beta and a given probability of the 1-node, the alpha
19*77c1e3ccSAndroid Build Coastguard Workeris first solved, and then the {alpha, beta} pair is used to generate
20*77c1e3ccSAndroid Build Coastguard Workerthe probabilities for the rest of the nodes.
21*77c1e3ccSAndroid Build Coastguard Worker"""
22*77c1e3ccSAndroid Build Coastguard Worker
23*77c1e3ccSAndroid Build Coastguard Workerimport heapq
24*77c1e3ccSAndroid Build Coastguard Workerimport sys
25*77c1e3ccSAndroid Build Coastguard Workerimport numpy as np
26*77c1e3ccSAndroid Build Coastguard Workerimport scipy.optimize
27*77c1e3ccSAndroid Build Coastguard Workerimport scipy.stats
28*77c1e3ccSAndroid Build Coastguard Worker
29*77c1e3ccSAndroid Build Coastguard Worker
30*77c1e3ccSAndroid Build Coastguard Workerdef cdf_spareto(x, xm, beta):
31*77c1e3ccSAndroid Build Coastguard Worker  p = 1 - (xm / (np.abs(x) + xm))**beta
32*77c1e3ccSAndroid Build Coastguard Worker  p = 0.5 + 0.5 * np.sign(x) * p
33*77c1e3ccSAndroid Build Coastguard Worker  return p
34*77c1e3ccSAndroid Build Coastguard Worker
35*77c1e3ccSAndroid Build Coastguard Worker
36*77c1e3ccSAndroid Build Coastguard Workerdef get_spareto(p, beta):
37*77c1e3ccSAndroid Build Coastguard Worker  cdf = cdf_spareto
38*77c1e3ccSAndroid Build Coastguard Worker
39*77c1e3ccSAndroid Build Coastguard Worker  def func(x):
40*77c1e3ccSAndroid Build Coastguard Worker    return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) /
41*77c1e3ccSAndroid Build Coastguard Worker            (1 - cdf(0.5, x, beta)) - p)**2
42*77c1e3ccSAndroid Build Coastguard Worker
43*77c1e3ccSAndroid Build Coastguard Worker  alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12)
44*77c1e3ccSAndroid Build Coastguard Worker  parray = np.zeros(11)
45*77c1e3ccSAndroid Build Coastguard Worker  parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5)
46*77c1e3ccSAndroid Build Coastguard Worker  parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta)))
47*77c1e3ccSAndroid Build Coastguard Worker  parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta)))
48*77c1e3ccSAndroid Build Coastguard Worker  parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta)))
49*77c1e3ccSAndroid Build Coastguard Worker  parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta)))
50*77c1e3ccSAndroid Build Coastguard Worker  parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta)))
51*77c1e3ccSAndroid Build Coastguard Worker  parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta)))
52*77c1e3ccSAndroid Build Coastguard Worker  parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta)))
53*77c1e3ccSAndroid Build Coastguard Worker  parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta)))
54*77c1e3ccSAndroid Build Coastguard Worker  parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta)))
55*77c1e3ccSAndroid Build Coastguard Worker  parray[10] = 2 * (1. - cdf(66.5, alpha, beta))
56*77c1e3ccSAndroid Build Coastguard Worker  return parray
57*77c1e3ccSAndroid Build Coastguard Worker
58*77c1e3ccSAndroid Build Coastguard Worker
59*77c1e3ccSAndroid Build Coastguard Workerdef quantize_probs(p, save_first_bin, bits):
60*77c1e3ccSAndroid Build Coastguard Worker  """Quantize probability precisely.
61*77c1e3ccSAndroid Build Coastguard Worker
62*77c1e3ccSAndroid Build Coastguard Worker  Quantize probabilities minimizing dH (Kullback-Leibler divergence)
63*77c1e3ccSAndroid Build Coastguard Worker  approximated by: sum (p_i-q_i)^2/p_i.
64*77c1e3ccSAndroid Build Coastguard Worker  References:
65*77c1e3ccSAndroid Build Coastguard Worker  https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
66*77c1e3ccSAndroid Build Coastguard Worker  https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit
67*77c1e3ccSAndroid Build Coastguard Worker  """
68*77c1e3ccSAndroid Build Coastguard Worker  num_sym = p.size
69*77c1e3ccSAndroid Build Coastguard Worker  p = np.clip(p, 1e-16, 1)
70*77c1e3ccSAndroid Build Coastguard Worker  L = 2**bits
71*77c1e3ccSAndroid Build Coastguard Worker  pL = p * L
72*77c1e3ccSAndroid Build Coastguard Worker  ip = 1. / p  # inverse probability
73*77c1e3ccSAndroid Build Coastguard Worker  q = np.clip(np.round(pL), 1, L + 1 - num_sym)
74*77c1e3ccSAndroid Build Coastguard Worker  quant_err = (pL - q)**2 * ip
75*77c1e3ccSAndroid Build Coastguard Worker  sgn = np.sign(L - q.sum())  # direction of correction
76*77c1e3ccSAndroid Build Coastguard Worker  if sgn != 0:  # correction is needed
77*77c1e3ccSAndroid Build Coastguard Worker    v = []  # heap of adjustment results (adjustment err, index) of each symbol
78*77c1e3ccSAndroid Build Coastguard Worker    for i in range(1 if save_first_bin else 0, num_sym):
79*77c1e3ccSAndroid Build Coastguard Worker      q_adj = q[i] + sgn
80*77c1e3ccSAndroid Build Coastguard Worker      if q_adj > 0 and q_adj < L:
81*77c1e3ccSAndroid Build Coastguard Worker        adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
82*77c1e3ccSAndroid Build Coastguard Worker        heapq.heappush(v, (adj_err, i))
83*77c1e3ccSAndroid Build Coastguard Worker    while q.sum() != L:
84*77c1e3ccSAndroid Build Coastguard Worker      # apply lowest error adjustment
85*77c1e3ccSAndroid Build Coastguard Worker      (adj_err, i) = heapq.heappop(v)
86*77c1e3ccSAndroid Build Coastguard Worker      quant_err[i] += adj_err
87*77c1e3ccSAndroid Build Coastguard Worker      q[i] += sgn
88*77c1e3ccSAndroid Build Coastguard Worker      # calculate the cost of adjusting this symbol again
89*77c1e3ccSAndroid Build Coastguard Worker      q_adj = q[i] + sgn
90*77c1e3ccSAndroid Build Coastguard Worker      if q_adj > 0 and q_adj < L:
91*77c1e3ccSAndroid Build Coastguard Worker        adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
92*77c1e3ccSAndroid Build Coastguard Worker        heapq.heappush(v, (adj_err, i))
93*77c1e3ccSAndroid Build Coastguard Worker  return q
94*77c1e3ccSAndroid Build Coastguard Worker
95*77c1e3ccSAndroid Build Coastguard Worker
96*77c1e3ccSAndroid Build Coastguard Workerdef get_quantized_spareto(p, beta, bits, first_token):
97*77c1e3ccSAndroid Build Coastguard Worker  parray = get_spareto(p, beta)
98*77c1e3ccSAndroid Build Coastguard Worker  parray = parray[1:] / (1 - parray[0])
99*77c1e3ccSAndroid Build Coastguard Worker  # CONFIG_NEW_TOKENSET
100*77c1e3ccSAndroid Build Coastguard Worker  if first_token > 1:
101*77c1e3ccSAndroid Build Coastguard Worker    parray = parray[1:] / (1 - parray[0])
102*77c1e3ccSAndroid Build Coastguard Worker  qarray = quantize_probs(parray, first_token == 1, bits)
103*77c1e3ccSAndroid Build Coastguard Worker  return qarray.astype(np.int)
104*77c1e3ccSAndroid Build Coastguard Worker
105*77c1e3ccSAndroid Build Coastguard Worker
106*77c1e3ccSAndroid Build Coastguard Workerdef main(bits=15, first_token=1):
107*77c1e3ccSAndroid Build Coastguard Worker  beta = 8
108*77c1e3ccSAndroid Build Coastguard Worker  for q in range(1, 256):
109*77c1e3ccSAndroid Build Coastguard Worker    parray = get_quantized_spareto(q / 256., beta, bits, first_token)
110*77c1e3ccSAndroid Build Coastguard Worker    assert parray.sum() == 2**bits
111*77c1e3ccSAndroid Build Coastguard Worker    print('{', ', '.join('%d' % i for i in parray), '},')
112*77c1e3ccSAndroid Build Coastguard Worker
113*77c1e3ccSAndroid Build Coastguard Worker
114*77c1e3ccSAndroid Build Coastguard Workerif __name__ == '__main__':
115*77c1e3ccSAndroid Build Coastguard Worker  if len(sys.argv) > 2:
116*77c1e3ccSAndroid Build Coastguard Worker    main(int(sys.argv[1]), int(sys.argv[2]))
117*77c1e3ccSAndroid Build Coastguard Worker  elif len(sys.argv) > 1:
118*77c1e3ccSAndroid Build Coastguard Worker    main(int(sys.argv[1]))
119*77c1e3ccSAndroid Build Coastguard Worker  else:
120*77c1e3ccSAndroid Build Coastguard Worker    main()
121