Files
3d_audio/lib_audio_dsp/test/pipeline/python/audio_helpers.py
Steven Dan d8b2974133 init
2025-12-11 09:43:42 +08:00

126 lines
4.4 KiB
Python

# Copyright 2024-2025 XMOS LIMITED.
# This Software is subject to the terms of the XMOS Public Licence: Version 1.
"""
Helper functions for creating and validating audio files
"""
import numpy as np
import scipy
import matplotlib.pyplot as plt
from pathlib import Path
from audio_dsp.dsp import generic
def read_wav(path):
rate, data = scipy.io.wavfile.read(path)
if data.ndim == 1:
data = np.expand_dims(data, axis=1)
return rate, data
def write_wav(path, fs, data):
return scipy.io.wavfile.write(path, fs, data)
def read_and_truncate(path, f_bits=generic.Q_SIG):
"""Read wav and truncate the least significant fractional bits"""
rate, data = read_wav(path)
if data.dtype != np.int32:
raise TypeError(f"wav data type {data.dtype} not supported")
mask = ~int(2**(31 - f_bits) - 1)
print("mask ", mask)
return rate, data & mask
def correlate_and_diff(output_file, input_file, out_ch_start_end, in_ch_start_end, skip_seconds_start, skip_seconds_end, tol, corr_plot_file=None, verbose=False):
rate_out, data_out = scipy.io.wavfile.read(output_file)
rate_in, data_in = scipy.io.wavfile.read(input_file)
if data_out.ndim == 1:
data_out = data_out.reshape(len(data_out), 1)
if data_in.ndim == 1:
data_in = data_in.reshape(len(data_in), 1)
if rate_out != rate_in:
assert False, f"input and output file rates are not equal. input rate {rate_in}, output rate {rate_out}"
if data_in.dtype != np.int32:
if data_in.dtype == np.int16:
data_in = np.array(data_in, dtype=np.int32)
data_in = data_in * (2**16)
else:
assert False, "Unsupported data_in.dtype {data_in.dtype}"
if data_out.dtype != np.int32:
if data_out.dtype == np.int16:
data_out = np.array(data_out, dtype=np.int32)
else:
assert False, "Unsupported data_out.dtype {data_out.dtype}"
assert out_ch_start_end[1]-out_ch_start_end[0] == in_ch_start_end[1]-in_ch_start_end[0], "input and output files have different channel nos."
skip_samples_start = int(rate_out * skip_seconds_start)
skip_samples_end = int(rate_out * skip_seconds_end)
data_in = data_in[:,in_ch_start_end[0]:in_ch_start_end[1]+1]
data_out = data_out[:,out_ch_start_end[0]:out_ch_start_end[1]+1]
small_len = min(len(data_in), len(data_out), 64000)
data_in_small = data_in[skip_samples_start : small_len+skip_samples_start, :].astype(np.float64)
data_out_small = data_out[skip_samples_start : small_len+skip_samples_start, :].astype(np.float64)
corr = scipy.signal.correlate(data_in_small[:, 0], data_out_small[:, 0], "full")
delay = (corr.shape[0] // 2) - np.argmax(corr)
print(f"delay = {delay}")
if corr_plot_file != None:
plt.plot(corr)
plt.savefig(corr_plot_file)
plt.clf()
delay_orig = delay
data_size = min(data_in.shape[0], data_out.shape[0])
data_size -= skip_samples_end
print(f"compare {data_size - skip_samples_start} samples")
print(data_in.shape)
print(data_out.shape)
print(delay)
num_channels = out_ch_start_end[1]-out_ch_start_end[0]+1
all_close = True
max_diff = []
for ch in range(num_channels):
print(f"comparing ch {ch}")
close = np.isclose(
data_in[skip_samples_start : data_size - delay, ch],
data_out[skip_samples_start + delay : data_size, ch],
atol=tol,
)
print(f"ch {ch}, close = {np.all(close)}")
if verbose:
int_max_idxs = np.argwhere(close[:] == False)
print("shape = ", int_max_idxs.shape)
print(int_max_idxs)
if np.all(close) == False:
if int_max_idxs[0] != 0:
count = 0
for i in int_max_idxs:
if count < 100: # Print first 100 values that were not close
print(i, data_in[skip_samples_start+i, ch], data_out[skip_samples_start + delay + i, ch])
count += 1
diff = np.abs((data_in[skip_samples_start : data_size - delay, ch]) - (data_out[skip_samples_start + delay : data_size, ch]))
max_diff.append(np.amax(diff))
print(f"max diff value is {max_diff[-1]}")
all_close = all_close & np.all(close)
print(f"all_close: {np.all(all_close)}")
return all_close, max(max_diff), delay_orig