1- #!/usr/bin/env python3
21"""
32Unit tests for chunked logits functionality in TensorRT-LLM.
43
@@ -30,7 +29,7 @@ def chunked_request():
3029 sampling_config = SamplingConfig (),
3130 is_streaming = False ,
3231 return_generation_logits = True ,
33- use_chunked_logits = True ,
32+ use_chunked_generation_logits = True ,
3433 logits_chunk_size = 4 )
3534
3635
@@ -43,7 +42,7 @@ def non_chunked_request():
4342 sampling_config = SamplingConfig (),
4443 is_streaming = False ,
4544 return_generation_logits = True ,
46- use_chunked_logits = False )
45+ use_chunked_generation_logits = False )
4746
4847
4948# Test parameters
@@ -62,13 +61,13 @@ def test_initialization(self):
6261 storage = LogitsStorage (seq_length = 10 ,
6362 use_device_memory = True ,
6463 should_exclude_last = False ,
65- use_chunked_logits = False ,
64+ use_chunked_generation_logits = False ,
6665 chunk_size = 8 )
6766
6867 assert storage .seq_length == 10
6968 assert storage .use_device_memory is True
7069 assert storage ._should_exclude_last is False
71- assert storage .use_chunked_logits is False
70+ assert storage .use_chunked_generation_logits is False
7271 assert storage .chunk_size == 8
7372 assert storage ._logits_indices == []
7473 assert storage .beam_width == - 1
@@ -77,10 +76,10 @@ def test_initialization(self):
7776 def test_initialization_chunked_mode (self ):
7877 """Test LogitsStorage initialization in chunked mode"""
7978 storage = LogitsStorage (seq_length = 10 ,
80- use_chunked_logits = True ,
79+ use_chunked_generation_logits = True ,
8180 chunk_size = 4 )
8281
83- assert storage .use_chunked_logits is True
82+ assert storage .use_chunked_generation_logits is True
8483 assert storage .chunk_size == 4
8584 assert hasattr (storage , '_device_fragments' )
8685 assert hasattr (storage , '_current_position' )
@@ -89,23 +88,23 @@ def test_initialization_chunked_mode(self):
8988
9089 def test_append_3d_logits (self , sample_logits ):
9190 """Test appending 3D logits"""
92- storage = LogitsStorage (seq_length = 10 , use_chunked_logits = False )
91+ storage = LogitsStorage (seq_length = 10 , use_chunked_generation_logits = False )
9392 storage .append (sample_logits )
9493
9594 assert storage .beam_width == 1
9695 assert storage .vocab_size == 1000
9796
9897 def test_append_invalid_shape (self ):
9998 """Test appending logits with invalid shape"""
100- storage = LogitsStorage (seq_length = 10 , use_chunked_logits = False )
99+ storage = LogitsStorage (seq_length = 10 , use_chunked_generation_logits = False )
101100
102101 with pytest .raises (AssertionError ):
103102 storage .append (torch .randn (1000 )) # 1D - should fail
104103
105104 def test_append_chunked_mode_streaming (self , sample_logits ):
106105 """Test append behavior in chunked streaming mode"""
107106 storage = LogitsStorage (seq_length = 10 ,
108- use_chunked_logits = True ,
107+ use_chunked_generation_logits = True ,
109108 chunk_size = 1 )
110109 storage .append (sample_logits )
111110
@@ -116,7 +115,7 @@ def test_append_chunked_mode_streaming(self, sample_logits):
116115 def test_append_chunked_mode_non_streaming (self , sample_logits ):
117116 """Test append behavior in chunked non-streaming mode"""
118117 storage = LogitsStorage (seq_length = 10 ,
119- use_chunked_logits = True ,
118+ use_chunked_generation_logits = True ,
120119 chunk_size = 2 )
121120
122121 # Add first fragment
@@ -131,7 +130,7 @@ def test_append_chunked_mode_non_streaming(self, sample_logits):
131130 def test_finalize_transfer_chunked_mode (self , sample_logits ):
132131 """Test finalize_transfer in chunked mode"""
133132 storage = LogitsStorage (seq_length = 10 ,
134- use_chunked_logits = True ,
133+ use_chunked_generation_logits = True ,
135134 chunk_size = 5 )
136135 storage .append (sample_logits )
137136
@@ -145,14 +144,14 @@ def test_finalize_transfer_chunked_mode(self, sample_logits):
145144
146145 def test_finalize_transfer_non_chunked_mode (self ):
147146 """Test finalize_transfer in non-chunked mode (should be no-op)"""
148- storage = LogitsStorage (seq_length = 10 , use_chunked_logits = False )
147+ storage = LogitsStorage (seq_length = 10 , use_chunked_generation_logits = False )
149148
150149 # Should not raise any errors
151150 storage .finalize_transfer ()
152151
153152 def test_storage_overflow (self , sample_logits ):
154153 """Test storage overflow handling"""
155- storage = LogitsStorage (seq_length = 2 , use_chunked_logits = False )
154+ storage = LogitsStorage (seq_length = 2 , use_chunked_generation_logits = False )
156155 storage .append (sample_logits )
157156 storage .append (sample_logits )
158157
@@ -173,7 +172,7 @@ def test_initialization(self):
173172 return_context_logits = True ,
174173 return_generation_logits = True ,
175174 exclude_last_generation_logits = False ,
176- use_chunked_logits = True ,
175+ use_chunked_generation_logits = True ,
177176 chunk_size = 4 )
178177
179178 assert result ._streaming is False
@@ -198,7 +197,7 @@ def test_post_processing_transfer(self, sample_logits):
198197 result = PyResult (prompt_len = 5 ,
199198 max_new_tokens = 10 ,
200199 return_generation_logits = True ,
201- use_chunked_logits = True )
200+ use_chunked_generation_logits = True )
202201
203202 result .append_generation_logits (sample_logits )
204203 result .post_processing_transfer ()
@@ -210,7 +209,8 @@ def test_context_generation_logits_property(self, sample_logits):
210209 result = PyResult (prompt_len = 5 ,
211210 max_new_tokens = 10 ,
212211 return_context_logits = True ,
213- use_chunked_logits = False )
212+ return_generation_logits = True ,
213+ use_chunked_generation_logits = False )
214214
215215 result .append_context_logits (sample_logits )
216216 context_logits = result .context_logits
@@ -225,20 +225,6 @@ def test_context_generation_logits_property(self, sample_logits):
225225 assert generation_logits .shape == (1 , 1 , 1000
226226 ) # Should transpose dimensions
227227
228- def test_generation_logits_property_streaming (self , sample_logits ):
229- """Test generation_logits property in streaming mode"""
230- result = PyResult (prompt_len = 5 ,
231- max_new_tokens = 10 ,
232- return_generation_logits = True ,
233- use_chunked_logits = False ,
234- streaming = True )
235-
236- result .append_generation_logits (sample_logits )
237- generation_logits = result .generation_logits
238-
239- assert generation_logits is not None
240- assert generation_logits .shape == (1 , 1 , 1000 )
241-
242228
243229class TestLlmRequest :
244230 """Unit tests for LlmRequest class"""
@@ -278,7 +264,7 @@ def test_chunked_vs_non_chunked_equivalence(self, sample_logits):
278264 sampling_config = SamplingConfig (),
279265 is_streaming = False ,
280266 return_generation_logits = True ,
281- use_chunked_logits = True ,
267+ use_chunked_generation_logits = True ,
282268 logits_chunk_size = 2 )
283269
284270 # Create non-chunked request
@@ -288,7 +274,7 @@ def test_chunked_vs_non_chunked_equivalence(self, sample_logits):
288274 sampling_config = SamplingConfig (),
289275 is_streaming = False ,
290276 return_generation_logits = True ,
291- use_chunked_logits = False )
277+ use_chunked_generation_logits = False )
292278
293279 # Add same logits to both
294280 for _ in range (5 ):
@@ -319,7 +305,7 @@ def test_streaming_vs_non_streaming_behavior(self, sample_logits):
319305 sampling_config = SamplingConfig (),
320306 is_streaming = True ,
321307 return_generation_logits = True ,
322- use_chunked_logits = True ,
308+ use_chunked_generation_logits = True ,
323309 logits_chunk_size = 3 )
324310
325311 # Create non-streaming request
@@ -329,7 +315,7 @@ def test_streaming_vs_non_streaming_behavior(self, sample_logits):
329315 sampling_config = SamplingConfig (),
330316 is_streaming = False ,
331317 return_generation_logits = True ,
332- use_chunked_logits = True ,
318+ use_chunked_generation_logits = True ,
333319 logits_chunk_size = 3 )
334320
335321 # Add logits one by one
@@ -375,7 +361,7 @@ def test_memory_management(self, sample_logits):
375361 sampling_config = SamplingConfig (),
376362 is_streaming = False ,
377363 return_generation_logits = True ,
378- use_chunked_logits = True ,
364+ use_chunked_generation_logits = True ,
379365 logits_chunk_size = 2 ,
380366 return_logits_device_memory = False # Use host memory
381367 )
@@ -402,7 +388,7 @@ def test_large_sequence_handling(self):
402388 sampling_config = SamplingConfig (),
403389 is_streaming = False ,
404390 return_generation_logits = True ,
405- use_chunked_logits = True ,
391+ use_chunked_generation_logits = True ,
406392 logits_chunk_size = 10 )
407393
408394 # Add many logits
@@ -447,7 +433,7 @@ def get_memory_usage():
447433 sampling_config = SamplingConfig (),
448434 is_streaming = False ,
449435 return_generation_logits = True ,
450- use_chunked_logits = True ,
436+ use_chunked_generation_logits = True ,
451437 logits_chunk_size = 5 ,
452438 return_logits_device_memory = False )
453439
@@ -464,7 +450,7 @@ def get_memory_usage():
464450 sampling_config = SamplingConfig (),
465451 is_streaming = False ,
466452 return_generation_logits = True ,
467- use_chunked_logits = False ,
453+ use_chunked_generation_logits = False ,
468454 return_logits_device_memory = False )
469455
470456 for _ in range (50 ):
0 commit comments