Support python slicing with tensor inputs.#165074
Support python slicing with tensor inputs.#165074laithsakka wants to merge 14 commits intogh/laithsakka/311/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165074
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 5c835c5 with merge base e595136 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
allow things like
```
#!/usr/bin/env python
import torch
print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)
# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f" x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Eager mode works!")
# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
return x[:idx]
try:
compiled_fn = torch.compile(slice_fn)
x = torch.randn(10)
idx = torch.tensor(4)
result = compiled_fn(x, idx)
print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Compiled mode works!")
except Exception as e:
print(f" ✗ Compiled mode failed: {e}")
import traceback
traceback.print_exc()
# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
idx = lengths.sum()
return x[:idx]
try:
compiled_fn = torch.compile(dynamic_slice_fn)
x = torch.randn(10)
lengths = torch.tensor([1, 1, 1, 1])
result = compiled_fn(x, lengths)
print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Dynamic slicing works!")
except Exception as e:
print(f" ✗ Dynamic slicing failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)
```
[ghstack-poisoned]
allow things like
```
#!/usr/bin/env python
import torch
print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)
# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f" x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Eager mode works!")
# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
return x[:idx]
try:
compiled_fn = torch.compile(slice_fn)
x = torch.randn(10)
idx = torch.tensor(4)
result = compiled_fn(x, idx)
print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Compiled mode works!")
except Exception as e:
print(f" ✗ Compiled mode failed: {e}")
import traceback
traceback.print_exc()
# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
idx = lengths.sum()
return x[:idx]
try:
compiled_fn = torch.compile(dynamic_slice_fn)
x = torch.randn(10)
lengths = torch.tensor([1, 1, 1, 1])
result = compiled_fn(x, lengths)
print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Dynamic slicing works!")
except Exception as e:
print(f" ✗ Dynamic slicing failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)
```
[ghstack-poisoned]
allow things like
```
#!/usr/bin/env python
import torch
print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)
# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f" x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Eager mode works!")
# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
return x[:idx]
try:
compiled_fn = torch.compile(slice_fn)
x = torch.randn(10)
idx = torch.tensor(4)
result = compiled_fn(x, idx)
print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Compiled mode works!")
except Exception as e:
print(f" ✗ Compiled mode failed: {e}")
import traceback
traceback.print_exc()
# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
idx = lengths.sum()
return x[:idx]
try:
compiled_fn = torch.compile(dynamic_slice_fn)
x = torch.randn(10)
lengths = torch.tensor([1, 1, 1, 1])
result = compiled_fn(x, lengths)
print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Dynamic slicing works!")
except Exception as e:
print(f" ✗ Dynamic slicing failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)
```
[ghstack-poisoned]
allow things like
```
#!/usr/bin/env python
import torch
print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)
# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f" x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Eager mode works!")
# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
return x[:idx]
try:
compiled_fn = torch.compile(slice_fn)
x = torch.randn(10)
idx = torch.tensor(4)
result = compiled_fn(x, idx)
print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Compiled mode works!")
except Exception as e:
print(f" ✗ Compiled mode failed: {e}")
import traceback
traceback.print_exc()
# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
idx = lengths.sum()
return x[:idx]
try:
compiled_fn = torch.compile(dynamic_slice_fn)
x = torch.randn(10)
lengths = torch.tensor([1, 1, 1, 1])
result = compiled_fn(x, lengths)
print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Dynamic slicing works!")
except Exception as e:
print(f" ✗ Dynamic slicing failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)
```
[ghstack-poisoned]
allow things like
```
#!/usr/bin/env python
import torch
print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)
# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f" x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Eager mode works!")
# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
return x[:idx]
try:
compiled_fn = torch.compile(slice_fn)
x = torch.randn(10)
idx = torch.tensor(4)
result = compiled_fn(x, idx)
print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Compiled mode works!")
except Exception as e:
print(f" ✗ Compiled mode failed: {e}")
import traceback
traceback.print_exc()
# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
idx = lengths.sum()
return x[:idx]
try:
compiled_fn = torch.compile(dynamic_slice_fn)
x = torch.randn(10)
lengths = torch.tensor([1, 1, 1, 1])
result = compiled_fn(x, lengths)
print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Dynamic slicing works!")
except Exception as e:
print(f" ✗ Dynamic slicing failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)
```
[ghstack-poisoned]
|
great i just need to add unit tests! |
allow things like
```
#!/usr/bin/env python
import torch
print("="*60)
print("Testing tensor slicing with torch.compile")
print("="*60)
# Test 1: Simple eager mode
print("\n1. Eager mode test:")
x = torch.randn(10)
idx = torch.tensor(4)
result = x[:idx]
print(f" x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Eager mode works!")
# Test 2: With torch.compile
print("\n2. Compiled mode test:")
def slice_fn(x, idx):
return x[:idx]
try:
compiled_fn = torch.compile(slice_fn)
x = torch.randn(10)
idx = torch.tensor(4)
result = compiled_fn(x, idx)
print(f" Compiled x[:idx] where idx=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Compiled mode works!")
except Exception as e:
print(f" ✗ Compiled mode failed: {e}")
import traceback
traceback.print_exc()
# Test 3: With dynamic slicing from sum
print("\n3. Dynamic slicing with sum:")
def dynamic_slice_fn(x, lengths):
idx = lengths.sum()
return x[:idx]
try:
compiled_fn = torch.compile(dynamic_slice_fn)
x = torch.randn(10)
lengths = torch.tensor([1, 1, 1, 1])
result = compiled_fn(x, lengths)
print(f" Compiled x[:lengths.sum()] where sum=4: result.shape = {result.shape}")
assert result.shape[0] == 4
print(" ✓ Dynamic slicing works!")
except Exception as e:
print(f" ✗ Dynamic slicing failed: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("SUMMARY: Check results above")
print("="*60)
```
[ghstack-poisoned]
Lucaskabela
left a comment
There was a problem hiding this comment.
Generally this looks good, to me but needs some cleanup of asserts and I would like to see a test with dynamic=True to make sure the checks before .item() calls have reasonable behavior
Lucaskabela
left a comment
There was a problem hiding this comment.
See comments - few minor changes suggested but this looks good otherwise
ah i see i will just remove those checks as you suggested probably size will be none with the input tensor is dynamic. |
when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
|
Test failures seem legitimate - feel free to re-ping once fixed |
when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela [ghstack-poisoned]
Lucaskabela
left a comment
There was a problem hiding this comment.
Thanks for addressing feedback and fixing the test failures - the dtensor failure seems unrelated to me, so approving
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3.10-gcc11 / test (distributed, 2, 2, linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (default, 2, 5, lf.linux.g6.4xlarge.experimental.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: pull / linux-jammy-py3.10-gcc11 / test (distributed, 2, 2, linux.2xlarge) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
when the slice is tensor, we decompose it to .item() call and pass the unbacked symbol to the slice to avoid DDE.
the diff also fix an existing bug in codegen_dynamic_slice_size in the cpp wrapper. a +1 should be -1 making it match
python codegen.
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @Lucaskabela