Skip to content

Meet an error when training the phenaki #41

@levyisthebest

Description

@levyisthebest

I have successfully trained on the C-ViViT. But when go to the step of training phenaki, they always met with an error with logits[mask_token_mask] on code line 665 of phenaki_pytorch.py. I have checked the shapes of mask_token_mask and logits, it seems that they should fit well. I am not sure what is going on or is that possible the problem is caused by GPU machine. It will be much appreciate if some one could help me. Much appreciation. I have attached the error below.
0986916b26528d1c614ee5e919bf2d3
import torch
from phenaki_pytorch import CViViT, MaskGit, Phenaki, TokenCritic

cvivit = CViViT(
dim = 512,
codebook_size = 65536,
image_size = 256, # video with rectangular screen allowed
patch_size = 32,
temporal_patch_size = 2,
spatial_depth = 4,
temporal_depth = 4,
dim_head = 64,
heads = 8
)
#'/raid/camca/yl463/results/vae.9000.pt'
cvivit.load('/home/local/PARTNERS/yl463/Robot/results/vae.0.pt')
print("----------------------------load weights successfully!----------------------------")
maskgit = MaskGit(
num_tokens = 8000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
)

critic = TokenCritic(

num_tokens = 5000,

max_seq_len = 1024,

dim = 512,

dim_context = 768,

depth = 6,

has_cross_attn = True

)

phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit
).cuda()

videos = torch.randn(3, 3, 11, 256, 256).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 11)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch

texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]

loss = phenaki(videos, texts = texts, video_frame_mask = mask) ##
print("----------------------------Training----------------------------")
loss.backward()

do the above for many steps, then ...

video = phenaki.sample(texts = 'a squirrel examines an acorn', num_frames = 17, cond_scale = 5.) # (1, 3, 17, 256, 128)

so in the paper, they do not really achieve 2 minutes of coherent video

at each new scene with new text conditioning, they condition on the previous K frames

you can easily achieve this with this framework as so

video_prime = video[:, :, -3:] # (1, 3, 3, 256, 128) # say K = 3

video_next = phenaki.sample(texts = 'a cat watches the squirrel from afar', prime_frames = video_prime, num_frames = 14) # (1, 3, 14, 256, 128)

the total video

entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 128)

and so on...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions