Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions telegram/ext/basepersistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ class BasePersistence(object):
must overwrite :meth:`get_conversations` and :meth:`update_conversation`.
* :meth:`flush` will be called when the bot is shutdown.

Note:
It may be benifitial to check if data has changed, before persisting it. Therefore
:meth:`get_chat_data`, :meth:`get_user_data` and :meth:`get_conversations` should *not*
return the data stored in the instance of your persistence class but rather a deep copy
of it.

Attributes:
store_user_data (:obj:`bool`): Optional, Whether user_data should be saved by this
persistence class.
Expand Down
50 changes: 38 additions & 12 deletions telegram/ext/dictpersistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
except ImportError:
import json
from collections import defaultdict
from copy import deepcopy
from telegram.ext import BasePersistence


Expand All @@ -36,22 +37,26 @@ class DictPersistence(BasePersistence):
persistence class.
store_chat_data (:obj:`bool`): Whether chat_data should be saved by this
persistence class.
on_update (:obj:`bool`): Optional. When ``True`` will only save to file, if data has
changed. When ``False`` will save to file on every update. Default is ``False``.

Args:
store_user_data (:obj:`bool`, optional): Whether user_data should be saved by this
persistence class. Default is ``True``.
store_chat_data (:obj:`bool`, optional): Whether user_data should be saved by this
store_chat_data (:obj:`bool`, optional): Whether chat_data should be saved by this
persistence class. Default is ``True``.
user_data_json (:obj:`str`, optional): Json string that will be used to reconstruct
user_data on creating this persistence. Default is ``""``.
chat_data_json (:obj:`str`, optional): Json string that will be used to reconstruct
chat_data on creating this persistence. Default is ``""``.
conversations_json (:obj:`str`, optional): Json string that will be used to reconstruct
conversation on creating this persistence. Default is ``""``.
on_update (:obj:`bool`, optional): When ``True`` will only save to file, if data has
changed. When ``False`` will save to file on every update. Default is ``False``.
"""

def __init__(self, store_user_data=True, store_chat_data=True, user_data_json='',
chat_data_json='', conversations_json=''):
chat_data_json='', conversations_json='', on_update=False):
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
self._user_data = None
Expand All @@ -60,6 +65,7 @@ def __init__(self, store_user_data=True, store_chat_data=True, user_data_json=''
self._user_data_json = None
self._chat_data_json = None
self._conversations_json = None
self.on_update = on_update
if user_data_json:
try:
self._user_data = decode_user_chat_data_from_json(user_data_json)
Expand Down Expand Up @@ -129,7 +135,10 @@ def get_user_data(self):
pass
else:
self._user_data = defaultdict(dict)
return self.user_data.copy()
if self.on_update:
return deepcopy(self.user_data)
else:
return self.user_data

def get_chat_data(self):
"""Returns the chat_data created from the ``chat_data_json`` or an empty defaultdict.
Expand All @@ -141,7 +150,10 @@ def get_chat_data(self):
pass
else:
self._chat_data = defaultdict(dict)
return self.chat_data.copy()
if self.on_update:
return deepcopy(self.chat_data)
else:
return self.chat_data

def get_conversations(self, name):
"""Returns the conversations created from the ``conversations_json`` or an empty
Expand All @@ -154,7 +166,10 @@ def get_conversations(self, name):
pass
else:
self._conversations = {}
return self.conversations.get(name, {}).copy()
if self.on_update:
return deepcopy(self.conversations.get(name, {}))
else:
return self.conversations.get(name, {})

def update_conversation(self, name, key, new_state):
"""Will update the conversations for the given handler.
Expand All @@ -164,9 +179,12 @@ def update_conversation(self, name, key, new_state):
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
if self._conversations.setdefault(name, {}).get(key) == new_state:
return
self._conversations[name][key] = new_state
if self.on_update:
if self._conversations.setdefault(name, {}).get(key) == new_state:
return
self._conversations[name][key] = deepcopy(new_state)
else:
self._conversations[name][key] = new_state
self._conversations_json = None

def update_user_data(self, user_id, data):
Expand All @@ -176,8 +194,12 @@ def update_user_data(self, user_id, data):
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
if self._user_data.get(user_id) == data:
return
if self.on_update:
if self._user_data.get(user_id) == data:
return
self._user_data[user_id] = deepcopy(data)
else:
self._user_data[user_id] = data
self._user_data[user_id] = data
self._user_data_json = None

Expand All @@ -188,7 +210,11 @@ def update_chat_data(self, chat_id, data):
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id].
"""
if self._chat_data.get(chat_id) == data:
return
if self.on_update:
if self._chat_data.get(chat_id) == data:
return
self._chat_data[chat_id] = deepcopy(data)
else:
self._chat_data[chat_id] = data
self._chat_data[chat_id] = data
self._chat_data_json = None
52 changes: 38 additions & 14 deletions telegram/ext/picklepersistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""This module contains the PicklePersistence class."""
import pickle
from collections import defaultdict
from copy import deepcopy

