Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/CPUApplyUtils.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include <sstream>
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"

namespace at {

Expand Down
9 changes: 8 additions & 1 deletion aten/src/ATen/Check.cpp → aten/src/ATen/TensorUtils.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "ATen/Config.h"
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"

#include "ATen/ATen.h"

Expand Down Expand Up @@ -210,4 +210,11 @@ void checkBackend(CheckedFrom c, ArrayRef<Tensor> tensors, at::Backend backend)
}
}

void * maybe_data_ptr(const Tensor& tensor) {
return tensor.defined() ? (void *)tensor.data_ptr() : nullptr;
}

void * maybe_data_ptr(const TensorArg& tensor) {
return tensor->defined() ? (void *)tensor->data_ptr() : nullptr;
}
}
16 changes: 11 additions & 5 deletions aten/src/ATen/Check.h → aten/src/ATen/TensorUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
#include "ATen/TensorGeometry.h"
#include "ATen/Utils.h"

// This file contains utility functions for checking that arguments
// make sense. This is particularly useful for native functions,
// which do NO argument checking by default.
//
// It's NOT in Utils.h, because this file has a dep on Tensor.h
// These functions are NOT in Utils.h, because this file has a dep on Tensor.h

namespace at {

// The following are utility functions for checking that arguments
// make sense. These are particularly useful for native functions,
// which do NO argument checking by default.

struct TensorArg {
Tensor tensor;
const char* name;
Expand Down Expand Up @@ -72,4 +72,10 @@ void checkAllDefined(CheckedFrom c, at::ArrayRef<TensorArg> t);

// FixMe: does TensorArg slow things down?
void checkBackend(CheckedFrom c, at::ArrayRef<Tensor> t, at::Backend backend);

// Methods for getting data_ptr if tensor is defined
void * maybe_data_ptr(const Tensor& tensor);
void * maybe_data_ptr(const TensorArg& tensor);

}

2 changes: 1 addition & 1 deletion aten/src/ATen/cuda/CUDAApplyUtils.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "detail/IndexUtils.cuh"
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"

//
// This file contains pointwise operation functions and kernels that
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/cudnn/Descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

#include "cudnn-wrapper.h"
#include <ATen/ATen.h>
#include <ATen/Check.h>
#include <ATen/TensorUtils.h>

#if CUDNN_VERSION < 7000

Expand Down
18 changes: 14 additions & 4 deletions aten/src/ATen/native/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,20 @@ namespace {

Tensor batch_norm(
const Tensor& input, const Tensor& weight /* optional */, const Tensor& bias /* optional */,
const Tensor& running_mean, const Tensor& running_var,
const Tensor& running_mean /* optional */, const Tensor& running_var /* optional */,
bool training, double momentum, double eps, bool cudnn_enabled) {

auto num_features = input.sizes()[1];
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
check_dims_match_num_input_features("running_var", num_features, running_var.numel());
if (running_mean.defined()) {
check_dims_match_num_input_features("running_mean", num_features, running_mean.numel());
} else if (!training) {
throw std::runtime_error("running_mean must be defined in evaluation mode");
}
if (running_var.defined()) {
check_dims_match_num_input_features("running_var", num_features, running_var.numel());
} else if (!training) {
throw std::runtime_error("running_var must be defined in evaluation mode");
}
if (weight.defined()) {
check_dims_match_num_input_features("weight", num_features, weight.numel());
}
Expand All @@ -38,8 +46,10 @@ Tensor batch_norm(
#if AT_CUDNN_ENABLED()
use_cudnn = (input.type().is_cuda()
&& (input.type().scalarType() != at::kHalf
|| weight.type().scalarType() == at::kFloat)
|| weight.type().scalarType() == at::kFloat)
&& weight.defined() && bias.defined()
&& ((running_mean.defined() && running_var.defined())
|| (!running_mean.defined() && !running_var.defined() && training))
&& input.size(0) <= 131070
&& cudnn_enabled && CUDNN_VERSION >= 5110L);
#endif
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Embedding.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "ATen/ATen.h"
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"
#include "ATen/NativeFunctions.h"

#include <cstring>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/EmbeddingBag.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "ATen/ATen.h"
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"
#include "ATen/NativeFunctions.h"

#include <cstring>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/Pooling.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "ATen/ATen.h"
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"
#include "ATen/NativeFunctions.h"

#include <sstream>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/Embedding.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "ATen/ATen.h"
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/NativeFunctions.h"

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/EmbeddingBag.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#include "ATen/ATen.h"
#include "ATen/Check.h"
#include "ATen/TensorUtils.h"
#include "ATen/Dispatch.h"
#include "ATen/NativeFunctions.h"

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cudnn/AffineGridGenerator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Tensor cudnn_affine_grid_generator_backward(
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>

#include <ATen/Check.h>
#include <ATen/TensorUtils.h>

namespace at { namespace native {

Expand Down
27 changes: 14 additions & 13 deletions aten/src/ATen/native/cudnn/BatchNorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm_backward(
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>

#include <ATen/Check.h>
#include <ATen/TensorUtils.h>

namespace at { namespace native {

Expand Down Expand Up @@ -60,7 +60,10 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(
CheckedFrom c = "cudnn_batch_norm";
setCuDNNStreamToCurrent();

checkAllDefined(c, {input, weight, bias, running_mean, running_var});
checkAllDefined(c, {input, weight, bias});
if (!training) {
checkAllDefined(c, {running_mean, running_var});
}
checkAllSameGPU(c, {input, weight, bias, running_mean, running_var});
if (input->type().scalarType() == ScalarType::Half) {
checkScalarType(c, weight, ScalarType::Float);
Expand All @@ -73,7 +76,9 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(
checkDimRange(c, input, 2, 6 /* exclusive */);
auto num_features = input->size(1);
for (auto t : {weight, bias, running_mean, running_var}) {
checkNumel(c, t, num_features);
if (t->defined()) {
checkNumel(c, t, num_features);
}
}

cudnnBatchNormMode_t mode;
Expand All @@ -97,16 +102,12 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(

Constant one(dataType, 1);
Constant zero(dataType, 0);

// Though technically we only need to allocate this for training,
// (1) THNN batch normalization expects non-undefined tensors for
// backwards (which we will pass these to, if !training, because
// CuDNN backwards with !training doesn't gradcheck), and
// (2) These are pretty small tensors, no big deal.
Tensor save_mean = running_mean_t.type().tensor(running_mean_t.sizes());
Tensor save_var = running_var_t.type().tensor(running_var_t.sizes());
Tensor save_mean, save_var;

if (training) {
int64_t num_features = input_t.size(1);
save_mean = weight_t.type().tensor({ num_features });
save_var = weight_t.type().tensor({ num_features });
CUDNN_CHECK(cudnnBatchNormalizationForwardTraining(
handle, mode, &one, &zero,
idesc.desc(), input->data_ptr(),
Expand All @@ -115,8 +116,8 @@ std::tuple<Tensor, Tensor, Tensor> cudnn_batch_norm(
weight->data_ptr(),
bias->data_ptr(),
exponential_average_factor,
running_mean->data_ptr(),
running_var->data_ptr(),
at::maybe_data_ptr(running_mean),
at::maybe_data_ptr(running_var),
epsilon,
save_mean.data_ptr(),
save_var.data_ptr()));
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cudnn/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backwar
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>

#include <ATen/Check.h>
#include <ATen/TensorUtils.h>

#include <functional>
#include <iterator>
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cudnn/GridSampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
#include <ATen/cudnn/Types.h>
#include <ATen/cudnn/Utils.h>

#include <ATen/Check.h>
#include <ATen/TensorUtils.h>

// TODO: descriptor checking

Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
- func: addr_out(Tensor result, Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
variants: function

- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor
- func: batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double momentum, double eps, bool cudnn_enabled) -> Tensor
variants: function

- func: bernoulli_(Tensor self, Tensor p, Generator* generator=nullptr) -> Tensor
Expand Down Expand Up @@ -91,11 +91,11 @@
name: grad_theta
variants: function

- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, double exponential_average_factor, double epsilon) -> (Tensor, Tensor, Tensor)
- func: cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, double exponential_average_factor, double epsilon) -> (Tensor, Tensor, Tensor)
variants: function

# NB: You can only use this if you used cudnn_batch_norm training=True
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor running_mean, Tensor running_var, Tensor? save_mean, Tensor? save_var, double epsilon) -> (Tensor, Tensor, Tensor)
- func: cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, double epsilon) -> (Tensor, Tensor, Tensor)
variants: function

- func: cudnn_convolution(Tensor self, Tensor weight, Tensor? bias, IntList padding, IntList stride, IntList dilation, int64_t groups, bool benchmark, bool deterministic) -> Tensor
Expand Down
12 changes: 8 additions & 4 deletions aten/src/THCUNN/BatchNormalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ template <typename Dtype, typename Acctype, typename DeviceTensor1, typename Dev
__global__ void BatchNormalizationUpdateOutputInference_kernel(
const DeviceTensor3 input,
DeviceTensor3 output,
DeviceTensor1 runningMean,
DeviceTensor1 runningVar,
const DeviceTensor1 runningMean,
const DeviceTensor1 runningVar,

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

const DeviceTensor1 weight,
const DeviceTensor1 bias,
Acctype epsilon) {
Expand Down Expand Up @@ -196,8 +196,12 @@ __global__ void BatchNormalizationUpdateOutput_kernel(
Acctype unbiasedVar = varN / (N - 1);
saveMean[plane] = ScalarConvert<Acctype, Dtype>::to(mean);
saveStd[plane] = ScalarConvert<Acctype, Dtype>::to(invStd);
runningMean[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningMean[plane] + momentum * mean);
runningVar[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningVar[plane] + momentum * unbiasedVar);
if (runningMean.data() != NULL) {
runningMean[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningMean[plane] + momentum * mean);
}
if (runningVar.data() != NULL) {
runningVar[plane] = ScalarConvert<Acctype, Dtype>::to((1 - momentum) * runningVar[plane] + momentum * unbiasedVar);
}
}

// Write normalized and update the output
Expand Down
5 changes: 3 additions & 2 deletions aten/src/THCUNN/generic/BatchNormalization.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ void THNN_(BatchNormalization_updateOutput)(

THCTensor_(resizeAs)(state, output_, input_);
if (train) {
THCTensor_(resizeAs)(state, saveMean_, runningMean_);
THCTensor_(resizeAs)(state, saveStd_, runningVar_);
int64_t nInput = THCTensor_(size)(state, input_, 1);
THCTensor_(resize1d)(state, saveMean_, nInput);
THCTensor_(resize1d)(state, saveStd_, nInput);
}
DeviceTensor3 input = devicetensor<3>(state, input_);
DeviceTensor3 output = devicetensor<3>(state, output_);
Expand Down
12 changes: 6 additions & 6 deletions aten/src/THCUNN/generic/THCUNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ TH_API void THNN_(BatchNormalization_updateOutput)(
THCTensor *output_,
THCTensor *weight_, // [OPTIONAL]
THCTensor *bias_, // [OPTIONAL]
THCTensor *runningMean_,
THCTensor *runningVar_,
THCTensor *runningMean_, // [OPTIONAL] if train
THCTensor *runningVar_, // [OPTIONAL] if train
THCTensor *saveMean_,
THCTensor *saveStd_,
bool train,
Expand All @@ -52,10 +52,10 @@ TH_API void THNN_(BatchNormalization_backward)(
THCTensor *gradWeight_, // [OPTIONAL]
THCTensor *gradBias_, // [OPTIONAL]
THCTensor *weight_, // [OPTIONAL]
THCTensor *runningMean_,
THCTensor *runningVar_,
THCTensor *saveMean_,
THCTensor *saveStd_,
THCTensor *runningMean_, // [OPTIONAL] if train
THCTensor *runningVar_, // [OPTIONAL] if train
THCTensor *saveMean_, // [OPTIONAL] if !train
THCTensor *saveStd_, // [OPTIONAL] if !train
bool train,
double scale,
double eps);
Expand Down
19 changes: 11 additions & 8 deletions aten/src/THNN/generic/BatchNormalization.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ void THNN_(BatchNormalization_updateOutput)(
ptrdiff_t n = THTensor_(nElement)(input) / nInput;

if (train) {
THTensor_(resizeAs)(save_mean, running_mean);
THTensor_(resizeAs)(save_std, running_var);
THTensor_(resize1d)(save_mean, nInput);
THTensor_(resize1d)(save_std, nInput);
}

#pragma omp parallel for
Expand Down Expand Up @@ -47,12 +47,15 @@ void THNN_(BatchNormalization_updateOutput)(
THTensor_(set1d)(save_std, f, (real) invstd);

// update running averages
THTensor_(set1d)(running_mean, f,
(real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));

accreal unbiased_var = sum / (n - 1);
THTensor_(set1d)(running_var, f,
(real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
if (running_mean) {
THTensor_(set1d)(running_mean, f,
(real) (momentum * mean + (1 - momentum) * THTensor_(get1d)(running_mean, f)));
}
if (running_var) {
accreal unbiased_var = sum / (n - 1);
THTensor_(set1d)(running_var, f,
(real) (momentum * unbiased_var + (1 - momentum) * THTensor_(get1d)(running_var, f)));
}
} else {
mean = THTensor_(get1d)(running_mean, f);
invstd = 1 / sqrt(THTensor_(get1d)(running_var, f) + eps);
Expand Down
12 changes: 6 additions & 6 deletions aten/src/THNN/generic/THNN.h
Original file line number Diff line number Diff line change
Expand Up @@ -814,8 +814,8 @@ TH_API void THNN_(BatchNormalization_updateOutput)(
THTensor *output,
THTensor *weight, // [OPTIONAL]
THTensor *bias, // [OPTIONAL]
THTensor *running_mean,
THTensor *running_var,
THTensor *running_mean, // [OPTIONAL] if train
THTensor *running_var, // [OPTIONAL] if train
THTensor *save_mean,
THTensor *save_std,
bool train,
Expand All @@ -829,10 +829,10 @@ TH_API void THNN_(BatchNormalization_backward)(
THTensor *gradWeight, // [OPTIONAL]
THTensor *gradBias, // [OPTIONAL]
THTensor *weight, // [OPTIONAL]
THTensor *running_mean,
THTensor *running_var,
THTensor *save_mean,
THTensor *save_std,
THTensor *running_mean, // [OPTIONAL] if train
THTensor *running_var, // [OPTIONAL] if train
THTensor *save_mean, // [OPTIONAL] if !train
THTensor *save_std, // [OPTIONAL] if !train
bool train,
double scale,
double eps);
Expand Down
Loading