Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions src/classes/baseprocessor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import logging

from common.config import ServerConfig
from common.static import Static
from common.tools import Tools
from common.opcodes import Opcodes
from common.proto_tcp import MobileProto
from common.proto_web import WebProto
from common.opcodes import Opcodes
from common.static import Static
from common.tools import Tools


class BaseProcessor:
def __init__(self, db_pool=None, clients=None, send_event=None, type="socket"):
Expand All @@ -21,7 +23,7 @@ def __init__(self, db_pool=None, clients=None, send_event=None, type="socket"):
self.event = send_event
self.logger = logging.getLogger(__name__)

self.type = type
self.type = "mobile" if type == "socket" else type

if type == "socket":
self.proto = MobileProto()
Expand All @@ -31,23 +33,26 @@ def __init__(self, db_pool=None, clients=None, send_event=None, type="socket"):
async def _send(self, writer, packet):
try:
# Если объектом является вебсокет, то используем функцию send для отправки
if hasattr(writer, 'send'):
if hasattr(writer, "send"):
await writer.send(packet)
else: # В ином случае отправляем как в обычный сокет
else: # В ином случае отправляем как в обычный сокет
writer.write(packet)
await writer.drain()
except Exception:
pass

async def _send_error(self, seq, opcode, error_type, writer):
payload = self.static.ERROR_TYPES.get(error_type, {
"localizedMessage": "Неизвестная ошибка",
"error": "unknown.error",
"message": "Unknown error",
"title": "Неизвестная ошибка"
})
payload = self.static.ERROR_TYPES.get(
error_type,
{
"localizedMessage": "Неизвестная ошибка",
"error": "unknown.error",
"message": "Unknown error",
"title": "Неизвестная ошибка",
},
)

packet = self.proto.pack_packet(
cmd=self.proto.CMD_ERR, seq=seq, opcode=opcode, payload=payload
)
await self._send(writer, packet)
await self._send(writer, packet)
2 changes: 1 addition & 1 deletion src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,4 @@ def __init__(self):
origins = [x.strip() for x in os.getenv("origins", "").split(",") if x.strip()] if os.getenv("origins") else None