from telegram.ext import BasePersistence

Expand All @@ -31,14 +32,16 @@ class PicklePersistence(BasePersistence):
is false this will be used as a prefix.
store_user_data (:obj:`bool`): Optional. Whether user_data should be saved by this
persistence class.
store_chat_data (:obj:`bool`): Optional. Whether user_data should be saved by this
store_chat_data (:obj:`bool`): Optional. Whether chat_data should be saved by this
persistence class.
single_file (:obj:`bool`): Optional. When ``False`` will store 3 sperate files of
`filename_user_data`, `filename_chat_data` and `filename_conversations`. Default is
``True``.
on_flush (:obj:`bool`): Optional. When ``True`` will only save to file when :meth:`flush`
is called and keep data in memory until that happens. When False will store data on any
transaction. Default is ``False``.
on_update (:obj:`bool`): Optional. When ``True`` will only save to file, if data has
changed. When ``False`` will save to file on every update. Default is ``False``.

Args:
filename (:obj:`str`): The filename for storing the pickle files. When :attr:`single_file`
Expand All @@ -53,15 +56,18 @@ class PicklePersistence(BasePersistence):
on_flush (:obj:`bool`, optional): When ``True`` will only save to file when :meth:`flush`
is called and keep data in memory until that happens. When False will store data on any
transaction. Default is ``False``.
on_update (:obj:`bool`, optional): When ``True`` will only save to file, if data has
changed. When ``False`` will save to file on every update. Default is ``False``.
"""

def __init__(self, filename, store_user_data=True, store_chat_data=True, singe_file=True,
on_flush=False):
on_flush=False, on_update=False):
self.filename = filename
self.store_user_data = store_user_data
self.store_chat_data = store_chat_data
self.single_file = singe_file
self.on_flush = on_flush
self.on_update = on_update
self.user_data = None
self.chat_data = None
self.conversations = None
Expand Down Expand Up @@ -122,7 +128,10 @@ def get_user_data(self):
self.user_data = data
else:
self.load_singlefile()
return self.user_data.copy()
if self.on_update:
return deepcopy(self.user_data)
else:
return self.user_data

def get_chat_data(self):
"""Returns the chat_data from the pickle file if it exsists or an empty defaultdict.
Expand All @@ -142,7 +151,10 @@ def get_chat_data(self):
self.chat_data = data
else:
self.load_singlefile()
return self.chat_data.copy()
if self.on_update:
return deepcopy(self.chat_data)
else:
return self.chat_data

