Skip to content

Commit 3565af2

Browse files
committed
Update on "[CPU] Refactor Random Number Generators in ATen"
[CPU] Refactor Random Number Generators in ATen gh-metadata: pytorch pytorch 21364 gh/syed-ahmed/14/head
2 parents c6522dd + 8baab52 commit 3565af2

File tree

3 files changed

+2
-3
lines changed

3 files changed

+2
-3
lines changed

aten/src/ATen/core/DistributionsHelper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ struct bernoulli_distribution {
166166
p = p_in;
167167
}
168168

169-
inline T operator()(at::CPUGenerator* generator) {
169+
inline int operator()(at::CPUGenerator* generator) {
170170
uniform_real_distribution<T> uniform(0.0, 1.0);
171171
return uniform(generator) <= p;
172172
}

aten/src/ATen/core/MT19937RNGEngine.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ namespace at {
1414

1515
constexpr int MERSENNE_STATE_N = 624;
1616
constexpr int MERSENNE_STATE_M = 397;
17-
constexpr int INIT_KEY_MULTIPLIER = 3;
1817
constexpr uint32_t MATRIX_A = 0x9908b0df;
1918
constexpr uint32_t UMASK = 0x80000000;
2019
constexpr uint32_t LMASK = 0x7fffffff;

aten/src/ATen/native/cpu/UnaryOpsKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ void bernoulli_mkl_kernel(Tensor &self, const double p, Generator* gen) {
126126
{
127127
// See Note [Acquire lock when using random generators]
128128
std::lock_guard<std::mutex> lock(generator->mutex_);
129-
seed = generator->random64();
129+
seed = generator->random();
130130
}
131131
int64_t n = self.numel();
132132
bool contig = self.is_contiguous();

0 commit comments

Comments
 (0)