Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Mar 13, 2018

Fixes #5741

The only operation that really benefits from this right now is tensor.mean().

cc @gchanan @colesbury

Test Plan

python test/test_autograd.py

@gchanan
Copy link
Contributor

gchanan commented Mar 13, 2018

looks good, mind showing the generated code?

@zou3519
Copy link
Contributor Author

zou3519 commented Mar 13, 2018

MeanBackward1 struct:

struct MeanBackward1 : public TraceableFunction {
  using TraceableFunction::TraceableFunction;
  variable_list apply(const variable_list& grads) override;
  std::string name() override { return "MeanBackward1"; }
  void release_variables() override {

  }

  std::vector<int64_t> self_sizes;
  int64_t self_numel;

};

forward:

Tensor VariableType::mean(const Tensor & self) const {
  profiler::RecordFunction profiler("mean");
  auto& self_ = unpack(self, "self", 0);
  std::shared_ptr<MeanBackward1> grad_fn;
  if (compute_requires_grad( self )) {
    grad_fn = std::make_shared<MeanBackward1>();
    grad_fn->set_next_edges(collect_next_edges( self ));
    grad_fn->self_sizes = self.sizes();
    grad_fn->self_numel = self.numel();
  }
  jit::tracer::PreTraceInfo trace_info;
  if (jit::tracer::isTracing( self )) {
    trace_info = jit::tracer::preRecordTrace( "mean", { self } );

  }
  auto result = as_variable(baseType->mean(self_));
  set_history(result, grad_fn);
  if (trace_info.state != nullptr) {
    jit::tracer::postRecordTrace( trace_info,  { result } );
  }
  return result;
}

backward:

 variable_list MeanBackward1::apply(const variable_list& grads) {
   IndexRangeGenerator gen;
   auto self_ix = gen.range(1);
   variable_list grad_inputs(gen.size());
   auto& grad = grads[0];
   if (should_compute_output({ self_ix })) {
     auto grad_result = grad.expand(self_sizes) / self_numel;
     copy_range(grad_inputs, self_ix, grad_result);
   }
   return grad_inputs;
 }

@soumith soumith merged commit 11444a7 into pytorch:master Mar 13, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants