Coverage for src / cufile_patcher / auto_patch.py: 100%

62 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-11 15:06 +0000

1from __future__ import annotations 

2 

3import importlib 

4from contextlib import AbstractContextManager 

5from dataclasses import dataclass 

6from types import ModuleType 

7from typing import Any 

8 

9from .safetensor_patcher import patch_safetensor_load_file 

10from .tensorflow_patcher import patch_tensorflow_load_model 

11from .torch_patcher import patch_torch_load 

12 

13 

14@dataclass(frozen=True) 

15class _FrameworkSpec: 

16 name: str 

17 import_name: str 

18 enabled: bool | None 

19 

20 

21def _normalize_flag(flag: bool | None) -> bool | None: 

22 if flag in (True, False, None): 

23 return flag 

24 raise TypeError("framework flags must be True, False, or None") 

25 

26 

27class AutoPatchContext(AbstractContextManager["AutoPatchContext"]): 

28 """Context manager that installs and removes framework patchers automatically.""" 

29 

30 def __init__( 

31 self, 

32 *, 

33 torch: bool | None = None, 

34 tensorflow: bool | None = None, 

35 safetensors: bool | None = None, 

36 strict: bool = False, 

37 min_file_size_mb: int = 64, 

38 chunk_size_mb: int = 16, 

39 use_cufile: bool = False, 

40 fallback_to_original: bool = True, 

41 ) -> None: 

42 self._strict = strict 

43 self._min_file_size_mb = min_file_size_mb 

44 self._chunk_size_mb = chunk_size_mb 

45 self._use_cufile = use_cufile 

46 self._fallback_to_original = fallback_to_original 

47 self._specs = [ 

48 _FrameworkSpec("torch", "torch", _normalize_flag(torch)), 

49 _FrameworkSpec("tensorflow", "tensorflow", _normalize_flag(tensorflow)), 

50 _FrameworkSpec("safetensors", "safetensors.torch", _normalize_flag(safetensors)), 

51 ] 

52 self._patchers: list[Any] = [] 

53 

54 def __enter__(self) -> AutoPatchContext: 

55 self._patchers = [] 

56 for spec in self._specs: 

57 module = self._resolve_module(spec) 

58 if module is None: 

59 continue 

60 patcher = self._build_patcher(spec.name, module) 

61 patcher.install() 

62 self._patchers.append(patcher) 

63 return self 

64 

65 def __exit__(self, exc_type, exc, tb) -> None: 

66 while self._patchers: 

67 patcher = self._patchers.pop() 

68 patcher.uninstall() 

69 return None 

70 

71 def _resolve_module(self, spec: _FrameworkSpec) -> ModuleType | None: 

72 if spec.enabled is False: 

73 return None 

74 try: 

75 return importlib.import_module(spec.import_name) 

76 except ModuleNotFoundError: 

77 if self._strict or spec.enabled is True: 

78 raise RuntimeError(f"requested framework is not available: {spec.name}") from None 

79 return None 

80 

81 def _build_patcher(self, framework: str, module: ModuleType): 

82 kwargs = { 

83 "min_file_size_mb": self._min_file_size_mb, 

84 "chunk_size_mb": self._chunk_size_mb, 

85 "use_cufile": self._use_cufile, 

86 "fallback_to_original": self._fallback_to_original, 

87 } 

88 if framework == "torch": 

89 return patch_torch_load(module, **kwargs) 

90 if framework == "tensorflow": 

91 return patch_tensorflow_load_model(module, **kwargs) 

92 if framework == "safetensors": 

93 return patch_safetensor_load_file(module, **kwargs) 

94 raise ValueError(f"unsupported framework: {framework}") 

95 

96 

97def auto_patch( 

98 *, 

99 torch: bool | None = None, 

100 tensorflow: bool | None = None, 

101 safetensors: bool | None = None, 

102 strict: bool = False, 

103 min_file_size_mb: int = 64, 

104 chunk_size_mb: int = 16, 

105 use_cufile: bool = False, 

106 fallback_to_original: bool = True, 

107) -> AutoPatchContext: 

108 """ 

109 Return a context manager that installs available framework patchers. 

110 

111 By default, framework flags are auto-detected (`None`): if import succeeds, 

112 the corresponding patcher is installed. 

113 """ 

114 return AutoPatchContext( 

115 torch=torch, 

116 tensorflow=tensorflow, 

117 safetensors=safetensors, 

118 strict=strict, 

119 min_file_size_mb=min_file_size_mb, 

120 chunk_size_mb=chunk_size_mb, 

121 use_cufile=use_cufile, 

122 fallback_to_original=fallback_to_original, 

123 )