126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
# Copyright 2024-2025 XMOS LIMITED.
|
|
# This Software is subject to the terms of the XMOS Public Licence: Version 1.
|
|
|
|
"""
|
|
Helper functions for creating and validating audio files
|
|
"""
|
|
|
|
import numpy as np
|
|
import scipy
|
|
import matplotlib.pyplot as plt
|
|
from pathlib import Path
|
|
|
|
from audio_dsp.dsp import generic
|
|
|
|
def read_wav(path):
|
|
rate, data = scipy.io.wavfile.read(path)
|
|
if data.ndim == 1:
|
|
data = np.expand_dims(data, axis=1)
|
|
return rate, data
|
|
|
|
def write_wav(path, fs, data):
|
|
return scipy.io.wavfile.write(path, fs, data)
|
|
|
|
def read_and_truncate(path, f_bits=generic.Q_SIG):
|
|
"""Read wav and truncate the least significant fractional bits"""
|
|
rate, data = read_wav(path)
|
|
|
|
if data.dtype != np.int32:
|
|
raise TypeError(f"wav data type {data.dtype} not supported")
|
|
|
|
mask = ~int(2**(31 - f_bits) - 1)
|
|
print("mask ", mask)
|
|
return rate, data & mask
|
|
|
|
def correlate_and_diff(output_file, input_file, out_ch_start_end, in_ch_start_end, skip_seconds_start, skip_seconds_end, tol, corr_plot_file=None, verbose=False):
|
|
rate_out, data_out = scipy.io.wavfile.read(output_file)
|
|
rate_in, data_in = scipy.io.wavfile.read(input_file)
|
|
|
|
if data_out.ndim == 1:
|
|
data_out = data_out.reshape(len(data_out), 1)
|
|
|
|
if data_in.ndim == 1:
|
|
data_in = data_in.reshape(len(data_in), 1)
|
|
|
|
if rate_out != rate_in:
|
|
assert False, f"input and output file rates are not equal. input rate {rate_in}, output rate {rate_out}"
|
|
|
|
|
|
if data_in.dtype != np.int32:
|
|
if data_in.dtype == np.int16:
|
|
data_in = np.array(data_in, dtype=np.int32)
|
|
data_in = data_in * (2**16)
|
|
else:
|
|
assert False, "Unsupported data_in.dtype {data_in.dtype}"
|
|
|
|
if data_out.dtype != np.int32:
|
|
if data_out.dtype == np.int16:
|
|
data_out = np.array(data_out, dtype=np.int32)
|
|
else:
|
|
assert False, "Unsupported data_out.dtype {data_out.dtype}"
|
|
|
|
assert out_ch_start_end[1]-out_ch_start_end[0] == in_ch_start_end[1]-in_ch_start_end[0], "input and output files have different channel nos."
|
|
|
|
|
|
skip_samples_start = int(rate_out * skip_seconds_start)
|
|
skip_samples_end = int(rate_out * skip_seconds_end)
|
|
data_in = data_in[:,in_ch_start_end[0]:in_ch_start_end[1]+1]
|
|
data_out = data_out[:,out_ch_start_end[0]:out_ch_start_end[1]+1]
|
|
|
|
small_len = min(len(data_in), len(data_out), 64000)
|
|
data_in_small = data_in[skip_samples_start : small_len+skip_samples_start, :].astype(np.float64)
|
|
data_out_small = data_out[skip_samples_start : small_len+skip_samples_start, :].astype(np.float64)
|
|
|
|
corr = scipy.signal.correlate(data_in_small[:, 0], data_out_small[:, 0], "full")
|
|
delay = (corr.shape[0] // 2) - np.argmax(corr)
|
|
print(f"delay = {delay}")
|
|
|
|
if corr_plot_file != None:
|
|
plt.plot(corr)
|
|
plt.savefig(corr_plot_file)
|
|
plt.clf()
|
|
delay_orig = delay
|
|
|
|
data_size = min(data_in.shape[0], data_out.shape[0])
|
|
data_size -= skip_samples_end
|
|
|
|
print(f"compare {data_size - skip_samples_start} samples")
|
|
print(data_in.shape)
|
|
print(data_out.shape)
|
|
print(delay)
|
|
|
|
num_channels = out_ch_start_end[1]-out_ch_start_end[0]+1
|
|
all_close = True
|
|
max_diff = []
|
|
for ch in range(num_channels):
|
|
print(f"comparing ch {ch}")
|
|
close = np.isclose(
|
|
data_in[skip_samples_start : data_size - delay, ch],
|
|
data_out[skip_samples_start + delay : data_size, ch],
|
|
atol=tol,
|
|
)
|
|
print(f"ch {ch}, close = {np.all(close)}")
|
|
|
|
if verbose:
|
|
int_max_idxs = np.argwhere(close[:] == False)
|
|
print("shape = ", int_max_idxs.shape)
|
|
print(int_max_idxs)
|
|
if np.all(close) == False:
|
|
if int_max_idxs[0] != 0:
|
|
count = 0
|
|
for i in int_max_idxs:
|
|
if count < 100: # Print first 100 values that were not close
|
|
print(i, data_in[skip_samples_start+i, ch], data_out[skip_samples_start + delay + i, ch])
|
|
count += 1
|
|
|
|
diff = np.abs((data_in[skip_samples_start : data_size - delay, ch]) - (data_out[skip_samples_start + delay : data_size, ch]))
|
|
max_diff.append(np.amax(diff))
|
|
print(f"max diff value is {max_diff[-1]}")
|
|
all_close = all_close & np.all(close)
|
|
|
|
print(f"all_close: {np.all(all_close)}")
|
|
return all_close, max(max_diff), delay_orig
|
|
|
|
|
|
|