33
44import numpy as np
55import torch
6- import torch .nn .functional as F
76import sklearn .manifold
87import transformers
98
@@ -14,40 +13,30 @@ def parse_arguments():
1413 parser .add_argument ("json" , default = False , help = "the path the json containing all papers." )
1514 parser .add_argument ("outpath" , default = False , help = "the target path of the visualizations papers." )
1615 parser .add_argument ("--seed" , default = 0 , help = "The seed for TSNE." , type = int )
17- parser .add_argument ("--model" , default = 'sentence-transformers/all-MiniLM-L6-v2' , help = "Name of the HF model" )
18-
1916 return parser .parse_args ()
2017
21- def mean_pooling (token_embeddings , attention_mask ):
22- """ Mean Pooling, takes attention mask into account for correct averaging"""
23- input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
24- return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
25-
2618
2719if __name__ == "__main__" :
2820 args = parse_arguments ()
29- tokenizer = transformers .AutoTokenizer .from_pretrained (args . model )
30- model = transformers .AutoModel .from_pretrained (args . model )
21+ tokenizer = transformers .AutoTokenizer .from_pretrained ("deepset/sentence_bert" )
22+ model = transformers .AutoModel .from_pretrained ("deepset/sentence_bert" )
3123 model .eval ()
3224
3325 with open (args .json ) as f :
3426 data = json .load (f )
3527
3628 print (f"Num papers: { len (data )} " )
3729
38- corpus = []
30+ all_embeddings = []
3931 for paper_info in data :
40- corpus .append (tokenizer .sep_token .join ([paper_info ['title' ], paper_info ['abstract' ]]))
41-
42- encoded_corpus = tokenizer (corpus , padding = True , truncation = True , return_tensors = 'pt' )
43- with torch .no_grad ():
44- hidden_states = model (** encoded_corpus ).last_hidden_state
45-
46- corpus_embeddings = mean_pooling (hidden_states , encoded_corpus ['attention_mask' ])
47- corpus_embeddings = F .normalize (corpus_embeddings , p = 2 , dim = 1 )
32+ with torch .no_grad ():
33+ token_ids = torch .tensor ([tokenizer .encode (paper_info ["abstract" ])][:512 ])
34+ hidden_states , _ = model (token_ids )[- 2 :]
35+ all_embeddings .append (hidden_states .mean (0 ).mean (0 ).numpy ())
4836
4937 np .random .seed (args .seed )
50- out = sklearn .manifold .TSNE (n_components = 2 , metric = "cosine" ).fit_transform (corpus_embeddings )
38+ all_embeddings = np .array (all_embeddings )
39+ out = sklearn .manifold .TSNE (n_components = 2 , metric = "cosine" ).fit_transform (all_embeddings )
5140
5241 for i , paper_info in enumerate (data ):
5342 paper_info ['tsne_embedding' ] = out [i ].tolist ()
0 commit comments