Skip to content

Commit 0dbc477

Browse files
committed
add centered back
1 parent 5018abf commit 0dbc477

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/diffusers/models/unet_sde_score_estimation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
self,
230230
image_size=1024,
231231
num_channels=3,
232+
centered=False,
232233
attn_resolutions=(16,),
233234
ch_mult=(1, 2, 4, 8, 16, 32, 32, 32),
234235
conditional=True,
@@ -253,6 +254,7 @@ def __init__(
253254
self.register_to_config(
254255
image_size=image_size,
255256
num_channels=num_channels,
257+
centered=centered,
256258
attn_resolutions=attn_resolutions,
257259
ch_mult=ch_mult,
258260
conditional=conditional,
@@ -457,7 +459,8 @@ def forward(self, x, timesteps, sigmas=None):
457459
temb = None
458460

459461
# If input data is in [0, 1]
460-
x = 2 * x - 1.0
462+
if not self.config.centered:
463+
x = 2 * x - 1.0
461464

462465
# Downsampling block
463466
input_pyramid = None

0 commit comments

Comments
 (0)