Skip to content

Commit 8f9658a

Browse files
Add support for qwen. (#12)
* Add support for qwen. * CLI to run qwen. * Get some sampling to work.
1 parent 51a2fb0 commit 8f9658a

File tree

3 files changed

+197
-1
lines changed

3 files changed

+197
-1
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ run-mimi: build
1515
run-helium: build
1616
./build/Build/Products/Release/MoshiCLI run-helium
1717

18+
run-qwen: build
19+
./build/Build/Products/Release/MoshiCLI run-qwen
20+
1821
build:
1922
xcodebuild -scheme moshi-cli -derivedDataPath ./build

MoshiCLI/CLI.swift

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ struct Moshi: ParsableCommand {
3636
static let configuration = CommandConfiguration(
3737
subcommands: [
3838
Run.self, RunHelium.self, RunMimi.self, AudioToCodes.self, CodesToAudio.self,
39-
RunAsr.self,
39+
RunAsr.self, RunQwen.self,
4040
]
4141
)
4242
}
@@ -105,6 +105,44 @@ public enum HeliumConfig: String, CaseIterable, ExpressibleByArgument {
105105
case bf16
106106
}
107107

108+
struct RunQwen: ParsableCommand {
109+
@Option(help: "the config")
110+
var hfRepo: String = "Qwen/Qwen2.5-0.5B"
111+
112+
mutating func run() throws {
113+
let configUrl = try downloadFromHub(id: hfRepo, filename: "config.json")
114+
let configData = try Data(contentsOf: configUrl)
115+
let decoder = JSONDecoder()
116+
decoder.keyDecodingStrategy = .convertFromSnakeCase
117+
let config = try decoder.decode(QwenConfig.self, from: configData)
118+
print("config \(config)")
119+
let modelUrl = try downloadFromHub(id: hfRepo, filename: "model.safetensors")
120+
print("model \(modelUrl)")
121+
let weights = try loadArrays(url: modelUrl)
122+
guard let modelItem = ModuleParameters.unflattened(weights)["model"] else {
123+
fatalError("no model key in {configUrl}")
124+
}
125+
let parameters =
126+
switch modelItem {
127+
case .dictionary(let d): NestedDictionary(values: d)
128+
default: fatalError("model key in {configUrl} is not a dict")
129+
}
130+
131+
let model = QwenModel(config)
132+
try model.update(parameters: parameters, verify: [.all])
133+
eval(model)
134+
let cache = model.makeCache(bSize: 1)
135+
let sampler = Sampler()
136+
var lastToken = 0
137+
for _ in 0...100 {
138+
let logits = model(MLXArray([lastToken]).reshaped(1, 1), cache: cache)
139+
let (tok, _) = sampler(logits: logits[0])
140+
lastToken = tok.item<Int>()
141+
print("sampled \(lastToken)")
142+
}
143+
}
144+
}
145+
108146
struct RunHelium: ParsableCommand {
109147
@Option(help: "the config")
110148
var config: HeliumConfig = .q4

MoshiLib/Qwen2.swift

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
// Copyright (c) Kyutai, all rights reserved.
2+
// This source code is licensed under the license found in the
3+
// LICENSE file in the root directory of this source tree.
4+
5+
import MLX
6+
import MLXFast
7+
import MLXNN
8+
9+
public struct QwenConfig: Codable {
10+
public var bosTokenId: Int
11+
public var eosTokenId: Int
12+
public var hiddenSize: Int
13+
public var intermediateSize: Int
14+
public var maxPositionEmbeddings: Int
15+
public var maxWindowLayers: Int
16+
public var numAttentionHeads: Int
17+
public var numHiddenLayers: Int
18+
public var numKeyValueHeads: Int
19+
public var rmsNormEps: Float
20+
public var ropeTheta: Float
21+
public var tieWordEmbeddings: Bool
22+
public var useSlidingWindow: Bool
23+
public var vocabSize: Int
24+
25+
public func headDim() -> Int {
26+
self.hiddenSize / self.numAttentionHeads
27+
}
28+
}
29+
30+
private class Mlp: Module, UnaryLayer {
31+
@ModuleInfo(key: "gate_proj") var gateProj: Linear
32+
@ModuleInfo(key: "down_proj") var downProj: Linear
33+
@ModuleInfo(key: "up_proj") var upProj: Linear
34+
35+
init(_ cfg: QwenConfig) {
36+
self._gateProj.wrappedValue = Linear(cfg.hiddenSize, cfg.intermediateSize, bias: false)
37+
self._upProj.wrappedValue = Linear(cfg.hiddenSize, cfg.intermediateSize, bias: false)
38+
self._downProj.wrappedValue = Linear(cfg.intermediateSize, cfg.hiddenSize, bias: false)
39+
}
40+
41+
func callAsFunction(_ x: MLXArray) -> MLXArray {
42+
return downProj(silu(gateProj(x)) * upProj(x))
43+
}
44+
}
45+
46+
private class Attention: Module {
47+
let cfg: QwenConfig
48+
let scale: Float
49+
let rope: RoPE
50+
51+
@ModuleInfo(key: "q_proj") var qProj: Linear
52+
@ModuleInfo(key: "k_proj") var kProj: Linear
53+
@ModuleInfo(key: "v_proj") var vProj: Linear
54+
@ModuleInfo(key: "o_proj") var oProj: Linear
55+
56+
init(_ cfg: QwenConfig) {
57+
self.cfg = cfg
58+
self.scale = 1.0 / sqrt(Float(cfg.headDim()))
59+
let headDim = cfg.headDim()
60+
self._qProj.wrappedValue = Linear(
61+
cfg.hiddenSize, cfg.numAttentionHeads * headDim, bias: true)
62+
self._kProj.wrappedValue = Linear(
63+
cfg.hiddenSize, cfg.numKeyValueHeads * headDim, bias: true)
64+
self._vProj.wrappedValue = Linear(
65+
cfg.hiddenSize, cfg.numKeyValueHeads * headDim, bias: true)
66+
self._oProj.wrappedValue = Linear(
67+
cfg.numAttentionHeads * headDim, cfg.hiddenSize, bias: false)
68+
self.rope =
69+
RoPE(dimensions: cfg.headDim(), traditional: false, base: Float(cfg.ropeTheta))
70+
}
71+
72+
func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache?) -> MLXArray {
73+
let (B, T, H) = (x.dim(0), x.dim(1), x.dim(2))
74+
let headDim = cfg.headDim()
75+
var queryStates = qProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3)
76+
var keyStates = kProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3)
77+
var valueStates = vProj(x).reshaped(B, T, -1, headDim).transposed(0, 2, 1, 3)
78+
let offset = cache?.offset ?? 0
79+
queryStates = rope(queryStates, offset: offset)
80+
keyStates = rope(keyStates, offset: offset)
81+
if let cache {
82+
(keyStates, valueStates) = cache.update(keys: keyStates, values: valueStates)
83+
}
84+
// sliding window is not supported here
85+
var mask = mask
86+
if let m = mask {
87+
let maskLen = m.dim(-1)
88+
if keyStates.dim(2) < maskLen {
89+
let offset = maskLen - keyStates.dim(2)
90+
mask = m[0..., offset...]
91+
}
92+
}
93+
let x = MLXFast.scaledDotProductAttention(
94+
queries: queryStates, keys: keyStates, values: valueStates, scale: self.scale,
95+
mask: mask
96+
).transposed(0, 2, 1, 3).reshaped(B, T, H)
97+
return oProj(x)
98+
}
99+
}
100+
101+
private class Layer: Module {
102+
@ModuleInfo(key: "mlp") var mlp: Mlp
103+
@ModuleInfo(key: "input_layernorm") var inputNorm: RMSNorm
104+
@ModuleInfo(key: "post_attention_layernorm") var postAttnNorm: RMSNorm
105+
@ModuleInfo(key: "self_attn") var selfAttn: Attention
106+
107+
init(_ cfg: QwenConfig) {
108+
self._mlp.wrappedValue = Mlp(cfg)
109+
self._inputNorm.wrappedValue = RMSNorm(dimensions: cfg.hiddenSize, eps: cfg.rmsNormEps)
110+
self._postAttnNorm.wrappedValue = RMSNorm(dimensions: cfg.hiddenSize, eps: cfg.rmsNormEps)
111+
self._selfAttn.wrappedValue = Attention(cfg)
112+
}
113+
114+
func callAsFunction(_ x: MLXArray, mask: MLXArray?, cache: KVCache) -> MLXArray {
115+
var residual = x
116+
var x = x
117+
x = selfAttn(inputNorm(x), mask: mask, cache: cache)
118+
x = residual + x
119+
residual = x
120+
x = mlp(postAttnNorm(x))
121+
return residual + x
122+
}
123+
}
124+
125+
public class QwenModel: Module {
126+
let cfg: QwenConfig
127+
private let norm: RMSNorm
128+
private let layers: [Layer]
129+
@ModuleInfo(key: "embed_tokens") var embedTokens: Embedding
130+
131+
public init(_ cfg: QwenConfig) {
132+
self.cfg = cfg
133+
self.layers = (0..<cfg.numHiddenLayers).map { _ in Layer(cfg) }
134+
self.norm = RMSNorm(dimensions: cfg.hiddenSize, eps: cfg.rmsNormEps)
135+
self._embedTokens.wrappedValue = Embedding(
136+
embeddingCount: cfg.vocabSize, dimensions: cfg.hiddenSize)
137+
}
138+
139+
public func callAsFunction(_ x: MLXArray, cache: [KVCache]) -> MLXArray {
140+
var x = embedTokens(x)
141+
let mask = cache.first?.createAttentionMask(h: x)
142+
for (layer, c) in zip(self.layers, cache) {
143+
x = layer(x, mask: mask, cache: c)
144+
}
145+
return embedTokens.asLinear(norm(x))
146+
}
147+
148+
public func makeCache(bSize: Int) -> [KVCache] {
149+
let kvHeads = cfg.numKeyValueHeads
150+
let cache = (0..<cfg.numHiddenLayers).map { _ in
151+
KVCacheSimple(headDim: .init(cfg.headDim()), kvHeads: kvHeads)
152+
}
153+
return cache
154+
}
155+
}

0 commit comments

Comments
 (0)