diff --git a/src/bot/handler/abstractHandler.py b/src/bot/handler/abstractHandler.py index c52118d..d686924 100755 --- a/src/bot/handler/abstractHandler.py +++ b/src/bot/handler/abstractHandler.py @@ -7,6 +7,7 @@ from telegram.update import Update from bot.message.inboundMessage import InboundMessage from bot.message.replier import Replier from exception.actionNotAllowedException import ActionNotAllowedException +from exception.invalidActionException import InvalidActionException from exception.invalidArgumentException import InvalidArgumentException from logger import Logger @@ -14,6 +15,7 @@ from logger import Logger class AbstractHandler: bot_handler: Handler inbound: InboundMessage + action: str @abstractmethod def handle(self, update: Update, context: CallbackContext) -> None: @@ -25,7 +27,8 @@ class AbstractHandler: self.inbound = InboundMessage.create(update, context, group_specific) self.handle(update, context) - except (ActionNotAllowedException, InvalidArgumentException) as e: + Logger.action(self.inbound, self.action) + except (InvalidActionException, InvalidArgumentException, ActionNotAllowedException) as e: Replier.markdown(update, str(e)) except Exception as e: Logger.exception(e) diff --git a/src/bot/handler/everyoneHandler.py b/src/bot/handler/everyoneHandler.py index 36ac603..a31feac 100755 --- a/src/bot/handler/everyoneHandler.py +++ b/src/bot/handler/everyoneHandler.py @@ -5,26 +5,28 @@ from telegram.update import Update from bot.handler.abstractHandler import AbstractHandler from bot.message.replier import Replier from config.contents import mention_failed +from exception.invalidActionException import InvalidActionException from exception.notFoundException import NotFoundException -from logger import Logger +from repository.chatRepository import ChatRepository from repository.userRepository import UserRepository from utils.messageBuilder import MessageBuilder class EveryoneHandler(AbstractHandler): bot_handler: CommandHandler + chat_repository: ChatRepository user_repository: UserRepository action: str = 'everyone' def __init__(self) -> None: self.bot_handler = CommandHandler(self.action, self.wrap) + self.chat_repository = ChatRepository() self.user_repository = UserRepository() def handle(self, update: Update, context: CallbackContext) -> None: try: - users = self.user_repository.get_all_for_chat(self.inbound.chat_id) + users = self.chat_repository.get_users_for_group(self.inbound.chat_id, self.inbound.group_name) Replier.markdown(update, MessageBuilder.mention_message(users)) - Logger.action(self.inbound, self.action) - except NotFoundException: - Replier.markdown(update, mention_failed) + except NotFoundException as e: + raise InvalidActionException(mention_failed) from e diff --git a/src/bot/handler/groupsHandler.py b/src/bot/handler/groupsHandler.py index deff1c4..9190b45 100644 --- a/src/bot/handler/groupsHandler.py +++ b/src/bot/handler/groupsHandler.py @@ -5,29 +5,30 @@ from telegram.update import Update from bot.handler.abstractHandler import AbstractHandler from bot.message.replier import Replier from config.contents import no_groups +from exception.invalidActionException import InvalidActionException from exception.notFoundException import NotFoundException -from logger import Logger -from repository.groupRepository import GroupRepository +from repository.chatRepository import ChatRepository from utils.messageBuilder import MessageBuilder class GroupsHandler(AbstractHandler): bot_handler: CommandHandler - group_repository: GroupRepository + chat_repository: ChatRepository action: str = 'groups' def __init__(self) -> None: self.bot_handler = CommandHandler(self.action, self.wrap) - self.group_repository = GroupRepository() + self.chat_repository = ChatRepository() def handle(self, update: Update, context: CallbackContext) -> None: try: - groups = self.group_repository.get_by_chat_id(self.inbound.chat_id) - Replier.html(update, MessageBuilder.group_message(groups)) + chat = self.chat_repository.get(self.inbound.chat_id) + if not chat.groups: + raise NotFoundException - Logger.action(self.inbound, self.action) + Replier.html(update, MessageBuilder.group_message(chat.groups)) except NotFoundException: - Replier.markdown(update, no_groups) + raise InvalidActionException(no_groups) def is_group_specific(self) -> bool: return False diff --git a/src/bot/handler/inlineQueryHandler.py b/src/bot/handler/inlineQueryHandler.py index f2a5d88..baa9cd3 100644 --- a/src/bot/handler/inlineQueryHandler.py +++ b/src/bot/handler/inlineQueryHandler.py @@ -6,7 +6,7 @@ from telegram.inline.inputtextmessagecontent import InputTextMessageContent from telegram.update import Update from bot.handler.abstractHandler import AbstractHandler -from entity.group import Group +from bot.message.inboundMessage import InboundMessage from exception.actionNotAllowedException import ActionNotAllowedException from validator.accessValidator import AccessValidator @@ -24,8 +24,8 @@ class InlineQueryHandler(AbstractHandler): update.inline_query.answer([]) return - group_display = update.inline_query.query or Group.default_name - group = '' if group_display == Group.default_name else group_display + group_display = update.inline_query.query or InboundMessage.default_group + group = '' if group_display == InboundMessage.default_group else group_display results = [ InlineQueryResultArticle( diff --git a/src/bot/handler/joinHandler.py b/src/bot/handler/joinHandler.py index d6f2067..92a8a69 100755 --- a/src/bot/handler/joinHandler.py +++ b/src/bot/handler/joinHandler.py @@ -5,8 +5,8 @@ from telegram.update import Update from bot.handler.abstractHandler import AbstractHandler from bot.message.replier import Replier from config.contents import joined, not_joined -from exception.notFoundException import NotFoundException -from logger import Logger +from exception.invalidActionException import InvalidActionException +from repository.chatRepository import ChatRepository from repository.userRepository import UserRepository @@ -18,18 +18,17 @@ class JoinHandler(AbstractHandler): def __init__(self) -> None: self.bot_handler = CommandHandler(self.action, self.wrap) self.user_repository = UserRepository() + self.chat_repository = ChatRepository() def handle(self, update: Update, context: CallbackContext) -> None: - try: - user = self.user_repository.get_by_id(self.inbound.user_id) + user = self.user_repository.provide(self.inbound) + chat = self.chat_repository.provide(self.inbound) + users = chat.groups.get(self.inbound.group_name) - if user.is_in_chat(self.inbound.chat_id): - return Replier.markdown(update, Replier.interpolate(not_joined, self.inbound)) + if user.user_id in users: + raise InvalidActionException(Replier.interpolate(not_joined, self.inbound)) - user.add_to_chat(self.inbound.chat_id) - self.user_repository.save(user) - except NotFoundException: - self.user_repository.save_by_inbound_message(self.inbound) + users.append(user.user_id) + self.chat_repository.save(chat) Replier.markdown(update, Replier.interpolate(joined, self.inbound)) - Logger.action(self.inbound, self.action) diff --git a/src/bot/handler/leaveHandler.py b/src/bot/handler/leaveHandler.py index e031596..2c29b4a 100755 --- a/src/bot/handler/leaveHandler.py +++ b/src/bot/handler/leaveHandler.py @@ -5,27 +5,34 @@ from telegram.update import Update from bot.handler.abstractHandler import AbstractHandler from bot.message.replier import Replier from config.contents import left, not_left -from exception.notFoundException import NotFoundException -from logger import Logger +from exception.invalidActionException import InvalidActionException from repository.userRepository import UserRepository +from repository.chatRepository import ChatRepository class LeaveHandler(AbstractHandler): bot_handler: CommandHandler user_repository: UserRepository + chat_repository: ChatRepository action: str = 'leave' def __init__(self) -> None: self.bot_handler = CommandHandler(self.action, self.wrap) self.user_repository = UserRepository() + self.chat_repository = ChatRepository() def handle(self, update: Update, context: CallbackContext) -> None: - try: - user = self.user_repository.get_by_id_and_chat_id(self.inbound.user_id, self.inbound.chat_id) - user.remove_from_chat(self.inbound.chat_id) - self.user_repository.save(user) + user = self.user_repository.provide(self.inbound) + chat = self.chat_repository.provide(self.inbound) + group = chat.groups.get(self.inbound.group_name) - Replier.markdown(update, Replier.interpolate(left, self.inbound)) - Logger.action(self.inbound, self.action) - except NotFoundException: - return Replier.markdown(update, Replier.interpolate(not_left, self.inbound)) + if user.user_id not in group: + raise InvalidActionException(Replier.interpolate(not_left, self.inbound)) + + group.remove(user.user_id) + if not group: + chat.groups.pop(self.inbound.group_name) + + self.chat_repository.save(chat) + + Replier.markdown(update, Replier.interpolate(left, self.inbound)) diff --git a/src/bot/message/inboundMessage.py b/src/bot/message/inboundMessage.py index 697f7c6..c0f5505 100644 --- a/src/bot/message/inboundMessage.py +++ b/src/bot/message/inboundMessage.py @@ -6,7 +6,6 @@ import names from telegram.ext.callbackcontext import CallbackContext from telegram.update import Update -from entity.group import Group from validator.accessValidator import AccessValidator from validator.groupNameValidator import GroupNameValidator @@ -18,22 +17,21 @@ class InboundMessage: group_name: str username: str + default_group: str = 'default' + @staticmethod def create(update: Update, context: CallbackContext, group_specific: bool) -> InboundMessage: user_id = str(update.effective_user.id) AccessValidator.validate(user_id) chat_id = str(update.effective_chat.id) - group_name = Group.default_name + group_name = InboundMessage.default_group if context.args and context.args[0] and group_specific: group_name = str(context.args[0]).lower() GroupNameValidator.validate(group_name) - if group_name is not Group.default_name: - chat_id += f'~{group_name}' - username = update.effective_user.username or update.effective_user.first_name if not username: diff --git a/src/database/client.py b/src/database/client.py index 31aa468..37d9f4e 100755 --- a/src/database/client.py +++ b/src/database/client.py @@ -30,10 +30,11 @@ class Client(metaclass=Singleton): def find_many(self, collection: str, filter: dict) -> dict: return self.database.get_collection(collection).find(filter) - def update_one(self, collection: str, filter: dict, data: dict) -> None: + def save(self, collection: str, filter: dict, data: dict) -> None: self.database.get_collection(collection).update_one( filter, - {"$set": data} + {"$set": data}, + upsert=True ) def aggregate(self, collection, pipeline: list): diff --git a/src/entity/chat.py b/src/entity/chat.py new file mode 100644 index 0000000..3d70978 --- /dev/null +++ b/src/entity/chat.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable + +from bot.message.inboundMessage import InboundMessage + + +@dataclass +class Chat: + chat_id: str + groups: dict + + mongo_chat_id_index: str = '_id' + mongo_groups_index: str = 'groups' + + def to_mongo_document(self) -> dict: + return { + self.mongo_chat_id_index: self.chat_id, + self.mongo_groups_index: self.groups + } + + @staticmethod + def from_mongo_document(mongo_document: dict) -> Chat: + return Chat( + mongo_document[Chat.mongo_chat_id_index], + mongo_document[Chat.mongo_groups_index] + ) + + @staticmethod + def from_inbound_message(inbound: InboundMessage) -> Chat: + return Chat(inbound.chat_id, {inbound.group_name: []}) diff --git a/src/entity/group.py b/src/entity/group.py deleted file mode 100644 index 113202d..0000000 --- a/src/entity/group.py +++ /dev/null @@ -1,12 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass - - -@dataclass -class Group: - chat_id: str - group_name: str - users_count: int - - default_name: str = 'default' diff --git a/src/entity/user.py b/src/entity/user.py index 75611db..c8b58b4 100644 --- a/src/entity/user.py +++ b/src/entity/user.py @@ -1,40 +1,31 @@ from __future__ import annotations from dataclasses import dataclass -from typing import List + +from bot.message.inboundMessage import InboundMessage @dataclass class User: user_id: str username: str - chats: List[str] - collection: str = 'users' - id_index: str = '_id' - chats_index: str = 'chats' - username_index: str = 'username' - - def is_in_chat(self, chat_id: str) -> bool: - return chat_id in self.chats - - def add_to_chat(self, chat_id: str) -> None: - self.chats.append(chat_id) - - def remove_from_chat(self, chat_id: str) -> None: - if chat_id in self.chats: - self.chats.remove(chat_id) + mongo_user_id_index: str = '_id' + mongo_username_index: str = 'username' def to_mongo_document(self) -> dict: return { - self.username_index: self.username, - self.chats_index: self.chats + self.mongo_user_id_index: self.user_id, + self.mongo_username_index: self.username } @staticmethod def from_mongo_document(mongo_document: dict) -> User: return User( - mongo_document[User.id_index], - mongo_document[User.username_index], - mongo_document[User.chats_index] + mongo_document[User.mongo_user_id_index], + mongo_document[User.mongo_username_index] ) + + @staticmethod + def from_inbound_message(inbound: InboundMessage) -> User: + return User(inbound.user_id, inbound.username) diff --git a/src/exception/invalidActionException.py b/src/exception/invalidActionException.py new file mode 100644 index 0000000..e7df135 --- /dev/null +++ b/src/exception/invalidActionException.py @@ -0,0 +1,2 @@ +class InvalidActionException(Exception): + pass \ No newline at end of file diff --git a/src/repository/abstractRepository.py b/src/repository/abstractRepository.py new file mode 100644 index 0000000..320caed --- /dev/null +++ b/src/repository/abstractRepository.py @@ -0,0 +1,9 @@ +from database.client import Client + + +class AbstractRepository: + collection_name: str + database_client: Client + + def __init__(self): + self.database_client = Client() diff --git a/src/repository/chatRepository.py b/src/repository/chatRepository.py new file mode 100644 index 0000000..c8786f3 --- /dev/null +++ b/src/repository/chatRepository.py @@ -0,0 +1,55 @@ +from typing import Iterable + +from bot.message.inboundMessage import InboundMessage +from entity.chat import Chat +from entity.user import User +from exception.notFoundException import NotFoundException +from repository.abstractRepository import AbstractRepository +from repository.userRepository import UserRepository + + +class ChatRepository(AbstractRepository): + collection_name: str = 'chats' + user_repository: UserRepository + + def __init__(self): + super().__init__() + self.user_repository = UserRepository() + + def provide(self, inbound: InboundMessage) -> Chat: + try: + chat = self.get(inbound.chat_id) + if not chat.groups.get(inbound.group_name): + chat.groups[inbound.group_name] = [] + except NotFoundException: + chat = Chat.from_inbound_message(inbound) + + return chat + + def get(self, chat_id: str) -> Chat: + chat = self.database_client.find_one( + self.collection_name, + { + Chat.mongo_chat_id_index: chat_id + } + ) + + if not chat: + raise NotFoundException + + return Chat.from_mongo_document(chat) + + def get_users_for_group(self, chat_id: str, group: str) -> Iterable[User]: + chat = self.get(chat_id) + if not chat.groups.get(group): + raise NotFoundException + + return [self.user_repository.get(user_id) for user_id in chat.groups.get(group)] + + def save(self, chat: Chat) -> None: + self.database_client.save( + self.collection_name, + {Chat.mongo_chat_id_index: chat.chat_id}, + chat.to_mongo_document() + ) + diff --git a/src/repository/groupRepository.py b/src/repository/groupRepository.py deleted file mode 100644 index 004683f..0000000 --- a/src/repository/groupRepository.py +++ /dev/null @@ -1,56 +0,0 @@ -import re -from typing import Iterable - -from database.client import Client -from entity.group import Group -from entity.user import User -from exception.notFoundException import NotFoundException - - -class GroupRepository: - client: Client - - count: str = 'count' - - def __init__(self) -> None: - self.client = Client() - - def get_by_chat_id(self, chat_id: str) -> Iterable[Group]: - groups = self.client.aggregate( - User.collection, - [ - {"$unwind": f'${User.chats_index}'}, - { - "$match": { - User.chats_index: {"$regex": re.compile(f'^{chat_id}.*$')}, - }, - }, - { - "$group": { - "_id": { - "$last": {"$split": [f'${User.chats_index}', "~"]}, - }, - self.count: {"$count": {}}, - }, - }, - { - "$sort": {'_id': 1} - } - ] - ) - - result = [] - for group in groups: - group_name = group['_id'] - - if group_name == chat_id: - group_name = Group.default_name - - result.append( - Group(chat_id, group_name, group[self.count]) - ) - - if not result: - raise NotFoundException - - return result diff --git a/src/repository/userRepository.py b/src/repository/userRepository.py index 0ea3716..60dff49 100644 --- a/src/repository/userRepository.py +++ b/src/repository/userRepository.py @@ -1,74 +1,43 @@ -from typing import Iterable - from bot.message.inboundMessage import InboundMessage -from database.client import Client from entity.user import User from exception.notFoundException import NotFoundException +from repository.abstractRepository import AbstractRepository -class UserRepository: - client: Client +class UserRepository(AbstractRepository): + collection_name: str = 'users' - def __init__(self) -> None: - self.client = Client() + def __init__(self): + super().__init__() - def get_by_id(self, user_id: str) -> User: - user = self.client.find_one( - User.collection, + def provide(self, inbound: InboundMessage) -> User: + user = User.from_inbound_message(inbound) + + try: + entity = self.get(user.user_id) + if entity != user: + self.save(user) + except NotFoundException: + self.save(user) + + return user + + def get(self, user_id: str) -> User: + user = self.database_client.find_one( + self.collection_name, { - User.id_index: user_id + User.mongo_user_id_index: user_id } ) if not user: - raise NotFoundException(f'Could not find user with "{user_id}" id') - - return User( - user[User.id_index], - user[User.username_index], - user[User.chats_index] - ) - - def get_by_id_and_chat_id(self, user_id: str, chat_id: str) -> User: - user = self.get_by_id(user_id) - - if not user.is_in_chat(chat_id): raise NotFoundException - return user + return User.from_mongo_document(user) def save(self, user: User) -> None: - self.client.update_one( - User.collection, - {User.id_index: user.user_id}, + self.database_client.save( + self.collection_name, + {User.mongo_user_id_index: user.user_id}, user.to_mongo_document() ) - - def save_by_inbound_message(self, inbound_message: InboundMessage) -> None: - self.client.insert_one( - User.collection, - { - User.id_index: inbound_message.user_id, - User.username_index: inbound_message.username, - User.chats_index: [inbound_message.chat_id] - } - ) - - def get_all_for_chat(self, chat_id: str) -> Iterable[User]: - result = [] - users = self.client.find_many( - User.collection, - { - User.chats_index: { - "$in": [chat_id] - } - } - ) - - for record in users: - result.append(User.from_mongo_document(record)) - - if not result: - raise NotFoundException - - return result diff --git a/src/utils/messageBuilder.py b/src/utils/messageBuilder.py index 396664e..4cfaf52 100644 --- a/src/utils/messageBuilder.py +++ b/src/utils/messageBuilder.py @@ -3,16 +3,16 @@ from typing import Iterable from prettytable import prettytable from telegram.utils.helpers import mention_markdown -from entity.group import Group from entity.user import User class MessageBuilder: @staticmethod - def group_message(groups: Iterable[Group]) -> str: + def group_message(groups: dict) -> str: table = prettytable.PrettyTable(['Name', 'Members']) - table.add_rows([[record.group_name, record.users_count] for record in groups]) + for group in groups: + table.add_row([group, len(groups[group])]) return f'
{str(table)}
'