Coverage for src / cufile_patcher / tensorflow_patcher.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-11 15:06 +0000

1from __future__ import annotations 

2 

3import ctypes 

4import os 

5from abc import ABC, abstractmethod 

6from pathlib import Path 

7from tempfile import SpooledTemporaryFile 

8from types import ModuleType 

9from typing import Any 

10 

11from .cufile import CuFile 

12 

13 

14class TensorFlowStreamReader(ABC): 

15 """Plugin contract for reading TensorFlow model files in chunks.""" 

16 

17 @abstractmethod 

18 def iter_chunks(self, file_path: str, chunk_size: int): 

19 """Yield file bytes in chunk-sized pieces.""" 

20 

21 

22class PythonTFStreamReader(TensorFlowStreamReader): 

23 """Portable chunk reader that works without GPU dependencies.""" 

24 

25 def iter_chunks(self, file_path: str, chunk_size: int): 

26 with open(file_path, "rb") as fp: 

27 while True: 

28 chunk = fp.read(chunk_size) 

29 if not chunk: 

30 break 

31 yield chunk 

32 

33 

34class CuFileTFStreamReader(TensorFlowStreamReader): 

35 """Chunk reader that pulls data through the CuFile wrapper.""" 

36 

37 def iter_chunks(self, file_path: str, chunk_size: int): 

38 file_size = os.path.getsize(file_path) 

39 offset = 0 

40 

41 with CuFile(file_path, "r") as cu_file: 

42 while offset < file_size: 

43 to_read = min(chunk_size, file_size - offset) 

44 staging = ctypes.create_string_buffer(to_read) 

45 read_n = cu_file.read( 

46 ctypes.cast(staging, ctypes.c_void_p), 

47 to_read, 

48 file_offset=offset, 

49 ) 

50 if read_n <= 0: 

51 break 

52 offset += read_n 

53 yield staging.raw[:read_n] 

54 

55 

56class TensorFlowCuFilePatcher: 

57 """ 

58 Monkey-patch tf.keras.models.load_model for chunked streaming of large files. 

59 

60 Large path-based loads are streamed into a spooled file before invoking the 

61 original TensorFlow loader. Small files and non-path inputs keep original behavior. 

62 """ 

63 

64 def __init__( 

65 self, 

66 tf_module: ModuleType, 

67 *, 

68 min_file_size_mb: int = 64, 

69 chunk_size_mb: int = 16, 

70 stream_reader: TensorFlowStreamReader | None = None, 

71 use_cufile: bool = False, 

72 fallback_to_original: bool = True, 

73 ) -> None: 

74 self._tf = tf_module 

75 self._min_file_size = max(1, min_file_size_mb) * 1024 * 1024 

76 self._chunk_size = max(1, chunk_size_mb) * 1024 * 1024 

77 self._reader = stream_reader or ( 

78 CuFileTFStreamReader() if use_cufile else PythonTFStreamReader() 

79 ) 

80 self._fallback_to_original = fallback_to_original 

81 self._original_load_model = None 

82 

83 @property 

84 def installed(self) -> bool: 

85 return self._original_load_model is not None 

86 

87 def install(self) -> None: 

88 if self.installed: 

89 return 

90 self._original_load_model = self._tf.keras.models.load_model 

91 self._tf.keras.models.load_model = self._patched_load_model 

92 

93 def uninstall(self) -> None: 

94 if not self.installed: 

95 return 

96 self._tf.keras.models.load_model = self._original_load_model 

97 self._original_load_model = None 

98 

99 def _should_stream(self, source: Any) -> bool: 

100 if not isinstance(source, (str, os.PathLike, Path)): 

101 return False 

102 source_path = os.fspath(source) 

103 if not os.path.isfile(source_path): 

104 return False 

105 return os.path.getsize(source_path) >= self._min_file_size 

106 

107 def _patched_load_model(self, source: Any, *args: Any, **kwargs: Any): 

108 if not self._should_stream(source): 

109 return self._original_load_model(source, *args, **kwargs) 

110 

111 source_path = os.fspath(source) 

112 try: 

113 return self._streaming_load_model(source_path, *args, **kwargs) 

114 except Exception: 

115 if not self._fallback_to_original: 

116 raise 

117 return self._original_load_model(source, *args, **kwargs) 

118 

119 def _streaming_load_model(self, file_path: str, *args: Any, **kwargs: Any): 

120 with SpooledTemporaryFile(max_size=self._chunk_size * 2) as spooled: 

121 for chunk in self._reader.iter_chunks(file_path, self._chunk_size): 

122 spooled.write(chunk) 

123 spooled.seek(0) 

124 return self._original_load_model(spooled, *args, **kwargs) 

125 

126 

127def patch_tensorflow_load_model( 

128 tf_module: ModuleType, 

129 *, 

130 min_file_size_mb: int = 64, 

131 chunk_size_mb: int = 16, 

132 stream_reader: TensorFlowStreamReader | None = None, 

133 use_cufile: bool = False, 

134 fallback_to_original: bool = True, 

135) -> TensorFlowCuFilePatcher: 

136 patcher = TensorFlowCuFilePatcher( 

137 tf_module, 

138 min_file_size_mb=min_file_size_mb, 

139 chunk_size_mb=chunk_size_mb, 

140 stream_reader=stream_reader, 

141 use_cufile=use_cufile, 

142 fallback_to_original=fallback_to_original, 

143 ) 

144 patcher.install() 

145 return patcher