Skip to content

Commit 25eb7aa

Browse files
Glorfathitten
andauthored
Introduce evaluation API (NVIDIA-NeMo#11895)
* Introduce evaluation API Signed-off-by: Michal Bien <mbien@nvidia.com> --------- Signed-off-by: Michal Bien <mbien@nvidia.com> Signed-off-by: Glorf <Glorf@users.noreply.github.com> Co-authored-by: Glorf <Glorf@users.noreply.github.com> Co-authored-by: Abhishree Thittenamane <47577437+athitten@users.noreply.github.com>
1 parent 2f66ada commit 25eb7aa

File tree

3 files changed

+118
-54
lines changed

3 files changed

+118
-54
lines changed

nemo/collections/llm/api.py

Lines changed: 32 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing_extensions import Annotated
2727

2828
import nemo.lightning as nl
29+
from nemo.collections.llm.evaluation.api import EvaluationConfig, EvaluationTarget
2930
from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig
3031
from nemo.lightning import (
3132
AutoResume,
@@ -432,56 +433,21 @@ def deploy(
432433

433434

434435
def evaluate(
435-
nemo_checkpoint_path: AnyPath,
436-
url: str = "grpc://0.0.0.0:8001",
437-
triton_http_port: int = 8000,
438-
model_name: str = "triton_model",
439-
eval_task: str = "gsm8k",
440-
num_fewshot: Optional[int] = None,
441-
limit: Optional[Union[int, float]] = None,
442-
bootstrap_iters: int = 100000,
443-
# inference params
444-
batch_size: Optional[int] = 1,
445-
max_tokens_to_generate: Optional[int] = 256,
446-
temperature: Optional[float] = 0.000000001,
447-
top_p: Optional[float] = 0.0,
448-
top_k: Optional[int] = 1,
449-
add_bos: Optional[bool] = False,
436+
target_cfg: EvaluationTarget,
437+
eval_cfg: EvaluationConfig = EvaluationConfig(type="gsm8k"),
450438
):
451439
"""
452440
Evaluates nemo model deployed on PyTriton server (via trtllm) using lm-evaluation-harness
453441
(https://github.com/EleutherAI/lm-evaluation-harness/tree/main).
454442
455443
Args:
456-
nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt
457-
which is required to tokenize the evaluation input and output prompts.
458-
url (str): grpc service url that were used in the deploy method above
459-
in the format: grpc://{grpc_service_ip}:{grpc_port}.
460-
triton_http_port (int): HTTP port that was used for the PyTriton server in the deploy method. Default: 8000.
461-
Please pass the triton_http_port if using a custom port in the deploy method.
462-
model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as
463-
triton_model_name passed to the deploy method above to be able to launch evaluation. Deafult: "triton_model".
464-
eval_task (str): task to be evaluated on. For ex: "gsm8k", "gsm8k_cot", "mmlu", "lambada". Default: "gsm8k".
465-
These are the tasks that are supported currently. Any other task of type generate_until or loglikelihood from
466-
lm-evaluation-harness can be run, but only the above mentioned ones are tested. Tasks of type
467-
loglikelihood_rolling are not supported yet.
468-
num_fewshot (int): number of examples in few-shot context. Default: None.
469-
limit (Union[int, float]): Limit the number of examples per task. If <1 (i.e float val between 0 and 1), limit
470-
is a percentage of the total number of examples. If int say x, then run evaluation only on x number of samples
471-
from the eval dataset. Default: None, which means eval is run the entire dataset.
472-
bootstrap_iters (int): Number of iterations for bootstrap statistics, used when calculating stderrs. Set to 0
473-
for no stderr calculations to be performed. Default: 100000.
474-
# inference params
475-
temperature: Optional[float]: float value between 0 and 1. temp of 0 indicates greedy decoding, where the token
476-
with highest prob is chosen. Temperature can't be set to 0.0 currently, due to a bug with TRTLLM
477-
(# TODO to be investigated). Hence using a very samll value as the default. Default: 0.000000001.
478-
top_p: Optional[float]: float value between 0 and 1. limits to the top tokens within a certain probability.
479-
top_p=0 means the model will only consider the single most likely token for the next prediction. Default: 0.0.
480-
top_k: Optional[int]: limits to a certain number (K) of the top tokens to consider. top_k=1 means the model
481-
will only consider the single most likely token for the next prediction. Default: 1
482-
add_bos: Optional[bool]: whether a special token representing the beginning of a sequence should be added when
483-
encoding a string. Default: False since typically for CausalLM its set to False. If needed set add_bos to True.
444+
target_cfg (EvaluationTarget): target of the evaluation. Providing nemo_checkpoint_path, model_id and url in EvaluationTarget.api_endpoint is required to run evaluations.
445+
eval_cfg (EvaluationConfig): configuration for evaluations. Default type (task): gsm8k.
484446
"""
447+
448+
if target_cfg.api_endpoint.nemo_checkpoint_path is None:
449+
raise ValueError("Please provide nemo_checkpoint_path in your target_cfg.")
450+
485451
try:
486452
# lm-evaluation-harness import
487453
from lm_eval import evaluator
@@ -490,22 +456,37 @@ def evaluate(
490456
"Please ensure that lm-evaluation-harness is installed in your env as it is required " "to run evaluations"
491457
)
492458

493-
from nemo.collections.llm import evaluation
459+
from nemo.collections.llm.evaluation.base import NeMoFWLMEval, wait_for_server_ready
494460

495461
# Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt.
496-
tokenizer = io.load_context(nemo_checkpoint_path + "/context", subpath="model.tokenizer")
462+
endpoint = target_cfg.api_endpoint
463+
tokenizer = io.load_context(endpoint.nemo_checkpoint_path + "/context", subpath="model.tokenizer")
464+
497465
# Wait for server to be ready before starting evaluation
498-
evaluation.wait_for_server_ready(url=url, triton_http_port=triton_http_port, model_name=model_name)
466+
wait_for_server_ready(
467+
url=endpoint.url, triton_http_port=endpoint.nemo_triton_http_port, model_name=endpoint.model_id
468+
)
499469
# Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate
500-
model = evaluation.NeMoFWLMEval(
501-
model_name, url, tokenizer, batch_size, max_tokens_to_generate, temperature, top_p, top_k, add_bos
470+
params = eval_cfg.params
471+
model = NeMoFWLMEval(
472+
model_name=endpoint.model_id,
473+
api_url=endpoint.url,
474+
tokenizer=tokenizer,
475+
batch_size=params.batch_size,
476+
max_tokens_to_generate=params.max_new_tokens,
477+
temperature=params.temperature,
478+
top_p=params.top_p,
479+
top_k=params.top_k,
480+
add_bos=params.add_bos,
502481
)
482+
483+
eval_task = eval_cfg.type
503484
results = evaluator.simple_evaluate(
504485
model=model,
505486
tasks=eval_task,
506-
limit=limit,
507-
num_fewshot=num_fewshot,
508-
bootstrap_iters=bootstrap_iters,
487+
limit=params.limit_samples,
488+
num_fewshot=params.num_fewshot,
489+
bootstrap_iters=params.bootstrap_iters,
509490
)
510491

511492
print("score", results["results"][eval_task])
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from nemo.collections.llm.evaluation.base import NeMoFWLMEval, wait_for_server_ready
2-
3-
__all__ = ["NeMoFWLMEval", "wait_for_server_ready"]
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional
16+
17+
from pydantic import BaseModel, Field
18+
19+
20+
class ApiEndpoint(BaseModel):
21+
"""
22+
Represents evaluation Standard API target.api_endpoint object
23+
"""
24+
25+
url: str = Field(description="Url of the model", default="http://0.0.0.0:8000")
26+
model_id: str = Field(description="Name of the model in API", default="triton_model")
27+
nemo_checkpoint_path: Optional[str] = Field(
28+
description="Path for nemo 2.0 checkpoint",
29+
default=None,
30+
)
31+
nemo_triton_http_port: Optional[int] = Field(
32+
description="HTTP port that was used for the PyTriton server in the deploy method. Default: 8000.",
33+
default=8000,
34+
)
35+
36+
37+
class EvaluationTarget(BaseModel):
38+
"""
39+
Represents evaluation Standard API target object
40+
"""
41+
42+
api_endpoint: ApiEndpoint = Field(description="Api endpoint to be used for evaluation")
43+
44+
45+
class ConfigParams(BaseModel):
46+
"""
47+
Represents evaluation Standard API config.params object
48+
"""
49+
50+
top_p: float = Field(
51+
description="Limits to the top tokens within a certain probability",
52+
default=0.9999999,
53+
)
54+
temperature: float = Field(
55+
description="Temp of 0 indicates greedy decoding, where the token with highest prob is chosen",
56+
default=0.0000001,
57+
)
58+
limit_samples: Optional[int] = Field(
59+
description="Limit evaluation to `limit` samples. Default: use all samples", default=None
60+
)
61+
num_fewshot: Optional[int] = Field(
62+
description="Number of examples in few-shot context. Default: None.", default=None
63+
)
64+
max_new_tokens: Optional[int] = Field(description="max tokens to generate", default=256)
65+
batch_size: Optional[int] = Field(description="batch size to use for evaluation", default=1)
66+
top_k: Optional[int] = Field(
67+
description="Limits to a certain number (K) of the top tokens to consider",
68+
default=1,
69+
)
70+
add_bos: Optional[bool] = Field(
71+
description="whether a special bos token should be added when encoding a string",
72+
default=False,
73+
)
74+
bootstrap_iters: int = Field(
75+
description="Number of iterations for bootstrap statistics",
76+
default=100000,
77+
)
78+
79+
80+
class EvaluationConfig(BaseModel):
81+
"""
82+
Represents evaluation Standard API config object
83+
"""
84+
85+
type: str = Field(description="Name/type of the task")
86+
params: ConfigParams = Field(description="Parameters to be used for evaluation", default=ConfigParams())

0 commit comments

Comments
 (0)