We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fe691fe commit fdef40bCopy full SHA for fdef40b
src/diffusers/schedulers/scheduling_vq_diffusion.py
@@ -53,7 +53,7 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
53
Log onehot vectors
54
"""
55
batch_size, vector_length = x.shape
56
- log_x = torch.FloatTensor((batch_size, num_classes, vector_length), fill_value=1e-30, device=x.device)
+ log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device)
57
log_x.scatter_(index=x[:, None, :], src=0, dim=1)
58
return log_x
59
0 commit comments