Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .jenkins/caffe2/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ if [[ "$BUILD_ENVIRONMENT" == *onnx* ]]; then
# default pip version is too old(9.0.2), unable to support tag `manylinux2010`.
# Fix the pip error: Couldn't find a version that satisfies the requirement
sudo pip install --upgrade pip
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.3.1.dev202007102
pip install -q --user -i https://test.pypi.org/simple/ ort-nightly==1.4.0.dev202007311
fi
"$ROOT_DIR/scripts/onnx/test.sh"
fi
3 changes: 2 additions & 1 deletion test/onnx/test_pytorch_onnx_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2699,7 +2699,8 @@ def forward(self, x):

class ComparisonModel(torch.nn.Module):
def forward(self, x, y):
return x.ge(0.5) & y.le(2)
a = torch.tensor([12.0])
return x.lt(1.5) & y.le(2) & x.le(1), x.gt(y), x.lt(y), a.ge(x.size(0))

x = torch.ones(2, 3, dtype=torch.int32)
y = torch.ones(2, 3, dtype=torch.float32)
Expand Down
55 changes: 33 additions & 22 deletions torch/csrc/jit/passes/onnx/scalar_type_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,30 +121,41 @@ static c10::optional<c10::ScalarType> InferExpectedScalarType(const Node* n) {
const c10::optional<c10::ScalarType> output_st =
n->output()->type()->cast<TensorType>()->scalarType();

if (typesFromScalars.size() == n->inputs().size()) {
// If all inputs are scalars, infer scalar_type by calling
// c10::promoteTypes.
if (IsComparisonOp(n->kind())) {
// For comparison ops, always promote scalar type to highest among inputs,
// regardless if that input is a tensor or scalar.
typesFromScalars.insert(
typesFromScalars.end(),
typesFromTensors.begin(),
typesFromTensors.end());
st = PromoteScalarTypes(typesFromScalars);
} else if (output_st && !IsComparisonOp(n->kind())) {
// If output scalar type is available, use that.
st = output_st;
} else if (!typesFromTensors.empty()) {
// When inputs consist of tensors and scalars. In PyTorch, scalars are
// implicitly casted to have the same scalar type as input tensors.
st = typesFromTensors[0];
if (std::any_of(
typesFromTensors.begin(),
typesFromTensors.end(),
[&st](const c10::ScalarType& type) { return type != st; })) {
std::cerr
<< "Warning: ONNX Scalar Type Analysis - Scalar types mismatch for tensor inputs of operator "
<< n->kind().toDisplayString() << ". Please report a bug to PyTorch. "
<< "The scalar type " << c10::toString(*st)
<< " of the first tensor is chosen." << std::endl;
}
} else {
// When inputs consist of only scalars.
st = PromoteScalarTypes(typesFromScalars);
if (typesFromScalars.size() == n->inputs().size()) {
// If all inputs are scalars, infer scalar_type by calling
// c10::promoteTypes.
st = PromoteScalarTypes(typesFromScalars);
} else if (output_st) {
// If output scalar type is available, use that.
st = output_st;
} else if (!typesFromTensors.empty()) {
// When inputs consist of tensors and scalars. In PyTorch, scalars are
// implicitly casted to have the same scalar type as input tensors.
st = typesFromTensors[0];
if (std::any_of(
typesFromTensors.begin(),
typesFromTensors.end(),
[&st](const c10::ScalarType& type) { return type != st; })) {
std::cerr
<< "Warning: ONNX Scalar Type Analysis - Scalar types mismatch for tensor inputs of operator "
<< n->kind().toDisplayString()
<< ". Please report a bug to PyTorch. "
<< "The scalar type " << c10::toString(*st)
<< " of the first tensor is chosen." << std::endl;
}
} else {
// When inputs consist of only scalars.
st = PromoteScalarTypes(typesFromScalars);
}
}

return st;
Expand Down