Skip to content

Commit 4a22981

Browse files
Add unit tests
1 parent 4152604 commit 4a22981

File tree

4 files changed

+638
-0
lines changed

4 files changed

+638
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ internal_filesystem/SDLPointer_3
88

99
# config files etc:
1010
internal_filesystem/data
11+
internal_filesystem/sdcard
1112

tests/manual_test_nostr_asyncio.py

Lines changed: 333 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
import asyncio
2+
import json
3+
import ssl
4+
import _thread
5+
import time
6+
import unittest
7+
8+
from mpos import App, PackageManager
9+
import mpos.apps
10+
11+
from nostr.relay_manager import RelayManager
12+
from nostr.message_type import ClientMessageType
13+
from nostr.filter import Filter, Filters
14+
from nostr.event import EncryptedDirectMessage
15+
from nostr.key import PrivateKey
16+
17+
18+
# keeps a list of items
19+
# The .add() method ensures the list remains unique (via __eq__)
20+
# and sorted (via __lt__) by inserting new items in the correct position.
21+
class UniqueSortedList:
22+
def __init__(self):
23+
self._items = []
24+
25+
def add(self, item):
26+
#print(f"before add: {str(self)}")
27+
# Check if item already exists (using __eq__)
28+
if item not in self._items:
29+
# Insert item in sorted position for descending order (using __gt__)
30+
for i, existing_item in enumerate(self._items):
31+
if item > existing_item:
32+
self._items.insert(i, item)
33+
return
34+
# If item is smaller than all existing items, append it
35+
self._items.append(item)
36+
#print(f"after add: {str(self)}")
37+
38+
def __iter__(self):
39+
# Return iterator for the internal list
40+
return iter(self._items)
41+
42+
def get(self, index_nr):
43+
# Retrieve item at given index, raise IndexError if invalid
44+
try:
45+
return self._items[index_nr]
46+
except IndexError:
47+
raise IndexError("Index out of range")
48+
49+
def __len__(self):
50+
# Return the number of items for len() calls
51+
return len(self._items)
52+
53+
def __str__(self):
54+
#print("UniqueSortedList tostring called")
55+
return "\n".join(str(item) for item in self._items)
56+
57+
def __eq__(self, other):
58+
if len(self._items) != len(other):
59+
return False
60+
return all(p1 == p2 for p1, p2 in zip(self._items, other))
61+
62+
# Payment class remains unchanged
63+
class Payment:
64+
def __init__(self, epoch_time, amount_sats, comment):
65+
self.epoch_time = epoch_time
66+
self.amount_sats = amount_sats
67+
self.comment = comment
68+
69+
def __str__(self):
70+
sattext = "sats"
71+
if self.amount_sats == 1:
72+
sattext = "sat"
73+
#return f"{self.amount_sats} {sattext} @ {self.epoch_time}: {self.comment}"
74+
return f"{self.amount_sats} {sattext}: {self.comment}"
75+
76+
def __eq__(self, other):
77+
if not isinstance(other, Payment):
78+
return False
79+
return self.epoch_time == other.epoch_time and self.amount_sats == other.amount_sats and self.comment == other.comment
80+
81+
def __lt__(self, other):
82+
if not isinstance(other, Payment):
83+
return NotImplemented
84+
return (self.epoch_time, self.amount_sats, self.comment) < (other.epoch_time, other.amount_sats, other.comment)
85+
86+
def __le__(self, other):
87+
if not isinstance(other, Payment):
88+
return NotImplemented
89+
return (self.epoch_time, self.amount_sats, self.comment) <= (other.epoch_time, other.amount_sats, other.comment)
90+
91+
def __gt__(self, other):
92+
if not isinstance(other, Payment):
93+
return NotImplemented
94+
return (self.epoch_time, self.amount_sats, self.comment) > (other.epoch_time, other.amount_sats, other.comment)
95+
96+
def __ge__(self, other):
97+
if not isinstance(other, Payment):
98+
return NotImplemented
99+
return (self.epoch_time, self.amount_sats, self.comment) >= (other.epoch_time, other.amount_sats, other.comment)
100+
101+
102+
103+
class TestNostr(unittest.TestCase):
104+
105+
PAYMENTS_TO_SHOW = 5
106+
107+
keep_running = None
108+
connected = None
109+
balance = -1
110+
payment_list = []
111+
transactions_welcome = False
112+
113+
relays = [ "ws://192.168.1.16:5000/nostrrelay/test", "ws://192.168.1.16:5000/nostrclient/api/v1/relay" ]
114+
#relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay" ]
115+
#relays = [ "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
116+
#relays = [ "ws://127.0.0.1:5000/nostrrelay/test", "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
117+
#relays = [ "ws://127.0.0.1:5000/nostrclient/api/v1/relay", "wss://relay.damus.io", "wss://nostr-pub.wellorder.net" ]
118+
secret = "fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3"
119+
wallet_pubkey = "e46762afab282c324278351165122345f9983ea447b47943b052100321227571"
120+
121+
async def fetch_balance(self):
122+
if not self.keep_running:
123+
return
124+
# Create get_balance request
125+
balance_request = {
126+
"method": "get_balance",
127+
"params": {}
128+
}
129+
print(f"DEBUG: Created balance request: {balance_request}")
130+
print(f"DEBUG: Creating encrypted DM to wallet pubkey: {self.wallet_pubkey}")
131+
dm = EncryptedDirectMessage(
132+
recipient_pubkey=self.wallet_pubkey,
133+
cleartext_content=json.dumps(balance_request),
134+
kind=23194
135+
)
136+
print(f"DEBUG: Signing DM {json.dumps(dm)} with private key")
137+
self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm
138+
print(f"DEBUG: Publishing encrypted DM")
139+
self.relay_manager.publish_event(dm)
140+
141+
def handle_new_balance(self, new_balance, fetchPaymentsIfChanged=True):
142+
if not self.keep_running or new_balance is None:
143+
return
144+
if fetchPaymentsIfChanged: # Fetching *all* payments isn't necessary if balance was changed by a payment notification
145+
print("Refreshing payments...")
146+
self.fetch_payments() # if the balance changed, then re-list transactions
147+
148+
def fetch_payments(self):
149+
if not self.keep_running:
150+
return
151+
# Create get_balance request
152+
list_transactions = {
153+
"method": "list_transactions",
154+
"params": {
155+
"limit": self.PAYMENTS_TO_SHOW
156+
}
157+
}
158+
dm = EncryptedDirectMessage(
159+
recipient_pubkey=self.wallet_pubkey,
160+
cleartext_content=json.dumps(list_transactions),
161+
kind=23194
162+
)
163+
self.private_key.sign_event(dm) # sign also does encryption if it's a encrypted dm
164+
print("\nPublishing DM to fetch payments...")
165+
self.relay_manager.publish_event(dm)
166+
self.transactions_welcome = True
167+
168+
def handle_new_payments(self, new_payments):
169+
if not self.keep_running or not self.transactions_welcome:
170+
return
171+
print("handle_new_payments")
172+
if self.payment_list != new_payments:
173+
print("new list of payments")
174+
self.payment_list = new_payments
175+
self.payments_updated_cb()
176+
177+
def payments_updated_cb(self):
178+
print("payments_updated_cb called, now closing everything!")
179+
self.keep_running = False
180+
181+
def getCommentFromTransaction(self, transaction):
182+
comment = ""
183+
try:
184+
comment = transaction["description"]
185+
json_comment = json.loads(comment)
186+
for field in json_comment:
187+
if field[0] == "text/plain":
188+
comment = field[1]
189+
break
190+
else:
191+
print("text/plain field is missing from JSON description")
192+
except Exception as e:
193+
print(f"Info: could not parse comment as JSON, this is fine, using as-is ({e})")
194+
return comment
195+
196+
197+
async def NOmainHERE(self):
198+
self.keep_running = True
199+
self.private_key = PrivateKey(bytes.fromhex(self.secret))
200+
self.relay_manager = RelayManager()
201+
for relay in self.relays:
202+
self.relay_manager.add_relay(relay)
203+
204+
print(f"DEBUG: Opening relay connections")
205+
await self.relay_manager.open_connections({"cert_reqs": ssl.CERT_NONE})
206+
self.connected = False
207+
for _ in range(20):
208+
print("Waiting for relay connection...")
209+
await asyncio.sleep(0.5)
210+
nrconnected = 0
211+
for index, relay in enumerate(self.relays):
212+
try:
213+
relay = self.relay_manager.relays[self.relays[index]]
214+
if relay.connected is True:
215+
print(f"connected: {self.relays[index]}")
216+
nrconnected += 1
217+
else:
218+
print(f"not connected: {self.relays[index]}")
219+
except Exception as e:
220+
print(f"could not find relay: {e}")
221+
break # not all of them have been initialized, skip...
222+
self.connected = ( nrconnected == len(self.relays) )
223+
if self.connected:
224+
print("All relays connected!")
225+
break
226+
if not self.connected or not self.keep_running:
227+
print(f"ERROR: could not connect to relay or not self.keep_running, aborting...")
228+
# TODO: call an error callback to notify the user
229+
return
230+
231+
# Set up subscription to receive response
232+
self.subscription_id = "micropython_nwc_" + str(round(time.time()))
233+
print(f"DEBUG: Setting up subscription with ID: {self.subscription_id}")
234+
self.filters = Filters([Filter(
235+
#event_ids=[self.subscription_id], # would be nice to filter, but not like this
236+
kinds=[23195, 23196], # NWC reponses and notifications
237+
authors=[self.wallet_pubkey],
238+
pubkey_refs=[self.private_key.public_key.hex()]
239+
)])
240+
print(f"DEBUG: Subscription filters: {self.filters.to_json_array()}")
241+
self.relay_manager.add_subscription(self.subscription_id, self.filters)
242+
print(f"DEBUG: Creating subscription request")
243+
request_message = [ClientMessageType.REQUEST, self.subscription_id]
244+
request_message.extend(self.filters.to_json_array())
245+
print(f"DEBUG: Publishing subscription request")
246+
self.relay_manager.publish_message(json.dumps(request_message))
247+
print(f"DEBUG: Published subscription request")
248+
for _ in range(4):
249+
if not self.keep_running:
250+
return
251+
print("Waiting a bit before self.fetch_balance()")
252+
await asyncio.sleep(0.5)
253+
254+
await self.fetch_balance()
255+
256+
while True:
257+
print(f"checking for incoming events...")
258+
await asyncio.sleep(1)
259+
if not self.keep_running:
260+
print("NWCWallet: not keep_running, closing connections...")
261+
await self.relay_manager.close_connections()
262+
break
263+
264+
start_time = time.ticks_ms()
265+
if self.relay_manager.message_pool.has_events():
266+
print(f"DEBUG: Event received from message pool after {time.ticks_ms()-start_time}ms")
267+
event_msg = self.relay_manager.message_pool.get_event()
268+
event_created_at = event_msg.event.created_at
269+
print(f"Received at {time.localtime()} a message with timestamp {event_created_at} after {time.ticks_ms()-start_time}ms")
270+
try:
271+
# This takes a very long time, even for short messages:
272+
decrypted_content = self.private_key.decrypt_message(
273+
event_msg.event.content,
274+
event_msg.event.public_key,
275+
)
276+
print(f"DEBUG: Decrypted content: {decrypted_content} after {time.ticks_ms()-start_time}ms")
277+
response = json.loads(decrypted_content)
278+
print(f"DEBUG: Parsed response: {response}")
279+
result = response.get("result")
280+
if result:
281+
if result.get("balance") is not None:
282+
new_balance = round(int(result["balance"]) / 1000)
283+
print(f"Got balance: {new_balance}")
284+
self.handle_new_balance(new_balance)
285+
elif result.get("transactions") is not None:
286+
print("Response contains transactions!")
287+
new_payment_list = UniqueSortedList()
288+
for transaction in result["transactions"]:
289+
amount = transaction["amount"]
290+
amount = round(amount / 1000)
291+
comment = self.getCommentFromTransaction(transaction)
292+
epoch_time = transaction["created_at"]
293+
paymentObj = Payment(epoch_time, amount, comment)
294+
new_payment_list.add(paymentObj)
295+
if len(new_payment_list) > 0:
296+
# do them all in one shot instead of one-by-one because the lv_async() isn't always chronological,
297+
# so when a long list of payments is added, it may be overwritten by a short list
298+
self.handle_new_payments(new_payment_list)
299+
else:
300+
notification = response.get("notification")
301+
if notification:
302+
amount = notification["amount"]
303+
amount = round(amount / 1000)
304+
type = notification["type"]
305+
if type == "outgoing":
306+
amount = -amount
307+
elif type == "incoming":
308+
new_balance = self.last_known_balance + amount
309+
self.handle_new_balance(new_balance, False) # don't trigger full fetch because payment info is in notification
310+
epoch_time = notification["created_at"]
311+
comment = self.getCommentFromTransaction(notification)
312+
paymentObj = Payment(epoch_time, amount, comment)
313+
self.handle_new_payment(paymentObj)
314+
else:
315+
print(f"WARNING: invalid notification type {type}, ignoring.")
316+
else:
317+
print("Unsupported response, ignoring.")
318+
except Exception as e:
319+
print(f"DEBUG: Error processing response: {e}")
320+
else:
321+
#print(f"pool has no events after {time.ticks_ms()-start_time}ms") # completes in 0-1ms
322+
pass
323+
324+
def test_it(self):
325+
print("before do_two")
326+
asyncio.run(self.do_two())
327+
print("after do_two")
328+
329+
def do_two(self):
330+
print("before await self.NOmainHERE()")
331+
await self.NOmainHERE()
332+
print("after await self.NOmainHERE()")
333+

