Skip to content

Conversation

@gchanan
Copy link
Contributor

@gchanan gchanan commented Dec 13, 2017

Main changes in this PR:

  1. Added a TH_APPLY-style templatized function for CPU apply calls (currently only 2 and 3 tensor argument versions are supported, but more are easy to add). In fact, this is basically identical to TH_APPLY, except it uses ATen functions and the API is a template instead of a macro. The template takes an operation that is performed on the data (and an indicator to signal early termination); i.e. you don't need to know that x_data is a pointer to the current data location of x.

  2. Refactors the ATen dispatch code to easily generate dispatch code for different subsets of the scalar types. This is in preference to the template_scalar path, which requires valid specialization of each scalar type. Valid specializations are particularly annoying with CUDA because you most likely can't put the specializations in a header so need to write some sort of for-all-scalar-type macro to get the correct specializations. Currently, we only generate dispatch_all (all scalar types, the equivalent existed already), and dispatch_cpu_floating_types (which is used by standard_gamma).

  3. Implements standard_gamma using the above changes as a proof of concept (this is an arbitrary choice, it was the latest apply macro to be committed). The forward is bound via Declarations.yaml, the backward via the Apply template, and then they are hooked together in derivatives.yaml. This eliminates needing to change TH at all going forward, which means one can write idiomatic C++ instead of the TH-style macros (e.g. TH_MATH_NAME).

… it.

Main changes in this PR:
1) Added a TH_APPLY-style templatized function for CPU apply calls (currently only 2 and 3 tensor argument
versions are supported, but more are easy to add).  In fact, this is basically identical to TH_APPLY, except
it uses ATen functions and the API is a template instead of a macro.  The template takes an operation that
is performed on the data (and an indicator to signal early termination); i.e. you don't need to know that
x_data is a pointer to the current data location of x.

2) Refactors the ATen dispatch code to easily generate dispatch code for different subsets of the scalar types.
This is in preference to the template_scalar path, which requires valid specialization of each scalar type.  Valid
specializations are  particularly annoying with CUDA because you most likely can't put the specializations
in a header so need to write some sort of for-all-scalar-type macro to get the correct specializations.
Currently, we only generate dispatch_all (all scalar types, the equivalent existed already), and
dispatch_cpu_floating_types (which is used by standard_gamma).

3) Implements standard_gamma using the above changes (this is an arbitrary choice, it was the latest
apply macro to be committed).  The forward is bound via Declarations.yaml,
the backward via the Apply template, and then they are hooked together in derivatives.yaml.  This eliminates
needing to change TH at all going forward, which means one can write idiomatic C++ instead of the TH-style macros
(e.g. TH_MATH_NAME).
@pytorchbot
Copy link
Collaborator

@gchanan, thanks for your PR! We identified @zdevito to be a potential reviewer.

@gchanan
Copy link
Contributor Author

gchanan commented Dec 13, 2017

I have a CUDA version implemented as well, but this seemed like a sensible place to split up the PR.

const Type& the_type = self.type();
dispatch_cpu_floating_types<StandardGammaGradOp>(the_type, "_standard_gamma_grad", ret, self, alpha);
return ret;
}

This comment was marked as off-topic.

This comment was marked as off-topic.

- arg: THTensor* output
output: True
- arg: THGenerator* generator
default: THPGenerator_TH_CData(THPDefaultGenerator)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@gchanan
Copy link
Contributor Author

gchanan commented Dec 14, 2017

@pytorchbot retest this please.

@ezyang
Copy link
Contributor

ezyang commented Dec 14, 2017

@pytorchbot retest this please

@ezyang
Copy link
Contributor

ezyang commented Dec 14, 2017

@pytorchbot retest this please

ret_val = standard_gamma_grad_one(self_val, alpha_val);
}