def get_conversations(self, name):
"""Returns the conversations from the pickle file if it exsists or an empty defaultdict.
Expand All @@ -163,7 +175,10 @@ def get_conversations(self, name):
self.conversations = data
else:
self.load_singlefile()
return self.conversations.get(name, {}).copy()
if self.on_update:
return deepcopy(self.conversations.get(name, {}))
else:
return self.conversations.get(name, {})

def update_conversation(self, name, key, new_state):
"""Will update the conversations for the given handler and depending on :attr:`on_flush`
Expand All @@ -174,9 +189,12 @@ def update_conversation(self, name, key, new_state):
key (:obj:`tuple`): The key the state is changed for.
new_state (:obj:`tuple` | :obj:`any`): The new state for the given key.
"""
if self.conversations.setdefault(name, {}).get(key) == new_state:
return
self.conversations[name][key] = new_state
if self.on_update:
if self.conversations.setdefault(name, {}).get(key) == new_state:
return
self.conversations[name][key] = deepcopy(new_state)
else:
self.conversations[name][key] = new_state
if not self.on_flush:
if not self.single_file:
filename = "{}_conversations".format(self.filename)
Expand All @@ -192,9 +210,12 @@ def update_user_data(self, user_id, data):
user_id (:obj:`int`): The user the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.user_data`[user_id].
"""
if self.user_data.get(user_id) == data:
return
self.user_data[user_id] = data
if self.on_update:
if self.user_data.get(user_id) == data:
return
self.user_data[user_id] = deepcopy(data)
else:
self.user_data[user_id] = data
if not self.on_flush:
if not self.single_file:
filename = "{}_user_data".format(self.filename)
Expand All @@ -210,9 +231,12 @@ def update_chat_data(self, chat_id, data):
chat_id (:obj:`int`): The chat the data might have been changed for.
data (:obj:`dict`): The :attr:`telegram.ext.dispatcher.chat_data`[chat_id].
"""
if self.chat_data.get(chat_id) == data:
return
self.chat_data[chat_id] = data
if self.on_update:
if self.chat_data.get(chat_id) == data:
return
self.chat_data[chat_id] = deepcopy(data)
else:
self.chat_data[chat_id] = data
if not self.on_flush:
if not self.single_file:
filename = "{}_chat_data".format(self.filename)
Expand Down
37 changes: 22 additions & 15 deletions tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,8 @@ def pickle_persistence():
store_user_data=True,
store_chat_data=True,
singe_file=False,
on_flush=False)
on_flush=False,
on_update=True)


@pytest.fixture(scope='function')
Expand Down Expand Up @@ -598,7 +599,7 @@ def conversations_json(conversations):

class TestDictPersistence(object):
def test_no_json_given(self):
dict_persistence = DictPersistence()
dict_persistence = DictPersistence(on_update=True)
assert dict_persistence.get_user_data() == defaultdict(dict)
assert dict_persistence.get_chat_data() == defaultdict(dict)
assert dict_persistence.get_conversations('noname') == {}
Expand All @@ -608,27 +609,28 @@ def test_bad_json_string_given(self):
bad_chat_data = 'thisisnojson99900()))('
bad_conversations = 'thisisnojson99900()))('
with pytest.raises(TypeError, match='user_data'):
DictPersistence(user_data_json=bad_user_data)
DictPersistence(user_data_json=bad_user_data, on_update=True)
with pytest.raises(TypeError, match='chat_data'):
DictPersistence(chat_data_json=bad_chat_data)
DictPersistence(chat_data_json=bad_chat_data, on_update=True)
with pytest.raises(TypeError, match='conversations'):
DictPersistence(conversations_json=bad_conversations)
DictPersistence(conversations_json=bad_conversations, on_update=True)

def test_invalid_json_string_given(self, pickle_persistence, bad_pickle_files):
bad_user_data = '["this", "is", "json"]'
bad_chat_data = '["this", "is", "json"]'
bad_conversations = '["this", "is", "json"]'
with pytest.raises(TypeError, match='user_data'):
DictPersistence(user_data_json=bad_user_data)
DictPersistence(user_data_json=bad_user_data, on_update=True)
with pytest.raises(TypeError, match='chat_data'):
DictPersistence(chat_data_json=bad_chat_data)
DictPersistence(chat_data_json=bad_chat_data, on_update=True)
with pytest.raises(TypeError, match='conversations'):
DictPersistence(conversations_json=bad_conversations)
DictPersistence(conversations_json=bad_conversations, on_update=True)

def test_good_json_input(self, user_data_json, chat_data_json, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
conversations_json=conversations_json,
on_update=True)
user_data = dict_persistence.get_user_data()
assert isinstance(user_data, defaultdict)
assert user_data[12345]['test1'] == 'test2'
Expand Down Expand Up @@ -658,15 +660,17 @@ def test_dict_outputs(self, user_data, user_data_json, chat_data, chat_data_json
conversations, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
conversations_json=conversations_json,
on_update=True)
assert dict_persistence.user_data == user_data
assert dict_persistence.chat_data == chat_data
assert dict_persistence.conversations == conversations

def test_json_outputs(self, user_data_json, chat_data_json, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
conversations_json=conversations_json,
on_update=True)
assert dict_persistence.user_data_json == user_data_json
assert dict_persistence.chat_data_json == chat_data_json
assert dict_persistence.conversations_json == conversations_json
Expand All @@ -675,7 +679,8 @@ def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json
conversations, conversations_json):
dict_persistence = DictPersistence(user_data_json=user_data_json,
chat_data_json=chat_data_json,
conversations_json=conversations_json)
conversations_json=conversations_json,
on_update=True)
user_data_two = user_data.copy()
user_data_two.update({4: {5: 6}})
dict_persistence.update_user_data(4, {5: 6})
Expand All @@ -699,7 +704,7 @@ def test_json_changes(self, user_data, user_data_json, chat_data, chat_data_json
conversations_two)

def test_with_handler(self, bot, update):
dict_persistence = DictPersistence()
dict_persistence = DictPersistence(on_update=True)
u = Updater(bot=bot, persistence=dict_persistence)
dp = u.dispatcher

Expand Down Expand Up @@ -727,15 +732,17 @@ def second(bot, update, user_data, chat_data):
chat_data = dict_persistence.chat_data_json
del (dict_persistence)
dict_persistence_2 = DictPersistence(user_data_json=user_data,
chat_data_json=chat_data)
chat_data_json=chat_data,
on_update=True)

u = Updater(bot=bot, persistence=dict_persistence_2)
dp = u.dispatcher
dp.add_handler(h2)
dp.process_update(update)

def test_with_conversationHandler(self, dp, update, conversations_json):
dict_persistence = DictPersistence(conversations_json=conversations_json)
dict_persistence = DictPersistence(conversations_json=conversations_json,
on_update=True)
dp.persistence = dict_persistence
NEXT, NEXT2 = range(2)

Expand Down