Skip to content

Commit c758ce7

Browse files
committed
Update files
1 parent b62066a commit c758ce7

File tree

3 files changed

+82
-32
lines changed

3 files changed

+82
-32
lines changed

petagraph/configs/config_petagraph_dev.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
checkpoints:
22
checkpoint_interval: 100
3-
checkpoints_path: /users/burgerm/petagraph/logs/dev/checkpoints
3+
checkpoints_path: /users/burgerm/petagraph/logs/transcriptomics/base_ntp/checkpoints
44
checkpoints_path_is_shared_file_system: true
55
resume_checkpoint_path: null
66
save_initial_state: false
@@ -9,7 +9,7 @@ data_stages:
99
dataset: null # Custom dataloader will be used
1010
num_loading_workers: 0
1111
seed: 42
12-
sequence_files_path: "/users/burgerm/petagraph/resources/training_sets/dev_wgs_fungi_2022.csv"
12+
sequence_files_path: "/users/burgerm/petagraph/resources/training_sets/transcriptomics/eukaryota_transcriptomics_500_400_2024-10-03_09-33-14.csv"
1313
all_sequences_resources_path: "/users/burgerm/petagraph/resources"
1414
prefetch_buffer_seq_size: 2048
1515
name: Stable Training Stage
@@ -90,5 +90,5 @@ tokens:
9090
limit_val_batches: 0
9191
micro_batch_size: 128
9292
sequence_length: 1024
93-
train_steps: 1000
93+
train_steps: 2000
9494
val_check_interval: -1

src/nanotron/data/petagraph_dataset.py

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,10 @@ def __iter__(self) -> dict[str, np.ndarray]:
522522

523523

524524

525-
526-
527-
525+
STRING_COMPLEMENT_MAP = {
526+
"A": "T", "C": "G", "G": "C", "T": "A", "a": "t", "c": "g", "g": "c", "t": "a",
527+
"N": "N", "n": "n",
528+
}
528529

529530

530531
class PetaGraphStreamDatasetV2(torch.utils.data.IterableDataset):
@@ -553,7 +554,8 @@ def __init__(self,
553554
rank: int = 0,
554555
packed: bool = False,
555556
sampling_seq_len_inflection: int = 1024,
556-
reverse_probability: float = 0.0
557+
reverse_probability: float = 0.0,
558+
build_graph: bool = False
557559
):
558560