### sms шлюз
sms_gateway_url = os.getenv("sms_gateway_url") or "http://127.0.0.1/sms-gateway"
sms_gateway_url = os.getenv("sms_gateway_url", "")
43 changes: 30 additions & 13 deletions src/common/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,12 @@ def generate_chat(
"""Генерация чата"""
# Генерируем список участников
if isinstance(participants, dict):
result_participants = {str(k): v for k, v in participants.items()}
result_participants = {
int(k): int(v) if v is not None else 0 for k, v in participants.items()
}
else:
# assume list
result_participants = {str(participant): 0 for participant in participants}
result_participants = {int(participant): 0 for participant in participants}

result = None

Expand All @@ -125,7 +127,14 @@ def generate_chat(
# Возвращаем
return result

async def generate_chats(self, chatIds, db_pool, senderId, include_favourites=True, protocol_type='mobile'):
async def generate_chats(
self,
chatIds,
db_pool,
senderId,
include_favourites=True,
protocol_type="mobile",
):
"""Генерирует чаты для отдачи клиенту"""
# Готовый список с чатами
chats = []
Expand Down Expand Up @@ -188,12 +197,14 @@ async def generate_chats(self, chatIds, db_pool, senderId, include_favourites=Tr
)

# Получаем ID предыдущего сообщения для избранного (чат ID = senderId)
prevMessageId = await self.get_previous_message_id(senderId, db_pool, protocol_type=protocol_type)
prevMessageId = await self.get_previous_message_id(
senderId, db_pool, protocol_type=protocol_type
)

# Хардкодим в лист чатов избранное
chats.append(
self.generate_chat(
chatId if protocol_type == 'mobile' else str(chatId),
chatId if protocol_type == "mobile" else str(chatId),
senderId,
"DIALOG",
participants,
Expand Down Expand Up @@ -241,7 +252,7 @@ async def insert_message(
# Возвращаем айдишки
return int(message_id), int(last_message_id), message_time

async def get_last_message(self, chatId, db_pool, protocol_type='mobile'):
async def get_last_message(self, chatId, db_pool, protocol_type="mobile"):
"""Получение последнего сообщения в чате"""
async with db_pool.acquire() as db_connection:
async with db_connection.cursor() as cursor:
Expand All @@ -259,7 +270,9 @@ async def get_last_message(self, chatId, db_pool, protocol_type='mobile'):

# Собираем сообщение
message = {
"id": row.get("id") if protocol_type == 'mobile' else str(row.get('id')),
"id": row.get("id")
if protocol_type == "mobile"
else str(row.get("id")),
"time": int(row.get("time")),
"type": row.get("type"),
"sender": row.get("sender"),
Expand All @@ -273,7 +286,7 @@ async def get_last_message(self, chatId, db_pool, protocol_type='mobile'):
# Возвращаем
return message, int(row.get("time"))

async def get_previous_message_id(self, chatId, db_pool, protocol_type='mobile'):
async def get_previous_message_id(self, chatId, db_pool, protocol_type="mobile"):
"""Получение ID предыдущего сообщения (второго с конца) в чате."""
async with db_pool.acquire() as db_connection:
async with db_connection.cursor() as cursor:
Expand All @@ -285,10 +298,14 @@ async def get_previous_message_id(self, chatId, db_pool, protocol_type='mobile')

# Если результат есть, возвращаем его
if row:
return row.get("id") if protocol_type == 'mobile' else str(row.get('id'))
return (
row.get("id")
if protocol_type == "mobile"
else str(row.get("id"))
)

# В ином случае возвращаем 0
return 0 if protocol_type == 'mobile' else "0"
return 0 if protocol_type == "mobile" else "0"

async def get_participant_last_activity(self, chatId, participant_ids, db_pool):
"""Возвращает словарь {participant_id: last_activity_time} для участников чата."""
Expand All @@ -310,11 +327,11 @@ async def get_participant_last_activity(self, chatId, participant_ids, db_pool):
rows = await cursor.fetchall()

# Собираем список участников без времени последней активности в чате
result = {str(pid): 0 for pid in participant_ids}
result = {int(pid): 0 for pid in participant_ids}

# Обновляем для каждого участника время последней активности в чате
for row in rows:
sender = str(row["sender"])
sender = int(row["sender"])
last_time = row["last_time"]
if last_time is not None:
result[sender] = int(last_time)
Expand All @@ -330,7 +347,7 @@ async def get_chat_participants(self, chatId, db_pool):
(chatId,),
)
rows = await cursor.fetchall()
return [row["user_id"] for row in rows]
return [int(row["user_id"]) for row in rows]

async def auth_required(self, userPhone, coro, *args):
if userPhone:
Expand Down
69 changes: 67 additions & 2 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,70 @@
server_config = ServerConfig()


class SQLiteCursorCompat:
def __init__(self, connection):
self.connection = connection
self.cursor = None

async def __aenter__(self):
self.cursor = await self.connection.cursor()
return self

async def __aexit__(self, exc_type, exc, tb):
if self.cursor is not None:
await self.cursor.close()
self.cursor = None

@property
def lastrowid(self):
return None if self.cursor is None else self.cursor.lastrowid

def _normalize_query(self, query):
return query.replace("%s", "?").replace(
"UNIX_TIMESTAMP()", "CAST(strftime('%s','now') AS INTEGER)"
)

async def execute(self, query, params=()):
normalized_query = self._normalize_query(query)
if params is None:
params = ()
elif not isinstance(params, (tuple, list, dict)):
params = (params,)
await self.cursor.execute(normalized_query, params)

async def fetchone(self):
row = await self.cursor.fetchone()
if row is None:
return None
return dict(row)

async def fetchall(self):
rows = await self.cursor.fetchall()
return [dict(row) for row in rows]


class SQLiteConnectionCompat:
def __init__(self, connection):
self.connection = connection

async def __aenter__(self):
return self

async def __aexit__(self, exc_type, exc, tb):
return False

def cursor(self):
return SQLiteCursorCompat(self.connection)


class SQLitePoolCompat:
def __init__(self, connection):
self.connection = connection

def acquire(self):
return SQLiteConnectionCompat(self.connection)


async def init_db():
"""Инициализация базы данных"""

Expand All @@ -32,8 +96,9 @@ async def init_db():
elif server_config.db_type == "sqlite":
import aiosqlite

raw_db = await aiosqlite.connect(server_config.db_file)
db["acquire"] = lambda: raw_db
raw_db = await aiosqlite.connect(server_config.db_file, isolation_level=None)
raw_db.row_factory = aiosqlite.Row
db = SQLitePoolCompat(raw_db)

# Возвращаем
return db
Expand Down
Loading