Coverage for src / cufile_patcher / auto_patch.py: 100%
62 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-11 15:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-11 15:06 +0000
1from __future__ import annotations
3import importlib
4from contextlib import AbstractContextManager
5from dataclasses import dataclass
6from types import ModuleType
7from typing import Any
9from .safetensor_patcher import patch_safetensor_load_file
10from .tensorflow_patcher import patch_tensorflow_load_model
11from .torch_patcher import patch_torch_load
14@dataclass(frozen=True)
15class _FrameworkSpec:
16 name: str
17 import_name: str
18 enabled: bool | None
21def _normalize_flag(flag: bool | None) -> bool | None:
22 if flag in (True, False, None):
23 return flag
24 raise TypeError("framework flags must be True, False, or None")
27class AutoPatchContext(AbstractContextManager["AutoPatchContext"]):
28 """Context manager that installs and removes framework patchers automatically."""
30 def __init__(
31 self,
32 *,
33 torch: bool | None = None,
34 tensorflow: bool | None = None,
35 safetensors: bool | None = None,
36 strict: bool = False,
37 min_file_size_mb: int = 64,
38 chunk_size_mb: int = 16,
39 use_cufile: bool = False,
40 fallback_to_original: bool = True,
41 ) -> None:
42 self._strict = strict
43 self._min_file_size_mb = min_file_size_mb
44 self._chunk_size_mb = chunk_size_mb
45 self._use_cufile = use_cufile
46 self._fallback_to_original = fallback_to_original
47 self._specs = [
48 _FrameworkSpec("torch", "torch", _normalize_flag(torch)),
49 _FrameworkSpec("tensorflow", "tensorflow", _normalize_flag(tensorflow)),
50 _FrameworkSpec("safetensors", "safetensors.torch", _normalize_flag(safetensors)),
51 ]
52 self._patchers: list[Any] = []
54 def __enter__(self) -> AutoPatchContext:
55 self._patchers = []
56 for spec in self._specs:
57 module = self._resolve_module(spec)
58 if module is None:
59 continue
60 patcher = self._build_patcher(spec.name, module)
61 patcher.install()
62 self._patchers.append(patcher)
63 return self
65 def __exit__(self, exc_type, exc, tb) -> None:
66 while self._patchers:
67 patcher = self._patchers.pop()
68 patcher.uninstall()
69 return None
71 def _resolve_module(self, spec: _FrameworkSpec) -> ModuleType | None:
72 if spec.enabled is False:
73 return None
74 try:
75 return importlib.import_module(spec.import_name)
76 except ModuleNotFoundError:
77 if self._strict or spec.enabled is True:
78 raise RuntimeError(f"requested framework is not available: {spec.name}") from None
79 return None
81 def _build_patcher(self, framework: str, module: ModuleType):
82 kwargs = {
83 "min_file_size_mb": self._min_file_size_mb,
84 "chunk_size_mb": self._chunk_size_mb,
85 "use_cufile": self._use_cufile,
86 "fallback_to_original": self._fallback_to_original,
87 }
88 if framework == "torch":
89 return patch_torch_load(module, **kwargs)
90 if framework == "tensorflow":
91 return patch_tensorflow_load_model(module, **kwargs)
92 if framework == "safetensors":
93 return patch_safetensor_load_file(module, **kwargs)
94 raise ValueError(f"unsupported framework: {framework}")
97def auto_patch(
98 *,
99 torch: bool | None = None,
100 tensorflow: bool | None = None,
101 safetensors: bool | None = None,
102 strict: bool = False,
103 min_file_size_mb: int = 64,
104 chunk_size_mb: int = 16,
105 use_cufile: bool = False,
106 fallback_to_original: bool = True,
107) -> AutoPatchContext:
108 """
109 Return a context manager that installs available framework patchers.
111 By default, framework flags are auto-detected (`None`): if import succeeds,
112 the corresponding patcher is installed.
113 """
114 return AutoPatchContext(
115 torch=torch,
116 tensorflow=tensorflow,
117 safetensors=safetensors,
118 strict=strict,
119 min_file_size_mb=min_file_size_mb,
120 chunk_size_mb=chunk_size_mb,
121 use_cufile=use_cufile,
122 fallback_to_original=fallback_to_original,
123 )