# Copyright 2024-2025 XMOS LIMITED. # This Software is subject to the terms of the XMOS Public Licence: Version 1. """ Helper functions for creating and validating audio files """ import numpy as np import scipy import matplotlib.pyplot as plt from pathlib import Path from audio_dsp.dsp import generic def read_wav(path): rate, data = scipy.io.wavfile.read(path) if data.ndim == 1: data = np.expand_dims(data, axis=1) return rate, data def write_wav(path, fs, data): return scipy.io.wavfile.write(path, fs, data) def read_and_truncate(path, f_bits=generic.Q_SIG): """Read wav and truncate the least significant fractional bits""" rate, data = read_wav(path) if data.dtype != np.int32: raise TypeError(f"wav data type {data.dtype} not supported") mask = ~int(2**(31 - f_bits) - 1) print("mask ", mask) return rate, data & mask def correlate_and_diff(output_file, input_file, out_ch_start_end, in_ch_start_end, skip_seconds_start, skip_seconds_end, tol, corr_plot_file=None, verbose=False): rate_out, data_out = scipy.io.wavfile.read(output_file) rate_in, data_in = scipy.io.wavfile.read(input_file) if data_out.ndim == 1: data_out = data_out.reshape(len(data_out), 1) if data_in.ndim == 1: data_in = data_in.reshape(len(data_in), 1) if rate_out != rate_in: assert False, f"input and output file rates are not equal. input rate {rate_in}, output rate {rate_out}" if data_in.dtype != np.int32: if data_in.dtype == np.int16: data_in = np.array(data_in, dtype=np.int32) data_in = data_in * (2**16) else: assert False, "Unsupported data_in.dtype {data_in.dtype}" if data_out.dtype != np.int32: if data_out.dtype == np.int16: data_out = np.array(data_out, dtype=np.int32) else: assert False, "Unsupported data_out.dtype {data_out.dtype}" assert out_ch_start_end[1]-out_ch_start_end[0] == in_ch_start_end[1]-in_ch_start_end[0], "input and output files have different channel nos." skip_samples_start = int(rate_out * skip_seconds_start) skip_samples_end = int(rate_out * skip_seconds_end) data_in = data_in[:,in_ch_start_end[0]:in_ch_start_end[1]+1] data_out = data_out[:,out_ch_start_end[0]:out_ch_start_end[1]+1] small_len = min(len(data_in), len(data_out), 64000) data_in_small = data_in[skip_samples_start : small_len+skip_samples_start, :].astype(np.float64) data_out_small = data_out[skip_samples_start : small_len+skip_samples_start, :].astype(np.float64) corr = scipy.signal.correlate(data_in_small[:, 0], data_out_small[:, 0], "full") delay = (corr.shape[0] // 2) - np.argmax(corr) print(f"delay = {delay}") if corr_plot_file != None: plt.plot(corr) plt.savefig(corr_plot_file) plt.clf() delay_orig = delay data_size = min(data_in.shape[0], data_out.shape[0]) data_size -= skip_samples_end print(f"compare {data_size - skip_samples_start} samples") print(data_in.shape) print(data_out.shape) print(delay) num_channels = out_ch_start_end[1]-out_ch_start_end[0]+1 all_close = True max_diff = [] for ch in range(num_channels): print(f"comparing ch {ch}") close = np.isclose( data_in[skip_samples_start : data_size - delay, ch], data_out[skip_samples_start + delay : data_size, ch], atol=tol, ) print(f"ch {ch}, close = {np.all(close)}") if verbose: int_max_idxs = np.argwhere(close[:] == False) print("shape = ", int_max_idxs.shape) print(int_max_idxs) if np.all(close) == False: if int_max_idxs[0] != 0: count = 0 for i in int_max_idxs: if count < 100: # Print first 100 values that were not close print(i, data_in[skip_samples_start+i, ch], data_out[skip_samples_start + delay + i, ch]) count += 1 diff = np.abs((data_in[skip_samples_start : data_size - delay, ch]) - (data_out[skip_samples_start + delay : data_size, ch])) max_diff.append(np.amax(diff)) print(f"max diff value is {max_diff[-1]}") all_close = all_close & np.all(close) print(f"all_close: {np.all(all_close)}") return all_close, max(max_diff), delay_orig