Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/source/data_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@

[[autodoc]] maybe_unpair_preference_dataset

## pack_examples

[[autodoc]] pack_examples

## pack_dataset

[[autodoc]] pack_dataset
Expand Down
10 changes: 8 additions & 2 deletions docs/source/reducing_memory_usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,16 @@ This technique applies only to SFT.
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.

<div class="flex justify-center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing.png" alt="Packing" width="600"/>
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing_2.png" alt="Packing" width="600"/>
</div>

Packing eliminates padding, preserves all sequence information, and allows for flexible sequence lengths, making it a more efficient alternative to truncation. To enable packing, use `packing=True` in the [`SFTConfig`]:
Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` and in the [`SFTConfig`].

<Tip>

In TRL 0.18 and earlier, packing used a more aggressive method that reduced padding to almost nothing, but had the downside of breaking sequence continuity for a large fraction of the dataset. To revert to this strategy, use `packing_strategy="wrapped"` in `SFTConfig`.

</Tip>

```python
from trl import SFTConfig
Expand Down
51 changes: 48 additions & 3 deletions tests/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def test_with_dataset(self):
self.assertEqual(dataset.to_dict(), expected_output)


class TestPackDataset(unittest.TestCase):
class TestPackDatasetWrapped(unittest.TestCase):
def test_with_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
Expand All @@ -451,7 +451,7 @@ def test_with_dataset(self):
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length)
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
self.assertEqual(dataset.to_dict(), expected_output)

def test_with_iterable_dataset(self):
Expand All @@ -465,11 +465,56 @@ def test_with_iterable_dataset(self):
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
"attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
}
dataset = pack_dataset(dataset, seq_length)
dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
num_examples = len(examples[next(iter(examples))])
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)


class TestPackDatasetFfd(unittest.TestCase):
def test_simple(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="ffd")
self.assertEqual(dataset.to_dict(), expected_output)

def test_with_iterable_dataset(self):
examples = {
"input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
"attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples).to_iterable_dataset()
seq_length = 4
expected_output = {
"input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
"attention_mask": [[0, 0, 1, 1], [0, 1, 1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="ffd")
num_examples = len(examples[next(iter(examples))])
self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output)

def test_with_truncation(self):
examples = {
"input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
"attention_mask": [[1, 1, 1, 1, 1], [1, 1], [1, 1, 1, 1], [1]],
}
dataset = Dataset.from_dict(examples)
seq_length = 4
expected_output = {
"input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]],
"attention_mask": [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1]],
}
dataset = pack_dataset(dataset, seq_length, strategy="ffd")
self.assertEqual(dataset.to_dict(), expected_output)


class TestTruncateExamples(unittest.TestCase):
def test_with_dataset(self):
examples = {
Expand Down
41 changes: 35 additions & 6 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import torch
from datasets import Dataset, Image, Sequence, load_dataset
from parameterized import parameterized
from transformers import (
AutoModelForCausalLM,
AutoProcessor,
Expand Down Expand Up @@ -812,7 +813,7 @@ def test_only_train_packing(self):
per_device_train_batch_size=2,
gradient_checkpointing=True,
packing=True,
max_length=16, # make sure there is at least 1 packed sequence
max_length=128, # make sure there is at least 1 packed sequence
eval_packing=False,
report_to="none",
)
Expand All @@ -824,15 +825,15 @@ def test_only_train_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)

self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), len(self.conversational_lm_dataset["test"]))

def test_eval_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=16, # make sure there is at least 1 packed sequence
max_length=128, # make sure there is at least 1 packed sequence
packing=True,
report_to="none",
)
Expand All @@ -843,15 +844,15 @@ def test_eval_packing(self):
eval_dataset=self.conversational_lm_dataset["test"],
)

self.assertEqual(len(trainer.train_dataset["input_ids"]), 46) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 6) # w/ this dataset, we end up with 6 seqs
self.assertEqual(len(trainer.train_dataset["input_ids"]), 7) # w/ this dataset, we end up with 46 seqs
self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs

