Skip to content

Commit fe0661f

Browse files
authored
Revert "tsne vis: change the model & embeddings"
This reverts commit edb4eb5.
1 parent edb4eb5 commit fe0661f

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

etc/compute_embeddings.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import numpy as np
55
import torch
6-
import torch.nn.functional as F
76
import sklearn.manifold
87
import transformers
98

@@ -14,40 +13,30 @@ def parse_arguments():
1413
parser.add_argument("json", default=False, help="the path the json containing all papers.")
1514
parser.add_argument("outpath", default=False, help="the target path of the visualizations papers.")
1615
parser.add_argument("--seed", default=0, help="The seed for TSNE.", type=int)
17-
parser.add_argument("--model", default='sentence-transformers/all-MiniLM-L6-v2', help="Name of the HF model")
18-
1916
return parser.parse_args()
2017

21-
def mean_pooling(token_embeddings, attention_mask):
22-
""" Mean Pooling, takes attention mask into account for correct averaging"""
23-
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
24-
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
25-
2618

2719
if __name__ == "__main__":
2820
args = parse_arguments()
29-
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model)
30-
model = transformers.AutoModel.from_pretrained(args.model)
21+
tokenizer = transformers.AutoTokenizer.from_pretrained("deepset/sentence_bert")
22+
model = transformers.AutoModel.from_pretrained("deepset/sentence_bert")
3123
model.eval()
3224

3325
with open(args.json) as f:
3426
data = json.load(f)
3527

3628
print(f"Num papers: {len(data)}")
3729

38-
corpus = []
30+
all_embeddings = []
3931
for paper_info in data:
40-
corpus.append(tokenizer.sep_token.join([paper_info['title'], paper_info['abstract']]))
41-
42-
encoded_corpus = tokenizer(corpus, padding=True, truncation=True, return_tensors='pt')
43-
with torch.no_grad():
44-
hidden_states = model(**encoded_corpus).last_hidden_state
45-
46-
corpus_embeddings = mean_pooling(hidden_states, encoded_corpus['attention_mask'])
47-
corpus_embeddings = F.normalize(corpus_embeddings, p=2, dim=1)
32+
with torch.no_grad():
33+
token_ids = torch.tensor([tokenizer.encode(paper_info["abstract"])][:512])
34+
hidden_states, _ = model(token_ids)[-2:]
35+
all_embeddings.append(hidden_states.mean(0).mean(0).numpy())
4836

4937
np.random.seed(args.seed)
50-
out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(corpus_embeddings)
38+
all_embeddings = np.array(all_embeddings)
39+
out = sklearn.manifold.TSNE(n_components=2, metric="cosine").fit_transform(all_embeddings)
5140

5241
for i, paper_info in enumerate(data):
5342
paper_info['tsne_embedding'] = out[i].tolist()

0 commit comments

Comments
 (0)