Skip to main content

rustriff_lib/infrastructure/
audio_handler.rs

1use cpal::traits::{DeviceTrait, StreamTrait};
2use cpal::{Device, Stream, StreamConfig};
3use derive_getters::Getters;
4use mockall::automock;
5use ringbuf::consumer::Consumer;
6use ringbuf::producer::Producer;
7use ringbuf::traits::Split;
8use ringbuf::{HeapCons, HeapProd, HeapRb};
9use tracing::error;
10
11/// A thin wrapper around a CPAL [`Stream`] that allows it to be started
12/// through a trait object, enabling mocking in tests.
13pub trait PlayableStream: Send {
14    /// Starts playback/capture on the underlying stream.
15    ///
16    /// # Panics
17    ///
18    /// Panics if the underlying CPAL stream returns an error when playing.
19    fn play(&self);
20}
21
22impl PlayableStream for Stream {
23    fn play(&self) {
24        StreamTrait::play(self).unwrap();
25    }
26}
27
28/// Abstraction over the audio I/O handler used by [`AudioService`].
29///
30/// Using the [`MockAudioHandlerTrait`] generated by `mockall` allows the audio pipeline to be tested without real hardware.
31///
32/// [`AudioService`]: crate::services::audio_service::AudioService
33#[automock]
34pub trait AudioHandlerTrait: Send + Sync {
35    /// Builds and returns a started-ready input stream that pushes captured
36    /// samples into `prod` (producer).
37    ///
38    /// # Arguments
39    ///
40    /// * `prod` - The ring-buffer producer that receives raw `f32` audio samples.
41    fn build_input_stream(&self, prod: HeapProd<f32>) -> Box<dyn PlayableStream>;
42
43    /// Builds and returns a started-ready output stream that drains samples
44    /// from `cons` (consumer) and sends them to the output device.
45    ///
46    /// Slots in the output buffer that have no corresponding sample are
47    /// filled with silence (`0.0`).
48    ///
49    /// # Arguments
50    ///
51    /// * `cons` - The ring-buffer consumer that supplies processed `f32` audio samples.
52    fn build_output_stream(&self, cons: HeapCons<f32>) -> Box<dyn PlayableStream>;
53
54    /// Returns a reference to the CPAL input device used by this handler.
55    fn input_device(&self) -> &Device;
56
57    /// Returns a reference to the CPAL output device used by this handler.
58    fn output_device(&self) -> &Device;
59
60    /// Returns the [`StreamConfig`] used for the input stream.
61    fn input_config(&self) -> &StreamConfig;
62
63    /// Returns the [`StreamConfig`] used for the output stream.
64    fn output_config(&self) -> &StreamConfig;
65
66    /// Returns the configured input sample rate in Hz.
67    fn input_sample_rate(&self) -> u32;
68
69    /// Returns the configured output sample rate in Hz.
70    fn output_sample_rate(&self) -> u32;
71}
72
73/// Concrete implementation of [`AudioHandlerTrait`] backed by real CPAL devices.
74///
75/// `AudioHandler` owns the input device, output device, and stream configuration
76/// required to build CPAL streams. It is cheaply cloneable so that it can be
77/// shared across threads via [`Arc`].
78///
79/// [`Arc`]: std::sync::Arc
80#[derive(Clone, Getters)]
81pub struct AudioHandler {
82    input_device: Device,
83    output_device: Device,
84    input_config: StreamConfig,
85    output_config: StreamConfig,
86    input_sample_rate: u32,
87    output_sample_rate: u32,
88}
89
90impl AudioHandler {
91    /// Creates a new `AudioHandler` with the given CPAL devices and stream config.
92    ///
93    /// # Arguments
94    ///
95    /// * `input_device` - The CPAL device used to capture audio.
96    /// * `output_device` - The CPAL device used to play back audio.
97    /// * `input_config` - The [`StreamConfig`] used for the input stream.
98    /// * `output_config` - The [`StreamConfig`] used for the output stream.
99    pub fn new(
100        input_device: Device,
101        output_device: Device,
102        input_config: StreamConfig,
103        output_config: StreamConfig,
104    ) -> Self {
105        let input_sample_rate = input_config.sample_rate;
106        let output_sample_rate = output_config.sample_rate;
107
108        Self {
109            input_device,
110            output_device,
111            input_config,
112            output_config,
113            input_sample_rate,
114            output_sample_rate,
115        }
116    }
117
118    /// Creates a lock-free ring buffer of `f32` samples with the given capacity.
119    ///
120    /// Returns a `(producer, consumer)` pair that can be moved into separate
121    /// threads for wait-free, single-producer/single-consumer audio transfer.
122    ///
123    /// # Arguments
124    ///
125    /// * `size` - The number of `f32` samples the ring buffer can hold.
126    pub fn create_ringbuffer(size: usize) -> (HeapProd<f32>, HeapCons<f32>) {
127        let rb = HeapRb::<f32>::new(size);
128        rb.split()
129    }
130
131    /// Replaces the current output device.
132    ///
133    /// # Arguments
134    ///
135    /// * `output_device` - The new CPAL output device.
136    pub fn set_output_device(&mut self, output_device: Device) {
137        self.output_device = output_device;
138    }
139
140    /// Replaces the current input device.
141    ///
142    /// # Arguments
143    ///
144    /// * `input_device` - The new CPAL input device.
145    pub fn set_input_device(&mut self, input_device: Device) {
146        self.input_device = input_device;
147    }
148}
149
150impl AudioHandlerTrait for AudioHandler {
151    /// Builds a CPAL input stream that forwards every captured sample into the
152    /// provided ring-buffer producer.
153    ///
154    /// Samples that cannot be pushed (i.e. the ring buffer is full) are silently
155    /// dropped. Input errors are added to logs as error.
156    ///
157    /// # Panics
158    ///
159    /// Panics if CPAL fails to build the input stream.
160    fn build_input_stream(&self, mut producer: HeapProd<f32>) -> Box<dyn PlayableStream> {
161        let stream = self
162            .input_device
163            .build_input_stream(
164                &self.input_config,
165                move |data: &[f32], _| {
166                    for &s in data {
167                        let _ = producer.try_push(s);
168                    }
169                },
170                move |err| error!("Input error: {:?}", err),
171                None,
172            )
173            .unwrap();
174        Box::new(stream)
175    }
176
177    /// Builds a CPAL output stream that drains samples from the provided
178    /// ring-buffer consumer into the hardware output buffer.
179    ///
180    /// Any output slot that has no corresponding sample is filled with `0.0`
181    /// (silence). Output errors are printed to stderr.
182    ///
183    /// # Panics
184    ///
185    /// Panics if CPAL fails to build the output stream.
186    fn build_output_stream(&self, mut consumer: HeapCons<f32>) -> Box<dyn PlayableStream> {
187        let stream = self
188            .output_device
189            .build_output_stream(
190                &self.output_config,
191                move |out: &mut [f32], _| {
192                    //println!("Output buffer: {:?}", &out[..10.min(out.len())]);
193                    for o in out.iter_mut() {
194                        *o = consumer.try_pop().unwrap_or(0.0);
195                    }
196                },
197                move |err| eprintln!("Output error: {:?}", err),
198                None,
199            )
200            .unwrap();
201        Box::new(stream)
202    }
203
204    fn input_device(&self) -> &Device {
205        &self.input_device
206    }
207
208    fn output_device(&self) -> &Device {
209        &self.output_device
210    }
211
212    fn input_config(&self) -> &StreamConfig {
213        &self.input_config
214    }
215
216    fn output_config(&self) -> &StreamConfig {
217        &self.output_config
218    }
219
220    fn input_sample_rate(&self) -> u32 {
221        self.input_sample_rate
222    }
223
224    fn output_sample_rate(&self) -> u32 {
225        self.output_sample_rate
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[cfg(test)]
234    mod success_path {
235        use super::*;
236
237        #[test]
238        fn test_input_callback_pushes_samples() {
239            let (mut prod, mut cons) = AudioHandler::create_ringbuffer(16);
240
241            let input_data = vec![0.1, 0.2, 0.3, 0.4];
242            {
243                for &s in &input_data {
244                    let _ = prod.try_push(s);
245                }
246            }
247
248            for expected in input_data {
249                assert_eq!(cons.try_pop(), Some(expected));
250            }
251        }
252        #[test]
253        fn test_output_callback_reads_samples() {
254            let (mut prod, mut cons) = AudioHandler::create_ringbuffer(16);
255            prod.try_push(10.0).unwrap();
256            prod.try_push(20.0).unwrap();
257            prod.try_push(30.0).unwrap();
258
259            let mut out = [0.0f32; 5];
260            {
261                for o in out.iter_mut() {
262                    *o = cons.try_pop().unwrap_or(0.0);
263                }
264            }
265
266            assert_eq!(out, [10.0, 20.0, 30.0, 0.0, 0.0]);
267        }
268    }
269
270    #[cfg(test)]
271    mod failure_path {
272        use super::*;
273
274        #[test]
275        fn test_input_callback_drops_when_full() {
276            let (mut prod, mut cons) = AudioHandler::create_ringbuffer(3);
277
278            let input_data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
279
280            for &s in &input_data {
281                let _ = prod.try_push(s);
282            }
283
284            assert_eq!(cons.try_pop(), Some(1.0));
285            assert_eq!(cons.try_pop(), Some(2.0));
286            assert_eq!(cons.try_pop(), Some(3.0));
287            assert_eq!(cons.try_pop(), None);
288        }
289
290        #[test]
291        fn test_output_callback_zero_fills_when_empty() {
292            let (_prod, mut cons) = AudioHandler::create_ringbuffer(8);
293
294            let mut out = [1.0f32; 4];
295
296            {
297                for o in out.iter_mut() {
298                    *o = cons.try_pop().unwrap_or(0.0);
299                }
300            }
301
302            assert_eq!(out, [0.0, 0.0, 0.0, 0.0]);
303        }
304    }
305}