11#include " common.h"
22#include " llama.h"
3+ #include " common/common.h"
34
45#include " binding.h"
56
@@ -123,10 +124,10 @@ int eval(void *params_ptr, void *state_pr, char *text)
123124 llama_context *ctx = (llama_context *)state_pr;
124125
125126 auto n_past = 0 ;
126- auto last_n_tokens_data = std::vector<llama_token>(params_p->repeat_last_n , 0 );
127+ auto last_n_tokens_data = std::vector<llama_token>(params_p->sparams . penalty_repeat , 0 );
127128
128129 auto tokens = std::vector<llama_token>(params_p->n_ctx );
129- auto n_prompt_tokens = llama_tokenize (llama_get_model (ctx), text, strlen (text), tokens.data (), tokens.size (), true );
130+ auto n_prompt_tokens = llama_tokenize (llama_get_model (ctx), text, strlen (text), tokens.data (), tokens.size (), true , false );
130131
131132 if (n_prompt_tokens < 1 )
132133 {
@@ -277,7 +278,7 @@ int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug)
277278 // do one empty run to warm up the model
278279 {
279280 llama_token tmp[1 ] = {
280- llama_token_bos (ctx),
281+ llama_token_bos (llama_get_model ( ctx) ),
281282 };
282283 llama_eval (ctx, tmp, 1 , 0 );
283284 llama_reset_timings (ctx);
@@ -370,19 +371,19 @@ int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug)
370371 if ((int )embd_inp.size () <= n_consumed)
371372 {
372373 // out of user input, sample next token
373- const float temp = params_p->temp ;
374- const int32_t top_k = params_p->top_k <= 0 ? llama_n_vocab (llama_get_model (ctx)) : params_p->top_k ;
375- const float top_p = params_p->top_p ;
376- const float tfs_z = params_p->tfs_z ;
377- const float typical_p = params_p->typical_p ;
378- const int32_t repeat_last_n = params_p->repeat_last_n < 0 ? n_ctx : params_p->repeat_last_n ;
379- const float repeat_penalty = params_p->repeat_penalty ;
380- const float alpha_presence = params_p->presence_penalty ;
381- const float alpha_frequency = params_p->frequency_penalty ;
382- const int mirostat = params_p->mirostat ;
383- const float mirostat_tau = params_p->mirostat_tau ;
384- const float mirostat_eta = params_p->mirostat_eta ;
385- const bool penalize_nl = params_p->penalize_nl ;
374+ const float temp = params_p->sparams . temp ;
375+ const int32_t top_k = params_p->sparams . top_k <= 0 ? llama_n_vocab (llama_get_model (ctx)) : params_p->sparams . top_k ;
376+ const float top_p = params_p->sparams . top_p ;
377+ const float tfs_z = params_p->sparams . tfs_z ;
378+ const float typical_p = params_p->sparams . typical_p ;
379+ const int32_t repeat_last_n = params_p->sparams . penalty_last_n < 0 ? n_ctx : params_p->sparams . penalty_last_n ;
380+ const float repeat_penalty = params_p->sparams . penalty_repeat ;
381+ const float alpha_presence = params_p->sparams . penalty_present ;
382+ const float alpha_frequency = params_p->sparams . penalty_freq ;
383+ const int mirostat = params_p->sparams . mirostat ;
384+ const float mirostat_tau = params_p->sparams . mirostat_tau ;
385+ const float mirostat_eta = params_p->sparams . mirostat_eta ;
386+ const bool penalize_nl = params_p->sparams . penalize_nl ;
386387
387388 // optionally save the session on first sample (for faster prompt loading next time)
388389 if (!path_session.empty () && need_to_save_session && !params_p->prompt_cache_ro )
@@ -398,7 +399,7 @@ int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug)
398399 auto n_vocab = llama_n_vocab (llama_get_model (ctx));
399400
400401 // Apply params_p->logit_bias map
401- for (auto it = params_p->logit_bias .begin (); it != params_p->logit_bias .end (); it++)
402+ for (auto it = params_p->sparams . logit_bias .begin (); it != params_p->sparams . logit_bias .end (); it++)
402403 {
403404 logits[it->first ] += it->second ;
404405 }
@@ -413,17 +414,14 @@ int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug)
413414 llama_token_data_array candidates_p = {candidates.data (), candidates.size (), false };
414415
415416 // Apply penalties
416- float nl_logit = logits[llama_token_nl (ctx)];
417+ float nl_logit = logits[llama_token_nl (llama_get_model ( ctx) )];
417418 auto last_n_repeat = std::min (std::min ((int )last_n_tokens.size (), repeat_last_n), n_ctx);
418- llama_sample_repetition_penalty (ctx, &candidates_p,
419+ llama_sample_repetition_penalties (ctx, &candidates_p,
419420 last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
420- last_n_repeat, repeat_penalty);
421- llama_sample_frequency_and_presence_penalties (ctx, &candidates_p,
422- last_n_tokens.data () + last_n_tokens.size () - last_n_repeat,
423- last_n_repeat, alpha_frequency, alpha_presence);
421+ last_n_repeat, repeat_penalty, alpha_frequency, alpha_presence);
424422 if (!penalize_nl)
425423 {
426- logits[llama_token_nl (ctx)] = nl_logit;
424+ logits[llama_token_nl (llama_get_model ( ctx) )] = nl_logit;
427425 }
428426
429427 if (temp <= 0 )
@@ -523,7 +521,7 @@ int llama_predict(void *params_ptr, void *state_pr, char *result, bool debug)
523521 }
524522
525523 // end of text token
526- if (!embd.empty () && embd.back () == llama_token_eos (ctx))
524+ if (!embd.empty () && embd.back () == llama_token_eos (llama_get_model ( ctx) ))
527525 {
528526 break ;
529527 }
@@ -635,15 +633,15 @@ void *llama_allocate_params(const char *prompt, int seed, int threads, int token
635633 params->n_threads = threads;
636634 params->n_threads_batch = threads;
637635 params->n_predict = tokens;
638- params->repeat_last_n = repeat_last_n;
636+ params->sparams . penalty_last_n = repeat_last_n;
639637 params->prompt_cache_ro = prompt_cache_ro;
640- params->top_k = top_k;
641- params->top_p = top_p;
638+ params->sparams . top_k = top_k;
639+ params->sparams . top_p = top_p;
642640 params->memory_f16 = memory_f16;
643- params->temp = temp;
641+ params->sparams . temp = temp;
644642 params->use_mmap = mmap;
645643 params->use_mlock = mlock;
646- params->repeat_penalty = repeat_penalty;
644+ params->sparams . penalty_repeat = repeat_penalty;
647645 params->n_batch = n_batch;
648646 params->n_keep = n_keep;
649647 if (maingpu[0 ] != ' \0 ' )
@@ -685,22 +683,22 @@ void *llama_allocate_params(const char *prompt, int seed, int threads, int token
685683 {
686684 params->antiprompt = create_vector (antiprompt, antiprompt_count);
687685 }
688- params->tfs_z = tfs_z;
689- params->typical_p = typical_p;
690- params->presence_penalty = presence_penalty;
691- params->mirostat = mirostat;
692- params->mirostat_eta = mirostat_eta;
693- params->mirostat_tau = mirostat_tau;
694- params->penalize_nl = penalize_nl;
686+ params->sparams . tfs_z = tfs_z;
687+ params->sparams . typical_p = typical_p;
688+ params->sparams . penalty_present = presence_penalty;
689+ params->sparams . mirostat = mirostat;
690+ params->sparams . mirostat_eta = mirostat_eta;
691+ params->sparams . mirostat_tau = mirostat_tau;
692+ params->sparams . penalize_nl = penalize_nl;
695693 std::stringstream ss (logit_bias);
696694 llama_token key;
697695 char sign;
698696 std::string value_str;
699697 if (ss >> key && ss >> sign && std::getline (ss, value_str) && (sign == ' +' || sign == ' -' ))
700698 {
701- params->logit_bias [key] = std::stof (value_str) * ((sign == ' -' ) ? -1 .0f : 1 .0f );
699+ params->sparams . logit_bias [key] = std::stof (value_str) * ((sign == ' -' ) ? -1 .0f : 1 .0f );
702700 }
703- params->frequency_penalty = frequency_penalty;
701+ params->sparams . penalty_freq = frequency_penalty;
704702 params->prompt = prompt;
705703
706704 return params;
0 commit comments