22import torch
33import yaml
44from modules .keypoint_detector import KPDetector
5- from modules .generator import OcclusionAwareGenerator
5+ from modules .generator_optim import OcclusionAwareGenerator
66from sync_batchnorm import DataParallelWithCallback
77import numpy as np
88import face_alignment
@@ -45,9 +45,6 @@ def load_checkpoints(self):
4545 generator .load_state_dict (checkpoint ['generator' ])
4646 kp_detector .load_state_dict (checkpoint ['kp_detector' ])
4747
48- generator = DataParallelWithCallback (generator )
49- kp_detector = DataParallelWithCallback (kp_detector )
50-
5148 generator .eval ()
5249 kp_detector .eval ()
5350
@@ -60,6 +57,14 @@ def set_source_image(self, source_image):
6057 self .source = to_tensor (source_image ).to (self .device )
6158 self .kp_source = self .kp_detector (self .source )
6259
60+ if self .enc_downscale > 1 :
61+ h , w = int (self .source .shape [2 ] / self .enc_downscale ), int (self .source .shape [3 ] / self .enc_downscale )
62+ source_enc = torch .nn .functional .interpolate (self .source , size = (h , w ), mode = 'bilinear' )
63+ else :
64+ source_enc = self .source
65+
66+ self .generator .encode_source (source_enc )
67+
6368 def predict (self , driving_frame ):
6469 with torch .no_grad ():
6570 driving = to_tensor (driving_frame ).to (self .device )
@@ -74,17 +79,7 @@ def predict(self, driving_frame):
7479 kp_driving_initial = self .kp_driving_initial , use_relative_movement = self .relative ,
7580 use_relative_jacobian = self .relative , adapt_movement_scale = self .adapt_movement_scale )
7681
77- if self .enc_downscale > 1 :
78- h , w = int (source .shape [2 ] / self .enc_downscale ), int (source .shape [3 ] / self .enc_downscale )
79- source_enc = torch .nn .functional .interpolate (source , size = (h , w ), mode = 'bilinear' )
80- else :
81- source_enc = None
82-
83- try :
84- out = self .generator (self .source , kp_source = self .kp_source , kp_driving = kp_norm , source_image_enc = source_enc , optim_ret = True )
85- except TypeError :
86- Once ('\n *** Please update FOMM:\n cd fomm\n git pull\n ' )
87- out = self .generator (self .source , kp_source = self .kp_source , kp_driving = kp_norm )
82+ out = self .generator (self .source , kp_source = self .kp_source , kp_driving = kp_norm )
8883
8984 out = np .transpose (out ['prediction' ].data .cpu ().numpy (), [0 , 2 , 3 , 1 ])[0 ]
9085 out = (np .clip (out , 0 , 1 ) * 255 ).astype (np .uint8 )
0 commit comments