Skip to content

Commit 6f80023

Browse files
lantigaezyang
authored andcommitted
Port ATen and JIT C++ tests to Catch2 (#5788)
This PR addresses #5648. In particular, following the discussion at #5648: - it adds Catch as a submodule (https://github.com/catchorg/Catch2) in torch/aten/utils - it ports all ATen tests to Catch - it ports torch/csrc/jit/test_jit.cpp to Catch (libtorch only, Python build is unaffected)
1 parent cf2e176 commit 6f80023

File tree

20 files changed

+758
-636
lines changed

20 files changed

+758
-636
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@
1414
path = aten/src/ATen/cpu/tbb/tbb_remote
1515
url = https://github.com/01org/tbb
1616
branch = tbb_2018
17+
[submodule "aten/src/ATen/utils/catch"]
18+
path = aten/src/ATen/utils/catch
19+
url = https://github.com/catchorg/Catch2.git

.jenkins/build.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ fi
7474

7575
if [[ "$BUILD_ENVIRONMENT" != *cuda* ]]; then
7676
echo "Testing ATen"
77-
time tools/run_aten_tests.sh
77+
( unset LD_PRELOAD; time tools/run_aten_tests.sh )
7878
fi
7979

8080
# Test C FFI plugins
@@ -99,7 +99,8 @@ fi
9999
if [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda9-cudnn7-py3 ]] || \
100100
[[ "$BUILD_ENVIRONMENT" == *pytorch-linux-trusty-py3.6-gcc7.2 ]]; then
101101
echo "Building libtorch with NO_PYTHON"
102-
pushd tools/cpp_build || exit 1
103-
bash build_all.sh
102+
LIBTORCH_INSTALL_PREFIX=`pwd`/../libtorch
103+
pushd tools/cpp_build
104+
bash build_all.sh "$LIBTORCH_INSTALL_PREFIX" || exit 1
104105
popd
105106
fi

.jenkins/test.sh

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,17 @@ rm -rf ninja
6565
pushd vision
6666
time python setup.py install
6767
popd
68+
69+
if [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda9-cudnn7-py3 ]] || \
70+
[[ "$BUILD_ENVIRONMENT" == *pytorch-linux-trusty-py3.6-gcc7.2 ]]; then
71+
echo "Testing libtorch with NO_PYTHON"
72+
LIBTORCH_INSTALL_PREFIX=`pwd`/../libtorch
73+
pushd tools/cpp_build
74+
75+
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
76+
"$LIBTORCH_INSTALL_PREFIX"/bin/test_jit
77+
else
78+
"$LIBTORCH_INSTALL_PREFIX"/bin/test_jit "[cpu]"
79+
fi
80+
popd
81+
fi

aten/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ ${CMAKE_CURRENT_SOURCE_DIR}/src/THCUNN)
476476
add_subdirectory(src/ATen)
477477
include_directories(
478478
${CMAKE_CURRENT_SOURCE_DIR}/src
479+
${CMAKE_CURRENT_SOURCE_DIR}/src/ATen/utils/catch/single_include
479480
${CMAKE_CURRENT_BINARY_DIR}/src/ATen)
480481
if(NOT NO_CUDA)
481482
include_directories(${CUDA_INCLUDE_DIRS})

aten/src/ATen/test/atest.cpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
1+
#define CATCH_CONFIG_MAIN
2+
#include "catch.hpp"
3+
14
#include "ATen/ATen.h"
2-
#include "test_assert.h"
35
#include "test_seed.h"
46

57
#include<iostream>
68
using namespace std;
79
using namespace at;
810

