308 lines
10 KiB
Python
308 lines
10 KiB
Python
import json
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
def load_track_annot(fn):
|
|
with open(fn) as fin:
|
|
annot = json.load(fin)
|
|
return annot
|
|
|
|
fns_all = [['beat_annotations/%02d/Rhythmology_%03d.json'%(j,i) for i in range(1,10)] for j in range(1,4)]
|
|
fns_all
|
|
|
|
"""
|
|
implement an algorithm to align two series of time intervals in Python. series will be upto 1000 entries long. use dynamic programming if needed.
|
|
"""
|
|
|
|
from dataclasses import dataclass
|
|
from math import inf
|
|
from typing import Callable, Iterable, Optional, Sequence, Tuple, List, Any
|
|
|
|
|
|
Interval = Tuple[float, float]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class AlignmentStep:
|
|
a_index: Optional[int]
|
|
b_index: Optional[int]
|
|
a_interval: Optional[Interval]
|
|
b_interval: Optional[Interval]
|
|
score: float
|
|
|
|
|
|
def overlap_length(x: Interval, y: Interval) -> float:
|
|
return max(0.0, min(x[1], y[1]) - max(x[0], y[0]))
|
|
|
|
|
|
def union_length(x: Interval, y: Interval) -> float:
|
|
return max(x[1], y[1]) - min(x[0], y[0])
|
|
|
|
|
|
def default_match_score(a: Interval, b: Interval) -> float:
|
|
"""
|
|
Scores interval similarity.
|
|
|
|
Returns:
|
|
+1.0 for perfect overlap
|
|
0.0 to 1.0 for partial overlap
|
|
-1.0 for no overlap
|
|
|
|
You can replace this with a domain-specific function.
|
|
"""
|
|
ov = overlap_length(a, b)
|
|
if ov <= 0:
|
|
return -1.0
|
|
|
|
u = union_length(a, b)
|
|
if u <= 0:
|
|
return 1.0
|
|
|
|
return ov / u
|
|
|
|
|
|
def align_intervals(
|
|
series_a: Sequence[Interval],
|
|
series_b: Sequence[Interval],
|
|
*,
|
|
match_score: Callable[[Interval, Interval], float] = default_match_score,
|
|
gap_penalty: float = -0.5,
|
|
) -> Tuple[float, List[AlignmentStep]]:
|
|
"""
|
|
Globally align two ordered interval series using dynamic programming.
|
|
|
|
Each step is one of:
|
|
- interval from A matched to interval from B
|
|
- interval from A matched to a gap
|
|
- gap matched to interval from B
|
|
|
|
Complexity:
|
|
Time: O(len(series_a) * len(series_b))
|
|
Memory: O(len(series_a) * len(series_b)) for traceback,
|
|
O(len(series_b)) for DP scores.
|
|
|
|
With up to 1000 intervals per series, this is usually fine.
|
|
"""
|
|
|
|
a = list(series_a)
|
|
b = list(series_b)
|
|
|
|
n = len(a)
|
|
m = len(b)
|
|
|
|
# Validate intervals.
|
|
for idx, iv in enumerate(a):
|
|
if iv[0] > iv[1]:
|
|
raise ValueError(f"Invalid interval in series_a at index {idx}: {iv}")
|
|
|
|
for idx, iv in enumerate(b):
|
|
if iv[0] > iv[1]:
|
|
raise ValueError(f"Invalid interval in series_b at index {idx}: {iv}")
|
|
|
|
# Traceback directions:
|
|
# 1 = diagonal: match A[i - 1] with B[j - 1]
|
|
# 2 = up: A[i - 1] with gap
|
|
# 3 = left: gap with B[j - 1]
|
|
trace = [bytearray(m + 1) for _ in range(n + 1)]
|
|
|
|
prev = [0.0] * (m + 1)
|
|
|
|
# Initialize first row: gaps in A.
|
|
for j in range(1, m + 1):
|
|
prev[j] = prev[j - 1] + gap_penalty
|
|
trace[0][j] = 3
|
|
|
|
# Fill DP table row by row.
|
|
for i in range(1, n + 1):
|
|
curr = [0.0] * (m + 1)
|
|
|
|
# First column: gaps in B.
|
|
curr[0] = prev[0] + gap_penalty
|
|
trace[i][0] = 2
|
|
|
|
for j in range(1, m + 1):
|
|
s_match = prev[j - 1] + match_score(a[i - 1], b[j - 1])
|
|
s_gap_b = prev[j] + gap_penalty # A[i - 1] aligned to gap
|
|
s_gap_a = curr[j - 1] + gap_penalty # gap aligned to B[j - 1]
|
|
|
|
# Deterministic tie-breaking:
|
|
# prefer match, then gap in B, then gap in A.
|
|
best = s_match
|
|
direction = 1
|
|
|
|
if s_gap_b > best:
|
|
best = s_gap_b
|
|
direction = 2
|
|
|
|
if s_gap_a > best:
|
|
best = s_gap_a
|
|
direction = 3
|
|
|
|
curr[j] = best
|
|
trace[i][j] = direction
|
|
|
|
prev = curr
|
|
|
|
total_score = prev[m]
|
|
|
|
# Reconstruct alignment.
|
|
alignment: List[AlignmentStep] = []
|
|
i, j = n, m
|
|
|
|
while i > 0 or j > 0:
|
|
direction = trace[i][j]
|
|
|
|
if direction == 1:
|
|
s = match_score(a[i - 1], b[j - 1])
|
|
alignment.append(
|
|
AlignmentStep(
|
|
a_index=i - 1,
|
|
b_index=j - 1,
|
|
a_interval=a[i - 1],
|
|
b_interval=b[j - 1],
|
|
score=s,
|
|
)
|
|
)
|
|
i -= 1
|
|
j -= 1
|
|
|
|
elif direction == 2:
|
|
alignment.append(
|
|
AlignmentStep(
|
|
a_index=i - 1,
|
|
b_index=None,
|
|
a_interval=a[i - 1],
|
|
b_interval=None,
|
|
score=gap_penalty,
|
|
)
|
|
)
|
|
i -= 1
|
|
|
|
elif direction == 3:
|
|
alignment.append(
|
|
AlignmentStep(
|
|
a_index=None,
|
|
b_index=j - 1,
|
|
a_interval=None,
|
|
b_interval=b[j - 1],
|
|
score=gap_penalty,
|
|
)
|
|
)
|
|
j -= 1
|
|
|
|
else:
|
|
# Only possible at trace[0][0], but the loop should not enter there.
|
|
raise RuntimeError("Invalid traceback state")
|
|
|
|
alignment.reverse()
|
|
return total_score, alignment
|
|
|
|
###
|
|
|
|
def make_intv_pairs(annot):
|
|
#ts = np.array(list(sorted(annot['beatTimesSec'])))
|
|
#ts_pdiff = np.diff(np.pad(ts, (1,0)))
|
|
#a = [(ts_pdiff[i], ts_pdiff[i+1]) for i in range(ts_pdiff.shape[0]-1)]
|
|
tsp = np.pad(annot['beatTimesSec'], (1,0))
|
|
a = [(tsp[i], tsp[i+1]) for i in range(tsp.shape[0]-1)]
|
|
return a
|
|
|
|
def interp_missing_1(ts):
|
|
"""interpolate missing beats"""
|
|
ib_th = 0.2 #: sec (inter-beat interval threshold for interpolation)
|
|
mn = np.mean(np.diff(ts))
|
|
idxs_interp = np.where(np.abs(np.diff(ts) - mn) > ib_th)[0]
|
|
tsl = list(ts)
|
|
for k, i in enumerate(idxs_interp):
|
|
tsl.insert(i+k+1, ts[i] + (ts[i+1] - ts[i]) / 2)
|
|
if len(idxs_interp) > 0:
|
|
print('interp_missing: added %d beats' % len(idxs_interp))
|
|
return np.array(tsl)
|
|
|
|
|
|
def interp_missing(ts_a_o, ts_b_o):
|
|
"""interpolate missing beats"""
|
|
ts_a = np.array(list(sorted(ts_a_o)))
|
|
ts_b = np.array(list(sorted(ts_b_o))) # sort & deep copy
|
|
assert len(ts_a) == len(ts_b)
|
|
ib_th = 0.2 #: sec (inter-beat interval threshold for interpolation)
|
|
mn_a, mn_b = np.mean(np.diff(ts_a)), np.mean(np.diff(ts_b))
|
|
idxs_interp_a = np.where(np.abs(np.diff(ts_a) - mn_a) > ib_th)[0]
|
|
idxs_interp_b = np.where(np.abs(np.diff(ts_b) - mn_b) > ib_th)[0]
|
|
# join the two sets, and apply interpolation to both timeseries
|
|
# (ensures they will have the same length afterwards)
|
|
idxs_interp = list(sorted(list(set(idxs_interp_a).union(set(idxs_interp_b)))))
|
|
tsl_a, tsl_b = list(ts_a), list(ts_b)
|
|
ma, mb = 0, 0
|
|
#print('idxs_interp_a', str(list(idxs_interp_a)))
|
|
#print('idxs_interp_b', str(list(idxs_interp_b)))
|
|
for k, i in enumerate(idxs_interp):
|
|
ai = (ts_a[i+1] - ts_a[i])
|
|
bi = (ts_b[i+1] - ts_b[i])
|
|
# maybe we need to interpolate just on 'a' or 'b' side
|
|
# maybe we need to add more than 1 point
|
|
na, nb = int(np.round(ai / mn_a)), int(np.round(bi / mn_b))
|
|
#
|
|
# maybe we need to merge adjacent
|
|
"""
|
|
if na == 0:
|
|
print('idx=%d: a interp (%.3f, %.3f), rm %d pts of %.3f interval' % (i, ts_a[i], ts_a[i+1], int(na-1), ai))
|
|
print('idx=%d: b interp (%.3f, %.3f), rm %d pts of %.3f interval' % (i, ts_b[i], ts_b[i+1], int(nb-1), bi))
|
|
print(' idx=%d: a1 %s' % (i, tsl_a[i+ma+1-3:i+ma+1+4]))
|
|
ts_a[i+1] = ts_a[i]
|
|
del tsl_a[i+ma]
|
|
print(' idx=%d: a2 %s' % (i, tsl_a[i+ma+1-3:i+ma+1+4]))
|
|
print(' idx=%d: b %s' % (i, tsl_b[i+mb+1-3:i+mb+1+4]))
|
|
if nb == 0:
|
|
print('idx=%d: a interp (%.3f, %.3f), rm %d pts of %.3f interval' % (i, ts_a[i], ts_a[i+1], int(na-1), ai))
|
|
print('idx=%d: b interp (%.3f, %.3f), rm %d pts of %.3f interval' % (i, ts_b[i], ts_b[i+1], int(nb-1), bi))
|
|
print(' idx=%d: b1 %s' % (i, tsl_b[i+mb+1-3:i+mb+1+4]))
|
|
ts_b[i+1] = ts_b[i]
|
|
del tsl_b[i+ma]
|
|
print(' idx=%d: b2 %s' % (i, tsl_b[i+mb+1-3:i+mb+1+4]))
|
|
print(' idx=%d: a %s' % (i, tsl_a[i+ma+1-3:i+ma+1+4]))
|
|
ai /= na if na != 0 else 1
|
|
bi /= nb if nb != 0 else 1
|
|
"""
|
|
ai /= na if na != 0 else 1
|
|
bi /= nb if nb != 0 else 1
|
|
#
|
|
#print('idx=%d: a interp (%.3f, %.3f), add %d pts of %.3f interval' % (i, ts_a[i], ts_a[i+1], int(na-1), ai))
|
|
#print('idx=%d: b interp (%.3f, %.3f), add %d pts of %.3f interval' % (i, ts_b[i], ts_b[i+1], int(nb-1), bi))
|
|
#print(' idx=%d: a1 %s' % (i, tsl_a[i+ma+1-3:i+ma+1+4]))
|
|
for pa in range(int(na-1)):
|
|
tsl_a.insert(i+ma+pa+1, ts_a[i] + ai * (pa+1))
|
|
#print(' idx=%d: a2 %s' % (i, tsl_a[i+ma+1-3:i+ma+1+4]))
|
|
#print(' idx=%d: b1 %s' % (i, tsl_b[i+mb+1-3:i+mb+1+4]))
|
|
for pb in range(int(nb-1)):
|
|
tsl_b.insert(i+mb+pb+1, ts_b[i] + bi * (pb+1))
|
|
#print(' idx=%d: b2 %s' % (i, tsl_b[i+mb+1-3:i+mb+1+4]))
|
|
ma += int(na-1)
|
|
mb += int(nb-1)
|
|
#if len(idxs_interp) > 0:
|
|
# print('interp_missing: added %d beats' % len(idxs_interp))
|
|
if ma > 0 or mb > 0:
|
|
print('interp_missing: added ma=%d mb=%d beats' % (ma, mb))
|
|
ll = len(ts_a) + min(ma, mb)
|
|
oa, ob = np.array(tsl_a)[:ll], np.array(tsl_b)[:ll]
|
|
print(len(ts_a), len(oa), len(ob))
|
|
return oa, ob
|
|
|
|
|
|
def align(annot_a, annot_b):
|
|
score, alignment = align_intervals(make_intv_pairs(annot_a), make_intv_pairs(annot_b))
|
|
def check_intv(al):
|
|
# - neither interval is None (we are aligned)
|
|
not_none = al.a_interval and al.b_interval
|
|
# - cut first few bad intervals (too lazy to interpolate them & they might be noisy by the user anyways)
|
|
init_bad = al.a_interval and al.b_interval and al.score < 0.5 and (al.a_index < 5 or al.b_index < 5)
|
|
return not_none and not init_bad
|
|
|
|
idxs_a = np.array([al.a_index for al in alignment if check_intv(al)])
|
|
idxs_b = np.array([al.b_index for al in alignment if check_intv(al)])
|
|
ts_a, ts_b = np.array(annot_a['beatTimesSec']), np.array(annot_b['beatTimesSec'])
|
|
tsa_a, tsa_b = ts_a[idxs_a], ts_b[idxs_b]
|
|
tsa_a, tsa_b = interp_missing(tsa_a, tsa_b)
|
|
|
|
return ts_a, ts_b, tsa_a, tsa_b
|