diff --git a/kausal_common b/kausal_common index 9a7084f8..cce778e1 160000 --- a/kausal_common +++ b/kausal_common @@ -1 +1 @@ -Subproject commit 9a7084f830d8cccbfe4fa9f41ce437301d09506d +Subproject commit cce778e1db05d217c078a9fc357ca807ad9c9be8 diff --git a/paths/schema.py b/paths/schema.py index cc80cecd..b7e71c47 100644 --- a/paths/schema.py +++ b/paths/schema.py @@ -27,7 +27,7 @@ from nodes.schema import Mutations as NodesMutations, Query as NodesQuery from pages.schema import Query as PagesQuery from params.schema import Mutations as ParamsMutations, Query as ParamsQuery, types as params_types -from users.schema import Mutations as UsersMutations, Query as UsersQuery +from users.schema import Query as UsersQuery if TYPE_CHECKING: from kausal_common.graphene import GQLInfo @@ -86,7 +86,7 @@ def resolve_unit(root: Query, info: GQLInfo, value: str) -> Unit: return unit -class Mutations(ParamsMutations, NodesMutations, FrameworksMutations, UsersMutations): +class Mutations(ParamsMutations, NodesMutations, FrameworksMutations): pass @@ -114,7 +114,7 @@ class SBNode: @sb.type(name='Query') -class SBQuery: +class SBQuery: # FIXME this does not seem to have any effect at the moment @sb.field def node(self, info: SBInfo, id: str) -> SBNode: context = info.context.instance.context @@ -122,28 +122,21 @@ def node(self, info: SBInfo, id: str) -> SBNode: return SBNode(id=cast(sb.ID, node.id)) -SB_MUTATION_TYPES: list[type] = [] -if test_mode_enabled(): - SB_MUTATION_TYPES.append(TestModeMutations) - -SBMutation: type | None = None -if SB_MUTATION_TYPES: - SBMutation = merge_types('Mutation', tuple(SB_MUTATION_TYPES)) - - -def generate_strawberry_schema() -> sb.Schema: +def generate_strawberry_schema(query: type, mutation: type | None = None) -> sb.Schema: from kausal_common.strawberry.registry import strawberry_types sb_schema = sb.Schema( - query=SBQuery, mutation=SBMutation, types=strawberry_types, directives=[context_directive] + # TODO: Add DjangoOptimizerExtension? + # https://strawberry.rocks/docs/django/guide/optimizer + query=query, mutation=mutation, types=strawberry_types, directives=[context_directive] ) return sb_schema -def generate_schema() -> tuple[sb.Schema, CombinedSchema]: +def generate_schema(sb_query: type, sb_mutation: type | None = None) -> tuple[sb.Schema, CombinedSchema]: # We generate the Strawberry schema just to be able to utilize the # resolved GraphQL types directly in the Graphene schema. - sb_schema = generate_strawberry_schema() + sb_schema = generate_strawberry_schema(sb_query, sb_mutation) schema = CombinedSchema( sb_schema=sb_schema, @@ -155,4 +148,4 @@ def generate_schema() -> tuple[sb.Schema, CombinedSchema]: return sb_schema, schema -sb_schema, schema = generate_schema() +sb_schema, schema = generate_schema(SBQuery) diff --git a/paths/schema_test_mode.py b/paths/schema_test_mode.py new file mode 100644 index 00000000..854c78ca --- /dev/null +++ b/paths/schema_test_mode.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from strawberry.tools import merge_types + +from kausal_common.deployment import test_mode_enabled +from kausal_common.testing.schema import TestModeMutations + +from paths.schema import SBQuery, generate_schema + +SB_MUTATION_TYPES: list[type] = [] +if test_mode_enabled(): + SB_MUTATION_TYPES.append(TestModeMutations) + +SBMutation: type | None = None +if SB_MUTATION_TYPES: + SBMutation = merge_types('Mutation', tuple(SB_MUTATION_TYPES)) + +sb_schema, schema = generate_schema(SBQuery, SBMutation) diff --git a/paths/settings.py b/paths/settings.py index 46cbb543..ca814c89 100644 --- a/paths/settings.py +++ b/paths/settings.py @@ -15,7 +15,7 @@ from corsheaders.defaults import default_headers as default_cors_headers from kausal_common import ENV_SCHEMA as COMMON_ENV_SCHEMA, register_settings as register_common_settings -from kausal_common.deployment import set_secret_file_vars +from kausal_common.deployment import set_secret_file_vars, test_mode_enabled from kausal_common.deployment.http import get_allowed_cors_headers from kausal_common.sentry.init import init_sentry @@ -290,8 +290,12 @@ SESSION_COOKIE_SAMESITE = 'None' SESSION_COOKIE_SECURE = True +if test_mode_enabled(): + GRAPHENE_SCHEMA = f'{PROJECT_NAME}.schema_test_mode.schema' +else: + GRAPHENE_SCHEMA = f'{PROJECT_NAME}.schema.schema' GRAPHENE = { - 'SCHEMA': f'{PROJECT_NAME}.schema.schema', + 'SCHEMA': GRAPHENE_SCHEMA, } GRAPPLE = { 'APPS': ['pages'], diff --git a/users/schema.py b/users/schema.py index cc19ba9f..70ee9790 100644 --- a/users/schema.py +++ b/users/schema.py @@ -1,66 +1,52 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import dataclasses import graphene -from graphene_django import DjangoObjectType -from graphql import GraphQLError +import strawberry +from strawberry.types.field import StrawberryField -from .models import User +from kausal_common.strawberry.registry import register_strawberry_type -if TYPE_CHECKING: - from collections.abc import Sequence +from frameworks.roles import FrameworkRoleDef +from users.models import User # noqa: TC001 - from kausal_common.graphene import GQLInfo +# Instead of defining a new class for the Strawberry type UserFrameworkRole and copying the fields of FrameworkRoleDef, +# register FrameworkRoleDef with a different name. This way we don't need to juggle data around in +# `UserType.framework_roles()`. +strawberry.type(FrameworkRoleDef, name='UserFrameworkRole') +register_strawberry_type(FrameworkRoleDef) - from frameworks.roles import FrameworkRoleDef +@register_strawberry_type +@strawberry.type +class UserType: + id: int + email: str + first_name: str + last_name: str -class UserFrameworkRole(graphene.ObjectType): - framework_id = graphene.ID(required=True) - role_id = graphene.String(required=False) - org_slug = graphene.String(required=False) - org_id = graphene.String(required=False) + _user: strawberry.Private[User] + def __init__(self, user: User): + proper_fields = [ + field.name for field in dataclasses.fields(self) + if not isinstance(field, StrawberryField) and field.name != '_user' + ] + for field in proper_fields: + setattr(self, field, getattr(user, field)) + self._user = user -class UserType(DjangoObjectType): - framework_roles = graphene.List(graphene.NonNull(UserFrameworkRole)) - - class Meta: - model = User - fields = ('id', 'email', 'first_name', 'last_name') - - @staticmethod - def resolve_framework_roles(root: User, info: GQLInfo) -> Sequence[FrameworkRoleDef]: - return root.extra.framework_roles + @strawberry.field + def framework_roles(self) -> list[FrameworkRoleDef]: + return list(self._user.extra.framework_roles) class Query(graphene.ObjectType): me = graphene.Field(UserType) - def resolve_me(self, info): + def resolve_me(self, info) -> UserType | None: user = info.context.user if user.is_authenticated: - return user + return UserType(user) return None - - -class RegisterUser(graphene.Mutation): - class Arguments: - email = graphene.String(required=True) - password = graphene.String(required=True) - - user = graphene.Field(UserType) - - def mutate(self, info: GQLInfo, email: str, password: str): - email = email.strip().lower() - if User.objects.filter(email=email).exists(): - raise GraphQLError("User with email already exists", nodes=info.field_nodes) - user = User(email=email) - user.set_password(password) - user.save() - return RegisterUser(user=user) - - -class Mutations(graphene.ObjectType): - register_user = RegisterUser.Field()