Skip to content

Commit 295b366

Browse files
committed
comply with optimized generator
1 parent 997760f commit 295b366

File tree

1 file changed

+10
-15
lines changed

1 file changed

+10
-15
lines changed

afy/predictor_local.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import yaml
44
from modules.keypoint_detector import KPDetector
5-
from modules.generator import OcclusionAwareGenerator
5+
from modules.generator_optim import OcclusionAwareGenerator
66
from sync_batchnorm import DataParallelWithCallback
77
import numpy as np
88
import face_alignment
@@ -45,9 +45,6 @@ def load_checkpoints(self):
4545
generator.load_state_dict(checkpoint['generator'])
4646
kp_detector.load_state_dict(checkpoint['kp_detector'])
4747

48-
generator = DataParallelWithCallback(generator)
49-
kp_detector = DataParallelWithCallback(kp_detector)
50-
5148
generator.eval()
5249
kp_detector.eval()
5350

@@ -60,6 +57,14 @@ def set_source_image(self, source_image):
6057
self.source = to_tensor(source_image).to(self.device)
6158
self.kp_source = self.kp_detector(self.source)
6259

60+
if self.enc_downscale > 1:
61+
h, w = int(self.source.shape[2] / self.enc_downscale), int(self.source.shape[3] / self.enc_downscale)
62+
source_enc = torch.nn.functional.interpolate(self.source, size=(h, w), mode='bilinear')
63+
else:
64+
source_enc = self.source
65+
66+
self.generator.encode_source(source_enc)
67+
6368
def predict(self, driving_frame):
6469
with torch.no_grad():
6570
driving = to_tensor(driving_frame).to(self.device)
@@ -74,17 +79,7 @@ def predict(self, driving_frame):
7479
kp_driving_initial=self.kp_driving_initial, use_relative_movement=self.relative,
7580
use_relative_jacobian=self.relative, adapt_movement_scale=self.adapt_movement_scale)
7681

77-
if self.enc_downscale > 1:
78-
h, w = int(source.shape[2] / self.enc_downscale), int(source.shape[3] / self.enc_downscale)
79-
source_enc = torch.nn.functional.interpolate(source, size=(h, w), mode='bilinear')
80-
else:
81-
source_enc = None
82-
83-
try:
84-
out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm, source_image_enc=source_enc, optim_ret=True)
85-
except TypeError:
86-
Once('\n*** Please update FOMM:\ncd fomm\ngit pull\n')
87-
out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm)
82+
out = self.generator(self.source, kp_source=self.kp_source, kp_driving=kp_norm)
8883

8984
out = np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]
9085
out = (np.clip(out, 0, 1) * 255).astype(np.uint8)

0 commit comments

Comments
 (0)