Skip to content

Commit 96239dd

Browse files
yf225soumith
authored andcommitted
Add mutex for CPU RNG and move TH to C++ (#4041)
* Add mutex for CPU RNG * move more things to cpp to make cuda build work * fix mutex bug on OS X * try to fix cuda9 half .x bug * try to fix windows error * create THGeneratorState as seperate field * fix mutex issues
1 parent ca5071d commit 96239dd

31 files changed

+189
-141
lines changed

aten/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ ENDIF(C_AVX_FOUND)
245245
IF(C_AVX2_FOUND)
246246
MESSAGE(STATUS "AVX2 Found")
247247
SET(CMAKE_C_FLAGS "-DUSE_AVX2 ${CMAKE_C_FLAGS}")
248+
SET(CMAKE_CXX_FLAGS "-DUSE_AVX2 ${CMAKE_CXX_FLAGS}")
248249
ENDIF(C_AVX2_FOUND)
249250

250251
CHECK_C_SOURCE_RUNS("

aten/src/ATen/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ ENDIF(C_AVX_FOUND)
3535

3636
IF(C_AVX2_FOUND)
3737
IF(MSVC)
38-
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.c PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX2 ${C_AVX2_FLAGS}")
38+
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "${MSVC_OPT_FLAG}/arch:AVX2 ${C_AVX2_FLAGS}")
3939
ELSE(MSVC)
40-
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.c PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX2_FLAGS}")
40+
SET_SOURCE_FILES_PROPERTIES(${PROJECT_SOURCE_DIR}/src/TH/vector/AVX2.cpp PROPERTIES COMPILE_FLAGS "-O3 ${C_AVX2_FLAGS}")
4141
ENDIF(MSVC)
4242
ENDIF(C_AVX2_FOUND)
4343

aten/src/TH/CMakeLists.txt

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ IF(C_AVX_FOUND)
1717
ENDIF(C_AVX_FOUND)
1818

1919
IF(C_AVX2_FOUND)
20-
LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/vector/AVX2.c)
20+
LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/vector/AVX2.cpp)
2121
ENDIF(C_AVX2_FOUND)
2222