9-
void check(bool c) {
10-
if(!c)
11-
throw std::runtime_error("check failed.");
12-
}
13-
1411
void trace() {
1512
Tensor foo = rand(CPU(kFloat), {12,12});
1613

@@ -21,64 +18,72 @@ void trace() {
2118
for(int i = 0; i < foo_a.size(0); i++) {
2219
trace += foo_a[i][i];
2320
}
24-
cout << trace << "\n" << foo << "\n";
21+
22+
REQUIRE(Scalar(foo.trace()).toFloat() == Approx(trace));
2523
}
26-
int main() {
24+
25+
TEST_CASE( "atest", "[]" ) {
26+
2727
manual_seed(123);
2828

2929
auto foo = rand(CPU(kFloat), {12,6});
30-
ASSERT(foo.data<float>() == foo.toFloatData());
30+
REQUIRE(foo.data<float>() == foo.toFloatData());
3131

32-
cout << foo << "\n" << foo.size(0) << " " << foo.size(1) << endl;
32+
REQUIRE(foo.size(0) == 12);
33+
REQUIRE(foo.size(1) == 6);
3334

3435
foo = foo+foo*3;
3536
foo -= 4;
3637

3738
{
3839
Tensor no;
39-
ASSERT_THROWS(add_out(no,foo,foo));
40+
REQUIRE_THROWS(add_out(no,foo,foo));
4041
}
4142
Scalar a = 4;
4243

4344
float b = a.to<float>();
44-
check(b == 4);
45+
REQUIRE(b == 4);
4546

4647
foo = (foo*foo) == (foo.pow(3));
4748
foo = 2 + (foo+1);
4849
//foo = foo[3];
4950
auto foo_v = foo.accessor<uint8_t,2>();
5051

51-
cout << foo_v.size(0) << " " << foo_v.size(1) << endl;
5252
for(int i = 0; i < foo_v.size(0); i++) {
5353
for(int j = 0; j < foo_v.size(1); j++) {
54-
//cout << foo_v[i][j] << " ";
5554
foo_v[i][j]++;
5655
}
57-
//cout << "\n";
5856
}
5957

60-
61-
cout << foo << "\n";
58+
REQUIRE(foo.equal(4 * CPU(kByte).ones({12, 6})));
6259

6360
trace();
6461

6562
float data[] = { 1, 2, 3,
6663
4, 5, 6};
6764

6865
auto f = CPU(kFloat).tensorFromBlob(data, {1,2,3});
66+
auto f_a = f.accessor<float,3>();
67+
68+
REQUIRE(f_a[0][0][0] == 1.0);
69+
REQUIRE(f_a[0][1][1] == 5.0);
70+
71+
REQUIRE(f.strides()[0] == 6);
72+
REQUIRE(f.strides()[1] == 3);
73+
REQUIRE(f.strides()[2] == 1);
74+
REQUIRE(f.sizes()[0] == 1);
75+
REQUIRE(f.sizes()[1] == 2);
76+
REQUIRE(f.sizes()[2] == 3);
6977

70-
cout << f << endl;
71-
cout << f.strides() << " " << f.sizes() << endl;
72-
ASSERT_THROWS(f.resize_({3,4,5}));
78+
REQUIRE_THROWS(f.resize_({3,4,5}));
7379
{
7480
int isgone = 0;
7581
{
7682
auto f2 = CPU(kFloat).tensorFromBlob(data, {1,2,3}, [&](void*) {
7783
isgone++;
7884
});
79-
cout << f2 << endl;
8085
}
81-
check(isgone == 1);
86+
REQUIRE(isgone == 1);
8287
}
8388
{
8489
int isgone = 0;
@@ -89,9 +94,9 @@ int main() {
8994
});
9095
a_view = f2.view({3,2,1});
9196
}
92-
check(isgone == 0);
97+
REQUIRE(isgone == 0);
9398
a_view.reset();
94-
check(isgone == 1);
99+
REQUIRE(isgone == 1);
95100
}
96101

97102
if(at::hasCUDA()) {
@@ -101,9 +106,6 @@ int main() {
101106
isgone++;
102107
});
103108
}
104-
check(isgone==1);
109+
REQUIRE(isgone==1);
105110
}
106-
107-
108-
return 0;
109111
}

0 commit comments

Comments
 (0)