44# Please take a look at the existing test_llm_api_pytorch.py file for reference.
55import concurrent
66import contextlib
7+ import json
78import os
89import tempfile
910import time
1920from tensorrt_llm .executor .result import GenerationResultBase
2021from tensorrt_llm .llmapi import CompletionOutput , RequestOutput , SamplingParams
2122from tensorrt_llm .llmapi .llm_args import LlmArgs
23+ from tensorrt_llm .llmapi .tokenizer import load_hf_tokenizer
2224
2325from ..conftest import (get_device_count , llm_models_root , parametrize_with_ids ,
2426 skip_pre_hopper )
2527from ..trt_test_alternative import popen
26- from .accuracy_core import (GSM8K , MMLU , LlmapiAccuracyTestHarness ,
27- get_accuracy_task )
28+ from .accuracy_core import (GSM8K , MMLU , JsonModeEval ,
29+ LlmapiAccuracyTestHarness , get_accuracy_task )
2830
2931
3032class Result (GenerationResultBase ):
@@ -43,7 +45,7 @@ def result(self):
4345 return self
4446
4547
46- DuckLLM = namedtuple ('DuckLLM' , ['args' , 'generate_async' ])
48+ DuckLLM = namedtuple ('DuckLLM' , ['args' , 'tokenizer' , ' generate_async' ])
4749
4850
4951class MyThreadPoolExecutor (ThreadPoolExecutor ):
@@ -162,17 +164,35 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
162164
163165 def send_request (prompt : str , sampling_params : SamplingParams ,
164166 streaming : bool ):
165- response = client .completions .create (
166- model = model_name ,
167- prompt = prompt ,
168- stream = streaming ,
169- ** ({
170- "max_tokens" : sampling_params .max_tokens ,
171- "temperature" : sampling_params .temperature ,
172- "top_p" : sampling_params .top_p ,
173- "stop" : sampling_params .stop ,
174- "seed" : sampling_params .seed
175- } if sampling_params else {}))
167+ kwargs = {}
168+ if sampling_params is not None :
169+ kwargs .update (max_tokens = sampling_params .max_tokens ,
170+ temperature = sampling_params .temperature ,
171+ top_p = sampling_params .top_p ,
172+ stop = sampling_params .stop ,
173+ seed = sampling_params .seed )
174+ if (guided_decoding_params :=
175+ sampling_params .guided_decoding ) is not None :
176+ extra_body = {}
177+ if (schema := guided_decoding_params .json ) is not None :
178+ extra_body .update (response_format = {
179+ "type" : "json" ,
180+ "schema" : json .loads (schema )
181+ })
182+ elif guided_decoding_params .json_object :
183+ extra_body .update (
184+ response_format = {"type" : "json_object" })
185+ else :
186+ # TODO: Support other guided decoding types
187+ raise ValueError (
188+ f"Unsupported guided decoding params: { guided_decoding_params } ."
189+ )
190+ kwargs .update (extra_body = extra_body )
191+
192+ response = client .completions .create (model = model_name ,
193+ prompt = prompt ,
194+ stream = streaming ,
195+ ** kwargs )
176196 result = Result (id = 0 ,
177197 sampling_params = sampling_params ,
178198 outputs = [
@@ -192,8 +212,10 @@ def generate_async(prompt: str,
192212 thread_pool .futures .append (future )
193213 return future
194214
215+ tokenizer = load_hf_tokenizer (model_name )
216+
195217 try :
196- yield DuckLLM (args , generate_async )
218+ yield DuckLLM (args , tokenizer , generate_async )
197219 finally :
198220 ctx_server .terminate ()
199221 gen_server .terminate ()
@@ -394,6 +416,95 @@ def test_eagle3(self, overlap_scheduler, eagle3_one_model):
394416 task = GSM8K (self .MODEL_NAME )
395417 task .evaluate (llm )
396418
419+ @pytest .mark .skip_less_device_memory (32000 )
420+ @pytest .mark .parametrize ("backend" , ["xgrammar" , "llguidance" ])
421+ def test_guided_decoding (self , backend : str , mocker ):
422+ mocker .patch .dict (os .environ , {"TRTLLM_XGUIDANCE_LENIENT" : "1" })
423+ ctx_server_config = {
424+ "disable_overlap_scheduler" : True ,
425+ "guided_decoding_backend" : backend ,
426+ "cache_transceiver_config" : {
427+ "backend" : "default"
428+ }
429+ }
430+ gen_server_config = {
431+ "guided_decoding_backend" : backend ,
432+ "cache_transceiver_config" : {
433+ "backend" : "default"
434+ }
435+ }
436+ disaggregated_server_config = {
437+ "hostname" : "localhost" ,
438+ "port" : 8000 ,
439+ "backend" : "pytorch" ,
440+ "context_servers" : {
441+ "num_instances" : 1 ,
442+ "urls" : ["localhost:8001" ]
443+ },
444+ "generation_servers" : {
445+ "num_instances" : 1 ,
446+ "urls" : ["localhost:8002" ]
447+ }
448+ }
449+ with launch_disaggregated_llm (disaggregated_server_config ,
450+ ctx_server_config , gen_server_config ,
451+ self .MODEL_PATH ) as llm :
452+ task = JsonModeEval (self .MODEL_NAME )
453+ task .evaluate (llm )
454+
455+ @pytest .mark .skip_less_device_memory (32000 )
456+ @pytest .mark .parametrize ("backend" , ["xgrammar" , "llguidance" ])
457+ def test_guided_decoding_with_eagle3 (self , backend : str , mocker ):
458+ mocker .patch .dict (os .environ , {"TRTLLM_XGUIDANCE_LENIENT" : "1" })
459+ speculative_decoding_config = {
460+ "decoding_type" : "Eagle" ,
461+ "max_draft_len" : 3 ,
462+ "speculative_model_dir" :
463+ f"{ llm_models_root ()} /EAGLE3-LLaMA3.1-Instruct-8B" ,
464+ "eagle3_one_model" : False
465+ }
466+
467+ ctx_server_config = {
468+ "disable_overlap_scheduler" : True ,
469+ "speculative_config" : speculative_decoding_config ,
470+ "kv_cache_config" : {
471+ "free_gpu_memory_fraction" : 0.8 ,
472+ },
473+ "guided_decoding_backend" : backend ,
474+ "cache_transceiver_config" : {
475+ "backend" : "default"
476+ }
477+ }
478+ gen_server_config = {
479+ "disable_overlap_scheduler" : True ,
480+ "speculative_config" : speculative_decoding_config ,
481+ "kv_cache_config" : {
482+ "free_gpu_memory_fraction" : 0.8 ,
483+ },
484+ "guided_decoding_backend" : backend ,
485+ "cache_transceiver_config" : {
486+ "backend" : "default"
487+ }
488+ }
489+ disaggregated_server_config = {
490+ "hostname" : "localhost" ,
491+ "port" : 8000 ,
492+ "backend" : "pytorch" ,
493+ "context_servers" : {
494+ "num_instances" : 1 ,
495+ "urls" : ["localhost:8001" ]
496+ },
497+ "generation_servers" : {
498+ "num_instances" : 1 ,
499+ "urls" : ["localhost:8002" ]
500+ }
501+ }
502+ with launch_disaggregated_llm (disaggregated_server_config ,
503+ ctx_server_config , gen_server_config ,
504+ self .MODEL_PATH ) as llm :
505+ task = JsonModeEval (self .MODEL_NAME )
506+ task .evaluate (llm )
507+
397508 @pytest .mark .skip_less_device (2 )
398509 @pytest .mark .parametrize ("tp,pp" , [(1 , 2 ), (2 , 1 ), (2 , 2 )],
399510 ids = ["tp1pp2" , "tp2pp1" , "tp2pp2" ])
0 commit comments