import copy import enum import pprint import unittest from enum import Enum # Importing these files make modifications to the op_db that we need import test_ops # noqa: F401 import test_vmap # noqa: F401 from functorch_additional_op_db import additional_op_db import torch import torch._functorch.top_operators_github_usage as top_ops from torch.testing._internal.common_device_type import toleranceOverride from torch.testing._internal.common_methods_invocations import op_db all_overridable = list(torch.overrides.get_testing_overrides().keys()) public_docs = [ (torch.nn.functional, "torch.nn.functional", "docs/source/nn.functional.rst"), (torch.fft, "torch.fft", "docs/source/fft.rst"), (torch.special, "torch.special", "docs/source/special.rst"), (torch.linalg, "torch.linalg", "docs/source/linalg.rst"), (torch, "torch", "docs/source/torch.rst"), (torch.Tensor, "torch.Tensor", "docs/source/tensors.rst"), ] # torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different def get_public_overridable_apis(pytorch_root="/raid/rzou/pt/debug-cpu"): results = {} all_overridable_apis = set(torch.overrides.get_testing_overrides().keys()) for module, module_name, src in public_docs: with open(f"{pytorch_root}/{src}") as f: lines = f.readlines() # APIs eitehr begin with 4 spaces or ".. autofunction::" api_lines1 = [line.strip() for line in lines if line.startswith(" " * 4)] api_lines2 = [ line.strip()[len(".. autofunction:: ") :] for line in lines if line.startswith(".. autofunction::") ] lines = api_lines1 + api_lines2 lines = [line[7:] if line.startswith("Tensor.") else line for line in lines] lines = [line for line in lines if hasattr(module, line)] for line in lines: api = getattr(module, line) if api in all_overridable_apis: results[f"{module_name}.{line}"] = api return results denylist = { "torch.Tensor.data_ptr", "torch.Tensor.dim", "torch.Tensor.element_size", "torch.Tensor.backward", "torch.Tensor.as_strided", "torch.Tensor.register_hook", "torch.Tensor.record_stream", "torch.Tensor.qscheme", "torch.Tensor.ndimension", "torch.Tensor.smm", "torch.Tensor.sspaddmm", "torch.Tensor.retain_grad", "torch.Tensor.sparse_mask", "torch.Tensor.sparse_dim", "torch.Tensor.dense_dim", "torch.Tensor.values", "torch.Tensor.indices", "torch.Tensor.numel", "torch.Tensor.size", "torch.Tensor.nelement", "torch.Tensor.q_scale", "torch.Tensor.q_zero_point", "torch.Tensor.q_per_channel_scales", "torch.Tensor.q_per_channel_zero_points", "torch.Tensor.q_per_channel_axis", "torch.Tensor.int_repr", "torch.Tensor.to_sparse", "torch.Tensor.is_inference", "torch.Tensor.storage", "torch.Tensor.storage_type", } def get_method_only_ops_we_care_about(): apis = get_public_overridable_apis() result = [] for key in apis.keys(): if not key.startswith("torch.Tensor"): continue if key in denylist: continue api = key.split(".")[2] # filter out in-place if api.endswith("_"): continue if f"torch.{api}" not in apis.keys(): result.append(api) return result # Deduplicates torch.abs and Tensor.abs def get_public_overridable_ops(): results = get_public_overridable_apis() cpy = copy.deepcopy(results) for key in cpy.keys(): if not key.startswith("torch.Tensor"): continue api = key.split(".")[2] if f"torch.{api}" in results.keys(): del results[key] return results def get_public_overridable_outplace_ops(): results = get_public_overridable_ops() cpy = copy.deepcopy(results) for key in cpy.keys(): # NB: there are no dunder methods bcs we don't document those if key.endswith("_"): del results[key] return results def get_public_overridable_outplace_we_care_about(): results = get_public_overridable_outplace_ops() cpy = copy.deepcopy(results) for key in cpy.keys(): # quantization if "quant" in key or ".q_" in key: del results[key] # is_cpu, etc. It doesn't make sense to have OpInfos for these if ".is_" in key: del results[key] if key in denylist and key in results: del results[key] return results # e.g. nn.functional.softmax def get_op(dotted_name): names = dotted_name.split(".") mod = torch for name in names: if not hasattr(mod, name): return None mod = getattr(mod, name) return mod # Maps function -> [OpInfo] def get_ops_covered_by_opinfos(): ops = {} def safe_append(dct, key, val): if key in dct: dct[key].append(val) else: dct[key] = [val] for opinfo in op_db: func_op = get_op(opinfo.name) if func_op: safe_append(ops, func_op, opinfo) if opinfo.method_variant: safe_append(ops, opinfo.method_variant, opinfo) if opinfo.inplace_variant: safe_append(ops, opinfo.inplace_variant, opinfo) for alias in opinfo.aliases: safe_append(ops, alias.op, opinfo) return ops factory_fns = { "tensor", "zeros", "ones", "randn", "arange", "rand", "empty", "randperm", "linspace", "logspace", "hann_window", "full", "eye", "blackman_window", "bartlett_window", "randint", "range", } def get_top_ops(torch_threshold, nn_fn_threshold, with_counts=False): denylist = set( { # These are either not real "operators", factory functions # that trivially work, or not-documented ops. "load", "no_grad", "save", "from_numpy", "manual_seed", "set_grad_enabled", "set_default_tensor_type", "set_num_threads", "set_printoptions", "numel", "set_default_dtype", "sparse_coo_tensor", "set_rng_state", "get_rng_state", "get_default_dtype", "initial_seed", "get_num_threads", "quantize_per_tensor", "hann_window", "is_tensor", "as_tensor", "equal", "enable_grad", "seed", "is_storage", "is_floating_point", "nn.functional.torch", "set_flush_denormal", "set_num_interop_threads", "dequantize", "get_num_interop_threads", "nn.functional.math", "nn.functional.threshold_", "nn.functional.selu_", "nn.functional.elu_", "nn.functional.rrelu_", "nn.functional.leaky_relu_", "nn.functional.hardtanh_", "nn.functional.has_torch_function", "nn.functional.has_torch_function_unary", "nn.functional.has_torch_function_variadic", "nn.functional.handle_torch_function", "nn.functional.adaptive_max_pool1d_with_indices", "nn.functional.adaptive_max_pool2d_with_indices", "nn.functional.adaptive_max_pool3d_with_indices", "nn.functional.fractional_max_pool2d_with_indices", "nn.functional.fractional_max_pool3d_with_indices", "is_complex", "grad", "quantize_per_channel", "nn.functional.max_pool2d_with_indices", "nn.functional.max_pool3d_with_indices", "nn.functional.max_pool1d_with_indices", "nn.functional.celu_", "nn.functional.grad", "nn.functional.relu_", "nn.functional.boolean_dispatch", "nn.functional.assert_int_or_pair", "fft", # is namespace } ) torch_ops = top_ops.top_torch nn_fn_ops = top_ops.get_nn_functional_top_list() torch_ops = [op for op in torch_ops if op[0] not in denylist] nn_fn_ops = [op for op in nn_fn_ops if op[0] not in denylist] ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold] # Now, sort by priority ops.sort(reverse=True, key=lambda op: op[1]) if not with_counts: ops = [op[0] for op in ops] return ops def get_ops_percentage(torch_threshold, nn_fn_threshold): data = top_ops.top_torch + top_ops.get_nn_functional_top_list() def get_num_usages(opname): # Ignore this, this is heavily inflated if opname == "t": return 0 result = [op[1] for op in data if op[0] == opname] assert len(result) == 1 return result[0] # get all operators that are not in the denylist all_ops = get_top_ops(999999, 999999) total_op_usages = sum(get_num_usages(op) for op in all_ops) # get subset of all operators subset_ops = get_top_ops(torch_threshold, nn_fn_threshold) subset_op_usages = sum(get_num_usages(op) for op in subset_ops) return subset_op_usages / total_op_usages def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0): ops = get_top_ops(torch_threshold, nn_fn_threshold) ops_with_opinfo = [] for op in op_db: ops_with_opinfo.append(op.name) ops_with_opinfo.extend([op.name for op in op.aliases]) ops_with_opinfo = set(ops_with_opinfo) result = [op for op in ops if op not in ops_with_opinfo] result = [op for op in result if op not in denylist] result = [op for op in result if op not in factory_fns] return result def get_covered_ops(ops_list, invert=False): ops_covered_by_opinfo = get_ops_covered_by_opinfos() overridable_outplace_ops = ops_list results = {} for key, op in overridable_outplace_ops.items(): cond = op in ops_covered_by_opinfo if invert: cond = not cond if cond: results[key] = op return results class Status(Enum): Correct = 0 Fast = 1 tests = { "test_vmap_exhaustive", "test_op_has_batch_rule", "test_vjp", "test_vmapvjp", "test_vmapvjp_has_batch_rule", "test_jvp", "test_vmapjvp", } def is_decorateinfo_skip_or_xfail(decorateinfo): assert len(decorateinfo.decorators) == 1 actual_decorator = decorateinfo.decorators[0] if isinstance(actual_decorator, toleranceOverride): return False if actual_decorator == unittest.expectedFailure: return True # Assume the rest are skips return True def get_all_tested_ops(): overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() op_to_opinfo = get_ops_covered_by_opinfos() result = set({}) for op in get_covered_ops(overridable_outplace_we_care_about).values(): opinfos = op_to_opinfo[op] result.update(opinfo.name for opinfo in opinfos) return result def get_skipped_or_xfailed_ops_for(test_name): overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() op_to_opinfo = get_ops_covered_by_opinfos() result = set({}) for op in get_covered_ops(overridable_outplace_we_care_about).values(): opinfos = op_to_opinfo[op] for opinfo in opinfos: for decorator in opinfo.decorators: if not hasattr(decorator, "test_name"): continue if decorator.test_name != test_name: continue if is_decorateinfo_skip_or_xfail(decorator): result.add(opinfo.name) return result def get_statuses(for_subset=None, invert=False): overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() if for_subset is not None: overridable_outplace_we_care_about = { k: v for k, v in overridable_outplace_we_care_about.items() # Removes "torch." if k[6:] in for_subset } op_to_opinfo = get_ops_covered_by_opinfos() result = {} _ = get_covered_ops(overridable_outplace_we_care_about) def get_covered_tests(op): opinfos = op_to_opinfo[op] result = copy.deepcopy(tests) for opinfo in opinfos: for decorator in opinfo.decorators: if not hasattr(decorator, "test_name"): continue if decorator.test_name in tests and decorator.test_name in result: result.remove(decorator.test_name) return result def get_all_aliases(op): opinfos = op_to_opinfo[op] result = [] for opinfo in opinfos: result.append(opinfo.name) result.extend(opinfo.aliases) return set(result) for name, op in get_covered_ops(overridable_outplace_we_care_about).items(): successful_tests = get_covered_tests(op) failed_tests = tests - successful_tests result[name] = failed_tests if invert else successful_tests return result def transpose_statuses(for_subset=None, invert=False): statuses = get_statuses(for_subset, invert=invert) result = {} for test in tests: result[test] = set({}) for op, supported in statuses.items(): for test in supported: result[test].add(op) return result overridable_apis = get_public_overridable_apis() overridable_ops = get_public_overridable_ops() overridable_outplace_ops = get_public_overridable_outplace_ops() overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about() tested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about) untested_overridable_outplace_ops = get_covered_ops( overridable_outplace_we_care_about, invert=True ) # print("List of OpInfos we need:") # for key in untested_overridable_outplace_ops.keys(): # print(key) # print("-" * 80) # print("") print(f"Overridable public APIs: {len(overridable_apis)}") print(f"Overridable public ops: {len(overridable_ops)}") print(f"Overridable public outplace ops: {len(overridable_outplace_ops)}") print( f"Overridable public outplace ops we care about: {len(overridable_outplace_we_care_about)}" ) print( f"OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}" ) def remove_torch(name): assert name[:6] == "torch." return name[6:] def get_list_of_all_tests(): all_tests = list(tested_overridable_outplace_ops.keys()) return {remove_torch(test) for test in all_tests} mytest = { "test_vmap_exhaustive", "test_op_has_batch_rule", "test_vjp", "test_vmapvjp", "test_vmapvjp_has_batch_rule", } print("*" * 80) all_tests = get_list_of_all_tests() for test in mytest: result = get_skipped_or_xfailed_ops_for(test) diff = len(all_tests - result) print(f"{test}: {diff}") def get_jvp_coverage(subset=None): # - number that support autograd # - number that support forward_ad (in pytorch core) # - number that support functorch.jvp op_to_opinfo = get_ops_covered_by_opinfos() ops_dct = tested_overridable_outplace_ops if subset is not None: ops_dct = { name: op for name, op in ops_dct.items() if remove_torch(name) in subset } supports_autograd_ops_dct = { name: op_to_opinfo[fn] for name, fn in ops_dct.items() if op_to_opinfo[fn][0].supports_autograd } supports_forwardad_ops_dct = { name: op_to_opinfo[fn] for name, fn in ops_dct.items() if op_to_opinfo[fn][0].supports_forward_ad } ops = {remove_torch(test) for test in list(ops_dct.keys())} supports_autograd = { remove_torch(test) for test in list(supports_autograd_ops_dct.keys()) } supports_forward_ad = { remove_torch(test) for test in list(supports_forwardad_ops_dct.keys()) } assert supports_forward_ad.issubset(supports_autograd) assert supports_autograd.issubset(ops) failed_ops = get_skipped_or_xfailed_ops_for("test_jvp") coverage = len(supports_forward_ad - failed_ops) no_forward_ad = len(supports_autograd) - len(supports_forward_ad) print(f"test_jvp, {coverage}, {no_forward_ad}, {len(ops)}") get_jvp_coverage() get_jvp_coverage(get_top_ops(100, 25)) for op in get_top_ops(100, 25): print(op) print("*" * 80) # result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive') # result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule') # result = get_skipped_or_xfailed_ops_for('test_vjp') # result = get_skipped_or_xfailed_ops_for('test_vmapvjp') # result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule') # import pdb; pdb.set_trace() statuses = transpose_statuses() for test in tests: print(f"{test} coverage {len(statuses[test])}") method_only_ops = get_method_only_ops_we_care_about() # for op in method_only_ops: # print(f' {op},') top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(100, 25) print("=" * 80) for op in top_ops_not_covered_by_opinfo: print(f"{op}, {top_ops.usage_count[op]}") # print("top ops not covered by opinfo: ") # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 50) # for op in top_ops_not_covered_by_opinfo: # print(f'{op}, {top_ops.usage_count[op]}') # print("top ops not covered by opinfo: ") # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(220, 92) # for op in top_ops_not_covered_by_opinfo: # print(f'{op}, {top_ops.usage_count[op]}') # print("top ops not covered by opinfo: ") # top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(999, 999) # for op in top_ops_not_covered_by_opinfo: # print(f'{op}, {top_ops.usage_count[op]}') def remove_from_set(parent, to_remove): for to_remove_elt in to_remove: if to_remove_elt in parent: parent.remove(to_remove_elt) def print_coverage_info(th=100, nn=25): print("=" * 80) print(f"top {th}, {nn} coverage") statuses = transpose_statuses(get_top_ops(th, nn), invert=True) top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn) # testing problems exemptions = { "torch.nn.functional.dropout", # randomness } # Allowed exemptions vmap_exemptions = { "torch.randn_like", # randomness "torch.rand_like", # randomness "torch.allclose", # number output "torch.unique", # dynamic "torch.nonzero", # dynamic "torch.masked_select", # dynamic "torch.prod", # dynamic (backward) "torch.norm", # norm with nuc is not commonly used; we support the other cases. "torch.svd", # There isn't a bug, it is just nondeterministic so we can't test it. "torch.nn.functional.embedding", # We support everything except the sparse option. } remove_from_set(statuses["test_vmap_exhaustive"], vmap_exemptions) remove_from_set(statuses["test_vmapvjp"], vmap_exemptions) remove_from_set(statuses["test_vmapvjp_has_batch_rule"], vmap_exemptions) remove_from_set(statuses["test_op_has_batch_rule"], vmap_exemptions) remove_from_set(statuses["test_vmapjvp"], vmap_exemptions) for test in tests: remove_from_set(statuses[test], exemptions) print(f"total ops in set: {th + nn}") print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}") for test in tests: if test in {"test_jvp", "test_vmapjvp"}: continue print(f"{test} failing coverage {len(statuses[test])}") # We don't care about these yet del statuses["test_jvp"] del statuses["test_vmapjvp"] pprint.pprint(statuses) def get_name_to_opinfo_map(): dct = {} for op in op_db + additional_op_db: def add(name, op): if name not in dct: dct[name] = [] dct[name].append(op) add(op.name, op) for alias in op.aliases: add(alias.name, op) return dct NAME_TO_OPINFO = get_name_to_opinfo_map() class Support(enum.Enum): NO = 0 YES = 1 UNKNOWN = 2 FACTORY_FNS = { "tensor", "zeros", "ones", "randn", "arange", "rand", "empty", "range", "full", "randperm", "eye", "randint", "linspace", "logspace", } VJP_EXEMPTIONS = { "nn.functional.dropout", # not actually problem, randomness testing artifact "nn.functional.dropout2d", # not actually problem, randomness testing artifact "nn.functional.rrelu", # not actually problem, randomness testing artifact "bernoulli", # not actually problem, randomness testing artifact "normal", # not actually problem, randomness testing artifact } VMAP_EXEMPTIONS = { "randn_like", # randomness "rand_like", # randomness "allclose", # number output "unique", # dynamic "nonzero", # dynamic "masked_select", # dynamic "prod", # dynamic (backward) "norm", # norm with nuc is not commonly used; we support the other cases. "svd", # There isn't a bug, it is just nondeterministic so we can't test it. "nn.functional.embedding", # We support everything except the sparse option. "nn.functional.dropout", # randomness "nn.functional.dropout2d", # randomness "bernoulli", # randomness "multinomial", # randomness "normal", # randomness } JVP_EXEMPTIONS = { "nn.functional.dropout", # not actually problem, randomness testing artifact "nn.functional.dropout2d", # not actually problem, randomness testing artifact "nn.functional.rrelu", # not actually problem, randomness testing artifact "normal", # not actually problem, randomness testing artifact "bernoulli", # not actually problem, randomness testing artifact } class Operator: def __init__(self, name): self.name = name self.opinfos = NAME_TO_OPINFO.get(name, None) assert self.opinfos is None or len(self.opinfos) > 0 def has_opinfo(self): return self.opinfos is not None def __repr__(self): return f'Operator("{self.name}")' def __hash__(self): return hash(self.name) def no_opinfos_skip_test(self, test_name): """Returns NO if any opinfos have a skip or xfail for the test""" if not self.has_opinfo(): return Support.UNKNOWN for opinfo in self.opinfos: for decorator in opinfo.decorators: if not hasattr(decorator, "test_name"): continue if decorator.test_name != test_name: continue if is_decorateinfo_skip_or_xfail(decorator): return Support.NO return Support.YES def any_opinfo_attr(self, attr): if not self.has_opinfo(): raise RuntimeError return any(getattr(opinfo, attr) for opinfo in self.opinfos) def all_opinfo_attr(self, attr): if not self.has_opinfo(): raise RuntimeError return all(getattr(opinfo, attr) for opinfo in self.opinfos) def supports_vjp(self): if self.name in FACTORY_FNS: return Support.YES if self.name in VJP_EXEMPTIONS: return Support.YES return self.no_opinfos_skip_test("test_vjp") def supports_vmap(self): if self.name in FACTORY_FNS: return Support.YES if self.name in VMAP_EXEMPTIONS: return Support.YES return self.no_opinfos_skip_test("test_vmap_exhaustive") def supports_fast_vmap(self): if self.name in FACTORY_FNS: return Support.YES if self.name in VMAP_EXEMPTIONS: return Support.YES return self.no_opinfos_skip_test("test_op_has_batch_rule") def supports_vmapvjp(self): if self.name in FACTORY_FNS: return Support.YES if self.name in VMAP_EXEMPTIONS: return Support.YES return self.no_opinfos_skip_test("test_vmapvjp") def supports_fast_vmapvjp(self): if self.name in FACTORY_FNS: return Support.YES if self.name in VMAP_EXEMPTIONS: return Support.YES return self.no_opinfos_skip_test("test_vmapvjp_has_batch_rule") def supports_jvp(self): if self.name in FACTORY_FNS: return Support.YES if self.name in JVP_EXEMPTIONS: return Support.YES if not self.has_opinfo(): return Support.UNKNOWN if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr( "supports_forward_ad" ): return Support.NO return self.no_opinfos_skip_test("test_jvp") def supports_jvpvjp(self): if self.name in FACTORY_FNS: return Support.YES exemptions = { # we have support (see OpInfo), testing artifact "nn.functional.dropout2d", "nn.functional.dropout", # exception: we dont even support double backward for this "nn.functional.hardswish", "bernoulli", # this isn't differentiable "normal", # not differentiable } if self.name in exemptions: return Support.YES return self.no_opinfos_skip_test("test_jvpvjp") def _supports_vmapjvp_base(self, test): if self.name in FACTORY_FNS: return Support.YES VMAPJVP_EXEMPTIONS = { "prod", # dynamic (backward) "nn.functional.batch_norm", # testing problem "normal", # not actually problem, randomness testing artifact "bernoulli", # not actually problem, randomness testing artifact "nn.functional.dropout2d", # not actually problem, randomness testing artifact "nn.functional.dropout", # not actually problem, randomness testing artifact # Not a problem. # It's just that the max_norm testing mutates inputs... # (we have our own functorch variant of the OpInfo without max_norm) "nn.functional.embedding", } if self.name in VMAPJVP_EXEMPTIONS: return Support.YES if not self.has_opinfo(): return Support.UNKNOWN if self.any_opinfo_attr("supports_autograd") and not self.all_opinfo_attr( "supports_forward_ad" ): return Support.NO return self.no_opinfos_skip_test(test) def supports_vmapjvp(self): return self._supports_vmapjvp_base("test_vmapjvpall") def supports_fast_vmapjvp(self): return self._supports_vmapjvp_base("test_vmapjvpall_has_batch_rule") class OperatorSet: def __init__(self, operators): self.data = set(operators) @classmethod def from_names(cls, names): return OperatorSet([Operator(name) for name in names]) @classmethod def from_top_ops_threshold(cls, torch_threshold, nn_fn_threshold): names = get_top_ops(torch_threshold, nn_fn_threshold) return cls.from_names(names) @classmethod def from_top125(cls): return cls.from_top_ops_threshold(100, 25) @classmethod def from_top160(cls): return cls.from_top_ops_threshold(107, 53) @classmethod def all(cls): dct = get_public_overridable_outplace_we_care_about() names = dct.keys() names_sanitized = [] for n in names: torch_tensor = "torch.Tensor." torch_dot = "torch." if n.startswith(torch_tensor): names_sanitized.append(n[len(torch_tensor) :]) elif n.startswith(torch_dot): names_sanitized.append(n[len(torch_dot) :]) else: raise AssertionError return cls.from_names(names_sanitized) def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)): result = {} for key in filter: result[key] = set() for op in self.data: support_status = operator_method(op) if support_status in filter: result[support_status].add(op) return result def summary(self): checks = [ "supports_vjp", "supports_vmap", "supports_fast_vmap", "supports_vmapvjp", "supports_fast_vmapvjp", "supports_jvp", "supports_vmapjvp", "supports_fast_vmapjvp", "supports_jvpvjp", ] result = ["test, yes, no, unknown"] for check in checks: accessor = getattr(Operator, check) all_results = self.query(accessor) yes_amt = len(all_results[Support.YES]) no_amt = len(all_results[Support.NO]) unknown_amt = len(all_results[Support.UNKNOWN]) result.append(f"{check}, {yes_amt}, {no_amt}, {unknown_amt}") return "\n".join(result) opset = OperatorSet.all() has_no_opinfo = opset.query(Operator.has_opinfo, (False,)) print("=" * 30 + " Summary " + "=" * 30) print(f"% of usages on github: {get_ops_percentage(99999, 99999)}") print(opset.summary()) # sanity checks result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) print("=" * 30 + " Top 60 Summary " + "=" * 30) print(f"% of usages on github: {get_ops_percentage(35, 25)}") opset = OperatorSet.from_top_ops_threshold(35, 25) # result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) # result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) # kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) # kpprint.pprint(result) # result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) # result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) # pprint.pprint(result) print(opset.summary()) print("=" * 30 + " Top 125 Summary " + "=" * 30) print(f"% of usages on github: {get_ops_percentage(100, 25)}") opset = OperatorSet.from_top125() # result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) # result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) print("supports_vjp") result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN)) pprint.pprint(result) print("supports_jvp") result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN)) pprint.pprint(result) print("supports_vmapjvp") result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN)) pprint.pprint(result) print("supports_jvpvjp") result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) pprint.pprint(result) # result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) # pprint.pprint(result) print(opset.summary()) # print("=" * 30 + " Top 160 Summary " + "=" * 30) # opset = OperatorSet.from_top160() # result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN)) # pprint.pprint(result) # print(opset.summary()) # Print list of everything in order # all_ops = get_top_ops(999999, 999999, with_counts=True) # for op, count in all_ops: # print(f'{op}, {count}')