def test_no_packing(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = SFTConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_length=16, # make sure there is at least 1 packed sequence
max_length=128, # make sure there is at least 1 packed sequence
packing=False,
report_to="none",
)
Expand Down Expand Up @@ -1229,3 +1230,31 @@ def test_train_padding_free(self):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

@parameterized.expand([("ffd",), ("wrapped",)])
def test_train_packing(self, packing_strategy):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
# Initialize the trainer
training_args = SFTConfig(
output_dir=tmp_dir, packing=True, packing_strategy=packing_strategy, max_length=10, report_to="none"
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", args=training_args, train_dataset=dataset
)

# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

# Train the model
trainer.train()

# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
160 changes: 127 additions & 33 deletions trl/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
import warnings
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Callable, Optional, TypeVar, Union

Expand Down Expand Up @@ -465,14 +466,117 @@ def pack_examples(examples: dict[str, list[list]], seq_length: int) -> dict[str,
{'input_ids': [[1, 2], [3, 4], [5, 6], [7, 8]], 'attention_mask': [[0, 1], [1, 0], [0, 1], [1, 1]]}
```
"""
warnings.warn(
"`pack_examples` is deprecated and will be removed in version 0.20.0. Use `pack_dataset` with a dataset "
"instead.",
DeprecationWarning,
)
# Join all the values into a single list
examples = {k: sum(v, []) for k, v in examples.items()}
# Split the values into chunks of size seq_length
examples = {k: [v[i : i + seq_length] for i in range(0, len(v), seq_length)] for k, v in examples.items()}
return examples


def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dict[str, Any]] = None) -> DatasetType:
def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using First Fit Decreasing strategy."""
packed_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
offsets, values = column.offsets, column.values
values = values[offsets[0].as_py() : offsets[-1].as_py()]

# Extract sequences using numpy for vectorized operations
offset_array = offsets.to_numpy()
starts = offset_array[:-1]
ends = offset_array[1:]
seq_lens = ends - starts

# Vectorized truncation
truncated_lens = np.minimum(seq_lens, seq_length)
truncated_ends = starts + truncated_lens

# Create sequences list with truncated values
sequences = list(zip(truncated_lens, starts, truncated_ends))

# Sort by length (decreasing) for First Fit Decreasing
sequences.sort(key=lambda x: x[0], reverse=True)

# Optimized bin packing using a priority queue approach
bins_by_remaining = defaultdict(list) # remaining_space -> [bin_indices]
bins = [] # [(current_length, seq_indices)]

for i, (seq_len, _start, _end) in enumerate(sequences):
# Find bins with enough space using the dictionary
placed = False
for remaining in range(seq_len, seq_length + 1):
if bins_by_remaining[remaining]:
# Use the first available bin with this remaining space
bin_idx = bins_by_remaining[remaining].pop()
current_len, seq_indices = bins[bin_idx]

# Update bin
new_len = current_len + seq_len
new_remaining = seq_length - new_len
bins[bin_idx] = (new_len, seq_indices + [i])

# Update the remaining space mapping
if new_remaining > 0:
bins_by_remaining[new_remaining].append(bin_idx)

placed = True
break

# If no bin fits, create new bin
if not placed:
bin_idx = len(bins)
bins.append((seq_len, [i]))
remaining = seq_length - seq_len
if remaining > 0:
bins_by_remaining[remaining].append(bin_idx)

# Reconstruct packed values more efficiently
values_numpy = values.to_numpy()
packed_values = []
new_offsets = [0]

for _, seq_indices in bins:
for seq_idx in seq_indices:
_, start, end = sequences[seq_idx]
packed_values.extend(values_numpy[start:end])
new_offsets.append(len(packed_values))

dtype = offsets.type.to_pandas_dtype()
new_offsets = np.array(new_offsets, dtype=dtype)
packed_values = pa.array(packed_values, type=values.type)
column = type(column).from_arrays(new_offsets, packed_values)
packed_columns.append(column)
return pa.Table.from_arrays(packed_columns, names=examples.column_names)


def _pack_wrapped(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using a wrapped strategy."""
packed_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
offsets, values = column.offsets, column.values
values = values[offsets[0].as_py() : offsets[-1].as_py()]
num_elements = len(values)
dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
offsets = np.concatenate((offsets, [num_elements]))
column = type(column).from_arrays(offsets, values)
packed_columns.append(column)
return pa.Table.from_arrays(packed_columns, names=examples.column_names)


def pack_dataset(
dataset: DatasetType, seq_length: int, strategy: str = "ffd", map_kwargs: Optional[dict[str, Any]] = None
) -> DatasetType:
r"""
Pack sequences in a dataset into chunks of size `seq_length`.

Expand All @@ -481,6 +585,13 @@ def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dic
Dataset to pack
seq_length (`int`):
Target sequence length to pack to.
strategy (`str`, *optional*, defaults to `"ffd"`):
Packing strategy to use. Can be either:

- `"ffd"` (First Fit Decreasing): Slower but preserves sequence boundaries. Sequences are never cut in the
middle.
- `"wrapped"`: Faster but more aggressive. Ignores sequence boundaries and will cut sequences in the middle
to completely fill each packed sequence with data.
map_kwargs (`dict` or `None`, *optional*, defaults to `None`):
Additional keyword arguments to pass to the dataset's map method when packing examples.

Expand All @@ -491,46 +602,29 @@ def pack_dataset(dataset: DatasetType, seq_length: int, map_kwargs: Optional[dic
Example:
```python
>>> from datasets import Dataset
>>> from trl import pack_dataset
>>> examples = {
... "input_ids": [[1, 2], [3, 4], [5, 6], [7]],
... "attention_mask": [[1, 1], [0, 1], [1, 1], [1]],
... "input_ids": [[1, 2, 3], [4, 5], [6, 7, 8], [9]],
... "attention_mask": [[1, 1, 0], [1, 0], [1, 0, 0], [1]]
... }
>>> dataset = Dataset.from_dict(examples)
>>> packed_dataset = pack_dataset(dataset, seq_length=4)
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="ffd")
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 4], [5, 6, 7]],
'attention_mask': [[1, 1, 0, 1], [1, 1, 1]]}
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
```
"""
if map_kwargs is None:
map_kwargs = {}
if isinstance(dataset, Dataset):
# Fast packing with pyarrow
def pack(examples):
packed_columns = []
for column in examples.columns:
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
offsets, values = column.offsets, column.values
values = values[offsets[0].as_py() : offsets[-1].as_py()]
num_elements = len(values)
dtype = offsets.type.to_pandas_dtype() # np.int32 or np.int64
offsets = np.arange(0, num_elements, seq_length, dtype=dtype)
offsets = np.concatenate((offsets, [num_elements]))
column = type(column).from_arrays(offsets, values)
packed_columns.append(column)
return pa.Table.from_arrays(packed_columns, names=examples.column_names)

dataset = dataset.with_format("arrow")
dataset = dataset.map(pack, batched=True, **map_kwargs)
dataset = dataset.with_format(None)
# Fast packing with pyarrow
dataset = dataset.with_format("arrow")
if strategy == "ffd":
dataset = dataset.map(_pack_ffd, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs)
elif strategy == "wrapped":
dataset = dataset.map(_pack_wrapped, batched=True, fn_kwargs={"seq_length": seq_length}, **map_kwargs)
else:
dataset = dataset.map(
functools.partial(pack_examples, seq_length=seq_length),
batched=True,
**map_kwargs,
)
raise ValueError(f"Invalid packing strategy: {strategy}. Use 'ffd' or 'wrapped'.")
dataset = dataset.with_format(None)
return dataset


Expand Down
Loading
Loading