Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Add sender_address column and index

Revision ID: a18051177947
Revises: dfc5f95e4fe6
Create Date: 2025-07-29 14:28:52.871778

"""
from alembic import op
import sqlalchemy as sa

# revision identifiers, used by Alembic.
revision = 'dfc5f95e4fe6'
down_revision = '8ece21fbeb47'
branch_labels = None
depends_on = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('pending_messages', sa.Column('content_address', sa.String(), sa.Computed("content->>'address'", persisted=True), nullable=True))
op.create_index(op.f('ix_pending_messages_content_address'), 'pending_messages', ['content_address'], unique=False)
op.create_index('ix_pending_messages_content_address_attempt', 'pending_messages', ['content_address', 'next_attempt'], unique=False)
# ### end Alembic commands ###


def downgrade() -> None:
op.drop_index('ix_pending_messages_content_address_attempt', table_name='pending_messages')
op.drop_index(op.f('ix_pending_messages_content_address'), table_name='pending_messages')
op.drop_column('pending_messages', 'content_address')
9 changes: 9 additions & 0 deletions src/aleph/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def get_defaults():
# Maximum number of chain/sync events processed at the same time.
"max_concurrency": 20,
},
"message_workers": {
# Number of message worker processes to start
"count": 5,
"message_count": 40, # number of message to fetch by worker
},
"cron": {
# Interval between cron job trackers runs, expressed in hours.
"period": 0.5, # 30 mins
Expand Down Expand Up @@ -198,6 +203,10 @@ def get_defaults():
"pending_message_exchange": "aleph-pending-messages",
# Name of the RabbitMQ exchange used for sync/message events (input of the TX processor).
"pending_tx_exchange": "aleph-pending-txs",
# Name of RabbotMQ exchange used for message processing
"message_processing_exchange": "aleph.processing",
# Name of RabbotMQ exchange used for result of message processing
"message_result_exchange": "aleph.results",
},
"redis": {
# Hostname of the Redis service.
Expand Down
146 changes: 140 additions & 6 deletions src/aleph/db/accessors/pending_messages.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import datetime as dt
from typing import Any, Collection, Dict, Iterable, Optional, Sequence
from typing import Any, Collection, Dict, Iterable, List, Optional, Sequence, Set, Tuple

from aleph_message.models import Chain
from sqlalchemy import delete, func, select, update
from sqlalchemy.orm import selectinload
from sqlalchemy import delete, func, select, text, update
from sqlalchemy.orm import selectinload, undefer
from sqlalchemy.sql import Update

from aleph.db.models import ChainTxDb, PendingMessageDb
Expand Down Expand Up @@ -79,10 +79,20 @@ async def get_pending_messages(
async def get_pending_message(
session: AsyncDbSession, pending_message_id: int
) -> Optional[PendingMessageDb]:
select_stmt = select(PendingMessageDb).where(
PendingMessageDb.id == pending_message_id
stmt = (
select(PendingMessageDb)
.where(PendingMessageDb.id == pending_message_id)
.options(selectinload(PendingMessageDb.tx), undefer("*"))
.execution_options(populate_existing=True)
)
return (await session.execute(select_stmt)).scalar_one_or_none()

result = await session.execute(stmt)
pending = result.scalar_one_or_none()

if pending is not None:
await session.refresh(pending, attribute_names=None)

return pending


async def count_pending_messages(
Expand Down Expand Up @@ -134,3 +144,127 @@ async def delete_pending_message(
await session.execute(
delete(PendingMessageDb).where(PendingMessageDb.id == pending_message.id)
)


async def get_next_pending_messages_from_different_senders(
session: AsyncDbSession,
current_time: dt.datetime,
fetched: bool = True,
exclude_item_hashes: Optional[Set[str]] = None,
exclude_addresses: Optional[Set[str]] = None,
limit: int = 40,
) -> List[PendingMessageDb]:
"""
Optimized query using content_address and indexed sorting.
"""

sql_parts = [
"SELECT DISTINCT ON (content_address) *",
"FROM pending_messages",
"WHERE next_attempt <= :current_time",
"AND fetched = :fetched",
"AND content IS NOT NULL",
"AND content_address IS NOT NULL",
]

params = {
"current_time": current_time,
"fetched": fetched,
"limit": limit,
}

if exclude_item_hashes:
hash_keys = []
for i, h in enumerate(exclude_item_hashes):
key = f"exclude_hash_{i}"
hash_keys.append(f":{key}")
params[key] = h
sql_parts.append(f"AND item_hash NOT IN ({', '.join(hash_keys)})")

if exclude_addresses:
addr_keys = []
for i, a in enumerate(exclude_addresses):
key = f"exclude_addr_{i}"
addr_keys.append(f":{key}")
params[key] = a
sql_parts.append(f"AND content_address NOT IN ({', '.join(addr_keys)})")

sql_parts.append("ORDER BY content_address, next_attempt")
sql_parts.append("LIMIT :limit")

stmt = (
select(PendingMessageDb)
.from_statement(text("\n".join(sql_parts)))
.params(**params)
)
result = await session.execute(stmt)
return result.scalars().all()


async def get_sender_with_pending_batch(
session,
batch_size: int,
exclude_addresses: Set[str],
exclude_item_hashes: Set[str],
current_time: dt.datetime,
candidate_senders: Optional[Set[str]] = None,
) -> Optional[Tuple[str, List[PendingMessageDb]]]:
"""
Finds the best sender to process a batch from.
Priority: sender with most pending messages, then oldest pending message.
"""

conditions = [
PendingMessageDb.next_attempt <= current_time,
PendingMessageDb.fetched.is_(True),
PendingMessageDb.content.isnot(None),
PendingMessageDb.content_address.isnot(None),
~PendingMessageDb.content_address.in_(exclude_addresses),
~PendingMessageDb.item_hash.in_(exclude_item_hashes),
]

if candidate_senders:
conditions.append(PendingMessageDb.content_address.in_(candidate_senders))

# Step 1: Find sender with most pending messages, then oldest attempt
subquery = (
select(
PendingMessageDb.content_address,
func.count().label("msg_count"),
func.min(PendingMessageDb.next_attempt).label("oldest_attempt"),
)
.where(*conditions)
.group_by(PendingMessageDb.content_address)
.order_by(
func.count().desc(), # Most messages
func.min(PendingMessageDb.next_attempt).asc(), # Oldest message
)
.limit(1)
.subquery()
)

sender_result = await session.execute(select(subquery.c.content_address))
row = sender_result.first()
if not row:
return None

sender = row[0]

# Step 2: Fetch batch of messages from that sender
messages_query = (
select(PendingMessageDb)
.where(
PendingMessageDb.content_address == sender,
PendingMessageDb.next_attempt <= current_time,
PendingMessageDb.fetched.is_(True),
PendingMessageDb.content.isnot(None),
~PendingMessageDb.item_hash.in_(exclude_item_hashes),
)
.order_by(PendingMessageDb.next_attempt.asc())
.limit(batch_size)
)

result = await session.execute(messages_query)
messages = result.scalars().all()

return sender, messages
9 changes: 9 additions & 0 deletions src/aleph/db/models/pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Boolean,
CheckConstraint,
Column,
Computed,
ForeignKey,
Index,
Integer,
Expand Down Expand Up @@ -69,11 +70,19 @@ class PendingMessageDb(Base):
fetched: bool = Column(Boolean, nullable=False)
origin: Optional[str] = Column(String, nullable=True, default=MessageOrigin.P2P)

content_address: Optional[str] = Column(
String, Computed("content->>'address'", persisted=True), index=True
)
__table_args__ = (
CheckConstraint(
"signature is not null or not check_message",
name="signature_not_null_if_check_message",
),
Index(
"ix_pending_messages_content_address_attempt",
"content_address",
"next_attempt",
),
UniqueConstraint("sender", "item_hash", "signature", name="uq_pending_message"),
)

Expand Down
12 changes: 9 additions & 3 deletions src/aleph/handlers/message_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from aleph.handlers.content.post import PostMessageHandler
from aleph.handlers.content.store import StoreMessageHandler
from aleph.handlers.content.vm import VmMessageHandler
from aleph.schemas.api.messages import PendingMessage, format_message
from aleph.schemas.pending_messages import parse_message
from aleph.storage import StorageService
from aleph.toolkit.timestamp import timestamp_to_datetime
Expand Down Expand Up @@ -414,7 +415,10 @@ async def process(
existing_message=existing_message,
pending_message=pending_message,
)
return ProcessedMessage(message=existing_message, is_confirmation=True)
# We parse to dict since it's will pass on rabbitmq (at this points we don't need anymore to have DB objects)
return ProcessedMessage(
message=format_message(existing_message), is_confirmation=True
)

# Note: Check if message is already forgotten (and confirm it)
# this is to avoid race conditions when a confirmation arrives after the FORGET message has been preocessed
Expand All @@ -428,7 +432,9 @@ async def process(
pending_message=pending_message,
)
return RejectedMessage(
pending_message=pending_message,
pending_message=PendingMessage.model_validate(
pending_message.to_dict()
),
error_code=ErrorCode.FORGOTTEN_DUPLICATE,
)

Expand Down Expand Up @@ -456,7 +462,7 @@ async def process(
await content_handler.process(session=session, messages=[message])

return ProcessedMessage(
message=message,
message=format_message(message),
origin=(
MessageOrigin(pending_message.origin)
if pending_message.origin
Expand Down
20 changes: 20 additions & 0 deletions src/aleph/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Coroutine, List

from aleph.jobs.fetch_pending_messages import fetch_pending_messages_subprocess
from aleph.jobs.message_worker import message_worker_subprocess
from aleph.jobs.process_pending_messages import (
fetch_and_process_messages_task,
pending_messages_subprocess,
Expand Down Expand Up @@ -38,6 +39,25 @@ def start_jobs(
target=pending_txs_subprocess,
args=(config_values,),
)

num_workers = (
config.aleph.jobs.message_workers.count.value
if hasattr(config.aleph.jobs, "message_workers")
and hasattr(config.aleph.jobs.message_workers, "count")
else 5
)
LOGGER.info(f"Starting {num_workers} message worker processes")
worker_processes = []
for i in range(num_workers):
worker_id = f"worker-{i+1}"
wp = Process(
target=message_worker_subprocess,
args=(config_values, worker_id),
)
worker_processes.append(wp)
wp.start()
LOGGER.info(f"Started message worker {worker_id}")

p1.start()
p2.start()
p3.start()
Expand Down
18 changes: 18 additions & 0 deletions src/aleph/jobs/fetch_pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aleph.chains.signature_verifier import SignatureVerifier
from aleph.db.accessors.pending_messages import (
get_next_pending_messages,
get_pending_message,
make_pending_message_fetched_statement,
)
from aleph.db.connection import make_async_engine, make_async_session_factory
Expand Down Expand Up @@ -54,6 +55,10 @@ def __init__(

async def fetch_pending_message(self, pending_message: PendingMessageDb):
async with self.session_factory() as session:
# Store ID before any potential session operations to ensure we can access it
pending_message_id = pending_message.id
item_hash = pending_message.item_hash

try:
message = await self.message_handler.verify_message(
pending_message=pending_message
Expand All @@ -69,12 +74,25 @@ async def fetch_pending_message(self, pending_message: PendingMessageDb):
except Exception as e:
await session.rollback()

# Query the message again after rollback

pending_message = await get_pending_message(
session, pending_message_id=pending_message_id
)

if pending_message is None:
LOGGER.error(
f"Could not retrieve pending message {item_hash} with ID {pending_message_id} after rollback"
)
return None

_ = await self.handle_processing_error(
session=session,
pending_message=pending_message,
exception=e,
)
await session.commit()

return None

async def fetch_pending_messages(
Expand Down
Loading
Loading