static void apply(Tensor& ret, const Tensor& self, const Tensor& alpha) {

This comment was marked as off-topic.

This comment was marked as off-topic.


template <typename Scalar>
struct StandardGammaGradOp {
void operator()(Scalar& ret_val, const Scalar& self_val, const Scalar &alpha_val, bool& early_exit)

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


densities = ['Dense', 'Sparse']

# scalar_name, c_type, accreal, th_scalar_type, is_floating_type

This comment was marked as off-topic.

* loops.
*/

static inline void check_correct_backend(const Tensor &t, unsigned int pos) {

This comment was marked as off-topic.

check_correct_backend(t3, 3);
}

#define __ATH_TENSOR_APPLYX_PREAMBLE(TYPE, ATENSOR, DIM, ALLOW_CONTIGUOUS) \

This comment was marked as off-topic.

This comment was marked as off-topic.

}

template <typename ScalarType, typename Op>
void CPU_tensor_apply3_dim(Tensor &tensor1, Tensor& tensor2, Tensor& tensor3, int64_t dim, Op op) {

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@zdevito zdevito left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! We have apply in ATen! I listed some ways I think we can make the API simpler, let me know what you think.


static void apply(Tensor& ret, const Tensor& self, const Tensor& alpha) {
StandardGammaGradOp<Scalar> op;
CPU_tensor_apply3<Scalar, StandardGammaGradOp<Scalar>>(ret, self, alpha, op);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


Tensor _standard_gamma_grad(const Tensor& self, const Tensor& alpha) {
Tensor ret = self.type().tensor(self.sizes());
dispatch_cpu_floating_types<StandardGammaGradOp>(self.type(), "_standard_gamma_grad", ret, self, alpha);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@gchanan
Copy link
Contributor Author

gchanan commented Dec 15, 2017

CC @fritzo you may be interested in this.

CPU_tensor_apply2_dim<ScalarType, Op>(tensor1, tensor2, -1, op);
}

template <typename ScalarType, typename Op>

This comment was marked as off-topic.

__ATH_TENSOR_APPLYX_UPDATE_COUNTERS(tensor3, 0)
}
if(tensor1_counter != NULL)
delete [] tensor1_counter;

This comment was marked as off-topic.

This comment was marked as off-topic.

/** Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
for random number x drawn from a standard Gamma distribution Gamma(alpha).
*/
template <typename Scalar>

This comment was marked as off-topic.

This comment was marked as off-topic.


// TODO Replace this with more accurate digamma().
template <typename Scalar>
static inline Scalar digamma_one(Scalar x) {

This comment was marked as off-topic.

This comment was marked as off-topic.


template <typename Scalar>
struct StandardGammaGradOp {
void operator()(Scalar& ret_val, const Scalar& self_val, const Scalar &alpha_val, bool& early_exit)

This comment was marked as off-topic.

- name: zeros # fallthrough

- name: _standard_gamma(Tensor self, Generator generator)
self: grad * output._standard_gamma_grad(self)

This comment was marked as off-topic.

This comment was marked as off-topic.

if not isinstance(alpha, Variable):
return torch._C._standard_gamma(alpha)
return _StandardGamma.apply(alpha)
return alpha._standard_gamma()

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@fritzo
Copy link
Collaborator

fritzo commented Dec 15, 2017

@gchanan Thanks, maybe we can use this for standard_gamma_grad(). I might also try incorporating it into dirichlet_grad() in #4117.

@gchanan
Copy link
Contributor Author

gchanan commented Dec 15, 2017

On early_exit (I can't comment directly because I've changed that code): to be fair, it's not totally exposing serial semantics; the apply function could enforce it in a thread safe way (it would just be a suggested early exit at that point). But I'll just get rid of it for now, since we aren't going to implement something like that in this PR.

Copy link
Member

@colesbury colesbury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

/** Computes the reparameterized gradient -(d/dalpha cdf(x;alpha)) / pdf(x;alpha)
for random number x drawn from a standard Gamma distribution Gamma(alpha).
*/
template <typename CScalar>

This comment was marked as off-topic.

This comment was marked as off-topic.

@gchanan
Copy link
Contributor Author

gchanan commented Dec 18, 2017

Test failure doesn't look related; I'm going to merge this because I've run into a bunch of merge conflicts over the last few days.

I think the only review comment left is from @zdevito, "(2) We use templated functions rather the classes. The reason for the functions was partial specialization. If you need partial specialization, then just write your own switch statement"; lmk if you want changes to this and I'll make them in a future commit.

@gchanan gchanan merged commit 0876bab into pytorch:master Dec 18, 2017
@alicanb
Copy link
Collaborator

alicanb commented Dec 22, 2017

Guys, @fritzo and I think this breaks test_gamma_sample_grad in test_distributions. (RuntimeError: VariableType::_standard_gamma_grad NYI)

@gchanan
Copy link
Contributor Author

gchanan commented Dec 22, 2017

@alicanb I'll take a look.

@fritzo
Copy link
Collaborator

fritzo commented Jan 9, 2018

@gchanan This PR removes _standard_gamma_grad() from the torch._C module. How can I access this function in Python now, e.g. for unit testing? Thanks!

@gchanan
Copy link
Contributor Author

gchanan commented Jan 9, 2018

On tensors or variables?

@fritzo
Copy link
Collaborator

fritzo commented Jan 9, 2018

On tensors. I see it is a Variable method now?

@fritzo
Copy link
Collaborator

fritzo commented Jan 9, 2018

Well that works, I'll just use it on Variables. Thanks.

@gchanan
Copy link
Contributor Author

gchanan commented Jan 9, 2018

Yah, since we are planning on merging Variables and tensors I didn't spend the extra effort to make it available on tensors (and it wasn't being used on tensors anyway).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants