I have successfully trained on the C-ViViT. But when go to the step of training phenaki, they always met with an error with logits[mask_token_mask] on code line 665 of phenaki_pytorch.py. I have checked the shapes of mask_token_mask and logits, it seems that they should fit well. I am not sure what is going on or is that possible the problem is caused by GPU machine. It will be much appreciate if some one could help me. Much appreciation. I have attached the error below.

import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki, TokenCritic
cvivit = CViViT(
dim = 512,
codebook_size = 65536,
image_size = 256, # video with rectangular screen allowed
patch_size = 32,
temporal_patch_size = 2,
spatial_depth = 4,
temporal_depth = 4,
dim_head = 64,
heads = 8
)
#'/raid/camca/yl463/results/vae.9000.pt'
cvivit.load('/home/local/PARTNERS/yl463/Robot/results/vae.0.pt')
print("----------------------------load weights successfully!----------------------------")
maskgit = MaskGit(
num_tokens = 8000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
)
critic = TokenCritic(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
has_cross_attn = True
)
phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit
).cuda()
videos = torch.randn(3, 3, 11, 256, 256).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 11)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]
loss = phenaki(videos, texts = texts, video_frame_mask = mask) ##
print("----------------------------Training----------------------------")
loss.backward()
do the above for many steps, then ...
video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)
so in the paper, they do not really achieve 2 minutes of coherent video
at each new scene with new text conditioning, they condition on the previous K frames
you can easily achieve this with this framework as so
video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3
video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)
the total video
entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)
and so on...
I have successfully trained on the C-ViViT. But when go to the step of training phenaki, they always met with an error with logits[mask_token_mask] on code line 665 of phenaki_pytorch.py. I have checked the shapes of mask_token_mask and logits, it seems that they should fit well. I am not sure what is going on or is that possible the problem is caused by GPU machine. It will be much appreciate if some one could help me. Much appreciation. I have attached the error below.

import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki, TokenCritic
cvivit = CViViT(
dim = 512,
codebook_size = 65536,
image_size = 256, # video with rectangular screen allowed
patch_size = 32,
temporal_patch_size = 2,
spatial_depth = 4,
temporal_depth = 4,
dim_head = 64,
heads = 8
)
#'/raid/camca/yl463/results/vae.9000.pt'
cvivit.load('/home/local/PARTNERS/yl463/Robot/results/vae.0.pt')
print("----------------------------load weights successfully!----------------------------")
maskgit = MaskGit(
num_tokens = 8000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
)
critic = TokenCritic(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
has_cross_attn = True
)
phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit
).cuda()
videos = torch.randn(3, 3, 11, 256, 256).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 11)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]
loss = phenaki(videos, texts = texts, video_frame_mask = mask) ##
print("----------------------------Training----------------------------")
loss.backward()
do the above for many steps, then ...
video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)
so in the paper, they do not really achieve 2 minutes of coherent video
at each new scene with new text conditioning, they condition on the previous K frames
you can easily achieve this with this framework as so
video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3
video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)
the total video
entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)
and so on...