Coverage for src / cufile_patcher / tensorflow_patcher.py: 100%
81 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-11 15:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-11 15:06 +0000
1from __future__ import annotations
3import ctypes
4import os
5from abc import ABC, abstractmethod
6from pathlib import Path
7from tempfile import SpooledTemporaryFile
8from types import ModuleType
9from typing import Any
11from .cufile import CuFile
14class TensorFlowStreamReader(ABC):
15 """Plugin contract for reading TensorFlow model files in chunks."""
17 @abstractmethod
18 def iter_chunks(self, file_path: str, chunk_size: int):
19 """Yield file bytes in chunk-sized pieces."""
22class PythonTFStreamReader(TensorFlowStreamReader):
23 """Portable chunk reader that works without GPU dependencies."""
25 def iter_chunks(self, file_path: str, chunk_size: int):
26 with open(file_path, "rb") as fp:
27 while True:
28 chunk = fp.read(chunk_size)
29 if not chunk:
30 break
31 yield chunk
34class CuFileTFStreamReader(TensorFlowStreamReader):
35 """Chunk reader that pulls data through the CuFile wrapper."""
37 def iter_chunks(self, file_path: str, chunk_size: int):
38 file_size = os.path.getsize(file_path)
39 offset = 0
41 with CuFile(file_path, "r") as cu_file:
42 while offset < file_size:
43 to_read = min(chunk_size, file_size - offset)
44 staging = ctypes.create_string_buffer(to_read)
45 read_n = cu_file.read(
46 ctypes.cast(staging, ctypes.c_void_p),
47 to_read,
48 file_offset=offset,
49 )
50 if read_n <= 0:
51 break
52 offset += read_n
53 yield staging.raw[:read_n]
56class TensorFlowCuFilePatcher:
57 """
58 Monkey-patch tf.keras.models.load_model for chunked streaming of large files.
60 Large path-based loads are streamed into a spooled file before invoking the
61 original TensorFlow loader. Small files and non-path inputs keep original behavior.
62 """
64 def __init__(
65 self,
66 tf_module: ModuleType,
67 *,
68 min_file_size_mb: int = 64,
69 chunk_size_mb: int = 16,
70 stream_reader: TensorFlowStreamReader | None = None,
71 use_cufile: bool = False,
72 fallback_to_original: bool = True,
73 ) -> None:
74 self._tf = tf_module
75 self._min_file_size = max(1, min_file_size_mb) * 1024 * 1024
76 self._chunk_size = max(1, chunk_size_mb) * 1024 * 1024
77 self._reader = stream_reader or (
78 CuFileTFStreamReader() if use_cufile else PythonTFStreamReader()
79 )
80 self._fallback_to_original = fallback_to_original
81 self._original_load_model = None
83 @property
84 def installed(self) -> bool:
85 return self._original_load_model is not None
87 def install(self) -> None:
88 if self.installed:
89 return
90 self._original_load_model = self._tf.keras.models.load_model
91 self._tf.keras.models.load_model = self._patched_load_model
93 def uninstall(self) -> None:
94 if not self.installed:
95 return
96 self._tf.keras.models.load_model = self._original_load_model
97 self._original_load_model = None
99 def _should_stream(self, source: Any) -> bool:
100 if not isinstance(source, (str, os.PathLike, Path)):
101 return False
102 source_path = os.fspath(source)
103 if not os.path.isfile(source_path):
104 return False
105 return os.path.getsize(source_path) >= self._min_file_size
107 def _patched_load_model(self, source: Any, *args: Any, **kwargs: Any):
108 if not self._should_stream(source):
109 return self._original_load_model(source, *args, **kwargs)
111 source_path = os.fspath(source)
112 try:
113 return self._streaming_load_model(source_path, *args, **kwargs)
114 except Exception:
115 if not self._fallback_to_original:
116 raise
117 return self._original_load_model(source, *args, **kwargs)
119 def _streaming_load_model(self, file_path: str, *args: Any, **kwargs: Any):
120 with SpooledTemporaryFile(max_size=self._chunk_size * 2) as spooled:
121 for chunk in self._reader.iter_chunks(file_path, self._chunk_size):
122 spooled.write(chunk)
123 spooled.seek(0)
124 return self._original_load_model(spooled, *args, **kwargs)
127def patch_tensorflow_load_model(
128 tf_module: ModuleType,
129 *,
130 min_file_size_mb: int = 64,
131 chunk_size_mb: int = 16,
132 stream_reader: TensorFlowStreamReader | None = None,
133 use_cufile: bool = False,
134 fallback_to_original: bool = True,
135) -> TensorFlowCuFilePatcher:
136 patcher = TensorFlowCuFilePatcher(
137 tf_module,
138 min_file_size_mb=min_file_size_mb,
139 chunk_size_mb=chunk_size_mb,
140 stream_reader=stream_reader,
141 use_cufile=use_cufile,
142 fallback_to_original=fallback_to_original,
143 )
144 patcher.install()
145 return patcher