UPSTREAM PR #17004: sampling : add support for GPU sampling (wip)#102
Open
UPSTREAM PR #17004: sampling : add support for GPU sampling (wip)#102
Conversation
b16251e to
95f6e9b
Compare
aa2fc28 to
0ad40ce
Compare
5d18032 to
a4d9044
Compare
048ad94 to
6c1fde6
Compare
ef7afbe to
d4c3480
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Mirrored from ggml-org/llama.cpp#17004
This is a work in progress to add support for GPU sampling.
The motivation for this feature is to enable sampling to be performed directly on the GPU as part of the computation graph being executed, allowing for some or all of the sampling to be done on the GPU.
For example, the GPU sampler chain might select/sample a token directly in which case only the sampled token needs to be transferred from device memory to host memory.
It is also possible for the GPU samplers to perform filtering of the logits, or compute and filter the probability distribution, in which case only the filtered logits or probabilites need to be transferred back to system memory for further processing by CPU samplers.
Currently the GPU sampling works in a similar manner to how pooling works, it is a function that is called by build_graph:
GPU samplers can be configured by creating sampler chains, where each sampler chain is associated with a specific sequence id:
The struct is defined as:
These sampler configs are then passed as context params:
llama_context_params cparams = llama_context_default_params(); cparams.samplers = sampler_configs.data(); cparams.n_samplers = sampler_configs.size();When the graph is built, the configured sampler's _apply function is called which allows them to add operations/nodes to the computation graph.
This enables the sampling to happen fully, or partially on the GPU. The samplers could sample a single token in which case that is what will be transferred from the device memory to host memory after llama_decode has been called. The sampled token can then be retrieved using:
Is it also possible to run a GPU sampler that only filters the logits and then only the filtered logits are transferred back to the host and the sampling can proceed on the CPU with the normal (CPU) sampler chain. In this case the CPU samplers are configured as usual but they will now operate on already filtered logits.
Similar to the above handling of logits, it is possible for a GPU samplers to compute the full probability distribution and transfer that to the host. And the CPU samplers can then operate on the those probabilities.
Building and running the tests
Download a model for testing:
$ cd models && wget https://huggingface.co/ggml-org/models/resolve/main/tinyllamas/stories15M-q4_0.ggufBuilding the test:
$ cmake --build build --target test-gpu-sampling -j8Runing all tests:
The following individual tests are available:
These can be run individually, for example:
TODO