We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 03c35c3 commit dac1da4Copy full SHA for dac1da4
1 file changed
evalplus/provider/hf.py
@@ -23,6 +23,7 @@ def __init__(
23
):
24
super().__init__(name=name, **kwargs)
25
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ self.device_map = device_map
27
28
kwargs = {
29
"device_map": device_map,
@@ -72,7 +73,7 @@ def codegen(
72
73
)
74
75
input_tokens = self.tokenizer.encode(prompt, return_tensors="pt")
- if device_map is None:
76
+ if self.device_map is None:
77
input_tokens.to(self.device)
78
kwargs = {}
79
if do_sample:
0 commit comments