@@ -546,6 +546,50 @@ void cauchy_kernel_cuda(TensorIterator& iter, double median_, double sigma_, Gen
546546 });
547547}
548548
549+ void exponential_kernel_cuda (TensorIterator& iter, double lambda_, Generator* gen_) {
550+ auto gen = check_generator<CUDAGenerator>(gen_, &globalContext ().defaultGenerator (kCUDA ));
551+ // Note that HIP doesn't support std::nextafter in device code.
552+ auto nextafter_1_0_float = std::nextafter (1 .0f , 0 .0f );
553+ auto nextafter_1_0_double = std::nextafter (1.0 , 0.0 );
554+ AT_DISPATCH_FLOATING_TYPES_AND_HALF (iter.dtype (), " exponential_cuda" , [&] {
555+ using accscalar_t = at::acc_type<scalar_t , true >;
556+ auto lambda = static_cast <accscalar_t >(lambda_);
557+ if (std::is_same<scalar_t , double >::value) {
558+ // define lambda for exponential transformation
559+ auto exponential_func = [lambda, nextafter_1_0_double] __device__ (accscalar_t rand) {
560+ accscalar_t sample;
561+ // curand_uniform has (0,1] bounds. log(1) is 0 and exponential excludes 0.
562+ // Hence, squash the 1 to just below 1.
563+ if (rand == static_cast <accscalar_t >(1.0 )) {
564+ sample = ::log (nextafter_1_0_double);
565+ } else {
566+ sample = ::log (rand);
567+ }
568+ return static_cast <scalar_t >(static_cast <accscalar_t >(-1.0 ) / lambda * sample);
569+ };
570+ distribution_nullary_kernel<scalar_t , accscalar_t , curand4_engine_calls/2 >(iter,
571+ gen,
572+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double (state); },
573+ exponential_func);
574+ } else {
575+ // use __logf fast approximation for peak bandwidth
576+ auto exponential_func = [lambda, nextafter_1_0_float] __device__ (accscalar_t rand) {
577+ accscalar_t sample;
578+ if (rand == static_cast <accscalar_t >(1.0 )) {
579+ sample = __logf (nextafter_1_0_float);
580+ } else {
581+ sample = __logf (rand);
582+ }
583+ return static_cast <scalar_t >(static_cast <accscalar_t >(-1.0 ) / lambda * sample);
584+ };
585+ distribution_nullary_kernel<scalar_t , accscalar_t , curand4_engine_calls>(iter,
586+ gen,
587+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4 (state); },
588+ exponential_func);
589+ }
590+ });
591+ }
592+
549593Tensor& uniform_cuda_ (Tensor& self, double from, double to, Generator* gen) {
550594 auto iter = TensorIterator::nullary_op (self);
551595 uniform_kernel_cuda (*iter, from, to, gen);
@@ -631,4 +675,10 @@ Tensor& cauchy_cuda_(Tensor& self, double median, double sigma, Generator* gen)
631675 return self;
632676}
633677
678+ Tensor& exponential_cuda_ (Tensor& self, double lambda, Generator* gen) {
679+ auto iter = TensorIterator::nullary_op (self);
680+ exponential_kernel_cuda (*iter, lambda, gen);
681+ return self;
682+ }
683+
634684}} // namespace at::native
0 commit comments