xref: /aosp_15_r20/external/rappor/analysis/R/decode_ngrams.R (revision 2abb31345f6c95944768b5222a9a5ed3fc68cc00)
1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15#
16# This file has functions that aid in the estimation of a distribution when the
17#     dictionary is unknown. There are functions for estimating pairwise joint
18#     ngram distributions, pruning out false positives, and combining the two
19#     steps.
20
21FindPairwiseCandidates <- function(report_data, N, ngram_params, params) {
22  # Finds the pairwise most likely ngrams.
23  #
24  # Args:
25  #   report_data: Object containing data relevant to reports:
26  #       $inds: The indices of reports collected using various pairs
27  #       $cohorts: The cohort of each report
28  #       $map: The map used for all the ngrams
29  #       $reports: The reports used for each ngram and full string
30  #   N: Number of reports collected
31  #   ngram_params: Parameters related to ngram size
32  #   params: Parameter list.
33  #
34  # Returns:
35  #   List: list of matrices, list of pairwise distributions.
36
37  inds <- report_data$inds
38  cohorts <- report_data$cohorts
39  num_ngrams_collected <- ngram_params$num_ngrams_collected
40  map <- report_data$map
41  reports <- report_data$reports
42
43  # Cycle over all the unique pairs of ngrams being collected
44  found_candidates <- list()
45
46  # Generate the map list to be used for all ngrams
47  maps <- lapply(1:num_ngrams_collected, function(x) map)
48  num_candidate_ngrams <- length(inds)
49
50  .ComputeDist <- function(i, inds, cohorts, reports, maps, params,
51                           num_ngrams_collected) {
52    library(glmnet)
53    ind <- inds[[i]]
54    cohort_subset <- lapply(1:num_ngrams_collected, function(x)
55                            cohorts[ind])
56    report_subset <- reports[[i]]
57    new_dist <- ComputeDistributionEM(report_subset,
58                                      cohort_subset,
59                                      maps, ignore_other = FALSE,
60                                      params = params, estimate_var = FALSE)
61    new_dist
62  }
63
64  # Compute the pairwise distributions (could be parallelized)
65  dists <- lapply(seq(num_candidate_ngrams), function(i)
66                  .ComputeDist(i, inds, cohorts, reports, maps,
67                               params, num_ngrams_collected))
68
69  dists_null <- sapply(dists, function(x) is.null(x))
70  if (any(dists_null)) {
71    return (list(found_candidates = list(), dists = dists))
72  }
73  cat("Found the pairwise ngram distributions.\n")
74
75  # Find the threshold for choosing "significant" ngram pairs
76  f <- params$f; q <- params$q; p <- params$p
77  q2 <- .5 * f * (p + q) + (1 - f) * q
78  p2 <- .5 * f * (p + q) + (1 - f) * p
79  std_dev_counts <- sqrt(p2 * (1 - p2) * N) / (q2 - p2)
80  (threshold <- std_dev_counts / N)
81  threshold <- 0.04
82
83  # Filter joints to remove infrequently co-occurring ngrams.
84  candidate_strs <- lapply(1:num_candidate_ngrams, function(i) {
85    fit <- dists[[i]]$fit
86    edges <- which(fit > threshold, arr.ind = TRUE, FALSE)
87
88    # Recover the list of strings that seem significant
89    found_candidates <- sapply(1:ncol(edges), function(x) {
90      chunks <- sapply(edges[, x],
91                       function(j) dimnames(fit)[[x]][j])
92      chunks
93    })
94    # sapply returns either "character" vector (for n=1) or a matrix.  Convert
95    # it to a matrix.  This can be seen as follows:
96    #
97    # > class(sapply(1:5, function(x) "a"))
98    # [1] "character"
99    # > class(sapply(1:5, function(x) c("a", "b")))
100    # [1] "matrix"
101    found_candidates <- rbind(found_candidates)
102
103    # Remove the "others"
104    others <- which(found_candidates == "Other")
105    if (length(others) > 0) {
106      other <- which(found_candidates == "Other", arr.ind = TRUE)[, 1]
107      # drop = FALSE necessary to keep it a matrix
108      found_candidates <- found_candidates[-other, , drop = FALSE]
109    }
110
111    found_candidates
112  })
113  if (any(lapply(found_candidates, function(x) length(x)) == 0)) {
114    return (NULL)
115  }
116
117  list(candidate_strs = candidate_strs, dists = dists)
118}
119
120FindFeasibleStrings <- function(found_candidates, pairings, num_ngrams,
121                                ngram_size) {
122  # Uses the list of strings found by the pairwise comparisons to build
123  #     a list of full feasible strings. This relies on the iterative,
124  #     graph-based approach.
125  #
126  # Args:
127  #   found_candidates: list of candidates found by each pairwise decoding
128  #   pairings: Matrix of size 2x(num_ngrams choose 2) listing all the
129  #       ngram position pairings.
130  #   num_ngrams: The total number of ngrams per word.
131  #   ngram_size: Number of characters per ngram
132  #
133  # Returns:
134  #   List of full string candidates.
135
136  # Which ngram pairs are adjacent, i.e. of the form (i,i+1)
137  adjacent <- sapply(seq(num_ngrams - 1), function(x) {
138    c(1 + (x - 1) * ngram_size, x * ngram_size + 1)
139  })
140
141  adjacent_pairs <- apply(adjacent, 2, function(x) {
142    which(apply(pairings, 1, function(y) identical(y, x)))
143  })
144
145  # The first set of candidates are ngrams found in positions 1 and 2
146  active_cands <- found_candidates[[adjacent_pairs[1]]]
147  if (class(active_cands) == "list") {
148    return (list())
149  } else {
150    active_cands <- as.data.frame(active_cands)
151  }
152
153  # Now check successive ngrams to find consistent combinations
154  #     i.e. after ngrams 1-2, check 2-3, 3-4, 4-5, etc.
155  for (i in 2:length(adjacent_pairs)) {
156    if (nrow(active_cands) == 0) {
157      return (list())
158    }
159    new_cands <- found_candidates[[adjacent_pairs[i]]]
160    new_cands <- as.data.frame(new_cands)
161    # Builds the set of possible candidates based only on ascending
162    #     candidate pairs
163    active_cands <- BuildCandidates(active_cands, new_cands)
164  }
165
166  if (nrow(active_cands) == 0) {
167    return (list())
168  }
169  # Now refine these candidates using non-adjacent bigrams
170  remaining <- (1:(num_ngrams * (num_ngrams - 1) / 2))[-c(1, adjacent_pairs)]
171  # For each non-adjacent pair, make sure that all the candidates are
172  #     consistent (in this phase, candidates can ONLY be eliminated)
173
174  for (i in remaining) {
175    new_cands <- found_candidates[[i]]
176    new_cands <- as.data.frame(new_cands)
177    # Prune out all candidates that do not agree with new_cands
178    active_cands <- PruneCandidates(active_cands, pairings[i, ],
179                                    ngram_size,
180                                    new_cands = new_cands)
181  }
182  # Consolidate the string ngrams into a full string representation
183  if (length(active_cands) > 0) {
184    active_cands <- sort(apply(active_cands, 1,
185                               function(x) paste0(x, collapse = "")))
186  }
187  unname(active_cands)
188}
189
190BuildCandidates <- function(active_cands, new_cands) {
191  # Takes in a data frame where each row is a valid sequence of ngrams
192  #     checks which of the new_cands ngram pairs are consistent with
193  #     the original active_cands ngram sequence.
194  #
195  # Args:
196  #   active_cands: data frame of ngram sequence candidates (1 candidate
197  #       sequence per row)
198  #   new_cands: An rx2 data frame with a new list of candidate ngram
199  #       pairs that might fit in with the previous list of candidates
200  #
201  # Returns:
202  #   Updated active_cands, with another column if valid extensions are
203  #       found.
204
205  # Get the trailing ngrams from the current candidates
206  to_check <- as.vector(tail(t(active_cands), n = 1))
207  # Check which of the elements in to_check are leading ngrams among the
208  #     new candidates
209  present <- sapply(to_check, function(x) any(x == new_cands))
210  # Remove the strings that are not represented among the new candidates
211  to_check <- to_check[present]
212  # Now insert the new candidates where they belong
213  active_cands <- active_cands[present, , drop = FALSE]
214  active_cands <- cbind(active_cands, col = NA)
215  num_cands <- nrow(active_cands)
216  hit_list <- c()
217  for (j in 1:num_cands) {
218    inds <- which(new_cands[, 1] == to_check[j])
219    if (length(inds) == 0) {
220      hit_list <- c(hit_list, j)
221      next
222    }
223    # If there are multiple candidates fitting with an ngram, include
224    #     each /full/ string as a candidate
225    extra <- length(inds) - 1
226    if (extra > 0) {
227      rep_inds <- c(j, (new_num_cands + 1):(new_num_cands + extra))
228      to_paste <- active_cands[j, ]
229      # Add the new candidates to the bottom
230      for (p in 1:extra) {
231        active_cands <- rbind(active_cands, to_paste)
232      }
233    } else {
234      rep_inds <- c(j)
235    }
236    active_cands[rep_inds, ncol(active_cands)] <-
237        as.vector(new_cands[inds, 2])
238    new_num_cands <- nrow(active_cands)
239  }
240  # If there were some false candidates in the original set, remove them
241  if (length(hit_list) > 0) {
242    active_cands <- active_cands[-hit_list, , drop = FALSE]
243  }
244  active_cands
245}
246
247PruneCandidates <- function(active_cands, pairing, ngram_size, new_cands) {
248  # Takes in a data frame where each row is a valid sequence of ngrams
249  #     checks which of the new_cands ngram pairs are consistent with
250  #     the original active_cands ngram sequence. This can ONLY remove
251  #     candidates presented in active_cands.
252  #
253  # Args:
254  #   active_cands: data frame of ngram sequence candidates (1 candidate
255  #       sequence per row)
256  #   pairing: A length-2 list storing which two ngrams are measured
257  #   ngram_size: Number of characters per ngram
258  #   new_cands: An rx2 data frame with a new list of candidate ngram
259  #       pairs that might fit in with the previous list of candidates
260  #
261  # Returns:
262  #   Updated active_cands, with a reduced number of rows.
263
264  # Convert the pairing to an ngram index
265  cols <- sapply(pairing, function(x) (x - 1) / ngram_size + 1)
266
267  cands_to_check <- active_cands[, cols, drop = FALSE]
268  # Find the candidates that are inconsistent with the new data
269  hit_list <- sapply(1:nrow(cands_to_check), function(j) {
270    to_kill <- FALSE
271    if (nrow(new_cands) == 0) {
272      return (TRUE)
273    }
274    if (!any(apply(new_cands, 1, function(x)
275                   all(cands_to_check[j, , drop = FALSE] == x)))) {
276      to_kill <- TRUE
277    }
278    to_kill
279  })
280
281  # Determine which rows are false positives
282  hit_indices <- which(hit_list)
283  # Remove the false positives
284  if (length(hit_indices) > 0) {
285    active_cands <- active_cands[-hit_indices, ]
286  }
287  active_cands
288}
289
290EstimateDictionary <- function(report_data, N, ngram_params, params) {
291  # Takes in a list of report data and returns a list of string
292  #     estimates of the dictionary.
293  #
294  # Args:
295  #     report_data: Object containing data relevant to reports:
296  #         $inds: The indices of reports collected using various pairs
297  #         $cohorts: The cohort of each report
298  #         $map: THe map used for all the ngrams
299  #         $reports: The reports used for each ngram and full string
300  #   N: the number of individuals sending reports
301  #   ngram_params: Parameters related to ngram length, etc
302  #   params: Parameter vector with RAPPOR noise levels, cohorts, etc
303  #
304  # Returns:
305  #   List: list of found candidates, list of pairwise candidates
306
307  pairwise_candidates <- FindPairwiseCandidates(report_data, N,
308                                                ngram_params,
309                                                params)$candidate_strs
310  cat("Found the pairwise candidates. \n")
311  if (is.null(pairwise_candidates)) {
312    return (list())
313  }
314  found_candidates <- FindFeasibleStrings(pairwise_candidates,
315                                          report_data$pairings,
316                                          ngram_params$num_ngrams,
317                                          ngram_params$ngram_size)
318  cat("Found all the candidates. \n")
319  list(found_candidates = found_candidates,
320       pairwise_candidates = pairwise_candidates)
321}
322
323WriteKPartiteGraph <- function(conn, pairwise_candidates, pairings, num_ngrams,
324                               ngram_size) {
325  # Args:
326  #  conn: R connection to write to.  Should be opened with mode w+.
327  #  pairwise_candidates: list of matrices.  Each matrix represents a subgraph;
328  #    it contains the edges between partitions i and j, so there are (k choose
329  #    2) matrices.  Each matrix has dimension 2 x E, where E is the number of
330  #    edges.
331  #  pairings: 2 x (k choose 2) matrix of character positions.  Each row
332  #    corresponds to a subgraph; it has 1-based character index of partitions
333  #    i and j.
334  #  num_ngrams: length of pairwise_candidates, or the number of partitions in
335  #    the k-partite graph
336
337  # File Format:
338  #
339  # num_partitions 3
340  # ngram_size 2
341  # 0.ab 1.cd
342  # 0.ab 2.ef
343  #
344  # The first line specifies the number of partitions (k).
345  # The remaining lines are edges, where each node is <partition>.<bigram>.
346  #
347  # Partitions are numbered from 0.  The partition of the left node will be
348  # less than the partition of the right node.
349
350  # First two lines are metadata
351  cat(sprintf('num_partitions %d\n', num_ngrams), file = conn)
352  cat(sprintf('ngram_size %d\n', ngram_size), file = conn)
353
354  for (i in 1:length(pairwise_candidates)) {
355    # The two pairwise_candidates for this subgraph.
356    # Turn 1-based character positions into 0-based partition numbers,
357    # e.g. (3, 5) -> (1, 2)
358
359    pos1 <- pairings[[i, 1]]
360    pos2 <- pairings[[i, 2]]
361    part1 <- (pos1 - 1) / ngram_size
362    part2 <- (pos2 - 1) / ngram_size
363    cat(sprintf("Writing partition (%d, %d)\n", part1, part2))
364
365    p <- pairwise_candidates[[i]]
366    # each row is an edge
367    for (j in 1:nrow(p)) {
368      n1 <- p[[j, 1]]
369      n2 <- p[[j, 2]]
370      line <- sprintf('edge %d.%s %d.%s\n', part1, n1, part2, n2)
371      # NOTE: It would be faster to preallocate 'lines', but we would have to
372      # make a two passes through pairwise_candidates.
373      cat(line, file = conn)
374    }
375  }
376}
377
378