559561
self.maxlen = maxlen
@@ -573,6 +575,13 @@ def __init__(self,
573575
self.logging_func(f"[PetaGraphStreamDataset] Sampling Seq. Len. Inflection: {self.sampling_seq_len_inflection}")
574576
if self.reverse_probability > 0.0:
575577
self.logging_func(f"[PetaGraphStreamDataset] Reverse Probability: {self.reverse_probability}")
578+
self.logging_func(f"[PetaGraphStreamDataset] Computing reverse complements for some sequences.")
579+
580+
self.build_graph = build_graph
581+
if self.build_graph:
582+
self.logging_func(f"[PetaGraphStreamDataset] Building sequence graph and sample random walks to increase seq. length")
583+
else:
584+
self.logging_func(f"[PetaGraphStreamDataset] Not building sequence graph")
576585

577586
self.VOCAB = vocabulary
578587
self._pad_token_id = self.VOCAB["PAD"]
@@ -789,8 +798,13 @@ def random_walk_graph_sequences(graph, sequences, k_mer: int = 31) -> list[str]:
789798

790799
def length_sampling_filter(self, sequence: str) -> bool:
791800
seq_len = len(sequence)
801+
802+
# Keep all sequences above the inflection point
792803
if seq_len >= self.sampling_seq_len_inflection:
793804
return True
805+
806+
# Below the inflection point we sample sequences
807+
# with a probability that is proportional to the sequence length
794808
else:
795809
prob = np.random.rand()
796810
if prob < (seq_len / self.sampling_seq_len_inflection):
@@ -799,7 +813,7 @@ def length_sampling_filter(self, sequence: str) -> bool:
799813
return False
800814

801815

802-
def fasta_parsing_func(self, input_data: Tuple[str, bytes]):
816+
def fasta_parsing_func(self, input_data: Tuple[str, bytes]) -> deque[tuple[str, ...]]:
803817
"""Parse the fasta data and return the sequences
804818
805819
Parameters
@@ -809,30 +823,39 @@ def fasta_parsing_func(self, input_data: Tuple[str, bytes]):
809823
"""
810824
path, data = input_data
811825
if data is None:
812-
return [("", "")]
826+
return deque([(path, "")])
813827

814828
sequences = []
815829
decoded_lines = data.decode()
816830
sequences = [str(s.seq) for s in SeqIO.parse(StringIO(decoded_lines), "fasta")]
831+
loaded_length = len(sequences)
817832

818833
# Following DNA-BERTv2: https://arxiv.org/pdf/2306.15006
819834
# Zhou et al.: "We exclude all sequences with N and retain only sequences that consist of A, T, C, and G.
820835
sequences = [s for s in sequences if set(s).issubset(ALPHABET)]
836+
after_alphabet_filter_length = len(sequences)
837+
838+
if self.build_graph:
839+
# Chop sequences in preparation for graph traversal
840+
sequences = [self.chop_at_first_repeated_kmer(s, k=KMER_LENGTH) for s in sequences]
821841

822-
# Chop sequences in preparation for graph traversal
823-
sequences = [self.chop_at_first_repeated_kmer(s, k=KMER_LENGTH) for s in sequences]
842+
# Construct sequence graph and perform random walks
843+
sequences_arr = np.array(sequences)
844+
sequence_graph = self.find_overlaps_and_build_graph(sequences_arr, k_mer=KMER_LENGTH)
845+
random_walk_sequences = self.random_walk_graph_sequences(sequence_graph, sequences_arr, k_mer=KMER_LENGTH)
846+
sequences = random_walk_sequences
824847

825-
# Construct sequence graph and perform random walks
826-
sequences_arr = np.array(sequences)
827-
sequence_graph = self.find_overlaps_and_build_graph(sequences_arr, k_mer=KMER_LENGTH)
828-
random_walk_sequences = self.random_walk_graph_sequences(sequence_graph, sequences_arr, k_mer=KMER_LENGTH)
848+
# Sample sequences for training based on length
849+
keep_sequences = [(path, s) for s in filter(self.length_sampling_filter, sequences)]
850+
after_length_filter_length = len(keep_sequences)
829851

830-
# Sample sequences for training
831-
keep_sequences = [(path, s) for s in filter(self.length_sampling_filter, random_walk_sequences)]
852+
# Log how many sequences were parsed
853+
log_msg = f"[PetaGraphStreamDataset:{self.rank}] Parsed {loaded_length} > {after_alphabet_filter_length} > {after_length_filter_length} sequences from {path}"
854+
log_rank(log_msg, logger=self.logger, level=logging.INFO, rank=self.rank)
832855

833856
# Test outputs
834857
if len(keep_sequences) == 0:
835-
return [("", "")]
858+
return deque([(path, "")])
836859

837860
assert isinstance(keep_sequences, list)
838861
assert isinstance(keep_sequences[0], tuple) and len(keep_sequences[0]) == 2
@@ -841,7 +864,7 @@ def fasta_parsing_func(self, input_data: Tuple[str, bytes]):
841864
# Shuffle the sequences
842865
random.shuffle(keep_sequences)
843866

844-
return keep_sequences
867+
return deque(keep_sequences)
845868

846869
def crop_maxlen(self, input_sequence: str, maxlen: int = None):
847870
# path, input_sequence = input_data
@@ -864,9 +887,10 @@ def tokenize_and_pad(self, input_sequence: str, apply_pad: bool = True):
864887
tokenized_sequence.append(self._eos_token_id) # end with EOS token
865888
tokenized_sequence = np.array(tokenized_sequence, dtype=np.int32)
866889

867-
if self.reverse_probability > 0.0:
868-
if np.random.rand() < self.reverse_probability:
869-
tokenized_sequence = tokenized_sequence[::-1]
890+
# No longer done here, done in the `generate` method, 5th Feb 2025
891+
# if self.reverse_probability > 0.0:
892+
# if np.random.rand() < self.reverse_probability:
893+
# tokenized_sequence = tokenized_sequence[::-1]
870894

871895
# Pad the sequence
872896
if apply_pad and len(tokenized_sequence) < maxlen:
@@ -880,7 +904,8 @@ def tokenize_and_pad(self, input_sequence: str, apply_pad: bool = True):
880904

881905
def generate(self):
882906
current_tokens = None
883-
current_sequences = []
907+
current_sequences = deque()
908+
last_reversed = False
884909
while True:
885910
try:
886911

@@ -906,14 +931,25 @@ def generate(self):
906931

907932
current_sequences = self.fasta_parsing_func((source_path, decompressed_data))
908933

909-
# Remove the first sequence
910-
source_path, text_raw = current_sequences.pop(0)
911-
if text_raw is None or len(text_raw) == 0:
912-
continue
913934

914-
# Log the consumed sequences
915-
with self.num_consumed_sequences.get_lock():
916-
self.num_consumed_sequences.value += 1
935+
# We're performing reverse complementation augmentations
936+
if self.reverse_probability > 0.0:
937+
# Apply the augmentation at random and only once per sequence
938+
if np.random.rand() < self.reverse_probability and not last_reversed:
939+
last_reversed = True
940+
# Just read the first sequence, but don't pop it
941+
# Next iteration will read the same sequence again
942+
source_path, text_raw = current_sequences[0]
943+
text_raw = "".join([STRING_COMPLEMENT_MAP[base] for base in text_raw[::-1]])
944+
945+
else:
946+
last_reversed = False
947+
source_path, text_raw = current_sequences.popleft()
948+
949+
# No rev. comp. augmentations
950+
else:
951+
source_path, text_raw = current_sequences.popleft()
952+
917953

918954
# Log the consumed files
919955
if self.log_directory is not None:
@@ -930,6 +966,13 @@ def generate(self):
930966
self.logging_func(f"Epoch {self.current_epoch} completed")
931967
self.consumed_files = set()
932968

969+
if text_raw is None or len(text_raw) == 0:
970+
continue
971+
972+
# Log the consumed sequences
973+
with self.num_consumed_sequences.get_lock():
974+
self.num_consumed_sequences.value += 1
975+
933976
except StopIteration as e:
934977
self.logger.warning(f"Reached end of dataset: {e}")
935978

@@ -975,7 +1018,7 @@ def generate(self):
9751018
else:
9761019
# Check the last token of the current sequence
9771020
# is an EOS token or BOS token (if reverse_probability > 0.0)
978-
assert current_tokens[-1] == self._eos_token_id or (self.reverse_probability > 0.0 and current_tokens[-1] == self._bos_token_id)
1021+
assert current_tokens[-1] == self._eos_token_id
9791022
current_tokens = np.concatenate([current_tokens, new_tokens])
9801023

9811024
if len(current_tokens) >= self.maxlen:

src/nanotron/optim/gradient_accumulator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,14 @@ def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
291291

292292
with torch.inference_mode():
293293
for name, elt in self.parameters.items():
294-
elt["fp32"].copy_(state_dict[name])
294+
param = state_dict[name]
295+
if len(param) != len(elt["fp32"]):
296+
logger.warning(
297+
f"Expected {name} to have the same size as {elt['fp32'].size()}, but got {param.size()}"
298+
)
299+
elt["fp32"].copy_(param[: len(elt["fp32"])])
300+
else:
301+
elt["fp32"].copy_(param)
295302

296303

297304
@dataclasses.dataclass

0 commit comments

Comments
 (0)