-
Notifications
You must be signed in to change notification settings - Fork 266
Expand file tree
/
Copy pathutils.py
More file actions
58 lines (52 loc) · 1.81 KB
/
utils.py
File metadata and controls
58 lines (52 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import annotations
from dataclasses import dataclass
"""Internal constants."""
_SEQUENCE_MODELS: tuple[str, ...] = (
"hash2vec",
"word2vec",
)
_RW_MODELS: tuple[str, ...] = ("rw_with_restart",)
_HASH2VEC_DECAY_FUNCTIONS: tuple[str, ...] = (
"gaussian",
"constant",
)
@dataclass
class _RandomWalksEmbeddingsParameters:
"""Internal class represents all possible parameters with defaults."""
use_edge_direction: bool = False
rw_model: str = "rw_with_restart"
rw_max_nbrs: int = 50
rw_num_walks_per_node: int = 5
rw_batch_size: int = 10
rw_num_batches: int = 5
rw_seed: int = 42
rw_restart_probability: float = 0.1
rw_temporary_prefix: str = ""
rw_cached_walks: str = ""
sequence_model: str = "hash2vec"
hash2vec_context_size: int = 5
hash2vec_num_partitions: int = 5
hash2vec_embeddings_dim: int = 512
hash2vec_decay_function: str = "gaussian"
hash2vec_gaussian_sigma: float = 1.0
hash2vec_hashing_seed: int = 42
hash2vec_sign_seed: int = 18
hash2vec_do_l2_norm: bool = True
hash2vec_safe_l2: bool = True
word2vec_max_iter: int = 1
word2vec_embeddings_dim: int = 100
word2vec_window_size: int = 5
word2vec_num_partitions: int = 1
word2vec_min_count: int = 5
word2vec_max_sentence_length: int = 1000
word2vec_seed: int = 42
word2vec_step_size: float = 0.025
aggregate_neighbors: bool = True
aggregate_neighbors_max_nbrs: int = 50
aggregate_neighbors_seed: int = 42
clean_up_after_run: bool = False
def validate(self) -> None:
if self.sequence_model not in _SEQUENCE_MODELS:
raise ValueError(f"supported seq2vec models are {str(_SEQUENCE_MODELS)}")
if self.rw_model not in _RW_MODELS:
raise ValueError(f"supported RW models are {str(_RW_MODELS)}")