Skip to content

Commit 34f4a4b

Browse files
committed
<minor> add attn_quant for opensora, the cross_attn still buggy
1 parent cb00633 commit 34f4a4b

File tree

9 files changed

+647
-62
lines changed

9 files changed

+647
-62
lines changed

examples/opensora1.2/configs/config.yaml

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
model:
22
model_id: "opensora"
33
model_type: 'opensora' # ['sd','sdxl']
4+
remain_fp_regex: embedder|adaLN_modulation|t_block
45
calib_data:
56
weight:
67
n_bits: 8
@@ -10,5 +11,28 @@ act:
1011
n_bits: 8
1112
group: 'token' # DIRTY: it is not used, forced as "token-wise"
1213
sym: True
13-
remain_fp_regex: embedder|adaLN_modulation|t_block
14+
attn:
15+
qk:
16+
n_bits: 8
17+
reorder_file_path:
18+
v:
19+
n_bits: 8
20+
attn_map: # V*attn_map_post_softmax quantized
21+
n_bits: 8
22+
sym: True
23+
group: 'row'
24+
# level_2: True
25+
# int8_scale: False
26+
# mixed_precision_cfg_path: './visualization/attn_map_mixed_precision.pt'
27+
cross_attn:
28+
q:
29+
n_bits: 8
30+
sym: True
31+
kv:
32+
n_bits: 8
33+
sym: True
34+
# attn_map:
35+
# n_bits: 8
36+
# sym: True
37+
1438

examples/opensora1.2/configs/sample.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
frame_interval = 1
66
save_fps = 24
77
ptq_config='./configs/config.yaml'
8-
save_dir = "./logs/int8_linear"
9-
seed = 1024
8+
save_dir = "./logs/attn_int8_naive"
9+
seed = 114514
1010
batch_size = 1
1111
multi_resolution = "STDiT2"
1212
dtype = "bf16"
@@ -18,7 +18,7 @@
1818
# from_pretrained="/share/public/zhuhongyu/hpcai-tech/OpenSora-STDiT-v3",
1919
from_pretrained="/home/zhaotianchen/models/hpcai-tech/OpenSora-STDiT-v3",
2020
qk_norm=True,
21-
enable_flash_attn=True,
21+
enable_flash_attn=False,
2222
enable_layernorm_kernel=False, # didnot install apex
2323
)
2424
vae = dict(

0 commit comments

Comments
 (0)