@@ -10,6 +10,7 @@ import Metal
1010import MoshiLib
1111import SwiftUI
1212import Synchronization
13+ import Tokenizers
1314
1415struct CustomError : Error {
1516 let message : String
@@ -24,6 +25,7 @@ enum ModelSelect: String, CaseIterable, Identifiable {
2425 case asr
2526 case hibiki
2627 case helium
28+ case qwen
2729
2830 var id : Self { return self }
2931
@@ -194,6 +196,18 @@ class Evaluator {
194196 return model
195197 }
196198
199+ func makeQwen( _ url: URL , _ cfg: QwenConfig ) throws -> QwenModel {
200+ let weights = try loadArrays ( url: url)
201+ let parameters = ModuleParameters . unflattened ( weights)
202+ let model = QwenModel ( cfg)
203+ if let q = cfg. quantization {
204+ quantize ( model: model, groupSize: q. groupSize, bits: q. bits)
205+ }
206+ try model. update ( parameters: parameters, verify: [ . all] )
207+ eval ( model)
208+ return model
209+ }
210+
197211 func makeMimi( numCodebooks: Int ) async throws -> Mimi {
198212 let cfg = MimiConfig . mimi_2024_07 ( numCodebooks: numCodebooks)
199213 let model = Mimi ( cfg, bSize: 1 )
@@ -367,6 +381,9 @@ class Evaluator {
367381 case . helium:
368382 let model = try await HeliumModel ( self , self . cb)
369383 m = ModelState ( model)
384+ case . qwen:
385+ let model = try await QwenModel_ ( self , self . cb)
386+ m = ModelState ( model)
370387 }
371388 self . loadState = . loaded( m, sm)
372389 return m
@@ -455,6 +472,48 @@ struct MimiModel: Model {
455472 }
456473}
457474
475+ struct QwenModel_ : Model {
476+ let qwen : QwenModel
477+ let tokenizer : any Tokenizer
478+
479+ init ( _ ev: Evaluator , _ cb: Callbacks ) async throws {
480+ let hfRepo = " Qwen/Qwen2.5-0.5B-Instruct "
481+ await ev. setModelInfo ( " building model " )
482+
483+ let configUrl = try await ev. downloadFromHub ( id: hfRepo, filename: " config.json " )
484+ let configData = try Data ( contentsOf: configUrl)
485+ let decoder = JSONDecoder ( )
486+ decoder. keyDecodingStrategy = . convertFromSnakeCase
487+ let cfg = try decoder. decode ( QwenConfig . self, from: configData)
488+
489+ let url = try await ev. downloadFromHub ( id: hfRepo, filename: " model.safetensors " )
490+ qwen = try await ev. makeQwen ( url, cfg)
491+ await ev. setModelInfo ( " model built " )
492+ tokenizer = try await AutoTokenizer . from ( pretrained: hfRepo)
493+ }
494+
495+ mutating func reset( ) {
496+ }
497+
498+ mutating func onMicrophonePcm( _ pcm: MLXArray , ap: AudioPlayer , ev: Evaluator ) -> Bool {
499+ let sampler = Sampler ( )
500+ let cache = qwen. makeCache ( bSize: 1 )
501+
502+ var lastToken = 151643
503+ for stepIdx in 0 ... 500 {
504+ let logits = qwen ( MLXArray ( [ lastToken] ) . reshaped ( 1 , 1 ) , cache: cache)
505+ let ( tok, _) = sampler ( logits: logits [ 0 ] )
506+ lastToken = tok . item < Int > ( )
507+ let s = tokenizer. decode ( tokens: [ lastToken] )
508+ Task { @MainActor in
509+ ev. output += s
510+ }
511+ }
512+
513+ return false
514+ }
515+ }
516+
458517struct HeliumModel : Model {
459518 let helium : LM
460519 let vocab : [ Int : String ]
0 commit comments