Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
)
return super().__torch_function__(func, types_, args, kwargs)

def __deepcopy__(self, memo):
"""
Custom implementation of deepcopy for TorchTensor to avoid storage sharing issues.
"""
# Create a new tensor with the same data and properties
new_tensor = self.clone()
# Set the class to the custom TorchTensor class
new_tensor.__class__ = self.__class__
return new_tensor

@classmethod
def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
"""Create a `tensor from a numpy array
Expand Down
18 changes: 18 additions & 0 deletions tests/integrations/typing/test_torch_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import torch
from docarray.typing.tensor.torch_tensor import TorchTensor
import copy

from docarray import BaseDoc
from docarray.typing import TorchEmbedding, TorchTensor
Expand All @@ -25,3 +27,19 @@ class MyDocument(BaseDoc):
assert isinstance(d.embedding, TorchEmbedding)
assert isinstance(d.embedding, torch.Tensor)
assert (d.embedding == torch.zeros((128,))).all()


def test_torchtensor_deepcopy():
# Setup
original_tensor_float = TorchTensor(torch.rand(10))
original_tensor_int = TorchTensor(torch.randint(0, 100, (10,)))

# Exercise
copied_tensor_float = copy.deepcopy(original_tensor_float)
copied_tensor_int = copy.deepcopy(original_tensor_int)

# Verify
assert torch.equal(original_tensor_float, copied_tensor_float)
assert original_tensor_float is not copied_tensor_float
assert torch.equal(original_tensor_int, copied_tensor_int)
assert original_tensor_int is not copied_tensor_int
Comment on lines +43 to +45

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test is not enough. You should rather do

assert original_tensor_float.data_ptr() != copied_tensor_float.data_ptr()

indeed in pytorch each view of a tensor would have a different python ID but what matter is the underlying storage.

Good PR otherwise !!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let me fix that..!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this test is not enough. You should rather do

assert original_tensor_float.data_ptr() != copied_tensor_float.data_ptr()

indeed in pytorch each view of a tensor would have a different python ID but what matter is the underlying storage.

Good PR otherwise !!

Hi @samsja
Accidently I deleted the repo from my local machine that's why I had to clone it again and regarding that I am unable to change in this merged pull request, so for to reafctor the test I need to create another PR.
So should I create a another PR regarding that?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes you would have to create another PR for this