@@ -522,9 +522,10 @@ def __iter__(self) -> dict[str, np.ndarray]:
522522
523523
524524
525-
526-
527-
525+ STRING_COMPLEMENT_MAP = {
526+ "A" : "T" , "C" : "G" , "G" : "C" , "T" : "A" , "a" : "t" , "c" : "g" , "g" : "c" , "t" : "a" ,
527+ "N" : "N" , "n" : "n" ,
528+ }
528529
529530
530531class PetaGraphStreamDatasetV2 (torch .utils .data .IterableDataset ):
@@ -553,7 +554,8 @@ def __init__(self,
553554 rank : int = 0 ,
554555 packed : bool = False ,
555556 sampling_seq_len_inflection : int = 1024 ,
556- reverse_probability : float = 0.0
557+ reverse_probability : float = 0.0 ,
558+ build_graph : bool = False
557559 ):
558560
559561 self .maxlen = maxlen
@@ -573,6 +575,13 @@ def __init__(self,
573575 self .logging_func (f"[PetaGraphStreamDataset] Sampling Seq. Len. Inflection: { self .sampling_seq_len_inflection } " )
574576 if self .reverse_probability > 0.0 :
575577 self .logging_func (f"[PetaGraphStreamDataset] Reverse Probability: { self .reverse_probability } " )
578+ self .logging_func (f"[PetaGraphStreamDataset] Computing reverse complements for some sequences." )
579+
580+ self .build_graph = build_graph
581+ if self .build_graph :
582+ self .logging_func (f"[PetaGraphStreamDataset] Building sequence graph and sample random walks to increase seq. length" )
583+ else :
584+ self .logging_func (f"[PetaGraphStreamDataset] Not building sequence graph" )
576585
577586 self .VOCAB = vocabulary
578587 self ._pad_token_id = self .VOCAB ["PAD" ]
@@ -789,8 +798,13 @@ def random_walk_graph_sequences(graph, sequences, k_mer: int = 31) -> list[str]:
789798
790799 def length_sampling_filter (self , sequence : str ) -> bool :
791800 seq_len = len (sequence )
801+
802+ # Keep all sequences above the inflection point
792803 if seq_len >= self .sampling_seq_len_inflection :
793804 return True
805+
806+ # Below the inflection point we sample sequences
807+ # with a probability that is proportional to the sequence length
794808 else :
795809 prob = np .random .rand ()
796810 if prob < (seq_len / self .sampling_seq_len_inflection ):
@@ -799,7 +813,7 @@ def length_sampling_filter(self, sequence: str) -> bool:
799813 return False
800814
801815
802- def fasta_parsing_func (self , input_data : Tuple [str , bytes ]):
816+ def fasta_parsing_func (self , input_data : Tuple [str , bytes ]) -> deque [ tuple [ str , ...]] :
803817 """Parse the fasta data and return the sequences
804818
805819 Parameters
@@ -809,30 +823,39 @@ def fasta_parsing_func(self, input_data: Tuple[str, bytes]):
809823 """
810824 path , data = input_data
811825 if data is None :
812- return [( "" , "" )]
826+ return deque ([( path , "" )])
813827
814828 sequences = []
815829 decoded_lines = data .decode ()
816830 sequences = [str (s .seq ) for s in SeqIO .parse (StringIO (decoded_lines ), "fasta" )]
831+ loaded_length = len (sequences )
817832
818833 # Following DNA-BERTv2: https://arxiv.org/pdf/2306.15006
819834 # Zhou et al.: "We exclude all sequences with N and retain only sequences that consist of A, T, C, and G.
820835 sequences = [s for s in sequences if set (s ).issubset (ALPHABET )]
836+ after_alphabet_filter_length = len (sequences )
837+
838+ if self .build_graph :
839+ # Chop sequences in preparation for graph traversal
840+ sequences = [self .chop_at_first_repeated_kmer (s , k = KMER_LENGTH ) for s in sequences ]
821841
822- # Chop sequences in preparation for graph traversal
823- sequences = [self .chop_at_first_repeated_kmer (s , k = KMER_LENGTH ) for s in sequences ]
842+ # Construct sequence graph and perform random walks
843+ sequences_arr = np .array (sequences )
844+ sequence_graph = self .find_overlaps_and_build_graph (sequences_arr , k_mer = KMER_LENGTH )
845+ random_walk_sequences = self .random_walk_graph_sequences (sequence_graph , sequences_arr , k_mer = KMER_LENGTH )
846+ sequences = random_walk_sequences
824847
825- # Construct sequence graph and perform random walks
826- sequences_arr = np .array (sequences )
827- sequence_graph = self .find_overlaps_and_build_graph (sequences_arr , k_mer = KMER_LENGTH )
828- random_walk_sequences = self .random_walk_graph_sequences (sequence_graph , sequences_arr , k_mer = KMER_LENGTH )
848+ # Sample sequences for training based on length
849+ keep_sequences = [(path , s ) for s in filter (self .length_sampling_filter , sequences )]
850+ after_length_filter_length = len (keep_sequences )
829851
830- # Sample sequences for training
831- keep_sequences = [(path , s ) for s in filter (self .length_sampling_filter , random_walk_sequences )]
852+ # Log how many sequences were parsed
853+ log_msg = f"[PetaGraphStreamDataset:{ self .rank } ] Parsed { loaded_length } > { after_alphabet_filter_length } > { after_length_filter_length } sequences from { path } "
854+ log_rank (log_msg , logger = self .logger , level = logging .INFO , rank = self .rank )
832855
833856 # Test outputs
834857 if len (keep_sequences ) == 0 :
835- return [( "" , "" )]
858+ return deque ([( path , "" )])
836859
837860 assert isinstance (keep_sequences , list )
838861 assert isinstance (keep_sequences [0 ], tuple ) and len (keep_sequences [0 ]) == 2
@@ -841,7 +864,7 @@ def fasta_parsing_func(self, input_data: Tuple[str, bytes]):
841864 # Shuffle the sequences
842865 random .shuffle (keep_sequences )
843866
844- return keep_sequences
867+ return deque ( keep_sequences )
845868
846869 def crop_maxlen (self , input_sequence : str , maxlen : int = None ):
847870 # path, input_sequence = input_data
@@ -864,9 +887,10 @@ def tokenize_and_pad(self, input_sequence: str, apply_pad: bool = True):
864887 tokenized_sequence .append (self ._eos_token_id ) # end with EOS token
865888 tokenized_sequence = np .array (tokenized_sequence , dtype = np .int32 )
866889
867- if self .reverse_probability > 0.0 :
868- if np .random .rand () < self .reverse_probability :
869- tokenized_sequence = tokenized_sequence [::- 1 ]
890+ # No longer done here, done in the `generate` method, 5th Feb 2025
891+ # if self.reverse_probability > 0.0:
892+ # if np.random.rand() < self.reverse_probability:
893+ # tokenized_sequence = tokenized_sequence[::-1]
870894
871895 # Pad the sequence
872896 if apply_pad and len (tokenized_sequence ) < maxlen :
@@ -880,7 +904,8 @@ def tokenize_and_pad(self, input_sequence: str, apply_pad: bool = True):
880904
881905 def generate (self ):
882906 current_tokens = None
883- current_sequences = []
907+ current_sequences = deque ()
908+ last_reversed = False
884909 while True :
885910 try :
886911
@@ -906,14 +931,25 @@ def generate(self):
906931
907932 current_sequences = self .fasta_parsing_func ((source_path , decompressed_data ))
908933
909- # Remove the first sequence
910- source_path , text_raw = current_sequences .pop (0 )
911- if text_raw is None or len (text_raw ) == 0 :
912- continue
913934
914- # Log the consumed sequences
915- with self .num_consumed_sequences .get_lock ():
916- self .num_consumed_sequences .value += 1
935+ # We're performing reverse complementation augmentations
936+ if self .reverse_probability > 0.0 :
937+ # Apply the augmentation at random and only once per sequence
938+ if np .random .rand () < self .reverse_probability and not last_reversed :
939+ last_reversed = True
940+ # Just read the first sequence, but don't pop it
941+ # Next iteration will read the same sequence again
942+ source_path , text_raw = current_sequences [0 ]
943+ text_raw = "" .join ([STRING_COMPLEMENT_MAP [base ] for base in text_raw [::- 1 ]])
944+
945+ else :
946+ last_reversed = False
947+ source_path , text_raw = current_sequences .popleft ()
948+
949+ # No rev. comp. augmentations
950+ else :
951+ source_path , text_raw = current_sequences .popleft ()
952+
917953
918954 # Log the consumed files
919955 if self .log_directory is not None :
@@ -930,6 +966,13 @@ def generate(self):
930966 self .logging_func (f"Epoch { self .current_epoch } completed" )
931967 self .consumed_files = set ()
932968
969+ if text_raw is None or len (text_raw ) == 0 :
970+ continue
971+
972+ # Log the consumed sequences
973+ with self .num_consumed_sequences .get_lock ():
974+ self .num_consumed_sequences .value += 1
975+
933976 except StopIteration as e :
934977 self .logger .warning (f"Reached end of dataset: { e } " )
935978
@@ -975,7 +1018,7 @@ def generate(self):
9751018 else :
9761019 # Check the last token of the current sequence
9771020 # is an EOS token or BOS token (if reverse_probability > 0.0)
978- assert current_tokens [- 1 ] == self ._eos_token_id or ( self . reverse_probability > 0.0 and current_tokens [ - 1 ] == self . _bos_token_id )
1021+ assert current_tokens [- 1 ] == self ._eos_token_id
9791022 current_tokens = np .concatenate ([current_tokens , new_tokens ])
9801023
9811024 if len (current_tokens ) >= self .maxlen :
0 commit comments