Skip to content

Commit e3d4fa6

Browse files
Fix continuous batching tests (#42012)
* Fix continuous batching tests * make fixup
1 parent dd4e048 commit e3d4fa6

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

tests/generation/test_continuous_batching.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -350,9 +350,9 @@ def test_streaming_request(self) -> None:
350350

351351
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
352352

353-
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
354-
model.device
355-
)[0]
353+
inputs = tokenizer.apply_chat_template(
354+
messages, return_tensors="pt", add_generation_prompt=True, return_dict=False
355+
).to(model.device)[0]
356356

357357
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=True)
358358

@@ -382,9 +382,9 @@ def test_non_streaming_request(self) -> None:
382382

383383
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
384384

385-
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
386-
model.device
387-
)[0]
385+
inputs = tokenizer.apply_chat_template(
386+
messages, return_tensors="pt", add_generation_prompt=True, return_dict=False
387+
).to(model.device)[0]
388388

389389
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)
390390

@@ -409,9 +409,9 @@ def test_streaming_and_non_streaming_requests_can_alternate(self) -> None:
409409

410410
messages = [{"content": "What is the Transformers library known for?", "role": "user"}]
411411

412-
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(
413-
model.device
414-
)[0]
412+
inputs = tokenizer.apply_chat_template(
413+
messages, return_tensors="pt", add_generation_prompt=True, return_dict=False
414+
).to(model.device)[0]
415415

416416
# Non-streaming request
417417
request_id = manager.add_request(inputs, max_new_tokens=max_new_tokens, streaming=False)

0 commit comments

Comments
 (0)