Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions src/hpc/audio/bands.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
//! Opus CELT band energy computation.
//!
//! 21 quasi-Bark critical bands at 48kHz. Each band's energy is the
//! gain component of gain-shape quantization. The normalized coefficients
//! (after dividing by band energy) are the shape component → PVQ.
//!
//! Band boundaries from Opus `celt/modes.c` eBands48.

/// Opus CELT band boundaries at 48kHz, 960-sample frames (480 MDCT bins).
/// 22 boundaries define 21 bands. Bin index = frequency / (48000 / 960).
/// Band 0: bins 0-3 (~0-200 Hz), Band 20: bins 400-480 (~20-24 kHz).
pub const CELT_BANDS_48K: [usize; 22] = [
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 44, 52, 60, 68, 80, 96,
112, 136, 160, 200, 256, 480,
];

/// Number of critical bands.
pub const N_BANDS: usize = 21;

/// Compute band energies from MDCT coefficients.
///
/// Returns 21 f32 energies (sqrt of sum-of-squares per band).
/// These are the "gain" in gain-shape quantization.
pub fn band_energies(coeffs: &[f32]) -> [f32; N_BANDS] {
let mut energies = [0.0f32; N_BANDS];
for band in 0..N_BANDS {
let lo = CELT_BANDS_48K[band];
let hi = CELT_BANDS_48K[band + 1].min(coeffs.len());
let mut sum_sq = 0.0f32;
for i in lo..hi {
if i < coeffs.len() {
sum_sq += coeffs[i] * coeffs[i];
}
}
energies[band] = sum_sq.sqrt();
}
energies
}

/// Normalize MDCT coefficients by band energy (produce unit-energy shape).
///
/// After normalization, each band has unit energy. The shape encodes
/// the spectral tilt within the band. PVQ quantizes this shape.
pub fn normalize_bands(coeffs: &[f32], energies: &[f32; N_BANDS]) -> Vec<f32> {
let mut normalized = coeffs.to_vec();
for band in 0..N_BANDS {
let lo = CELT_BANDS_48K[band];
let hi = CELT_BANDS_48K[band + 1].min(normalized.len());
let e = energies[band].max(1e-10);
for i in lo..hi {
if i < normalized.len() {
normalized[i] /= e;
}
}
}
normalized
}

/// Denormalize: multiply shape coefficients by band energies.
///
/// Inverse of normalize_bands. Used in the decoder path:
/// PVQ-decoded shape × band energies → MDCT coefficients → iMDCT → PCM.
pub fn denormalize_bands(shape: &[f32], energies: &[f32; N_BANDS]) -> Vec<f32> {
let mut coeffs = shape.to_vec();
for band in 0..N_BANDS {
let lo = CELT_BANDS_48K[band];
let hi = CELT_BANDS_48K[band + 1].min(coeffs.len());
let e = energies[band];
for i in lo..hi {
if i < coeffs.len() {
coeffs[i] *= e;
}
}
}
coeffs
}

/// Pack band energies to BF16 (21 × 2 bytes = 42 bytes).
pub fn energies_to_bf16(energies: &[f32; N_BANDS]) -> [u16; N_BANDS] {
let mut bf16 = [0u16; N_BANDS];
for i in 0..N_BANDS {
let bits = energies[i].to_bits();
let lsb = (bits >> 16) & 1;
let biased = bits.wrapping_add(0x7FFF).wrapping_add(lsb);
bf16[i] = (biased >> 16) as u16;
}
bf16
}

