Add readahead for segmenters

This commit is contained in:
Chris Cannam
2022-05-25 11:14:19 +01:00
parent 680393c5c6
commit 973a334f75
2 changed files with 129 additions and 54 deletions

View File

@@ -233,80 +233,134 @@ R3StretcherImpl::consume()
// We have a single unwindowed frame at the longest FFT // We have a single unwindowed frame at the longest FFT
// size ("scale"). Populate the shorter FFT sizes from the // size ("scale"). Populate the shorter FFT sizes from the
// centre of it, windowing as we copy // centre of it, windowing as we copy. The classification
// scale is handled separately because it has readahead,
// so skip it here as well as the longest. (In practice
// this means we are probably only populating one scale)
for (auto &it: cd->scales) { for (auto &it: cd->scales) {
int fftSize = it.first; int fftSize = it.first;
auto &scale = it.second; if (fftSize == classify || fftSize == longest) continue;
if (fftSize == longest) continue;
int offset = (longest - fftSize) / 2; int offset = (longest - fftSize) / 2;
m_scaleData.at(fftSize)->analysisWindow.cut m_scaleData.at(fftSize)->analysisWindow.cut
(buf + offset, scale->timeDomain.data()); (buf + offset, it.second->timeDomain.data());
} }
// Then window the longest one // The classification scale has a one-hop readahead (note
// that inhop is fixed), so populate its current data from
// the readahead and the readahead from further down the
// long unwindowed frame.
auto &classifyScale = cd->scales.at(classify);
ClassificationReadaheadData &readahead = cd->readahead;
m_scaleData.at(classify)->analysisWindow.cut
(buf + (longest - classify) / 2 + m_inhop,
readahead.timeDomain.data());
// Finally window the longest scale
m_scaleData.at(longest)->analysisWindow.cut(buf); m_scaleData.at(longest)->analysisWindow.cut(buf);
// FFT shift, forward FFT, and carry out cartesian-polar // FFT shift, forward FFT, and carry out cartesian-polar
// conversion for each FFT size // conversion for each FFT size.
// For the classification scale we need magnitudes for the
// full range (polar only in a subset) and we operate in
// the readahead, pulling current values from the existing
// readahead
v_fftshift(readahead.timeDomain.data(), classify);
v_copy(classifyScale->mag.data(),
readahead.mag.data(),
classifyScale->bufSize);
v_copy(classifyScale->phase.data(),
readahead.phase.data(),
classifyScale->bufSize);
m_scaleData.at(classify)->fft.forward(readahead.timeDomain.data(),
classifyScale->real.data(),
classifyScale->imag.data());
for (const auto &b : m_guideConfiguration.fftBandLimits) {
if (b.fftSize == classify) {
if (b.b0min > 0) {
v_cartesian_to_magnitudes(readahead.mag.data(),
classifyScale->real.data(),
classifyScale->imag.data(),
b.b0min);
}
v_cartesian_to_polar(readahead.mag.data() + b.b0min,
readahead.phase.data() + b.b0min,
classifyScale->real.data() + b.b0min,
classifyScale->imag.data() + b.b0min,
b.b1max - b.b0min);
if (b.b1max < classify/2 + 1) {
v_cartesian_to_magnitudes
(readahead.mag.data() + b.b1max,
classifyScale->real.data() + b.b1max,
classifyScale->imag.data() + b.b1max,
classify/2 + 1 - b.b1max);
}
v_scale(classifyScale->mag.data(),
1.0 / double(classify),
classifyScale->mag.size());
break;
}
}
// For the others we operate directly in the scale data
// and restrict the range for cartesian-polar conversion
for (auto &it: cd->scales) { for (auto &it: cd->scales) {
int fftSize = it.first; int fftSize = it.first;
if (fftSize == classify) continue;
auto &scale = it.second; auto &scale = it.second;
v_fftshift(scale->timeDomain.data(), fftSize); v_fftshift(scale->timeDomain.data(), fftSize);
if (fftSize == m_guideConfiguration.classificationFftSize) { m_scaleData.at(fftSize)->fft.forward(scale->timeDomain.data(),
// For the classification scale we need the full range
m_scaleData.at(fftSize)->fft.forwardPolar
(scale->timeDomain.data(),
scale->mag.data(),
scale->phase.data());
} else {
// For other scales we only need do
// cartesian-polar conversion for the necessary
// frequency subset
m_scaleData.at(fftSize)->fft.forward
(scale->timeDomain.data(),
scale->real.data(), scale->real.data(),
scale->imag.data()); scale->imag.data());
//!!! This should be a map
for (const auto &b : m_guideConfiguration.fftBandLimits) { for (const auto &b : m_guideConfiguration.fftBandLimits) {
if (b.fftSize == fftSize) { if (b.fftSize == fftSize) {
int offset = b.b0min; v_cartesian_to_polar(scale->mag.data() + b.b0min,
v_cartesian_to_polar scale->phase.data() + b.b0min,
(scale->mag.data() + offset, scale->real.data() + b.b0min,
scale->phase.data() + offset, scale->imag.data() + b.b0min,
scale->real.data() + offset, b.b1max - b.b0min);
scale->imag.data() + offset, v_scale(scale->mag.data() + b.b0min,
b.b1max - offset); 1.0 / double(fftSize),
b.b1max - b.b0min);
break; break;
} }
} }
} }
v_scale(scale->mag.data(), 1.0 / double(fftSize),
scale->mag.size());
}
// Use the classification scale to get a bin segmentation // Use the classification scale to get a bin segmentation
// and calculate the adaptive frequency guide for this // and calculate the adaptive frequency guide for this
// channel // channel
auto &classifyScale = cd->scales.at(classify);
cd->prevSegmentation = cd->segmentation; cd->prevSegmentation = cd->segmentation;
cd->segmentation = cd->segmentation = cd->nextSegmentation;
cd->segmenter->segment(classifyScale->mag.data()); cd->nextSegmentation = cd->segmenter->segment(readahead.mag.data());
m_troughPicker.findNearestAndNextPeaks m_troughPicker.findNearestAndNextPeaks
(classifyScale->mag.data(), 3, nullptr, (classifyScale->mag.data(), 3, nullptr,
classifyScale->nextTroughs.data()); classifyScale->troughs.data());
m_guide.calculate(instantaneousRatio, m_guide.calculate(instantaneousRatio,
classifyScale->mag.data(), classifyScale->mag.data(),
classifyScale->nextTroughs.data(), classifyScale->troughs.data(),
classifyScale->prevMag.data(), classifyScale->prevMag.data(),
cd->segmentation, cd->segmentation,
cd->prevSegmentation, cd->prevSegmentation,
BinSegmenter::Segmentation(), //!!! cd->nextSegmentation,
cd->guidance); cd->guidance);
} }
@@ -320,7 +374,7 @@ R3StretcherImpl::consume()
m_channelAssembly.mag[c] = classifyScale->mag.data(); m_channelAssembly.mag[c] = classifyScale->mag.data();
m_channelAssembly.phase[c] = classifyScale->phase.data(); m_channelAssembly.phase[c] = classifyScale->phase.data();
m_channelAssembly.guidance[c] = &cd->guidance; m_channelAssembly.guidance[c] = &cd->guidance;
m_channelAssembly.outPhase[c] = classifyScale->outPhase.data(); m_channelAssembly.outPhase[c] = classifyScale->advancedPhase.data();
} }
m_scaleData.at(fftSize)->guided.advance m_scaleData.at(fftSize)->guided.advance
(m_channelAssembly.outPhase.data(), (m_channelAssembly.outPhase.data(),
@@ -342,8 +396,12 @@ R3StretcherImpl::consume()
auto &scale = it.second; auto &scale = it.second;
int bufSize = scale->bufSize; int bufSize = scale->bufSize;
// copy to prevMag before filtering // copy to prevMag before filtering
v_copy(scale->prevMag.data(), scale->mag.data(), bufSize); v_copy(scale->prevMag.data(),
v_copy(scale->prevOutPhase.data(), scale->outPhase.data(), bufSize); scale->mag.data(),
bufSize);
v_copy(scale->prevAdvancedPhase.data(),
scale->advancedPhase.data(),
bufSize);
} }
for (const auto &band : cd->guidance.fftBands) { for (const auto &band : cd->guidance.fftBands) {
@@ -398,7 +456,7 @@ R3StretcherImpl::consume()
(scale->real.data() + offset, (scale->real.data() + offset,
scale->imag.data() + offset, scale->imag.data() + offset,
scale->mag.data() + offset, scale->mag.data() + offset,
scale->outPhase.data() + offset, scale->advancedPhase.data() + offset,
b.b1max - offset); b.b1max - offset);
break; break;
} }