2323
SET(hdr
@@ -30,16 +30,16 @@ set(ATen_CPU_SRCS ${ATen_CPU_SRCS}
3030
${CMAKE_CURRENT_SOURCE_DIR}/THAllocator.c
3131
${CMAKE_CURRENT_SOURCE_DIR}/THSize.c
3232
${CMAKE_CURRENT_SOURCE_DIR}/THStorage.c
33-
${CMAKE_CURRENT_SOURCE_DIR}/THTensor.c
33+
${CMAKE_CURRENT_SOURCE_DIR}/THTensor.cpp
3434
${CMAKE_CURRENT_SOURCE_DIR}/THBlas.c
35-
${CMAKE_CURRENT_SOURCE_DIR}/THLapack.c
35+
${CMAKE_CURRENT_SOURCE_DIR}/THLapack.cpp
3636
${CMAKE_CURRENT_SOURCE_DIR}/THLogAdd.c
37-
${CMAKE_CURRENT_SOURCE_DIR}/THRandom.c
37+
${CMAKE_CURRENT_SOURCE_DIR}/THRandom.cpp
3838
${CMAKE_CURRENT_SOURCE_DIR}/THFile.c
3939
${CMAKE_CURRENT_SOURCE_DIR}/THDiskFile.c
4040
${CMAKE_CURRENT_SOURCE_DIR}/THMemoryFile.c
4141
${CMAKE_CURRENT_SOURCE_DIR}/THAtomic.c
42-
${CMAKE_CURRENT_SOURCE_DIR}/THVector.c
42+
${CMAKE_CURRENT_SOURCE_DIR}/THVector.cpp
4343
${extra_src}
4444
PARENT_SCOPE)
4545
######################################################
@@ -90,29 +90,30 @@ INSTALL(FILES
9090
INSTALL(FILES
9191
vector/AVX.h
9292
vector/AVX2.h
93+
vector/avx_mathfun.h
9394
DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/TH/vector")
9495

9596
INSTALL(FILES
9697
generic/THBlas.c
9798
generic/THBlas.h
98-
generic/THLapack.c
99+
generic/THLapack.cpp
99100
generic/THLapack.h
100101
generic/THStorage.c
101102
generic/THStorage.h
102103
generic/THStorageCopy.c
103104
generic/THStorageCopy.h
104-
generic/THTensor.c
105+
generic/THTensor.cpp
105106
generic/THTensor.h
106-
generic/THTensorConv.c
107+
generic/THTensorConv.cpp
107108
generic/THTensorConv.h
108109
generic/THTensorCopy.c
109110
generic/THTensorCopy.h
110111
generic/THTensorLapack.c
111112
generic/THTensorLapack.h
112113
generic/THTensorMath.c
113114
generic/THTensorMath.h
114-
generic/THTensorRandom.c
115+
generic/THTensorRandom.cpp
115116
generic/THTensorRandom.h
116-
generic/THVectorDispatch.c
117+
generic/THVectorDispatch.cpp
117118
generic/THVector.h
118119
DESTINATION "${ATEN_INSTALL_INCLUDE_SUBDIR}/TH/generic")
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
#include "THLapack.h"
22

3-
#include "generic/THLapack.c"
3+
#include "generic/THLapack.cpp"
44
#include "THGenerateFloatTypes.h"
Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
/* Creates (unseeded) new generator*/
1414
static THGenerator* THGenerator_newUnseeded()
1515
{
16-
THGenerator *self = THAlloc(sizeof(THGenerator));
16+
THGenerator *self = (THGenerator *)THAlloc(sizeof(THGenerator));
1717
memset(self, 0, sizeof(THGenerator));
18-
self->left = 1;
19-
self->seeded = 0;
20-
self->normal_is_valid = 0;
18+
self->gen_state.left = 1;
19+
self->gen_state.seeded = 0;
20+
self->gen_state.normal_is_valid = 0;
21+
new (&self->mutex) std::mutex();
2122
return self;
2223
}
2324

@@ -31,24 +32,31 @@ THGenerator* THGenerator_new()
3132

3233
THGenerator* THGenerator_copy(THGenerator *self, THGenerator *from)
3334
{
34-
memcpy(self, from, sizeof(THGenerator));
35+
THGeneratorState_copy(&self->gen_state, &from->gen_state);
3536
return self;
3637
}
3738

3839
void THGenerator_free(THGenerator *self)
3940
{
41+
self->mutex.~mutex();
4042
THFree(self);
4143
}
4244

43-
int THGenerator_isValid(THGenerator *_generator)
45+
int THGeneratorState_isValid(THGeneratorState *_gen_state)
4446
{
45-
if ((_generator->seeded == 1) &&
46-
(_generator->left > 0 && _generator->left <= n) && (_generator->next <= n))
47+
if ((_gen_state->seeded == 1) &&
48+
(_gen_state->left > 0 && _gen_state->left <= n) && (_gen_state->next <= n))
4749
return 1;
4850

4951
return 0;
5052
}
5153

54+
THGeneratorState* THGeneratorState_copy(THGeneratorState *self, THGeneratorState *from)
55+
{
56+
memcpy(self, from, sizeof(THGeneratorState));
57+
return self;
58+
}
59+
5260
#ifndef _WIN32
5361
static uint64_t readURandomLong()
5462
{
@@ -146,41 +154,41 @@ void THRandom_manualSeed(THGenerator *_generator, uint64_t the_seed_)
146154
THGenerator_copy(_generator, blank);
147155
THGenerator_free(blank);
148156

149-
_generator->the_initial_seed = the_seed_;
150-
_generator->state[0] = _generator->the_initial_seed & 0xffffffffUL;
157+
_generator->gen_state.the_initial_seed = the_seed_;
158+
_generator->gen_state.state[0] = _generator->gen_state.the_initial_seed & 0xffffffffUL;
151159
for(j = 1; j < n; j++)
152160
{
153-
_generator->state[j] = (1812433253UL * (_generator->state[j-1] ^ (_generator->state[j-1] >> 30)) + j);
161+
_generator->gen_state.state[j] = (1812433253UL * (_generator->gen_state.state[j-1] ^ (_generator->gen_state.state[j-1] >> 30)) + j);
154162
/* See Knuth TAOCP Vol2. 3rd Ed. P.106 for multiplier. */
155163
/* In the previous versions, mSBs of the seed affect */
156164
/* only mSBs of the array state[]. */
157165
/* 2002/01/09 modified by makoto matsumoto */
158-
_generator->state[j] &= 0xffffffffUL; /* for >32 bit machines */
166+
_generator->gen_state.state[j] &= 0xffffffffUL; /* for >32 bit machines */
159167
}
160-
_generator->left = 1;
161-
_generator->seeded = 1;
168+
_generator->gen_state.left = 1;
169+
_generator->gen_state.seeded = 1;
162170
}
163171

164172
uint64_t THRandom_initialSeed(THGenerator *_generator)
165173
{
166-
return _generator->the_initial_seed;
174+
return _generator->gen_state.the_initial_seed;
167175
}
168176

169177
void THRandom_nextState(THGenerator *_generator)
170178
{
171-
uint64_t *p = _generator->state;
179+
uint64_t *p = _generator->gen_state.state;
172180
int j;
173181

174-
_generator->left = n;
175-
_generator->next = 0;
182+
_generator->gen_state.left = n;
183+
_generator->gen_state.next = 0;
176184

177185
for(j = n-m+1; --j; p++)
178186
*p = p[m] ^ TWIST(p[0], p[1]);
179187

180188
for(j = m; --j; p++)
181189
*p = p[m-n] ^ TWIST(p[0], p[1]);
182190

183-
*p = p[m-n] ^ TWIST(p[0], _generator->state[0]);
191+
*p = p[m-n] ^ TWIST(p[0], _generator->gen_state.state[0]);
184192
}
185193

186194
// TODO: this only returns 32-bits of randomness but as a uint64_t. This is
@@ -190,9 +198,9 @@ uint64_t THRandom_random(THGenerator *_generator)
190198
{
191199
uint64_t y;
192200

193-
if (--(_generator->left) == 0)
201+
if (--(_generator->gen_state.left) == 0)
194202
THRandom_nextState(_generator);
195-
y = *(_generator->state + (_generator->next)++);
203+
y = *(_generator->gen_state.state + (_generator->gen_state.next)++);
196204

197205
/* Tempering */
198206
y ^= (y >> 11);
@@ -260,20 +268,20 @@ double THRandom_normal(THGenerator *_generator, double mean, double stdv)
260268
THArgCheck(stdv > 0, 2, "standard deviation must be strictly positive");
261269

262270
/* This is known as the Box-Muller method */
263-
if(!_generator->normal_is_valid)
271+
if(!_generator->gen_state.normal_is_valid)
264272
{
265-
_generator->normal_x = uniform_double(_generator);
266-
_generator->normal_y = uniform_double(_generator);
267-
_generator->normal_rho = sqrt(-2. * log(1.0-_generator->normal_y));
268-
_generator->normal_is_valid = 1;
273+
_generator->gen_state.normal_x = uniform_double(_generator);
274+
_generator->gen_state.normal_y = uniform_double(_generator);
275+
_generator->gen_state.normal_rho = sqrt(-2. * log(1.0-_generator->gen_state.normal_y));
276+
_generator->gen_state.normal_is_valid = 1;
269277
}
270278
else
271-
_generator->normal_is_valid = 0;
279+
_generator->gen_state.normal_is_valid = 0;
272280

273-
if(_generator->normal_is_valid)
274-
return _generator->normal_rho*cos(2.*M_PI*_generator->normal_x)*stdv+mean;
281+
if(_generator->gen_state.normal_is_valid)
282+
return _generator->gen_state.normal_rho*cos(2.*M_PI*_generator->gen_state.normal_x)*stdv+mean;
275283
else
276-
return _generator->normal_rho*sin(2.*M_PI*_generator->normal_x)*stdv+mean;
284+
return _generator->gen_state.normal_rho*sin(2.*M_PI*_generator->gen_state.normal_x)*stdv+mean;
277285
}
278286

279287
double THRandom_exponential(THGenerator *_generator, double lambda)

aten/src/TH/THRandom.h

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,32 @@
33

44
#include "THGeneral.h"
55

6+
#include <mutex>
7+
68
#define _MERSENNE_STATE_N 624
79
#define _MERSENNE_STATE_M 397
8-
/* A THGenerator contains all the state required for a single random number stream */
9-
typedef struct THGenerator {
10+
11+
typedef struct THGeneratorState {
1012
/* The initial seed. */
1113
uint64_t the_initial_seed;
1214
int left; /* = 1; */
1315
int seeded; /* = 0; */
1416
uint64_t next;
1517
uint64_t state[_MERSENNE_STATE_N]; /* the array for the state vector */
18+
1619
/********************************/
1720

1821
/* For normal distribution */
1922
double normal_x;
2023
double normal_y;
2124
double normal_rho;
2225
int normal_is_valid; /* = 0; */
26+
} THGeneratorState;
27+
28+
/* A THGenerator contains all the state required for a single random number stream */
29+
typedef struct THGenerator {
30+
std::mutex mutex; /* mutex for using this generator */
31+
THGeneratorState gen_state;
2332
} THGenerator;
2433

2534
#define torch_Generator "torch.Generator"
@@ -29,8 +38,11 @@ TH_API THGenerator * THGenerator_new(void);
2938
TH_API THGenerator * THGenerator_copy(THGenerator *self, THGenerator *from);
3039
TH_API void THGenerator_free(THGenerator *gen);
3140

32-
/* Checks if given generator is valid */
33-
TH_API int THGenerator_isValid(THGenerator *_generator);
41+
/* Checks if given generator state is valid */
42+
TH_API int THGeneratorState_isValid(THGeneratorState *_gen_state);
43+
44+
/* Manipulate THGeneratorState objects */
45+
TH_API THGeneratorState * THGeneratorState_copy(THGeneratorState *self, THGeneratorState *from);
3446

3547
/* Initializes the random number generator from /dev/urandom (or on Windows
3648
platforms with the current time (granularity: seconds)) and returns the seed. */
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
#include "THTensorDimApply.h"
1010
#include "THMath.h"
1111

12-
#include "generic/THTensor.c"
12+
#include "generic/THTensor.cpp"
1313
#include "THGenerateAllTypes.h"
1414

15-
#include "generic/THTensor.c"
15+
#include "generic/THTensor.cpp"
1616
#include "THGenerateHalfType.h"
1717

1818
#include "generic/THTensorCopy.c"
@@ -21,13 +21,13 @@
2121
#include "generic/THTensorCopy.c"
2222
#include "THGenerateHalfType.h"
2323

24-
#include "generic/THTensorRandom.c"
24+
#include "generic/THTensorRandom.cpp"
2525
#include "THGenerateAllTypes.h"
2626

2727
#include "generic/THTensorMath.c"
2828
#include "THGenerateAllTypes.h"
2929

30-
#include "generic/THTensorConv.c"
30+
#include "generic/THTensorConv.cpp"
3131
#include "THGenerateAllTypes.h"
3232

3333
#include "generic/THTensorLapack.c"
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
#include "vector/AVX2.h"
2424
#endif
2525

26-
#include "generic/THVectorDefault.c"
26+
#include "generic/THVectorDefault.cpp"
2727
#include "THGenerateAllTypes.h"
2828

29-
#include "generic/THVectorDispatch.c"
29+
#include "generic/THVectorDispatch.cpp"
3030
#include "THGenerateAllTypes.h"
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#ifndef TH_GENERIC_FILE
2-
#define TH_GENERIC_FILE "generic/THLapack.c"
2+
#define TH_GENERIC_FILE "generic/THLapack.cpp"
33
#else
44

55

0 commit comments

Comments
 (0)