Skip to content

Commit 0cfee0a

Browse files
committed
Attempt at fixing issues with noisy output of MDX models and remove cpecs
1 parent cecb34e commit 0cfee0a

File tree

5 files changed

+43
-587
lines changed

5 files changed

+43
-587
lines changed

PolUVR/separator/architectures/mdx_separator.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, common_config, arch_config):
4646
# - For Non-MDX23C models: You can choose between 0.001-0.999
4747
self.overlap = arch_config.get("overlap", 0.25)
4848

49-
# Ensure overlap is within the range [0.001, 0.99]
49+
# Ensure overlap is within the range [0.001, 0.999]
5050
if self.overlap < 0.001:
5151
self.logger.warning(f"overlap {self.overlap} is less than the minimum allowed value of 0.001. Setting overlap to 0.001.")
5252
self.overlap = 0.001
@@ -184,34 +184,33 @@ def separate(self, audio_file_path, custom_output_names=None):
184184
mix = self.prepare_mix(self.audio_file_path)
185185

186186
self.logger.debug("Normalizing mix before demixing...")
187+
peak = np.abs(mix).max()
187188
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold)
188189

189190
# Start the demixing process
190-
source = self.demix(mix)
191+
source = self.demix(mix) * peak
191192
self.logger.debug("Demixing completed.")
192193

194+
if not isinstance(self.primary_source, np.ndarray):
195+
self.primary_source = source.T
196+
193197
# In UVR, the source is cached here if it's a vocal split model, but we're not supporting that yet
194198

195199
# Initialize the list for output files
196200
output_files = []
197201
self.logger.debug("Processing output files...")
198202

199-
# Normalize and transpose the primary source if it's not already an array
200-
if not isinstance(self.primary_source, np.ndarray):
201-
self.logger.debug("Normalizing primary source...")
202-
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold, min_peak=self.amplification_threshold).T
203-
204203
# Process the secondary source if not already an array
205204
if not isinstance(self.secondary_source, np.ndarray):
206205
self.logger.debug("Producing secondary source: demixing in match_mix mode")
207206
raw_mix = self.demix(mix, is_match_mix=True)
208207

209208
if self.invert_using_spec:
210209
self.logger.debug("Inverting secondary stem using spectogram as invert_using_spec is set to True")
211-
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
210+
self.secondary_source = spec_utils.invert_stem(raw_mix, self.primary_source * self.compensate)
212211
else:
213212
self.logger.debug("Inverting secondary stem by subtracting of transposed demixed stem from transposed original mix")
214-
self.secondary_source = mix.T - source.T
213+
self.secondary_source = (-self.primary_source * self.compensate) + mix.T
215214

216215
# Save and process the secondary stem if needed
217216
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
@@ -224,10 +223,6 @@ def separate(self, audio_file_path, custom_output_names=None):
224223
# Save and process the primary stem if needed
225224
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
226225
self.primary_stem_output_path = self.get_stem_output_path(self.primary_stem_name, custom_output_names)
227-
228-
if not isinstance(self.primary_source, np.ndarray):
229-
self.primary_source = source.T
230-
231226
self.logger.info(f"Saving {self.primary_stem_name} stem to {self.primary_stem_output_path}...")
232227
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
233228
output_files.append(self.primary_stem_output_path)
@@ -284,7 +279,15 @@ def initialize_mix(self, mix, is_ckpt=False):
284279
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
285280
self.logger.debug(f"Padding calculated: {pad}")
286281
# Add padding at the beginning and the end of the mix
287-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
282+
mixture = np.concatenate(
283+
(
284+
np.zeros((2, self.trim), dtype="float32"), # Pad at the start
285+
mix,
286+
np.zeros((2, pad), dtype="float32"), # Pad in the middle (to match chunk size)
287+
np.zeros((2, self.trim), dtype="float32"), # Pad at the end
288+
),
289+
1
290+
)
288291
# Determine the number of chunks based on the mixture's length
289292
num_chunks = mixture.shape[-1] // self.gen_size
290293
self.logger.debug(f"Mixture shape after padding: {mixture.shape}, Number of chunks: {num_chunks}")
@@ -431,11 +434,6 @@ def demix(self, mix, is_match_mix=False):
431434

432435
# TODO: In UVR, pitch changing happens here. Consider implementing this as a feature.
433436

434-
# Compensates the source if not matching the mix.
435-
if not is_match_mix:
436-
source *= self.compensate
437-
self.logger.debug("Match mix mode; compensate multiplier applied.")
438-
439437
# TODO: In UVR, VR denoise model gets applied here. Consider implementing this as a feature.
440438

441439
self.logger.debug("Demixing process completed.")

PolUVR/separator/roformer/parameter_validator.py

Lines changed: 26 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,53 +7,35 @@
77
import sys
88
import os
99

10-
# Add contracts to path for interface imports (optional)
11-
try:
12-
# Find project root dynamically
13-
current_dir = os.path.dirname(os.path.abspath(__file__))
14-
project_root = current_dir
15-
# Go up until we find the project root (contains specs/ directory)
16-
while project_root and not os.path.exists(os.path.join(project_root, 'specs')):
17-
parent = os.path.dirname(project_root)
18-
if parent == project_root: # Reached filesystem root
19-
break
20-
project_root = parent
21-
22-
contracts_path = os.path.join(project_root, 'specs', '001-update-roformer-implementation', 'contracts')
23-
if os.path.exists(contracts_path):
24-
sys.path.append(contracts_path)
25-
from parameter_validator_interface import (
26-
ParameterValidatorInterface,
27-
ValidationIssue,
28-
ValidationSeverity
29-
)
30-
_has_interface = True
31-
except ImportError:
32-
# Create dummy interfaces for when contracts are not available
33-
from enum import Enum
34-
from dataclasses import dataclass
35-
36-
class ValidationSeverity(Enum):
37-
ERROR = "error"
38-
WARNING = "warning"
39-
INFO = "info"
40-
41-
@dataclass
42-
class ValidationIssue:
43-
severity: ValidationSeverity
44-
parameter_name: str
45-
message: str
46-
suggested_fix: str
47-
current_value: any = None
48-
expected_value: any = None
49-
50-
class ParameterValidatorInterface:
51-
pass
52-
53-
_has_interface = False
10+
from enum import Enum
11+
from dataclasses import dataclass
12+
5413
from .parameter_validation_error import ParameterValidationError
5514

5615

16+
class ValidationSeverity(Enum):
17+
ERROR = "error"
18+
WARNING = "warning"
19+
INFO = "info"
20+
21+
22+
@dataclass
23+
class ValidationIssue:
24+
severity: ValidationSeverity
25+
parameter_name: str
26+
message: str
27+
suggested_fix: str
28+
current_value: any = None
29+
expected_value: any = None
30+
31+
32+
class ParameterValidatorInterface:
33+
pass
34+
35+
36+
_has_interface = False
37+
38+
5739
class ParameterValidator(ParameterValidatorInterface):
5840
"""
5941
Implementation of parameter validation for Roformer models.

specs/001-update-roformer-implementation/contracts/fallback_loader_interface.py

Lines changed: 0 additions & 65 deletions
This file was deleted.

0 commit comments

Comments
 (0)