View File

@@ -134,6 +134,21 @@ public:
size_t getChannelCount() const; size_t getChannelCount() const;
protected: protected:
struct ClassificationReadaheadData {
FixedVector<double> timeDomain;
FixedVector<double> mag;
FixedVector<double> phase;
ClassificationReadaheadData(int _fftSize) :
timeDomain(_fftSize, 0.f),
mag(_fftSize/2 + 1, 0.f),
phase(_fftSize/2 + 1, 0.f)
{ }
private:
ClassificationReadaheadData(const ClassificationReadaheadData &) =delete;
ClassificationReadaheadData &operator=(const ClassificationReadaheadData &) =delete;
};
struct ChannelScaleData { struct ChannelScaleData {
int fftSize; int fftSize;
int bufSize; // size of every freq-domain array here: fftSize/2 + 1 int bufSize; // size of every freq-domain array here: fftSize/2 + 1
@@ -143,10 +158,10 @@ protected:
FixedVector<double> imag; FixedVector<double> imag;
FixedVector<double> mag; FixedVector<double> mag;
FixedVector<double> phase; FixedVector<double> phase;
FixedVector<double> outPhase; //!!! "advanced"? FixedVector<double> advancedPhase;
FixedVector<int> nextTroughs; //!!! not used in every scale FixedVector<int> troughs; //!!! not used in every scale
FixedVector<double> prevMag; //!!! not used in every scale FixedVector<double> prevMag; //!!! not used in every scale
FixedVector<double> prevOutPhase; FixedVector<double> prevAdvancedPhase;
FixedVector<double> accumulator; FixedVector<double> accumulator;
ChannelScaleData(int _fftSize, int _longestFftSize) : ChannelScaleData(int _fftSize, int _longestFftSize) :
@@ -157,10 +172,10 @@ protected:
imag(bufSize, 0.f), imag(bufSize, 0.f),
mag(bufSize, 0.f), mag(bufSize, 0.f),
phase(bufSize, 0.f), phase(bufSize, 0.f),
outPhase(bufSize, 0.f), advancedPhase(bufSize, 0.f),
nextTroughs(bufSize, 0), troughs(bufSize, 0),
prevMag(bufSize, 0.f), prevMag(bufSize, 0.f),
prevOutPhase(bufSize, 0.f), prevAdvancedPhase(bufSize, 0.f),
accumulator(_longestFftSize, 0.f) accumulator(_longestFftSize, 0.f)
{ } { }
@@ -171,6 +186,7 @@ protected:
struct ChannelData { struct ChannelData {
std::map<int, std::shared_ptr<ChannelScaleData>> scales; std::map<int, std::shared_ptr<ChannelScaleData>> scales;
ClassificationReadaheadData readahead;
std::unique_ptr<BinSegmenter> segmenter; std::unique_ptr<BinSegmenter> segmenter;
BinSegmenter::Segmentation segmentation; BinSegmenter::Segmentation segmentation;
BinSegmenter::Segmentation prevSegmentation; BinSegmenter::Segmentation prevSegmentation;
@@ -183,6 +199,7 @@ protected:
BinClassifier::Parameters classifierParameters, BinClassifier::Parameters classifierParameters,
int ringBufferSize) : int ringBufferSize) :
scales(), scales(),
readahead(segmenterParameters.fftSize),
segmenter(new BinSegmenter(segmenterParameters, segmenter(new BinSegmenter(segmenterParameters,
classifierParameters)), classifierParameters)),
segmentation(), prevSegmentation(), nextSegmentation(), segmentation(), prevSegmentation(), nextSegmentation(),