Coverage for src / cufile_patcher / safetensor_patcher.py: 100%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-11 15:06 +0000

1from __future__ import annotations 

2 

3import ctypes 

4import os 

5from abc import ABC, abstractmethod 

6from pathlib import Path 

7from tempfile import TemporaryDirectory 

8from types import ModuleType 

9from typing import Any 

10 

11from .cufile import CuFile 

12 

13 

14class SafeTensorStreamReader(ABC): 

15 """Plugin contract for reading safetensor files in chunks.""" 

16 

17 @abstractmethod 

18 def iter_chunks(self, file_path: str, chunk_size: int): 

19 """Yield file bytes in chunk-sized pieces.""" 

20 

21 

22class PythonSafeTensorStreamReader(SafeTensorStreamReader): 

23 """Portable chunk reader that works without GPU dependencies.""" 

24 

25 def iter_chunks(self, file_path: str, chunk_size: int): 

26 with open(file_path, "rb") as fp: 

27 while True: 

28 chunk = fp.read(chunk_size) 

29 if not chunk: 

30 break 

31 yield chunk 

32 

33 

34class CuFileSafeTensorStreamReader(SafeTensorStreamReader): 

35 """Chunk reader that pulls data through the CuFile wrapper.""" 

36 

37 def iter_chunks(self, file_path: str, chunk_size: int): 

38 file_size = os.path.getsize(file_path) 

39 offset = 0 

40 

41 with CuFile(file_path, "r") as cu_file: 

42 while offset < file_size: 

43 to_read = min(chunk_size, file_size - offset) 

44 staging = ctypes.create_string_buffer(to_read) 

45 read_n = cu_file.read( 

46 ctypes.cast(staging, ctypes.c_void_p), 

47 to_read, 

48 file_offset=offset, 

49 ) 

50 if read_n <= 0: 

51 break 

52 offset += read_n 

53 yield staging.raw[:read_n] 

54 

55 

56class SafeTensorCuFilePatcher: 

57 """ 

58 Monkey-patch safetensors.torch.load_file for chunked streaming of large files. 

59 

60 Large path-based loads are streamed into a temporary safetensor path before 

61 invoking the original loader. Small files keep original behavior. 

62 """ 

63 

64 def __init__( 

65 self, 

66 safetensors_torch_module: ModuleType, 

67 *, 

68 min_file_size_mb: int = 64, 

69 chunk_size_mb: int = 16, 

70 stream_reader: SafeTensorStreamReader | None = None, 

71 use_cufile: bool = False, 

72 fallback_to_original: bool = True, 

73 ) -> None: 

74 self._st = safetensors_torch_module 

75 self._min_file_size = max(1, min_file_size_mb) * 1024 * 1024 

76 self._chunk_size = max(1, chunk_size_mb) * 1024 * 1024 

77 self._reader = stream_reader or ( 

78 CuFileSafeTensorStreamReader() if use_cufile else PythonSafeTensorStreamReader() 

79 ) 

80 self._fallback_to_original = fallback_to_original 

81 self._original_load_file = None 

82 

83 @property 

84 def installed(self) -> bool: 

85 return self._original_load_file is not None 

86 

87 def install(self) -> None: 

88 if self.installed: 

89 return 

90 self._original_load_file = self._st.load_file 

91 self._st.load_file = self._patched_load_file 

92 

93 def uninstall(self) -> None: 

94 if not self.installed: 

95 return 

96 self._st.load_file = self._original_load_file 

97 self._original_load_file = None 

98 

99 def _should_stream(self, source: Any) -> bool: 

100 if not isinstance(source, (str, os.PathLike, Path)): 

101 return False 

102 source_path = os.fspath(source) 

103 if not os.path.isfile(source_path): 

104 return False 

105 return os.path.getsize(source_path) >= self._min_file_size 

106 

107 def _patched_load_file(self, source: Any, *args: Any, **kwargs: Any): 

108 if not self._should_stream(source): 

109 return self._original_load_file(source, *args, **kwargs) 

110 

111 source_path = os.fspath(source) 

112 try: 

113 return self._streaming_load_file(source_path, *args, **kwargs) 

114 except Exception: 

115 if not self._fallback_to_original: 

116 raise 

117 return self._original_load_file(source, *args, **kwargs) 

118 

119 def _streaming_load_file(self, file_path: str, *args: Any, **kwargs: Any): 

120 with TemporaryDirectory() as temp_dir: 

121 temp_path = os.path.join(temp_dir, "streamed.safetensors") 

122 with open(temp_path, "wb") as out_file: 

123 for chunk in self._reader.iter_chunks(file_path, self._chunk_size): 

124 out_file.write(chunk) 

125 return self._original_load_file(temp_path, *args, **kwargs) 

126 

127 

128def patch_safetensor_load_file( 

129 safetensors_torch_module: ModuleType, 

130 *, 

131 min_file_size_mb: int = 64, 

132 chunk_size_mb: int = 16, 

133 stream_reader: SafeTensorStreamReader | None = None, 

134 use_cufile: bool = False, 

135 fallback_to_original: bool = True, 

136) -> SafeTensorCuFilePatcher: 

137 patcher = SafeTensorCuFilePatcher( 

138 safetensors_torch_module, 

139 min_file_size_mb=min_file_size_mb, 

140 chunk_size_mb=chunk_size_mb, 

141 stream_reader=stream_reader, 

142 use_cufile=use_cufile, 

143 fallback_to_original=fallback_to_original, 

144 ) 

145 patcher.install() 

146 return patcher