/// Unpack BF16 band energies to f32.
pub fn bf16_to_energies(bf16: &[u16; N_BANDS]) -> [f32; N_BANDS] {
let mut energies = [0.0f32; N_BANDS];
for i in 0..N_BANDS {
energies[i] = f32::from_bits((bf16[i] as u32) << 16);
}
energies
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn band_count() {
assert_eq!(CELT_BANDS_48K.len(), N_BANDS + 1);
}

#[test]
fn band_energies_nonzero() {
let coeffs: Vec<f32> = (0..480).map(|i| (i as f32 * 0.05).sin()).collect();
let e = band_energies(&coeffs);
let total: f32 = e.iter().sum();
assert!(total > 0.1, "Total band energy too low: {}", total);
}

#[test]
fn normalize_denormalize_roundtrip() {
let coeffs: Vec<f32> = (0..480).map(|i| (i as f32 * 0.1).sin() * 2.0).collect();
let e = band_energies(&coeffs);
let shape = normalize_bands(&coeffs, &e);
let recovered = denormalize_bands(&shape, &e);

for (orig, rec) in coeffs.iter().zip(recovered.iter()) {
assert!((orig - rec).abs() < 0.01,
"Roundtrip mismatch: {} vs {}", orig, rec);
}
}

#[test]
fn bf16_energy_roundtrip() {
let e = [1.0, 0.5, 2.0, 0.001, 100.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
let bf16 = energies_to_bf16(&e);
let recovered = bf16_to_energies(&bf16);
for i in 0..5 {
let err = (e[i] - recovered[i]).abs() / e[i].max(1e-6);
assert!(err < 0.02, "BF16 roundtrip error for band {}: {:.4}", i, err);
}
}
}
152 changes: 152 additions & 0 deletions src/hpc/audio/codec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//! AudioFrame: 48-byte codec for one frame of audio.
//!
//! The complete encode/decode pipeline:
//! encode: PCM → MDCT → band energies (gain) + PVQ (shape) → AudioFrame
//! decode: AudioFrame → band energies × PVQ shape → iMDCT → PCM
//!
//! One AudioFrame = one graph node in lance-graph. 48 bytes = CAM-compatible.

use super::mdct;
use super::bands;
use super::pvq;

/// One audio frame: 42 bytes gain + 6 bytes shape = 48 bytes.
///
/// Maps to SPO:
/// Subject = spectral (WHAT frequencies) → band energies
/// Predicate = temporal (WHEN they happen) → PVQ summary bytes 2-3
/// Object = harmonic (HOW they ring) → PVQ summary bytes 4-5
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AudioFrame {
/// 21 band energies as BF16 (42 bytes). The gain component.
pub band_energies: [u16; bands::N_BANDS],
/// PVQ shape fingerprint (6 bytes). HEEL/HIP/TWIG levels.
pub pvq_summary: [u8; 6],
}

impl AudioFrame {
/// Total byte size: 42 (energies) + 6 (pvq) = 48.
pub const BYTE_SIZE: usize = bands::N_BANDS * 2 + 6;

/// Encode one frame of PCM audio.
///
/// `pcm`: mono f32 samples (padded to power of 2 internally).
/// `pvq_k`: PVQ pulse budget per band (higher = better quality, more bits).
pub fn encode(pcm: &[f32], pvq_k: u32) -> Self {
// MDCT: time → frequency
let coeffs = mdct::mdct_forward(pcm);

// Band energies (gain)
let energies = bands::band_energies(&coeffs);
let bf16_energies = bands::energies_to_bf16(&energies);

// Normalize bands (remove gain, keep shape)
let shape = bands::normalize_bands(&coeffs, &energies);

// PVQ encode the shape of the first (most important) band
// For production: encode all 21 bands. For the POC: just first band's summary.
let first_band_end = bands::CELT_BANDS_48K[1].min(shape.len());
let pulses = pvq::pvq_encode(&shape[..first_band_end], pvq_k);
let summary = pvq::pvq_summary(&pulses);

AudioFrame {
band_energies: bf16_energies,
pvq_summary: summary,
}
}

/// Decode: reconstruct PCM from AudioFrame + optional full PVQ data.
///
/// Without PVQ data: uses band energies only (coarse reconstruction).
/// The PVQ summary gives the HHTL routing info, not the full shape.
/// For full quality: pass the per-band PVQ pulse vectors.
pub fn decode_coarse(&self) -> Vec<f32> {
let energies = bands::bf16_to_energies(&self.band_energies);

// Synthesize a simple spectral envelope from band energies
// Each band gets a flat spectrum at its energy level
let n2 = bands::CELT_BANDS_48K[bands::N_BANDS].min(480);
let mut coeffs = vec![0.0f32; n2];
for band in 0..bands::N_BANDS {
let lo = bands::CELT_BANDS_48K[band];
let hi = bands::CELT_BANDS_48K[band + 1].min(n2);
let n_bins = (hi - lo).max(1);
let per_bin = energies[band] / (n_bins as f32).sqrt();
for i in lo..hi {
// Alternate signs for a more natural-sounding shape
let sign = if (i - lo) % 2 == 0 { 1.0 } else { -1.0 };
coeffs[i] = per_bin * sign;
}
}

// iMDCT: frequency → time
mdct::mdct_backward(&coeffs)
}

/// Serialize to 48 bytes.
pub fn to_bytes(&self) -> [u8; Self::BYTE_SIZE] {
let mut bytes = [0u8; Self::BYTE_SIZE];
for i in 0..bands::N_BANDS {
let b = self.band_energies[i].to_le_bytes();
bytes[i * 2] = b[0];
bytes[i * 2 + 1] = b[1];
}
bytes[42..48].copy_from_slice(&self.pvq_summary);
bytes
}

/// Deserialize from 48 bytes.
pub fn from_bytes(bytes: &[u8; Self::BYTE_SIZE]) -> Self {
let mut band_energies = [0u16; bands::N_BANDS];
for i in 0..bands::N_BANDS {
band_energies[i] = u16::from_le_bytes([bytes[i * 2], bytes[i * 2 + 1]]);
}
let mut pvq_summary = [0u8; 6];
pvq_summary.copy_from_slice(&bytes[42..48]);
AudioFrame { band_energies, pvq_summary }
}
}

#[cfg(test)]
mod tests {
use super::*;
use core::f32::consts::PI;

#[test]
fn frame_48_bytes() {
assert_eq!(AudioFrame::BYTE_SIZE, 48);
}

#[test]
fn encode_decode_nonzero() {
// 440Hz sine at 48kHz, 1024 samples
let pcm: Vec<f32> = (0..1024)
.map(|i| (2.0 * PI * 440.0 * i as f32 / 48000.0).sin())
.collect();

let frame = AudioFrame::encode(&pcm, 8);

// Band energies should be nonzero (at least the band containing 440Hz)
let total_energy: f32 = frame.band_energies.iter()
.map(|&b| f32::from_bits((b as u32) << 16))
.sum();
assert!(total_energy > 0.01, "Encoded frame has no energy: {}", total_energy);

// Decode
let decoded = frame.decode_coarse();
assert!(!decoded.is_empty());
let decoded_energy: f32 = decoded.iter().map(|s| s * s).sum();
assert!(decoded_energy > 1e-10, "Decoded has no energy: {}", decoded_energy);
}

#[test]
fn serialize_roundtrip() {
let frame = AudioFrame {
band_energies: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
pvq_summary: [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF],
};
let bytes = frame.to_bytes();
let recovered = AudioFrame::from_bytes(&bytes);
assert_eq!(frame, recovered);
}
}
Loading
Loading