-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathcoro.cppm
More file actions
111 lines (97 loc) · 3.39 KB
/
coro.cppm
File metadata and controls
111 lines (97 loc) · 3.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
export module mcpplibs.llmapi:coro;
import std;
export namespace mcpplibs::llmapi {
template<typename T>
class Task {
public:
struct promise_type {
std::optional<T> value;
std::exception_ptr exception;
Task get_return_object() {
return Task{std::coroutine_handle<promise_type>::from_promise(*this)};
}
std::suspend_always initial_suspend() noexcept { return {}; }
std::suspend_always final_suspend() noexcept { return {}; }
void return_value(T val) { value = std::move(val); }
void unhandled_exception() { exception = std::current_exception(); }
};
private:
std::coroutine_handle<promise_type> handle_;
public:
explicit Task(std::coroutine_handle<promise_type> h) : handle_(h) {}
~Task() { if (handle_) handle_.destroy(); }
// Move only
Task(Task&& other) noexcept : handle_(std::exchange(other.handle_, {})) {}
Task& operator=(Task&& other) noexcept {
if (this != &other) {
if (handle_) handle_.destroy();
handle_ = std::exchange(other.handle_, {});
}
return *this;
}
Task(const Task&) = delete;
Task& operator=(const Task&) = delete;
// Awaitable
bool await_ready() const noexcept { return handle_.done(); }
void await_suspend(std::coroutine_handle<> awaiter) noexcept {
handle_.resume();
awaiter.resume();
}
T await_resume() {
if (handle_.promise().exception)
std::rethrow_exception(handle_.promise().exception);
return std::move(*handle_.promise().value);
}
// Sync get
T get() {
if (!handle_.done()) handle_.resume();
if (handle_.promise().exception)
std::rethrow_exception(handle_.promise().exception);
return std::move(*handle_.promise().value);
}
};
// Task<void> specialization
template<>
class Task<void> {
public:
struct promise_type {
std::exception_ptr exception;
Task get_return_object() {
return Task{std::coroutine_handle<promise_type>::from_promise(*this)};
}
std::suspend_always initial_suspend() noexcept { return {}; }
std::suspend_always final_suspend() noexcept { return {}; }
void return_void() noexcept {}
void unhandled_exception() { exception = std::current_exception(); }
};
private:
std::coroutine_handle<promise_type> handle_;
public:
explicit Task(std::coroutine_handle<promise_type> h) : handle_(h) {}
~Task() { if (handle_) handle_.destroy(); }
Task(Task&& other) noexcept : handle_(std::exchange(other.handle_, {})) {}
Task& operator=(Task&& other) noexcept {
if (this != &other) {
if (handle_) handle_.destroy();
handle_ = std::exchange(other.handle_, {});
}
return *this;
}
Task(const Task&) = delete;
Task& operator=(const Task&) = delete;
bool await_ready() const noexcept { return handle_.done(); }
void await_suspend(std::coroutine_handle<> awaiter) noexcept {
handle_.resume();
awaiter.resume();
}
void await_resume() {
if (handle_.promise().exception)
std::rethrow_exception(handle_.promise().exception);
}
void get() {
if (!handle_.done()) handle_.resume();
if (handle_.promise().exception)
std::rethrow_exception(handle_.promise().exception);
}
};
} // namespace mcpplibs::llmapi