-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathinference.py
More file actions
87 lines (67 loc) · 2.82 KB
/
inference.py
File metadata and controls
87 lines (67 loc) · 2.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import sys
import pprint
import argparse
import numpy as np
import torch
from torchvision import datasets
from torch_geometric.data import dataset
import torchdrug
from torch import nn
from torchdrug.patch import patch
from torchdrug import core, datasets, tasks, models, layers
from torchdrug.utils import comm
patch(nn, "Module", nn._Module)
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from diffpack import util, dataset, task
from diffpack.engine import DiffusionEngine
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-c", "--config", help="yaml configuration file",
default="config/inference.yaml")
parser.add_argument("--seed", help="random seed", type=int, default=0)
parser.add_argument("-o", "--output_dir", help="output directory", default="output")
parser.add_argument("-f", "--pdb_files", help="list of pdb files", nargs='*', default=[])
args = parser.parse_known_args()[0]
args.output_dir = os.path.expanduser(args.output_dir)
args.output_dir = os.path.realpath(args.output_dir)
return args
def set_seed(seed):
torch.manual_seed(seed + comm.get_rank())
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
return
def build_solver(cfg, logger):
task = core.Configurable.load_config_dict(cfg.task)
# build solver
solver = DiffusionEngine(task, None, None, None, None, None, **cfg.engine)
if "checkpoint" in cfg:
solver.load(cfg.checkpoint, load_optimizer=cfg.get("load_optimizer", False))
if "model_checkpoint" in cfg:
model_checkpoint = os.path.expanduser(cfg.model_checkpoint)
model_dict = torch.load(model_checkpoint, map_location=torch.device('cpu'))["model"]
missing_keys, unexpected_keys = task.load_state_dict(model_dict, strict=False)
# Calculate the parameter number of the model
if comm.get_rank() == 0:
logger.warning("#parameter: %d" % sum(p.numel() for p in task.parameters() if p.requires_grad))
return solver
if __name__ == "__main__":
args = parse_args()
args.config = os.path.realpath(args.config)
cfg = util.load_config(args.config)
cfg.test_set.pdb_files = args.pdb_files
set_seed(args.seed)
logger = util.get_root_logger()
if comm.get_rank() == 0:
logger.warning("Config file: %s" % args.config)
logger.warning(pprint.pformat(cfg))
logger.warning("Output dir: %s" % args.output_dir)
solver = build_solver(cfg, logger)
test_set = core.Configurable.load_config_dict(cfg.test_set)
solver.generate(test_set=test_set, path=args.output_dir)