1from __future__ import annotations 2 3import threading 4from contextlib import contextmanager 5from typing import Iterator 6 7 8# Simple dynamic scoping implementation. The name "parametrize" comes 9# from Racket. 10# 11# WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about 12# why you need to add a toggle to the global behavior of code 13# generation. The parameters here should really only be used 14# for "temporary" situations, where we need to temporarily change 15# the codegen in some cases because we cannot conveniently update 16# all call sites, and are slated to be eliminated once all call 17# sites are eliminated. If you don't have a plan for how to get there, 18# DON'T add a new entry here. 19 20 21class Locals(threading.local): 22 use_const_ref_for_mutable_tensors: bool | None = None 23 use_ilistref_for_tensor_lists: bool | None = None 24 25 26_locals = Locals() 27 28 29def use_const_ref_for_mutable_tensors() -> bool: 30 assert _locals.use_const_ref_for_mutable_tensors is not None, ( 31 "need to initialize local.use_const_ref_for_mutable_tensors with " 32 "local.parametrize" 33 ) 34 return _locals.use_const_ref_for_mutable_tensors 35 36 37def use_ilistref_for_tensor_lists() -> bool: 38 assert _locals.use_ilistref_for_tensor_lists is not None, ( 39 "need to initialize local.use_ilistref_for_tensor_lists with " 40 "local.parametrize" 41 ) 42 return _locals.use_ilistref_for_tensor_lists 43 44 45@contextmanager 46def parametrize( 47 *, use_const_ref_for_mutable_tensors: bool, use_ilistref_for_tensor_lists: bool 48) -> Iterator[None]: 49 old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors 50 old_use_ilistref_for_tensor_lists = _locals.use_ilistref_for_tensor_lists 51 try: 52 _locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors 53 _locals.use_ilistref_for_tensor_lists = use_ilistref_for_tensor_lists 54 yield 55 finally: 56 _locals.use_const_ref_for_mutable_tensors = ( 57 old_use_const_ref_for_mutable_tensors 58 ) 59 _locals.use_ilistref_for_tensor_lists = old_use_ilistref_for_tensor_lists 60