Skip to content

Conversation

@n0gu-furiosa
Copy link
Contributor

What does this PR do?

This PR optimizes to_py_obj by adding early returns for python-native numeric scalars and 1D lists/tuples of numbers. This avoids unnecessary recursive conversions, which can significantly impact performance of decode().

Fixes #36872

In the provided example from #36872, the runtime decreased from approximately 11 seconds to 0.8 seconds.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1 @ArthurZucker

@github-actions github-actions bot marked this pull request as draft March 21, 2025 12:58
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the Ready for review button (at the bottom of the PR page).

@n0gu-furiosa n0gu-furiosa marked this pull request as ready for review March 21, 2025 13:03
@Rocketknight1
Copy link
Member

This seems like a good idea, but I worry that there might be some edge cases where it fails because it only tests the first element!

One idea would be to use a conversion like np.array() as a test, without actually returning its output. For example, if:

  • np.array() succeeds
  • The returned array has 1 dimension
  • The returned array has int/float dtype

Then we can guarantee that the input list/tuple was a flat list of numbers and just return the original list. It would guarantee correctness without needing to recursively call to_py_obj(). WDYT?

@n0gu-furiosa
Copy link
Contributor Author

@Rocketknight1 Great idea. Thanks for the suggestion!

I ran some benchmarks comparing different approaches (using the same example code, but with an increased number of iterations). Here are the results:

  • Baseline before this PR: ~29.5s per 3000 iterations
  • Initial optimization (63bfca8): ~2.3s per 3000 iterations
  • Initial optimization without the obj[0] bypass hack (code as follows): ~3.0s per 3000 iterations
    elif isinstance(obj, (list, tuple)):
        return [to_py_obj(o) for o in obj]
    This version was tested to serve as a baseline for approaches that traverse all elements.
  • And finally, the suggested np.array approach (4d4baa7): ~2.7s

As you see, the np.array-based check seems to offer a nice balance between type safety and performance.

Also, instead of checking whether the array has 1 dimension and returning obj (a), I opted to return arr.tolist() regardless of the array's dimension (b). This allows the same optimization to apply to native multi-dimensional python lists as well. I benchmarked both options using the same example code, and the results were not significantly different (~2.47s for (a) and ~2.49s for (b)). Since this test used a 1D array - which favors (a) - I believe (b) is generally a more flexible and equally performant option.

Let me know if you have any feedback or further suggestions.

Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this seems good now! I made one small suggestion, so let me know what you think and then we can merge this.

"""
Convert a TensorFlow tensor, PyTorch tensor, Numpy array or python list to a python list.
"""
if is_py_number(obj):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if is_py_number(obj):
if isinstance(obj, (int, float)):

Since we're only using the function once, we can just inline it directly here. I think it's understandable!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you merge this you should also delete is_py_number()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applied in a04e338. Also added some test code for to_py_obj here.

@Rocketknight1
Copy link
Member

Also, one more thought: I think this should still work even if obj is a list/tuple containing e.g. Torch/TF arrays. However, we should be careful around that case, since I think np.array() will convert lists of those too.

@Rocketknight1
Copy link
Member

This looks good to me now! Ping me whenever you're ready for me to merge it @n0gu-furiosa

@n0gu-furiosa
Copy link
Contributor Author

@Rocketknight1 Everything’s ready on my end. Please feel free to merge whenever you get a chance. Thanks in advance!

@ArthurZucker
Copy link
Collaborator

thanks 🤗

@ArthurZucker ArthurZucker merged commit d1eafe8 into huggingface:main Mar 27, 2025
18 checks passed
@n0gu-furiosa n0gu-furiosa deleted the n0gu-fix branch March 27, 2025 13:17
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
…gingface#36885)

* Optimize to_py_obj for python-native numeric lists and scalars

* Fix bug that tuple is not converted to list

* Try np.array for more robust type checking

* Apply review and add tests for to_py_obj
soghomon-b pushed a commit to soghomon-b/transformers that referenced this pull request Aug 24, 2025
…gingface#36885)

* Optimize to_py_obj for python-native numeric lists and scalars

* Fix bug that tuple is not converted to list

* Try np.array for more robust type checking

* Apply review and add tests for to_py_obj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Optimize tokenizer.decode() Performance for List[int] Inputs

3 participants