Skip to content

Support python slicing with tensor inputs.#165074

Closed
laithsakka wants to merge 14 commits intogh/laithsakka/311/basefrom
gh/laithsakka/311/head
Closed

Support python slicing with tensor inputs.#165074
laithsakka wants to merge 14 commits intogh/laithsakka/311/basefrom
gh/laithsakka/311/head

Conversation

@laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Oct 9, 2025

Stack from ghstack (oldest at bottom):

when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE.
the diff also fix an existing bug in codegen_dynamic_slice_size in the cpp wrapper. a +1 should be -1 making it match
python codegen.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165074

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 5c835c5 with merge base e595136 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

laithsakka added a commit that referenced this pull request Oct 9, 2025
@laithsakka laithsakka changed the title WIP Support python slicing with data depedennt inptu tensors maybe [WIP] Support python slicing with tensor inputs. Oct 9, 2025
@laithsakka laithsakka marked this pull request as draft October 9, 2025 19:05
allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Oct 15, 2025
allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Oct 17, 2025
allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Oct 17, 2025
allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Oct 17, 2025
allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Oct 17, 2025
@laithsakka
Copy link
Contributor Author

great i just need to add unit tests!

allow things like
``` 
#!/usr/bin/env python
import torch

print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)

# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f"   x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print("   ✓ Eager mode works!")

# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
    return x[:idx]

try:
    compiled_fn = torch.compile(slice_fn)
    x = torch.randn(10)
    idx = torch.tensor(4)
    result = compiled_fn(x, idx)
    print(f"   Compiled x[:idx] where idx=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Compiled mode works!")
except Exception as e:
    print(f"   ✗ Compiled mode failed: {e}")
    import traceback
    traceback.print_exc()

# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
    idx = lengths.sum()
    return x[:idx]

try:
    compiled_fn = torch.compile(dynamic_slice_fn)
    x = torch.randn(10)
    lengths = torch.tensor([1, 1, 1, 1])
    result = compiled_fn(x, lengths)
    print(f"   Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
    assert result.shape[0] == 4
    print("   ✓ Dynamic slicing works!")
except Exception as e:
    print(f"   ✗ Dynamic slicing failed: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)

```


[ghstack-poisoned]
Copy link
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally this looks good, to me but needs some cleanup of asserts and I would like to see a test with dynamic=True to make sure the checks before .item() calls have reasonable behavior

Copy link
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See comments - few minor changes suggested but this looks good otherwise

@laithsakka
Copy link
Contributor Author

laithsakka commented Oct 25, 2025

Generally this looks good, to me but needs some cleanup of asserts and I would like to see a test with dynamic=True to make sure the checks before .item() calls have reasonable behavior

ah i see i will just remove those checks as you suggested probably size will be none with the input tensor is dynamic.

when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE. 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Oct 27, 2025
@Lucaskabela
Copy link
Contributor

Test failures seem legitimate - feel free to re-ping once fixed

when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE. 

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
Copy link
Contributor

@Lucaskabela Lucaskabela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing feedback and fixing the test failures - the dtensor failure seems unrelated to me, so approving

@laithsakka
Copy link
Contributor Author

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 28, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3.10-gcc11 / test (distributed, 2, 2, linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 2, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@laithsakka
Copy link
Contributor Author

@pytorchbot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3.10-gcc11 / test (distributed, 2, 2, linux.2xlarge)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo module: inductor release notes: inductor (aoti) topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants