a easy library to use teacache for speedup dit models, contain auto params solve. support different sequence length use special params to minimum loss.
you can use tea-cache to speedup your DIT models, and easily solve the parameters associated with the model weights.
if you already have params, you can use this mode
from EasyTeaCache import TeaCache
cache = TeaCache(
min_skip_step=2, # teacache can skip first step is index==1 (start from 0)
max_skip_step=48, # teacache can skip first step is index==48 (start from 0)
threshold=0.04,
model_keys=["mymodel","function"], # any strings to sign your model-weight, support any depth
cache_path="config/teacache/cache.json", # load config from here
)
# in transformer.forward
skip_blocks = False
if teacache is not None:
skip_blocks = teacache.check(
step=time_stemp_index,
t_mod=timestep_proj,
sequence_length=hidden_states.size(1),
)
if skip_blocks:
hidden_states = teacache.update(timestep_proj, hidden_states)
else:
input_hidden_states = hidden_states
# 4. Transformer blocks
for block in self.blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
encoder_hidden_states_mask,
)
if teacache is not None:
teacache.store_truth(
step=time_stemp_index,
t_mod=timestep_proj,
input_latent=input_hidden_states,
output_latent=hidden_states,
sequence_length=hidden_states.size(1),
)you can not use sp-parallel in this mode
from EasyTeaCache import TeaCache
cache = TeaCache(
min_skip_step=0,
max_skip_step=-1,
threshold=0.04,
model_keys=["mymodel","function"], # any strings to sign your model-weight, support any depth
cache_path="config/teacache/cache.json", # save config here
speedup_mode=False,
)
# in transformer.forward, same with speedup mode
skip_blocks = False
if teacache is not None:
skip_blocks = teacache.check(
step=time_stemp_index,
t_mod=timestep_proj,
sequence_length=hidden_states.size(1),
)
if skip_blocks:
hidden_states = teacache.update(timestep_proj, hidden_states)
else:
input_hidden_states = hidden_states
# 4. Transformer blocks
for block in self.blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
encoder_hidden_states_mask,
)
if teacache is not None:
teacache.store_truth(
step=time_stemp_index,
t_mod=timestep_proj,
input_latent=input_hidden_states,
output_latent=hidden_states,
sequence_length=hidden_states.size(1),
)