Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions aten/src/ATen/native/cudnn/Conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ struct Workspace {
}
Workspace(const Workspace&) = delete;
Workspace(Workspace&&) = default;
Workspace& operator=(Workspace&&) = default;
~Workspace() {
if (data) {
THCudaFree(globalContext().lazyInitCUDA(), data);
Expand Down
2 changes: 1 addition & 1 deletion test/expect/TestJit.test_input_pruning.expect
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ graph(%0 : Double(5, 5)
%3 : Double(5, 5) = add[alpha={1}](%0, %1)
---------------- stage 1 ----------------
%6 : Double(5, 5) = mul(%4, %1)
%7 : Double(5, 5) = add[alpha={1}](%6, %5)
%7 : Double(5, 5) = add[alpha={1}](%5, %6)
return (%2, %3, %7);
}
33 changes: 20 additions & 13 deletions torch/csrc/autograd/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <unordered_set>
#include <typeinfo>
#include <sstream>
#include <queue>
#include <TH/TH.h>

#ifdef WITH_CUDA
Expand Down Expand Up @@ -51,13 +52,19 @@ struct FunctionTask {
, inputs(std::move(inputs)) {}
};

struct CompareFunctionTaskTime {
bool operator()(FunctionTask const & t1, FunctionTask const & t2) {
return t1.fn->time < t2.fn->time;
}
};

struct ReadyQueue {
std::deque<FunctionTask> queue;
std::priority_queue<FunctionTask, std::vector<FunctionTask>, CompareFunctionTaskTime> heap;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

std::condition_variable not_empty;
std::mutex mutex;

void push_front(FunctionTask item);
FunctionTask pop_back();
void push(FunctionTask item);
FunctionTask pop();
};

struct GraphTask {
Expand Down Expand Up @@ -114,19 +121,19 @@ struct GraphTask {
, owner(NO_DEVICE) {}
};

auto ReadyQueue::push_front(FunctionTask item) -> void {
auto ReadyQueue::push(FunctionTask item) -> void {
{
std::lock_guard<std::mutex> lock(mutex);
++item.base->outstanding_tasks;
queue.push_front(std::move(item));
heap.push(std::move(item));
}
not_empty.notify_one();
}

auto ReadyQueue::pop_back() -> FunctionTask {
auto ReadyQueue::pop() -> FunctionTask {
std::unique_lock<std::mutex> lock(mutex);
not_empty.wait(lock, [this]{ return !queue.empty(); });
auto task = std::move(queue.back()); queue.pop_back();
not_empty.wait(lock, [this]{ return !heap.empty(); });
auto task = std::move(const_cast<FunctionTask&>(heap.top())); heap.pop();
return task;
}

Expand Down Expand Up @@ -160,7 +167,7 @@ auto Engine::thread_init(int device) -> void {
auto Engine::thread_main(GraphTask *graph_task) -> void {
auto queue = ready_queues[worker_device + 1];
while (!graph_task || graph_task->outstanding_tasks > 0) {
FunctionTask task = queue->pop_back();
FunctionTask task = queue->pop();
if (task.fn && !task.base->has_error.load()) {
GradMode::set_enabled(task.base->grad_mode);
try {
Expand Down Expand Up @@ -189,7 +196,7 @@ auto Engine::thread_main(GraphTask *graph_task) -> void {
if (--task.base->outstanding_tasks == 0) {
// Synchronize outstanding_tasks with queue mutex
std::atomic_thread_fence(std::memory_order_release);
ready_queue(base_owner).push_front(FunctionTask(task.base, nullptr, InputBuffer(0)));
ready_queue(base_owner).push(FunctionTask(task.base, nullptr, InputBuffer(0)));
}
}
}
Expand Down Expand Up @@ -297,7 +304,7 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
input_buffer.add(input_nr, std::move(output));
if (is_ready) {
auto& queue = ready_queue(input_buffer.device());
queue.push_front(FunctionTask(task.base, next_fn, std::move(input_buffer)));
queue.push(FunctionTask(task.base, next_fn, std::move(input_buffer)));
} else {
not_ready.emplace(next_fn.get(), std::move(input_buffer));
}
Expand All @@ -307,7 +314,7 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
input_buffer.add(input_nr, std::move(output));
if (is_ready) {
auto& queue = ready_queue(input_buffer.device());
queue.push_front(FunctionTask(task.base, next_fn, std::move(input_buffer)));
queue.push(FunctionTask(task.base, next_fn, std::move(input_buffer)));
not_ready.erase(not_ready_it);
}
}
Expand Down Expand Up @@ -370,7 +377,7 @@ auto Engine::execute(const function_list& input_roots,
if (!outputs.empty()) {
graph_task.init_to_execute(*graph_root, outputs);
}
ready_queue(-1).push_front(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0)));
ready_queue(-1).push(FunctionTask(&graph_task, std::move(graph_root), InputBuffer(0)));

// Not a worker
if (worker_device == NO_DEVICE) {
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/autograd/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

namespace torch { namespace autograd {

thread_local uint64_t Function::function_counter = 0;

template<typename T>
auto makeFlags(const T &inputs) -> FunctionFlags {
int num_inputs = inputs.size();
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/autograd/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ struct FunctionFlags {
};

struct Function : std::enable_shared_from_this<Function> {
static thread_local uint64_t function_counter;

Function()
: num_inputs(0)
, time(function_counter++)
, next_functions()
, pre_hooks()
, post_hooks()
Expand All @@ -60,6 +63,7 @@ struct Function : std::enable_shared_from_this<Function> {

Function(FunctionFlags&& flags)
: num_inputs(0)
, time(function_counter++)
, next_functions(std::move(flags.next_functions))
, pre_hooks()
, post_hooks()
Expand Down Expand Up @@ -152,6 +156,7 @@ struct Function : std::enable_shared_from_this<Function> {
const variable_list& inputs, const variable_list& outputs);

int num_inputs;
uint64_t time;
function_list next_functions;
std::vector<std::shared_ptr<FunctionPreHook>> pre_hooks;
std::vector<std::shared_ptr<FunctionPostHook>> post_hooks;
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/input_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct InputBuffer {
: buffer(size) {}
InputBuffer(const InputBuffer& other) = delete;
InputBuffer(InputBuffer&& other) = default;
InputBuffer& operator=(InputBuffer&& other) = default;

// Accumulates the variable at a specified index.
void add(size_t idx, Variable var);
Expand Down