Files
lockstep-eval/aligner.py
2026-05-15 16:02:48 +02:00

258 lines
7.7 KiB
Python

import json
import numpy as np
import matplotlib.pyplot as plt
def load_track_annot(fn):
with open(fn) as fin:
annot = json.load(fin)
return annot
fns_all = [['beat_annotations/%02d/Rhythmology_%03d.json'%(j,i) for i in range(1,10)] for j in range(1,4)]
fns_all
"""
implement an algorithm to align two series of time intervals in Python. series will be upto 1000 entries long. use dynamic programming if needed.
"""
from dataclasses import dataclass
from math import inf
from typing import Callable, Iterable, Optional, Sequence, Tuple, List, Any
Interval = Tuple[float, float]
@dataclass(frozen=True)
class AlignmentStep:
a_index: Optional[int]
b_index: Optional[int]
a_interval: Optional[Interval]
b_interval: Optional[Interval]
score: float
def overlap_length(x: Interval, y: Interval) -> float:
return max(0.0, min(x[1], y[1]) - max(x[0], y[0]))
def union_length(x: Interval, y: Interval) -> float:
return max(x[1], y[1]) - min(x[0], y[0])
def default_match_score(a: Interval, b: Interval) -> float:
"""
Scores interval similarity.
Returns:
+1.0 for perfect overlap
0.0 to 1.0 for partial overlap
-1.0 for no overlap
You can replace this with a domain-specific function.
"""
ov = overlap_length(a, b)
if ov <= 0:
return -1.0
u = union_length(a, b)
if u <= 0:
return 1.0
return ov / u
def align_intervals(
series_a: Sequence[Interval],
series_b: Sequence[Interval],
*,
match_score: Callable[[Interval, Interval], float] = default_match_score,
gap_penalty: float = -0.5,
) -> Tuple[float, List[AlignmentStep]]:
"""
Globally align two ordered interval series using dynamic programming.
Each step is one of:
- interval from A matched to interval from B
- interval from A matched to a gap
- gap matched to interval from B
Complexity:
Time: O(len(series_a) * len(series_b))
Memory: O(len(series_a) * len(series_b)) for traceback,
O(len(series_b)) for DP scores.
With up to 1000 intervals per series, this is usually fine.
"""
a = list(series_a)
b = list(series_b)
n = len(a)
m = len(b)
# Validate intervals.
for idx, iv in enumerate(a):
if iv[0] > iv[1]:
raise ValueError(f"Invalid interval in series_a at index {idx}: {iv}")
for idx, iv in enumerate(b):
if iv[0] > iv[1]:
raise ValueError(f"Invalid interval in series_b at index {idx}: {iv}")
# Traceback directions:
# 1 = diagonal: match A[i - 1] with B[j - 1]
# 2 = up: A[i - 1] with gap
# 3 = left: gap with B[j - 1]
trace = [bytearray(m + 1) for _ in range(n + 1)]
prev = [0.0] * (m + 1)
# Initialize first row: gaps in A.
for j in range(1, m + 1):
prev[j] = prev[j - 1] + gap_penalty
trace[0][j] = 3
# Fill DP table row by row.
for i in range(1, n + 1):
curr = [0.0] * (m + 1)
# First column: gaps in B.
curr[0] = prev[0] + gap_penalty
trace[i][0] = 2
for j in range(1, m + 1):
s_match = prev[j - 1] + match_score(a[i - 1], b[j - 1])
s_gap_b = prev[j] + gap_penalty # A[i - 1] aligned to gap
s_gap_a = curr[j - 1] + gap_penalty # gap aligned to B[j - 1]
# Deterministic tie-breaking:
# prefer match, then gap in B, then gap in A.
best = s_match
direction = 1
if s_gap_b > best:
best = s_gap_b
direction = 2
if s_gap_a > best:
best = s_gap_a
direction = 3
curr[j] = best
trace[i][j] = direction
prev = curr
total_score = prev[m]
# Reconstruct alignment.
alignment: List[AlignmentStep] = []
i, j = n, m
while i > 0 or j > 0:
direction = trace[i][j]
if direction == 1:
s = match_score(a[i - 1], b[j - 1])
alignment.append(
AlignmentStep(
a_index=i - 1,
b_index=j - 1,
a_interval=a[i - 1],
b_interval=b[j - 1],
score=s,
)
)
i -= 1
j -= 1
elif direction == 2:
alignment.append(
AlignmentStep(
a_index=i - 1,
b_index=None,
a_interval=a[i - 1],
b_interval=None,
score=gap_penalty,
)
)
i -= 1
elif direction == 3:
alignment.append(
AlignmentStep(
a_index=None,
b_index=j - 1,
a_interval=None,
b_interval=b[j - 1],
score=gap_penalty,
)
)
j -= 1
else:
# Only possible at trace[0][0], but the loop should not enter there.
raise RuntimeError("Invalid traceback state")
alignment.reverse()
return total_score, alignment
###
def make_intv_pairs(annot):
#ts = np.array(list(sorted(annot['beatTimesSec'])))
#ts_pdiff = np.diff(np.pad(ts, (1,0)))
#a = [(ts_pdiff[i], ts_pdiff[i+1]) for i in range(ts_pdiff.shape[0]-1)]
tsp = np.pad(annot['beatTimesSec'], (1,0))
a = [(tsp[i], tsp[i+1]) for i in range(tsp.shape[0]-1)]
return a
def interp_missing_1(ts):
"""interpolate missing beats"""
ib_th = 0.2 #: sec (inter-beat interval threshold for interpolation)
mn = np.mean(np.diff(ts))
idxs_interp = np.where(np.abs(np.diff(ts) - mn) > ib_th)[0]
tsl = list(ts)
for k, i in enumerate(idxs_interp):
tsl.insert(i+k+1, ts[i] + (ts[i+1] - ts[i]) / 2)
if len(idxs_interp) > 0:
print('interp_missing: added %d beats' % len(idxs_interp))
return np.array(tsl)
def interp_missing(ts_a, ts_b):
"""interpolate missing beats"""
assert len(ts_a) == len(ts_b)
ib_th = 0.2 #: sec (inter-beat interval threshold for interpolation)
mn_a, mn_b = np.mean(np.diff(ts_a)), np.mean(np.diff(ts_b))
idxs_interp_a = np.where(np.abs(np.diff(ts_a) - mn_a) > ib_th)[0]
idxs_interp_b = np.where(np.abs(np.diff(ts_b) - mn_b) > ib_th)[0]
# join the two sets, and apply interpolation to both timeseries
# (ensures they will have the same length afterwards)
idxs_interp = np.array(list(sorted(list(set(idxs_interp_b).union(set(idxs_interp_b))))))
tsl_a, tsl_b = list(ts_a), list(ts_b)
for k, i in enumerate(idxs_interp):
tsl_a.insert(i+k+1, ts_a[i] + (ts_a[i+1] - ts_a[i]) / 2)
tsl_b.insert(i+k+1, ts_b[i] + (ts_b[i+1] - ts_b[i]) / 2)
if len(idxs_interp) > 0:
print('interp_missing: added %d beats' % len(idxs_interp))
return np.array(tsl_a), np.array(tsl_b)
def align(annot_a, annot_b):
score, alignment = align_intervals(make_intv_pairs(annot_a), make_intv_pairs(annot_b))
def check_intv(al):
# - neither interval is None (we are aligned)
not_none = al.a_interval and al.b_interval
# - cut first few bad intervals (too lazy to interpolate them & they might be noisy by the user anyways)
init_bad = al.a_interval and al.b_interval and al.score < 0.5 and (al.a_index < 5 or al.b_index < 5)
return not_none and not init_bad
idxs_a = np.array([al.a_index for al in alignment if check_intv(al)])
idxs_b = np.array([al.b_index for al in alignment if check_intv(al)])
ts_a, ts_b = np.array(annot_a['beatTimesSec']), np.array(annot_b['beatTimesSec'])
tsa_a, tsa_b = ts_a[idxs_a], ts_b[idxs_b]
tsa_a, tsa_b = interp_missing(tsa_a, tsa_b)
return ts_a, ts_b, tsa_a, tsa_b