-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathvisualize.py
More file actions
87 lines (67 loc) · 3.03 KB
/
visualize.py
File metadata and controls
87 lines (67 loc) · 3.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import sys
import pprint
import torch
from torchdrug import core
from torchdrug.utils import comm
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from nbfnet import dataset, layer, model, task, util
vocab_file = os.path.join(os.path.dirname(__file__), "../data/fb15k237_entity.txt")
vocab_file = os.path.abspath(vocab_file)
def load_vocab(dataset):
entity_mapping = {}
with open(vocab_file, "r") as fin:
for line in fin:
k, v = line.strip().split("\t")
entity_mapping[k] = v
entity_vocab = [entity_mapping[t] for t in dataset.entity_vocab]
relation_vocab = ["%s (%d)" % (t[t.rfind("/") + 1:].replace("_", " "), i)
for i, t in enumerate(dataset.relation_vocab)]
return entity_vocab, relation_vocab
def visualize_path(solver, triplet, entity_vocab, relation_vocab):
num_relation = len(relation_vocab)
h, t, r = triplet.tolist()
triplet = torch.as_tensor([[h, t, r]], device=solver.device)
inverse = torch.as_tensor([[t, h, r + num_relation]], device=solver.device)
solver.model.eval()
pred, (mask, target) = solver.model.predict_and_target(triplet)
pos_pred = pred.gather(-1, target.unsqueeze(-1))
rankings = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1
rankings = rankings.squeeze(0)
logger.warning("")
samples = (triplet, inverse)
for sample, ranking in zip(samples, rankings):
h, t, r = sample.squeeze(0).tolist()
h_name = entity_vocab[h]
t_name = entity_vocab[t]
r_name = relation_vocab[r % num_relation]
if r >= num_relation:
r_name += "^(-1)"
logger.warning(">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>")
logger.warning("rank(%s | %s, %s) = %g" % (t_name, h_name, r_name, ranking))
paths, weights = solver.model.visualize(sample)
for path, weight in zip(paths, weights):
triplets = []
for h, t, r in path:
h_name = entity_vocab[h]
t_name = entity_vocab[t]
r_name = relation_vocab[r % num_relation]
if r >= num_relation:
r_name += "^(-1)"
triplets.append("<%s, %s, %s>" % (h_name, r_name, t_name))
logger.warning("weight: %g\n\t%s" % (weight, " ->\n\t".join(triplets)))
if __name__ == "__main__":
args, vars = util.parse_args()
cfg = util.load_config(args.config, context=vars)
working_dir = util.create_working_directory(cfg)
torch.manual_seed(args.seed + comm.get_rank())
logger = util.get_root_logger()
logger.warning("Config file: %s" % args.config)
logger.warning(pprint.pformat(cfg))
if cfg.dataset["class"] != "FB15k237":
raise ValueError("Visualization is only implemented for FB15k237")
dataset = core.Configurable.load_config_dict(cfg.dataset)
solver = util.build_solver(cfg, dataset)
entity_vocab, relation_vocab = load_vocab(dataset)
for i in range(500):
visualize_path(solver, solver.test_set[i], entity_vocab, relation_vocab)