tests/manual_test_nwcwallet.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import asyncio
2+
import json
3+
import ssl
4+
import _thread
5+
import time
6+
import unittest
7+
8+
from mpos import App, PackageManager
9+
import mpos.apps
10+
11+
import sys
12+
sys.path.append("apps/com.lightningpiggy.displaywallet/assets/")
13+
from wallet import NWCWallet
14+
15+
class TestNWCWallet(unittest.TestCase):
16+
17+
redraw_balance_cb_called = 0
18+
redraw_payments_cb_called = 0
19+
redraw_static_receive_code_cb_called = 0
20+
error_callback_called = 0
21+
22+
def redraw_balance_cb(self, balance=0):
23+
print(f"redraw_callback called, balance: {balance}")
24+
self.redraw_balance_cb_called += 1
25+
26+
def redraw_payments_cb(self):
27+
print(f"redraw_payments_cb called")
28+
self.redraw_payments_cb_called += 1
29+
30+
def redraw_static_receive_code_cb(self):
31+
print(f"redraw_static_receive_code_cb called")
32+
self.redraw_static_receive_code_cb_called += 1
33+
34+
def error_callback(self, error):
35+
print(f"error_callback called, error: {error}")
36+
self.error_callback_called += 1
37+
38+
def test_it(self):
39+
print("starting test")
40+
self.wallet = NWCWallet("nostr+walletconnect://e46762afab282c324278351165122345f9983ea447b47943b052100321227571?relay=ws://192.168.1.16:5000/nostrclient/api/v1/relay&secret=fab0a9a11d4cf4b1d92e901a0b2c56634275e2fa1a7eb396ff1b942f95d59fd3&lud16=test@example.com")
41+
self.wallet.start(self.redraw_balance_cb, self.redraw_payments_cb, self.redraw_static_receive_code_cb, self.error_callback)
42+
time.sleep(15)
43+
self.assertTrue(self.redraw_balance_cb_called > 0)
44+
self.assertTrue(self.redraw_payments_cb_called > 0)
45+
self.assertTrue(self.redraw_static_receive_code_cb_called > 0)
46+
self.assertTrue(self.error_callback_called == 0)
47+
print("test finished")
48+
49+

0 commit comments

Comments
 (0)