Coverage for src / cufile_patcher / safetensor_patcher.py: 100%
82 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-11 15:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-11 15:06 +0000
1from __future__ import annotations
3import ctypes
4import os
5from abc import ABC, abstractmethod
6from pathlib import Path
7from tempfile import TemporaryDirectory
8from types import ModuleType
9from typing import Any
11from .cufile import CuFile
14class SafeTensorStreamReader(ABC):
15 """Plugin contract for reading safetensor files in chunks."""
17 @abstractmethod
18 def iter_chunks(self, file_path: str, chunk_size: int):
19 """Yield file bytes in chunk-sized pieces."""
22class PythonSafeTensorStreamReader(SafeTensorStreamReader):
23 """Portable chunk reader that works without GPU dependencies."""
25 def iter_chunks(self, file_path: str, chunk_size: int):
26 with open(file_path, "rb") as fp:
27 while True:
28 chunk = fp.read(chunk_size)
29 if not chunk:
30 break
31 yield chunk
34class CuFileSafeTensorStreamReader(SafeTensorStreamReader):
35 """Chunk reader that pulls data through the CuFile wrapper."""
37 def iter_chunks(self, file_path: str, chunk_size: int):
38 file_size = os.path.getsize(file_path)
39 offset = 0
41 with CuFile(file_path, "r") as cu_file:
42 while offset < file_size:
43 to_read = min(chunk_size, file_size - offset)
44 staging = ctypes.create_string_buffer(to_read)
45 read_n = cu_file.read(
46 ctypes.cast(staging, ctypes.c_void_p),
47 to_read,
48 file_offset=offset,
49 )
50 if read_n <= 0:
51 break
52 offset += read_n
53 yield staging.raw[:read_n]
56class SafeTensorCuFilePatcher:
57 """
58 Monkey-patch safetensors.torch.load_file for chunked streaming of large files.
60 Large path-based loads are streamed into a temporary safetensor path before
61 invoking the original loader. Small files keep original behavior.
62 """
64 def __init__(
65 self,
66 safetensors_torch_module: ModuleType,
67 *,
68 min_file_size_mb: int = 64,
69 chunk_size_mb: int = 16,
70 stream_reader: SafeTensorStreamReader | None = None,
71 use_cufile: bool = False,
72 fallback_to_original: bool = True,
73 ) -> None:
74 self._st = safetensors_torch_module
75 self._min_file_size = max(1, min_file_size_mb) * 1024 * 1024
76 self._chunk_size = max(1, chunk_size_mb) * 1024 * 1024
77 self._reader = stream_reader or (
78 CuFileSafeTensorStreamReader() if use_cufile else PythonSafeTensorStreamReader()
79 )
80 self._fallback_to_original = fallback_to_original
81 self._original_load_file = None
83 @property
84 def installed(self) -> bool:
85 return self._original_load_file is not None
87 def install(self) -> None:
88 if self.installed:
89 return
90 self._original_load_file = self._st.load_file
91 self._st.load_file = self._patched_load_file
93 def uninstall(self) -> None:
94 if not self.installed:
95 return
96 self._st.load_file = self._original_load_file
97 self._original_load_file = None
99 def _should_stream(self, source: Any) -> bool:
100 if not isinstance(source, (str, os.PathLike, Path)):
101 return False
102 source_path = os.fspath(source)
103 if not os.path.isfile(source_path):
104 return False
105 return os.path.getsize(source_path) >= self._min_file_size
107 def _patched_load_file(self, source: Any, *args: Any, **kwargs: Any):
108 if not self._should_stream(source):
109 return self._original_load_file(source, *args, **kwargs)
111 source_path = os.fspath(source)
112 try:
113 return self._streaming_load_file(source_path, *args, **kwargs)
114 except Exception:
115 if not self._fallback_to_original:
116 raise
117 return self._original_load_file(source, *args, **kwargs)
119 def _streaming_load_file(self, file_path: str, *args: Any, **kwargs: Any):
120 with TemporaryDirectory() as temp_dir:
121 temp_path = os.path.join(temp_dir, "streamed.safetensors")
122 with open(temp_path, "wb") as out_file:
123 for chunk in self._reader.iter_chunks(file_path, self._chunk_size):
124 out_file.write(chunk)
125 return self._original_load_file(temp_path, *args, **kwargs)
128def patch_safetensor_load_file(
129 safetensors_torch_module: ModuleType,
130 *,
131 min_file_size_mb: int = 64,
132 chunk_size_mb: int = 16,
133 stream_reader: SafeTensorStreamReader | None = None,
134 use_cufile: bool = False,
135 fallback_to_original: bool = True,
136) -> SafeTensorCuFilePatcher:
137 patcher = SafeTensorCuFilePatcher(
138 safetensors_torch_module,
139 min_file_size_mb=min_file_size_mb,
140 chunk_size_mb=chunk_size_mb,
141 stream_reader=stream_reader,
142 use_cufile=use_cufile,
143 fallback_to_original=fallback_to_original,
144 )
145 patcher.install()
146 return patcher