Skip to content

Commit

Permalink
Bumped pydantic to v2 (#10)
Browse files Browse the repository at this point in the history
* bumped pydantic to v2

Signed-off-by: Tiago Santana <[email protected]>

* add pydantic.v1 compatibility in pydantic tools and validatefunction

Signed-off-by: Tiago Santana <[email protected]>

* replace deprecated parse_obj() with model_validate()

Signed-off-by: Tiago Santana <[email protected]>

* replace deprecated parse_raw() with model_validate_json()

Signed-off-by: Tiago Santana <[email protected]>

* replace deprecated update_forward_refs() with model_rebuild()

Signed-off-by: Tiago Santana <[email protected]>

* replace deprecated .json() with .model_dump_json()

Signed-off-by: Tiago Santana <[email protected]>

* replace missing deprecated parse_raw() with model_validate_json()

Signed-off-by: Tiago Santana <[email protected]>

* fix import ValidationError

Signed-off-by: Tiago Santana <[email protected]>

* remove None default value in raw_value

Signed-off-by: Tiago Santana <[email protected]>

* changed pydantic version

Signed-off-by: Tiago Santana <[email protected]>

* reverse changes in defaults url in test config

Signed-off-by: Tiago Santana <[email protected]>

* fix payload arg in MessagePayload

Signed-off-by: Tiago Santana <[email protected]>

* add missing fix payload arg in app

Signed-off-by: Tiago Santana <[email protected]>

* replace parse_obj_as with TypeAdapter. fix import v1 validationerror

Signed-off-by: Tiago Santana <[email protected]>

* replaced validatedFunction with v2 validate_call. replaced missing parse_raw_as with validate_json

Signed-off-by: Tiago Santana <[email protected]>

* bumped version

Signed-off-by: Tiago Santana <[email protected]>

* changed tabulate version

Signed-off-by: Tiago Santana <[email protected]>

* fix mkdocs build by set version of mkdocs and add mognet request class description

Signed-off-by: Tiago Santana <[email protected]>

---------

Signed-off-by: Tiago Santana <[email protected]>
  • Loading branch information
SantanaTiago authored Aug 7, 2024
1 parent 77002ce commit b171610
Show file tree
Hide file tree
Showing 16 changed files with 1,180 additions and 824 deletions.
4 changes: 2 additions & 2 deletions demo/test/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_full_a_to_z():
assert submit_response.ok

# 2. Get the job to wait for...
job = Job.parse_obj(submit_response.json())
job = Job.model_validate(submit_response.json())

# 3. Wait for it, through polling...
while True:
Expand All @@ -33,7 +33,7 @@ def test_full_a_to_z():

assert result_response.ok

result = UploadJobResult.parse_obj(result_response.json())
result = UploadJobResult.model_validate(result_response.json())

if result.job_status not in READY_STATES:
time.sleep(1)
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ from mognet import App, task, Context, AppConfig
#
# Note that we also don't specify the URLs to the Redis and RabbitMQ themselves,
# opting instead for the defaults.
config = AppConfig.parse_obj({
config = AppConfig.model_validate({
"result_backend": {
# Assumes a Redis on localhost:6379 with no credentials
"redis": {}
Expand Down
14 changes: 7 additions & 7 deletions mognet/app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,14 +194,14 @@ async def get_current_status_of_nodes(
payload=MessagePayload(
id=str(request.id),
kind="Query",
payload=request,
payload=request.model_dump(),
)
)

try:
async for response in responses:
try:
yield StatusResponseMessage.parse_obj(response)
yield StatusResponseMessage.model_validate(response)
except asyncio.CancelledError:
break
except Exception as exc: # pylint: disable=broad-except
Expand Down Expand Up @@ -260,7 +260,7 @@ async def submit(self, req: "Request", context: Optional[Context] = None) -> Res
payload = MessagePayload(
id=str(req.id),
kind="Request",
payload=req,
payload=req.model_dump(),
priority=req.priority,
)

Expand Down Expand Up @@ -407,7 +407,7 @@ async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result:
payload = MessagePayload(
id=str(uuid.uuid4()),
kind=Revoke.MESSAGE_KIND,
payload=Revoke(id=request_id),
payload=Revoke(id=request_id).model_dump(),
)

await self.broker.send_control_message(payload)
Expand Down Expand Up @@ -682,7 +682,7 @@ async def _process_control_message(self, msg: IncomingMessagePayload):

try:
if msg.kind == Revoke.MESSAGE_KIND:
abort = Revoke.parse_obj(msg.payload)
abort = Revoke.model_validate(msg.payload)

_log.debug("Received request to revoke request id=%r", abort.id)

Expand All @@ -704,7 +704,7 @@ async def _process_control_message(self, msg: IncomingMessagePayload):
return

if msg.kind == "Query":
query = QueryRequestMessage.parse_obj(msg.payload)
query = QueryRequestMessage.model_validate(msg.payload)

if query.name == "Status":
# Get the status of this worker and reply to the incoming message
Expand All @@ -723,7 +723,7 @@ async def _process_control_message(self, msg: IncomingMessagePayload):
)

payload = MessagePayload(
id=str(reply.id), kind=reply.kind, payload=reply
id=str(reply.id), kind=reply.kind, payload=reply.model_dump()
)

return await self.broker.send_reply(msg, payload)
Expand Down
2 changes: 1 addition & 1 deletion mognet/app/app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class AppConfig(BaseModel):
@classmethod
def from_file(cls, file_path: str) -> "AppConfig":
with open(file_path, "r", encoding="utf-8") as config_file:
return cls.parse_raw(config_file.read())
return cls.model_validate_json(config_file.read())

# Maximum number of attempts to connect
max_reconnect_retries: int = 5
Expand Down
10 changes: 5 additions & 5 deletions mognet/backend/redis_result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, AnyStr, Dict, Iterable, List, Optional, Set
from uuid import UUID

from pydantic.tools import parse_raw_as
from pydantic import TypeAdapter
from redis.asyncio import Redis, from_url
from redis.exceptions import ConnectionError, TimeoutError

Expand Down Expand Up @@ -142,7 +142,7 @@ async def get_or_create(self, result_id: UUID) -> Result:
return self._decode_result(value)

def _encode_result_value(self, value: ResultValueHolder) -> Dict[str, bytes]:
contents = value.json().encode()
contents = value.model_dump_json().encode()
encoding = b"null"

if self.config.redis.result_value_encoding == Encoding.GZIP:
Expand All @@ -164,7 +164,7 @@ def _decode_result_value(self, encoded: Dict[bytes, bytes]) -> ResultValueHolder
if encoded.get(b"content_type") != _json_bytes("application/json"):
raise ValueError(f"Unknown content_type={encoded.get(b'content_type')!r}")

return ResultValueHolder.parse_raw(contents, content_type="application/json")
return ResultValueHolder.model_validate_json(contents)

@_retry
async def set(self, result_id: UUID, result: Result):
Expand Down Expand Up @@ -322,7 +322,7 @@ async def waiter():
while True:
raw_state = await shield(self._redis.hget(key, "state")) or b"null"

state = parse_raw_as(t, raw_state)
state = TypeAdapter(t).validate_json(raw_state)

if state is None:
raise ResultValueLost(result_id)
Expand Down Expand Up @@ -432,7 +432,7 @@ def _decode_result(self, json_dict: Dict[bytes, bytes]) -> Result:


def _encode_result(result: Result) -> Dict[str, bytes]:
json_dict: dict = json.loads(result.json())
json_dict: dict = json.loads(result.model_dump_json())
return _dict_to_json_dict(json_dict)


Expand Down
10 changes: 5 additions & 5 deletions mognet/broker/amqp_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def send_task_message(self, queue: str, payload: MessagePayload):
amqp_queue = self._task_queue_name(queue)

msg = Message(
body=payload.json().encode(),
body=payload.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
priority=payload.priority,
Expand Down Expand Up @@ -201,7 +201,7 @@ async def consume_control_queue(
@_retry
async def send_control_message(self, payload: MessagePayload):
msg = Message(
body=payload.json().encode(),
body=payload.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
message_id=payload.id,
Expand All @@ -226,7 +226,7 @@ async def _send_query_message(self, payload: MessagePayload):
await callback_queue.bind(self._direct_exchange)

msg = Message(
body=payload.json().encode(),
body=payload.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
message_id=payload.id,
Expand Down Expand Up @@ -257,7 +257,7 @@ async def send_query_message(
msg = _AmqpIncomingMessagePayload(
broker=self, incoming_message=message, **contents
)
yield QueryResponseMessage.parse_obj(msg.payload)
yield QueryResponseMessage.model_validate(msg.payload)
finally:
if callback_queue is not None:
await callback_queue.delete()
Expand Down Expand Up @@ -358,7 +358,7 @@ async def send_reply(self, message: IncomingMessagePayload, reply: MessagePayloa
raise ValueError("Message has no reply_to set")

msg = Message(
body=reply.json().encode(),
body=reply.model_dump_json().encode(),
content_type="application/json",
content_encoding="utf-8",
message_id=reply.id,
Expand Down
2 changes: 1 addition & 1 deletion mognet/broker/memory_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def send_query_message(

while True:
response = await q.get()
yield QueryResponseMessage.parse_obj(response.payload)
yield QueryResponseMessage.model_validate(response.payload)
finally:
self._callback_queues.pop(queue_id, None)

Expand Down
2 changes: 1 addition & 1 deletion mognet/cli/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def read_status():
typer.echo(tabulate.tabulate(table_data, headers=table_headers))

elif format == "json":
typer.echo(report.json(indent=json_indent, ensure_ascii=False))
typer.echo(report.model_dump_json(indent=json_indent))

if not poll:
break
Expand Down
2 changes: 1 addition & 1 deletion mognet/cli/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def build_tree(n: ResultTree, parent: Optional[ResultTree] = None):
t.show()

if format == "json":
print(tree.json(indent=json_indent, ensure_ascii=False))
print(tree.model_dump_json(indent=json_indent))

if not poll:
break
Expand Down
4 changes: 2 additions & 2 deletions mognet/exceptions/task_exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel
from pydantic.error_wrappers import ValidationError
from pydantic import ValidationError
from typing import Any, Dict, List, Tuple, Union

# Taken from pydantic.error_wrappers
Expand Down Expand Up @@ -34,4 +34,4 @@ def __init__(self, errors: List[InvalidErrorInfo]) -> None:

@classmethod
def from_validation_error(cls, validation_error: ValidationError):
return cls([InvalidErrorInfo.parse_obj(e) for e in validation_error.errors()])
return cls([InvalidErrorInfo.model_validate(e) for e in validation_error.errors()])
19 changes: 9 additions & 10 deletions mognet/model/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
)

from datetime import datetime, timedelta
from pydantic import BaseModel, TypeAdapter
from pydantic.fields import PrivateAttr
from pydantic.main import BaseModel
from mognet.tools.dates import now_utc
from pydantic.tools import parse_obj_as
from uuid import UUID

if TYPE_CHECKING:
Expand Down Expand Up @@ -75,7 +74,7 @@ def deserialize(self) -> Any:
if self.value_type is not None:
cls = _get_attr(self.value_type)

value = parse_obj_as(cls, self.raw_value)
value = TypeAdapter(cls).validate_python(self.raw_value)
else:
value = self.raw_value

Expand Down Expand Up @@ -126,8 +125,8 @@ class _ExceptionInfo(BaseModel):

traceback: str

raw_data: Optional[str]
raw_data_encoding: Optional[str]
raw_data: Optional[str] = None
raw_data_encoding: Optional[str] = None

@classmethod
def from_exception(cls, exception: BaseException):
Expand Down Expand Up @@ -182,13 +181,13 @@ class Result(BaseModel):

parent_id: Optional[UUID] = None

created: Optional[datetime]
started: Optional[datetime]
finished: Optional[datetime]
created: Optional[datetime] = None
started: Optional[datetime] = None
finished: Optional[datetime] = None

node_id: Optional[str]
node_id: Optional[str] = None

request_kwargs_repr: Optional[str]
request_kwargs_repr: Optional[str] = None

_backend: "BaseResultBackend" = PrivateAttr()
_children: Optional[ResultChildren] = PrivateAttr()
Expand Down
2 changes: 1 addition & 1 deletion mognet/model/result_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,4 @@ def dict(self, **kwargs):
}


ResultTree.update_forward_refs()
ResultTree.model_rebuild()
13 changes: 8 additions & 5 deletions mognet/primitives/request.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from datetime import timedelta, datetime
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from uuid import UUID, uuid4
from pydantic import conint
from pydantic import BaseModel, Field

from pydantic.fields import Field
from pydantic.generics import GenericModel
from typing_extensions import Annotated


TReturn = TypeVar("TReturn")

Priority = conint(ge=0, le=10)
Priority = Annotated[int, Field(ge=0, le=10)]


class Request(GenericModel, Generic[TReturn]):
class Request(BaseModel, Generic[TReturn]):
"""
Represents the Mognet request.
"""
id: UUID = Field(default_factory=uuid4)
name: str

Expand Down
Loading

0 comments on commit b171610

Please sign in to comment.