Skip to main content

rustriff_lib/infrastructure/file_loader/
file_loader.rs

1use crate::infrastructure::file_loader::file_loader_trait::FileLoaderTrait;
2use hound::{SampleFormat, WavReader};
3use std::fs;
4use std::io::Cursor;
5use std::path::Path;
6use tracing::{info, warn};
7
8/// Filesystem-backed implementation of [`FileLoaderTrait`].
9///
10/// Uses [`std::fs`] for directory and file operations, and the [`hound`] crate
11/// for WAV decoding and validation.
12pub struct FileLoader;
13
14impl FileLoader {
15    /// Creates a new `FileLoader`.
16    pub fn new() -> Self {
17        Self
18    }
19
20    /// Downmixes a multi-channel interleaved buffer to mono by averaging channels.
21    ///
22    /// For mono input (`channels == 1`), returns the buffer unchanged.
23    /// For stereo and multi-channel input, groups samples by channel count and
24    /// averages them into a single mono stream.
25    ///
26    /// # Example
27    /// ```text
28    /// channels = 2, buffer = [L0, R0, L1, R1, L2, R2]
29    /// → [(L0 + R0)/2, (L1 + R1)/2, (L2 + R2)/2]
30    /// ```
31    fn downmix_to_mono(buffer: Vec<f32>, channels: u16) -> Vec<f32> {
32        if channels == 1 {
33            return buffer;
34        }
35
36        let channels = channels as usize;
37        let frame_count = buffer.len() / channels;
38        let mut mono = Vec::with_capacity(frame_count);
39
40        for frame_index in 0..frame_count {
41            let start = frame_index * channels;
42            let end = start + channels;
43            let frame_sum: f32 = buffer[start..end].iter().sum();
44            mono.push(frame_sum / channels as f32);
45        }
46
47        mono
48    }
49}
50
51impl Default for FileLoader {
52    fn default() -> Self {
53        Self::new()
54    }
55}
56
57impl FileLoaderTrait for FileLoader {
58    fn read_wav_sample_rate(&self, path: &Path) -> Option<u32> {
59        WavReader::open(path)
60            .ok()
61            .map(|reader| reader.spec().sample_rate)
62    }
63
64    fn read_wav_to_buffer(&self, path: &Path) -> Vec<f32> {
65        match WavReader::open(path) {
66            Ok(mut reader) => {
67                let spec = reader.spec();
68                match spec.sample_format {
69                    SampleFormat::Float => {
70                        match reader.samples::<f32>().collect::<Result<Vec<_>, _>>() {
71                            Ok(buffer) => {
72                                let mono = Self::downmix_to_mono(buffer, spec.channels);
73                                info!(
74                                    "Loaded IR '{}' (channels={}, sample_rate={}, mono_samples={})",
75                                    path.display(),
76                                    spec.channels,
77                                    spec.sample_rate,
78                                    mono.len()
79                                );
80                                mono
81                            }
82                            Err(e) => {
83                                warn!(
84                                    "Failed to read float samples from '{}': {e}",
85                                    path.display()
86                                );
87                                Vec::new()
88                            }
89                        }
90                    }
91                    SampleFormat::Int => {
92                        let max = ((1_i64 << (spec.bits_per_sample.saturating_sub(1))) - 1) as f32;
93                        match reader
94                            .samples::<i32>()
95                            .map(|sample| sample.map(|value| value as f32 / max.max(1.0)))
96                            .collect::<Result<Vec<_>, _>>()
97                        {
98                            Ok(buffer) => {
99                                let mono = Self::downmix_to_mono(buffer, spec.channels);
100                                info!(
101                                    "Loaded IR '{}' (channels={}, sample_rate={}, mono_samples={})",
102                                    path.display(),
103                                    spec.channels,
104                                    spec.sample_rate,
105                                    mono.len()
106                                );
107                                mono
108                            }
109                            Err(e) => {
110                                warn!("Failed to read int samples from '{}': {e}", path.display());
111                                Vec::new()
112                            }
113                        }
114                    }
115                }
116            }
117            Err(e) => {
118                warn!("Failed to open IR file '{}': {e}", path.display());
119                Vec::new()
120            }
121        }
122    }
123
124    fn list_ir_profile_file_names(&self, directory: &Path) -> Result<Vec<String>, String> {
125        let entries = fs::read_dir(directory)
126            .map_err(|e| format!("Failed to read directory '{}': {e}", directory.display()))?;
127
128        let mut names: Vec<String> = entries
129            .filter_map(|entry| entry.ok())
130            .filter_map(|entry| {
131                let path = entry.path();
132                if !path.is_file() {
133                    return None;
134                }
135
136                if path
137                    .extension()
138                    .and_then(|ext| ext.to_str())
139                    .map(|ext| ext.eq_ignore_ascii_case("wav"))
140                    != Some(true)
141                {
142                    return None;
143                }
144
145                path.file_name()
146                    .and_then(|name| name.to_str())
147                    .map(|name| name.to_string())
148            })
149            .collect();
150
151        names.sort();
152        Ok(names)
153    }
154
155    fn ensure_directory(&self, directory: &Path) -> Result<(), String> {
156        fs::create_dir_all(directory)
157            .map_err(|e| format!("Failed to create directory '{}': {e}", directory.display()))
158    }
159
160    fn write_file_bytes(&self, path: &Path, bytes: &[u8]) -> Result<(), String> {
161        fs::write(path, bytes)
162            .map_err(|e| format!("Failed to write file '{}': {e}", path.display()))
163    }
164
165    fn remove_file(&self, path: &Path) -> Result<(), String> {
166        fs::remove_file(path)
167            .map_err(|e| format!("Failed to remove file '{}': {e}", path.display()))
168    }
169
170    fn validate_ir_wav_bytes(
171        &self,
172        file_name: &str,
173        file_bytes: &[u8],
174        impulse_threshold: f32,
175    ) -> Result<(), String> {
176        if !file_name.to_ascii_lowercase().ends_with(".wav") {
177            return Err("Only .wav IR files are supported".to_string());
178        }
179
180        let mut reader = WavReader::new(Cursor::new(file_bytes)).map_err(|e| {
181            let raw = e.to_string();
182            if raw.contains("unexpected fmt chunk size") {
183                format!(
184                    "Unsupported WAV format for '{}': {}. Re-export as PCM 16/24-bit or IEEE float 32-bit WAV.",
185                    file_name, raw
186                )
187            } else {
188                format!("Invalid WAV file '{}': {raw}", file_name)
189            }
190        })?;
191
192        let spec = reader.spec();
193
194        const IMPULSE_SEARCH_WINDOW_SAMPLES: usize = 256;
195
196        let max_abs_in_window = match spec.sample_format {
197            SampleFormat::Float => {
198                let mut iter = reader.samples::<f32>();
199                let first = iter
200                    .next()
201                    .ok_or_else(|| "IR file is empty".to_string())
202                    .and_then(|s| s.map_err(|e| format!("Failed to read first sample: {e}")))?;
203
204                let mut max_abs = first.abs();
205                for sample in iter.take(IMPULSE_SEARCH_WINDOW_SAMPLES.saturating_sub(1)) {
206                    let value = sample.map_err(|e| format!("Failed to read WAV samples: {e}"))?;
207                    max_abs = max_abs.max(value.abs());
208                }
209
210                max_abs
211            }
212            SampleFormat::Int => {
213                let max = ((1_i64 << (spec.bits_per_sample.saturating_sub(1))) - 1) as f32;
214                let mut iter = reader.samples::<i32>();
215                let first = iter
216                    .next()
217                    .ok_or_else(|| "IR file is empty".to_string())
218                    .and_then(|s| s.map_err(|e| format!("Failed to read first sample: {e}")))?;
219
220                let mut max_abs = (first as f32 / max.max(1.0)).abs();
221                for sample in iter.take(IMPULSE_SEARCH_WINDOW_SAMPLES.saturating_sub(1)) {
222                    let value = sample.map_err(|e| format!("Failed to read WAV samples: {e}"))?;
223                    max_abs = max_abs.max((value as f32 / max.max(1.0)).abs());
224                }
225
226                max_abs
227            }
228        };
229
230        if max_abs_in_window <= impulse_threshold {
231            return Err(
232                "Invalid IR: no impulse detected at file start (first 256 samples are effectively silent)"
233                    .to_string(),
234            );
235        }
236
237        Ok(())
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use super::*;
244    use hound::{SampleFormat, WavSpec, WavWriter};
245    use std::path::{Path, PathBuf};
246    use std::time::{SystemTime, UNIX_EPOCH};
247
248    fn unique_test_dir() -> PathBuf {
249        let nanos = SystemTime::now()
250            .duration_since(UNIX_EPOCH)
251            .expect("time should be monotonic")
252            .as_nanos();
253        std::env::temp_dir().join(format!("rustriff-file-loader-{nanos}"))
254    }
255
256    fn write_float_wav_file(path: &Path, samples: &[f32], sample_rate: u32) {
257        let spec = WavSpec {
258            channels: 1,
259            sample_rate,
260            bits_per_sample: 32,
261            sample_format: SampleFormat::Float,
262        };
263
264        let mut writer = WavWriter::create(path, spec).expect("wav file should be creatable");
265        for sample in samples {
266            writer
267                .write_sample(*sample)
268                .expect("sample should be writable");
269        }
270        writer.finalize().expect("wav writer should finalize");
271    }
272
273    fn write_stereo_wav_file(path: &Path, samples: &[f32], sample_rate: u32) {
274        let spec = WavSpec {
275            channels: 2,
276            sample_rate,
277            bits_per_sample: 32,
278            sample_format: SampleFormat::Float,
279        };
280
281        let mut writer =
282            WavWriter::create(path, spec).expect("stereo wav file should be creatable");
283        for sample in samples {
284            writer
285                .write_sample(*sample)
286                .expect("sample should be writable");
287        }
288        writer.finalize().expect("wav writer should finalize");
289    }
290
291    fn float_wav_bytes(samples: &[f32]) -> Vec<u8> {
292        let dir = unique_test_dir();
293        fs::create_dir_all(&dir).expect("test directory should be creatable");
294        let path = dir.join("buffer.wav");
295        write_float_wav_file(&path, samples, 48_000);
296        let bytes = fs::read(&path).expect("generated wav should be readable");
297        let _ = fs::remove_dir_all(dir);
298        bytes
299    }
300
301    #[cfg(test)]
302    mod success_path {
303        use super::*;
304
305        #[test]
306        fn list_ir_profile_file_names_returns_sorted_wav_files_only() {
307            let loader = FileLoader::new();
308            let dir = unique_test_dir();
309            fs::create_dir_all(dir.join("nested")).expect("test directory should be creatable");
310
311            write_float_wav_file(&dir.join("z-room.wav"), &[0.5, 0.0], 48_000);
312            write_float_wav_file(&dir.join("A-clean.WAV"), &[0.5, 0.0], 48_000);
313            // Subdirectory WAVs and non-WAV files should be silently ignored
314            write_float_wav_file(&dir.join("nested").join("ignored.wav"), &[0.5, 0.0], 48_000);
315            fs::write(dir.join("notes.txt"), b"not a wav").expect("text file should be writable");
316
317            let names = loader
318                .list_ir_profile_file_names(&dir)
319                .expect("listing IR profiles should succeed");
320
321            assert_eq!(
322                names,
323                vec!["A-clean.WAV".to_string(), "z-room.wav".to_string()]
324            );
325
326            let _ = fs::remove_dir_all(dir);
327        }
328
329        #[test]
330        fn read_wav_helpers_return_sample_rate_and_buffer_for_valid_ir() {
331            let loader = FileLoader::new();
332            let dir = unique_test_dir();
333            fs::create_dir_all(&dir).expect("test directory should be creatable");
334            let path = dir.join("valid-ir.wav");
335            let samples = [0.75_f32, -0.25_f32, 0.125_f32];
336            write_float_wav_file(&path, &samples, 44_100);
337
338            let sample_rate = loader.read_wav_sample_rate(&path);
339            let buffer = loader.read_wav_to_buffer(&path);
340
341            assert_eq!(sample_rate, Some(44_100));
342            assert_eq!(buffer.len(), samples.len());
343            assert!((buffer[0] - 0.75).abs() < 1e-6);
344            assert!((buffer[1] + 0.25).abs() < 1e-6);
345            assert!((buffer[2] - 0.125).abs() < 1e-6);
346
347            let _ = fs::remove_dir_all(dir);
348        }
349
350        #[test]
351        fn validate_ir_wav_bytes_accepts_valid_impulse_wav() {
352            let loader = FileLoader::new();
353            let bytes = float_wav_bytes(&[0.25, 0.0, 0.0, 0.0]);
354
355            loader
356                .validate_ir_wav_bytes("cab.wav", &bytes, 1e-6)
357                .expect("impulse IR should validate");
358        }
359
360        #[test]
361        fn read_wav_to_buffer_downmixes_stereo_to_mono() {
362            let loader = FileLoader::new();
363            let dir = unique_test_dir();
364            fs::create_dir_all(&dir).expect("test directory should be creatable");
365            let path = dir.join("stereo-ir.wav");
366
367            // Interleaved stereo: [L0, R0, L1, R1, L2, R2]
368            let stereo_samples = [0.8_f32, 0.4_f32, 0.6_f32, 0.2_f32, 1.0_f32, 0.0_f32];
369            write_stereo_wav_file(&path, &stereo_samples, 48_000);
370
371            let mono = loader.read_wav_to_buffer(&path);
372
373            // Expected downmix: [(0.8+0.4)/2, (0.6+0.2)/2, (1.0+0.0)/2] = [0.6, 0.4, 0.5]
374            assert_eq!(mono.len(), 3);
375            assert!((mono[0] - 0.6).abs() < 1e-6);
376            assert!((mono[1] - 0.4).abs() < 1e-6);
377            assert!((mono[2] - 0.5).abs() < 1e-6);
378
379            let _ = fs::remove_dir_all(dir);
380        }
381    }
382
383    #[cfg(test)]
384    mod failure_path {
385        use super::*;
386
387        #[test]
388        fn read_wav_sample_rate_returns_none_for_missing_file() {
389            let loader = FileLoader::new();
390            let missing = std::env::temp_dir().join("does-not-exist-rustriff.wav");
391            assert!(loader.read_wav_sample_rate(&missing).is_none());
392        }
393
394        #[test]
395        fn read_wav_to_buffer_returns_empty_for_missing_file() {
396            let loader = FileLoader::new();
397            let missing = std::env::temp_dir().join("does-not-exist-rustriff.wav");
398            assert!(loader.read_wav_to_buffer(&missing).is_empty());
399        }
400
401        #[test]
402        fn validate_ir_wav_bytes_rejects_non_wav_extension() {
403            let loader = FileLoader::new();
404            let bytes = float_wav_bytes(&[0.25, 0.0, 0.0, 0.0]);
405
406            let err = loader
407                .validate_ir_wav_bytes("cab.mp3", &bytes, 1e-6)
408                .expect_err("non-wav extension should be rejected");
409            assert!(err.contains("Only .wav IR files are supported"));
410        }
411
412        #[test]
413        fn validate_ir_wav_bytes_rejects_silent_file_start() {
414            let loader = FileLoader::new();
415            let bytes = float_wav_bytes(&[0.0; 32]);
416
417            let err = loader
418                .validate_ir_wav_bytes("silent.wav", &bytes, 1e-6)
419                .expect_err("silent IR should be rejected");
420            assert!(err.contains("no impulse detected"));
421        }
422    }
423}