-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathRunHelium.swift
More file actions
55 lines (50 loc) · 1.76 KB
/
RunHelium.swift
File metadata and controls
55 lines (50 loc) · 1.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
// Copyright (c) Kyutai, all rights reserved.
// This source code is licensed under the license found in the
// LICENSE file in the root directory of this source tree.
import AVFoundation
import Foundation
import MLX
import MLXNN
import MoshiLib
func makeHelium(_ url: URL, _ cfg: LmConfig) throws -> LM {
let weights = try loadArrays(url: url)
let parameters = ModuleParameters.unflattened(weights)
let model = LM(cfg, bSize: 1)
if url.lastPathComponent.hasSuffix("q4.safetensors") {
quantize(model: model, groupSize: 64, bits: 4)
} else if url.lastPathComponent.hasSuffix("q6.safetensors") {
quantize(model: model, groupSize: 64, bits: 6)
} else if url.lastPathComponent.hasSuffix("q8.safetensors") {
quantize(model: model, groupSize: 64, bits: 8)
}
try model.update(parameters: parameters, verify: [.all])
eval(model)
return model
}
func runHelium(_ url: URL, cfg: LmConfig) throws {
let stats = PerfStats()
let helium = try makeHelium(url, cfg)
let vocab = try loadVocab(cfg)
helium.warmup()
print("done warming up")
let maxSteps = helium.cfg.transformer.maxSeqLen
let sampler = Sampler()
var lastToken = MLXArray([1])
for stepIdx in 0...maxSteps {
let (textToken, _) = helium.sample(
textIds: lastToken.reshaped([1, 1]), audioIds: [], stepIdx: stepIdx,
textSampler: sampler,
audioSampler: sampler, cb: stats)
let textTokenI: Int = textToken[0].item()
if var v = vocab[textTokenI] {
if v == "<0x0A>" {
print()
} else {
v.replace("▁", with: " ")
print(v, terminator: "")
fflush(stdout)
}
}
lastToken = textToken
}
}