-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathdata.py
More file actions
35 lines (28 loc) · 1.01 KB
/
data.py
File metadata and controls
35 lines (28 loc) · 1.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from datasets import load_dataset
from transformers import default_data_collator
def _preprocess(tokenizer, examples, max_length=128):
return tokenizer(
examples["text"], padding="max_length", truncation=True, max_length=max_length
)
def get_dataset(dataset_name, subset, split, size=None):
if size is None:
dataset = load_dataset(dataset_name, subset)[split]
else:
dataset = load_dataset(dataset_name, subset, streaming=True)[split]
dataset = dataset.take(size)
return dataset
def get_dataloader(dataset, tokenizer, batch_size, num_workers=4, max_length=128):
dataset = dataset.map(
lambda examples: _preprocess(tokenizer, examples, max_length),
batched=True,
batch_size=batch_size,
remove_columns=["text", "timestamp", "url"],
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
num_workers=0,
collate_fn=default_data_collator,
)
return dataloader