1import argparse 2import json 3from os import path 4 5import torch 6 7 8# Import all utils so that getattr below can find them 9 10all_submod_list = [ 11 "", 12 "nn", 13 "nn.functional", 14 "nn.init", 15 "optim", 16 "autograd", 17 "cuda", 18 "sparse", 19 "distributions", 20 "fft", 21 "linalg", 22 "jit", 23 "distributed", 24 "futures", 25 "onnx", 26 "random", 27 "utils.bottleneck", 28 "utils.checkpoint", 29 "utils.data", 30 "utils.model_zoo", 31] 32 33 34def get_content(submod): 35 mod = torch 36 if submod: 37 submod = submod.split(".") 38 for name in submod: 39 mod = getattr(mod, name) 40 content = dir(mod) 41 return content 42 43 44def namespace_filter(data): 45 out = {d for d in data if d[0] != "_"} 46 return out 47 48 49def run(args, submod): 50 print(f"## Processing torch.{submod}") 51 prev_filename = f"prev_data_{submod}.json" 52 new_filename = f"new_data_{submod}.json" 53 54 if args.prev_version: 55 content = get_content(submod) 56 with open(prev_filename, "w") as f: 57 json.dump(content, f) 58 print("Data saved for previous version.") 59 elif args.new_version: 60 content = get_content(submod) 61 with open(new_filename, "w") as f: 62 json.dump(content, f) 63 print("Data saved for new version.") 64 else: 65 assert args.compare 66 if not path.exists(prev_filename): 67 raise RuntimeError("Previous version data not collected") 68 69 if not path.exists(new_filename): 70 raise RuntimeError("New version data not collected") 71 72 with open(prev_filename) as f: 73 prev_content = set(json.load(f)) 74 75 with open(new_filename) as f: 76 new_content = set(json.load(f)) 77 78 if not args.show_all: 79 prev_content = namespace_filter(prev_content) 80 new_content = namespace_filter(new_content) 81 82 if new_content == prev_content: 83 print("Nothing changed.") 84 print("") 85 else: 86 print("Things that were added:") 87 print(new_content - prev_content) 88 print("") 89 90 print("Things that were removed:") 91 print(prev_content - new_content) 92 print("") 93 94 95def main(): 96 parser = argparse.ArgumentParser( 97 description="Tool to check namespace content changes" 98 ) 99 100 group = parser.add_mutually_exclusive_group(required=True) 101 group.add_argument("--prev-version", action="store_true") 102 group.add_argument("--new-version", action="store_true") 103 group.add_argument("--compare", action="store_true") 104 105 group = parser.add_mutually_exclusive_group() 106 group.add_argument("--submod", default="", help="part of the submodule to check") 107 group.add_argument( 108 "--all-submod", 109 action="store_true", 110 help="collects data for all main submodules", 111 ) 112 113 parser.add_argument( 114 "--show-all", 115 action="store_true", 116 help="show all the diff, not just public APIs", 117 ) 118 119 args = parser.parse_args() 120 121 if args.all_submod: 122 submods = all_submod_list 123 else: 124 submods = [args.submod] 125 126 for mod in submods: 127 run(args, mod) 128 129 130if __name__ == "__main__": 131 main() 132