Skip to content

Commit dac1da4

Browse files
authored
Update hf.py (#284)
1 parent 03c35c3 commit dac1da4

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

evalplus/provider/hf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
):
2424
super().__init__(name=name, **kwargs)
2525
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26+
self.device_map = device_map
2627

2728
kwargs = {
2829
"device_map": device_map,
@@ -72,7 +73,7 @@ def codegen(
7273
)
7374
)
7475
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt")
75-
if device_map is None:
76+
if self.device_map is None:
7677
input_tokens.to(self.device)
7778
kwargs = {}
7879
if do_sample:

0 commit comments

Comments
 (0)