Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Mar 23, 2018

Fixes #5739. The CUDA path for torch.cat was missing a check for the
case where all input tensors are empty.

cc @ssnl could you take a look at this one?

Fixes pytorch#5739. The CUDA path for `torch.cat` was missing a check for the
case where all input tensors are empty.
dtype = torch.float32

x = torch.randn((4, 3, 32, 32), dtype=dtype)
empty = torch.randn((0,), dtype=dtype)

This comment was marked as off-topic.

This comment was marked as off-topic.


// If all inputs are empty tensors, return an empty tensor
if (notEmptyTensor == NULL) {
return;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Collaborator

@ssnl ssnl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for the explanations @zou3519 !

@soumith soumith merged commit 9923701 into pytorch:master Mar 24, 2018
sighingnow added a commit to sighingnow/pytorch that referenced this pull request Mar 25, 2018
* upstream/master: (663 commits)
  Fix "command not found" error in perf test (pytorch#5982)
  add pip mkl-devel to the error message when mkl is found but mkl headers are not (pytorch#5984)
  Support batch LowerCholeskyTransform (pytorch#5980)
  Linearly interpolating upsampling fix (pytorch#5927)
  Store perf numbers in S3 (pytorch#5951)
  Modidy setup docs for Windows (pytorch#5981)
  Group Normalization (pytorch#5968)
  [distributions] Implement Power transform (pytorch#5976)
  Disable TestBottleneck test_cuda on Windows (pytorch#5977)
  Fix crash when cat-ing empty cuda tensors (pytorch#5971)
  Update no_unions flag for nanopb gen and update ONNX proto files (pytorch#5972)
  Expose gradients w.r.t. input & weight for conv1d, conv2d, conv3d in Python (pytorch#5408)
  Fixed non-determinate preprocessing on DataLoader (pytorch#4640)
  add AVX2 implementation for sigmoid function (pytorch#5010)
  Implement torch.util.bottleneck (pytorch#5216)
  Remove pragma once from cpp file (pytorch#5965)
  fix mvn docs (pytorch#5967)
  Fix incorrect rendering of Tensor.index_*_ doc examples. (pytorch#5969)
  Implement range for loop in script (pytorch#5827)
  Add windows doc (pytorch#5859)
  ...

# Conflicts:
#	aten/src/TH/generic/THTensorMath.c
#	torch/_tensor_docs.py
#	torch/csrc/generic/methods/TensorCompare.cwrap
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants