Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docarray/typing/tensor/torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,10 @@ def _docarray_from_ndarray(cls: Type[T], value: np.ndarray) -> T:
def _docarray_to_ndarray(self) -> np.ndarray:
"""cast itself to a numpy array"""
return self.detach().cpu().numpy()

def new_empty(self, *args, **kwargs):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we copy the full signature of the original method ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it is always the same, I do not think it makes sense, this is way more mantainable

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this break everything related to mypy and pycharm feature. In DocArray v2 we always repeat the full signature of the function

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but mypy check passes

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method on TorchEmbedding does the same.
This is a method that noone should use
It is easier to forget to update this method than anything else.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay oaky

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy pass because it does not look at it. It will only check if there is type hint

"""
This method enables the deepcopy of `TorchTensor` by returning another instance of this subclass.
If this function is not implemented, the deepcopy will throw an RuntimeError from Torch.
"""
return self.__class__(*args, **kwargs)
16 changes: 16 additions & 0 deletions tests/units/typing/tensor/test_torch_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,22 @@ class MMdoc(BaseDoc):
assert not (doc.embedding == doc_copy.embedding).all()


def test_deepcopy_tensor():
from docarray import BaseDoc

class MMdoc(BaseDoc):
embedding: TorchTensor

doc = MMdoc(embedding=torch.randn(32))
doc_copy = doc.copy(deep=True)

assert doc.embedding.data_ptr() != doc_copy.embedding.data_ptr()
assert (doc.embedding == doc_copy.embedding).all()

doc_copy.embedding = torch.randn(32)
assert not (doc.embedding == doc_copy.embedding).all()


@pytest.mark.parametrize('requires_grad', [True, False])
def test_json_serialization(requires_grad):
orig_doc = MyDoc(tens=torch.rand(10, requires_grad=requires_grad))
Expand Down