-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathoptimizer.py
More file actions
51 lines (36 loc) · 1.66 KB
/
optimizer.py
File metadata and controls
51 lines (36 loc) · 1.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn
import torch.nn.functional as F
import smplx
import numpy as np
from scipy.spatial import cKDTree
from params import ParamsDiffusion
from loss import loss_contact, loss_floor
from utils import smplx_utils
def callback_factory(optimizer, guidenet):
grad_last = 0
def callback(x, cond, label, t):
nonlocal grad_last
with torch.enable_grad():
x = x.detach().requires_grad_(True)
x_params = smplx_utils.decode(x, return_dict=True)
cond_params = smplx_utils.decode(cond, return_dict=True)
x_smplx = smplx_utils.smplx(**x_params)
cond_smplx = smplx_utils.smplx(**cond_params)
signature_pred, x_segmentation_pred, cond_segmentation_pred = guidenet(x_smplx, cond_smplx, label, t)
sigidx = guidenet.sigmark2sigidx(guidenet.sig2sigmark(signature_pred, x_segmentation_pred, cond_segmentation_pred))
l_cnt = loss_contact(x_smplx, cond_smplx, sigidx) #, t[0].item() / )
l_flr = loss_floor(x_smplx)
grad_cnt = torch.autograd.grad(l_cnt, x, retain_graph=True)[0]
# grad_col = torch.autograd.grad(l_cls, x, retain_graph=True)[0]
grad_flr = torch.autograd.grad(l_flr, x, retain_graph=True)[0]
grad = torch.zeros_like(grad_cnt)
# without momentum
grad[:,:3] = grad[:,:3] + 0.2 * grad_flr[:,:3]
grad[:,:9] = grad[:,:9] + 0.0001 * grad_cnt[:,:9]
grad[:,9:] = grad[:,9:] + 1.0 * grad_cnt[:,9:]
outdict = {
'sigidx': sigidx
}
return grad, outdict
return callback