Skip to content

Commit 50edf31

Browse files
committed
Enable qwen in the app.
1 parent 4ba32e7 commit 50edf31

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

Moshi/ContentView.swift

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Metal
1010
import MoshiLib
1111
import SwiftUI
1212
import Synchronization
13+
import Tokenizers
1314

1415
struct CustomError: Error {
1516
let message: String
@@ -24,6 +25,7 @@ enum ModelSelect: String, CaseIterable, Identifiable {
2425
case asr
2526
case hibiki
2627
case helium
28+
case qwen
2729

2830
var id: Self { return self }
2931

@@ -194,6 +196,18 @@ class Evaluator {
194196
return model
195197
}
196198

199+
func makeQwen(_ url: URL, _ cfg: QwenConfig) throws -> QwenModel {
200+
let weights = try loadArrays(url: url)
201+
let parameters = ModuleParameters.unflattened(weights)
202+
let model = QwenModel(cfg)
203+
if let q = cfg.quantization {
204+
quantize(model: model, groupSize: q.groupSize, bits: q.bits)
205+
}
206+
try model.update(parameters: parameters, verify: [.all])
207+
eval(model)
208+
return model
209+
}
210+
197211
func makeMimi(numCodebooks: Int) async throws -> Mimi {
198212
let cfg = MimiConfig.mimi_2024_07(numCodebooks: numCodebooks)
199213
let model = Mimi(cfg, bSize: 1)
@@ -367,6 +381,9 @@ class Evaluator {
367381
case .helium:
368382
let model = try await HeliumModel(self, self.cb)
369383
m = ModelState(model)
384+
case .qwen:
385+
let model = try await QwenModel_(self, self.cb)
386+
m = ModelState(model)
370387
}
371388
self.loadState = .loaded(m, sm)
372389
return m
@@ -455,6 +472,48 @@ struct MimiModel: Model {
455472
}
456473
}
457474

475+
struct QwenModel_: Model {
476+
let qwen: QwenModel
477+
let tokenizer: any Tokenizer
478+
479+
init(_ ev: Evaluator, _ cb: Callbacks) async throws {
480+
let hfRepo = "Qwen/Qwen2.5-0.5B-Instruct"
481+
await ev.setModelInfo("building model")
482+
483+
let configUrl = try await ev.downloadFromHub(id: hfRepo, filename: "config.json")
484+
let configData = try Data(contentsOf: configUrl)
485+
let decoder = JSONDecoder()
486+
decoder.keyDecodingStrategy = .convertFromSnakeCase
487+
let cfg = try decoder.decode(QwenConfig.self, from: configData)
488+
489+
let url = try await ev.downloadFromHub(id: hfRepo, filename: "model.safetensors")
490+
qwen = try await ev.makeQwen(url, cfg)
491+
await ev.setModelInfo("model built")
492+
tokenizer = try await AutoTokenizer.from(pretrained: hfRepo)
493+
}
494+
495+
mutating func reset() {
496+
}
497+
498+
mutating func onMicrophonePcm(_ pcm: MLXArray, ap: AudioPlayer, ev: Evaluator) -> Bool {
499+
let sampler = Sampler()
500+
let cache = qwen.makeCache(bSize: 1)
501+
502+
var lastToken = 151643
503+
for stepIdx in 0...500 {
504+
let logits = qwen(MLXArray([lastToken]).reshaped(1, 1), cache: cache)
505+
let (tok, _) = sampler(logits: logits[0])
506+
lastToken = tok.item<Int>()
507+
let s = tokenizer.decode(tokens: [lastToken])
508+
Task { @MainActor in
509+
ev.output += s
510+
}
511+
}
512+
513+
return false
514+
}
515+
}
516+
458517
struct HeliumModel: Model {
459518
let helium: LM
460519
let vocab: [Int: String]

0 commit comments

Comments
 (0)