Swift/MLX implementation of SAM-Audio (Segment Anything Model for Audio) for text-guided source separation.
import Foundation
import MLXAudioSTS
@main
struct Demo {
static func main() async throws {
let model = try await SAMAudio.fromPretrained("mlx-community/sam-audio-large")
let result = try await model.separate(
audioPaths: ["input.wav"],
descriptions: ["speech"]
)
print("sampleRate:", model.sampleRate)
print("target shape:", result.target[0].shape)
print("residual shape:", result.residual[0].shape)
print("peak memory (GB):", result.peakMemoryGB ?? -1)
}
}SAMAudio.fromPretrained(...) accepts either:
- A Hugging Face repo, e.g.
mlx-community/sam-audio-large - A local directory containing
config.jsonand one or more.safetensorsfiles
let model = try await SAMAudio.fromPretrained("mlx-community/sam-audio-large")For gated repos, pass a token:
let model = try await SAMAudio.fromPretrained(
"facebook/sam-audio-large",
hfToken: ProcessInfo.processInfo.environment["HF_TOKEN"]
)Best for short clips when memory is not a concern.
let result = try await model.separate(
audioPaths: ["input.wav"],
descriptions: ["speech"],
anchors: [[("+", 1.5, 3.0)]], // optional
ode: SAMAudioODEOptions(method: .midpoint, stepSize: 2.0 / 32.0)
)You can also pass pre-batched waveforms:
let result = try await model.separate(
audios: batchedAudio, // shape: (B, 1, T)
descriptions: prompts
)Chunked inference with cosine crossfade stitching.
let result = try await model.separateLong(
audioPaths: ["long_input.wav"],
descriptions: ["speech"],
chunkSeconds: 10.0,
overlapSeconds: 3.0,
ode: SAMAudioODEOptions(method: .euler, stepSize: 2.0 / 32.0)
)Generator-style:
let stream = try model.separateStreaming(
audioPaths: ["input.wav"],
descriptions: ["speech"],
chunkSeconds: 10.0,
overlapSeconds: 3.0
)
for try await chunk in stream {
print(chunk.chunkIndex, chunk.target.shape, chunk.isLastChunk)
}Callback-style:
let count = try await model.separateStreaming(
audios: batchedAudio,
descriptions: prompts,
targetCallback: { audioChunk, idx, isLast in
print("target chunk", idx, audioChunk.shape, isLast)
}
)
print("total samples emitted:", count)Anchor format:
SAMAudioAnchor = (token: String, startTime: Float, endTime: Float)
Token meanings:
"+": target sound is present in this span"-": target sound is not present in this span
Example:
anchors: [[("+", 1.0, 2.5), ("-", 4.0, 6.0)]]SAMAudioODEOptions controls quality/speed:
.midpointis slower and usually higher quality.euleris fasterstepSizemust be0 < stepSize < 1(default is2/32)
let ode = SAMAudioODEOptions(method: .midpoint, stepSize: 2.0 / 32.0)separateLong(...)and chunked streaming currently require batch size1- Chunked methods do not currently support anchors (
chunkedAnchorsNotSupported) - Output arrays are mono waveforms per sample (
(samples, 1))
Local integration test (no network):
swift test --filter fromPretrainedLoadsLocalFixtureNetwork-enabled integration test:
MLXAUDIO_ENABLE_NETWORK_TESTS=1 \
MLXAUDIO_SAMAUDIO_REPO=mlx-community/sam-audio-large \
swift test --filter fromPretrainedLoadsRealWeightsNetwork