chore: debug for SsfZxing detector

This commit is contained in:
2026-05-13 05:16:43 +02:00
parent 11934b1f61
commit 4627786dc4

15
beat.py
View File

@@ -13,6 +13,11 @@ class SsfZxing:
ssf_rel_thres = 3 #: magic number from Zong 2003, threshold from mean SSF amplitude ssf_rel_thres = 3 #: magic number from Zong 2003, threshold from mean SSF amplitude
ssf_rel_rise = 0.8 #: minimum rise of SSF edge (from foot to peak) relative to 'ssf_th' ssf_rel_rise = 0.8 #: minimum rise of SSF edge (from foot to peak) relative to 'ssf_th'
# TODO: C++ impl has diverged.
# - refractory period changes
# - ssf_th filter with 6-points
# - ?? others ??
def __init__(self): pass def __init__(self): pass
def _ssf_det_zxings(self, fs, ssf, ssf_th): def _ssf_det_zxings(self, fs, ssf, ssf_th):
@@ -37,6 +42,7 @@ class SsfZxing:
ssf_z[:i_range] = 0 # force-zero the bounds where we cannot check the amplitude rise ssf_z[:i_range] = 0 # force-zero the bounds where we cannot check the amplitude rise
ssf_zxings = np.where(ssf_z)[0] ssf_zxings = np.where(ssf_z)[0]
# only integer-index resolution (no interpolation) # only integer-index resolution (no interpolation)
self.ssf_zxings = ssf_zxings
return ssf_zxings return ssf_zxings
def _ssf_function(self, fs, y): def _ssf_function(self, fs, y):
@@ -50,8 +56,17 @@ class SsfZxing:
# compute threshold # compute threshold
# TODO: check if we need lowpass instead of mean for 'ssf_th' # TODO: check if we need lowpass instead of mean for 'ssf_th'
ssf_th = self.ssf_rel_thres * np.mean(ssf) ssf_th = self.ssf_rel_thres * np.mean(ssf)
self.ssf, self.ssf_th = ssf, ssf_th
return ssf, ssf_th return ssf, ssf_th
def debug_plot(self, i1, i2):
ssf, ssf_th, ssf_zxings = self.ssf, self.ssf_th, self.ssf_zxings
ssf_slice = ssf[i1:i2]
ssf_th_slice = ssf_th[i1:i2] if isinstance(ssf_th, np.ndarray) else ssf_th
plt.figure(figsize=(8, 2))
plt.plot(ssf_slice)
plt.plot(np.arange(ssf_slice.shape[0]), np.ones(ssf_slice.shape[0]) * ssf_th_slice)
plt.scatter(ssf_zxings[i1:i2], np.ones(ssf_zxings[i1:i2].shape[0]) * ssf_th_slice, c='r')
def get_mae_dist(ibis): def get_mae_dist(ibis):
"""make triangular wave between beats, representing absolute beat placement error.""" """make triangular wave between beats, representing absolute beat placement error."""