diff --git a/src/Trax.Api.GraphQL/Configuration/GraphQLConfiguration.cs b/src/Trax.Api.GraphQL/Configuration/GraphQLConfiguration.cs new file mode 100644 index 0000000..d7e92ad --- /dev/null +++ b/src/Trax.Api.GraphQL/Configuration/GraphQLConfiguration.cs @@ -0,0 +1,22 @@ +namespace Trax.Api.GraphQL.Configuration; + +/// +/// Holds the resolved configuration for the Trax GraphQL schema, +/// including discovered query model registrations. +/// +public class GraphQLConfiguration +{ + public IReadOnlyList ModelRegistrations { get; } + + /// + /// Tracks which namespace base types and namespace fields have been registered + /// across type modules to prevent duplicate registrations. Populated at runtime + /// by TrainTypeModule and QueryModelTypeModule. + /// + internal HashSet RegisteredNamespaceTypes { get; } = new(StringComparer.Ordinal); + + public GraphQLConfiguration(IReadOnlyList modelRegistrations) + { + ModelRegistrations = modelRegistrations; + } +} diff --git a/src/Trax.Api.GraphQL/Configuration/QueryModelRegistration.cs b/src/Trax.Api.GraphQL/Configuration/QueryModelRegistration.cs new file mode 100644 index 0000000..40976ef --- /dev/null +++ b/src/Trax.Api.GraphQL/Configuration/QueryModelRegistration.cs @@ -0,0 +1,13 @@ +using Trax.Effect.Attributes; + +namespace Trax.Api.GraphQL.Configuration; + +/// +/// Represents a discovered entity type marked with +/// and its owning DbContext type. +/// +public record QueryModelRegistration( + Type EntityType, + Type DbContextType, + TraxQueryModelAttribute Attribute +); diff --git a/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.Build.cs b/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.Build.cs new file mode 100644 index 0000000..f17a3d1 --- /dev/null +++ b/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.Build.cs @@ -0,0 +1,35 @@ +using System.Reflection; +using Microsoft.EntityFrameworkCore; +using Trax.Effect.Attributes; + +namespace Trax.Api.GraphQL.Configuration.TraxGraphQLBuilder; + +public partial class TraxGraphQLBuilder +{ + internal GraphQLConfiguration Build() + { + var modelRegistrations = new List(); + + foreach (var dbContextType in DbContextTypes) + { + var dbSetProps = dbContextType + .GetProperties(BindingFlags.Public | BindingFlags.Instance) + .Where(p => + p.PropertyType.IsGenericType + && p.PropertyType.GetGenericTypeDefinition() == typeof(DbSet<>) + ); + + foreach (var prop in dbSetProps) + { + var entityType = prop.PropertyType.GetGenericArguments()[0]; + var attr = entityType.GetCustomAttribute(); + if (attr is null) + continue; + + modelRegistrations.Add(new QueryModelRegistration(entityType, dbContextType, attr)); + } + } + + return new GraphQLConfiguration(modelRegistrations); + } +} diff --git a/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.DbContext.cs b/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.DbContext.cs new file mode 100644 index 0000000..99246c5 --- /dev/null +++ b/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.DbContext.cs @@ -0,0 +1,22 @@ +using Microsoft.EntityFrameworkCore; + +namespace Trax.Api.GraphQL.Configuration.TraxGraphQLBuilder; + +public partial class TraxGraphQLBuilder +{ + /// + /// Registers a DbContext whose DbSet<T> entities marked with + /// [TraxQueryModel] will be automatically exposed as paginated, + /// filterable, sortable GraphQL queries under discover. + /// + /// + /// The DbContext type containing DbSet properties for the entities to expose. + /// Must be registered in DI (e.g. via AddDbContextFactory or AddDbContext). + /// + public TraxGraphQLBuilder AddDbContext() + where TDbContext : DbContext + { + DbContextTypes.Add(typeof(TDbContext)); + return this; + } +} diff --git a/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.cs b/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.cs new file mode 100644 index 0000000..e0c3e5b --- /dev/null +++ b/src/Trax.Api.GraphQL/Configuration/TraxGraphQLBuilder/TraxGraphQLBuilder.cs @@ -0,0 +1,21 @@ +using System.ComponentModel; +using Microsoft.Extensions.DependencyInjection; + +namespace Trax.Api.GraphQL.Configuration.TraxGraphQLBuilder; + +/// +/// Builder for configuring the Trax GraphQL schema, including DbContext-based +/// model query registration. +/// +public partial class TraxGraphQLBuilder +{ + [EditorBrowsable(EditorBrowsableState.Never)] + internal IServiceCollection Services { get; } + + internal List DbContextTypes { get; } = []; + + public TraxGraphQLBuilder(IServiceCollection services) + { + Services = services; + } +} diff --git a/src/Trax.Api.GraphQL/Extensions/GraphQLServiceExtensions.cs b/src/Trax.Api.GraphQL/Extensions/GraphQLServiceExtensions.cs index 709adc8..d207a29 100644 --- a/src/Trax.Api.GraphQL/Extensions/GraphQLServiceExtensions.cs +++ b/src/Trax.Api.GraphQL/Extensions/GraphQLServiceExtensions.cs @@ -1,7 +1,10 @@ +using HotChocolate.Data; using HotChocolate.Types; using Microsoft.AspNetCore.Builder; using Microsoft.Extensions.DependencyInjection; using Trax.Api.Extensions; +using Trax.Api.GraphQL.Configuration; +using Trax.Api.GraphQL.Configuration.TraxGraphQLBuilder; using Trax.Api.GraphQL.Errors; using Trax.Api.GraphQL.Hooks; using Trax.Api.GraphQL.Mutations; @@ -20,11 +23,19 @@ public static class GraphQLServiceExtensions private const string SchemaName = "trax"; /// - /// Registers the Trax GraphQL schema on a named HotChocolate server ("trax"). - /// This avoids conflicts with a consumer's own default GraphQL schema. - /// Only trains annotated with [TraxQuery] or [TraxMutation] get typed operations generated. + /// Registers the Trax GraphQL schema on a named HotChocolate server ("trax") + /// with support for configuring DbContext-based model queries. /// - public static IServiceCollection AddTraxGraphQL(this IServiceCollection services) + /// + /// + /// services.AddTraxGraphQL(graphql => graphql + /// .AddDbContext<GameDbContext>()); + /// + /// + public static IServiceCollection AddTraxGraphQL( + this IServiceCollection services, + Func configure + ) { if (!services.Any(sd => sd.ServiceType == typeof(TraxMarker))) throw new InvalidOperationException( @@ -32,6 +43,11 @@ public static IServiceCollection AddTraxGraphQL(this IServiceCollection services + "Call services.AddTrax(trax => ...) before services.AddTraxGraphQL()." ); + var builder = new TraxGraphQLBuilder(services); + configure(builder); + var config = builder.Build(); + services.AddSingleton(config); + services.AddTraxApi(); services.AddSingleton(); services.AddTransient(); @@ -40,7 +56,8 @@ public static IServiceCollection AddTraxGraphQL(this IServiceCollection services .AddSingleton(sp => sp.GetRequiredService() ); - services + + var graphqlBuilder = services .AddGraphQLServer(SchemaName) .AddQueryType() .AddMutationType() @@ -52,6 +69,34 @@ public static IServiceCollection AddTraxGraphQL(this IServiceCollection services .AddErrorFilter() .AddInMemorySubscriptions(); + if (config.ModelRegistrations.Count > 0) + { + services.AddSingleton(); + graphqlBuilder.AddTypeModule(); + + // Register DiscoverQueries base type and discover field on RootQuery. + // TrainTypeModule will skip creating these when it detects model registrations. + graphqlBuilder.AddType(new ObjectType()); + graphqlBuilder.AddTypeExtension( + new ObjectTypeExtension(d => + { + d.Name("RootQuery"); + d.Field("discover") + .Type>() + .Resolve(_ => new DiscoverQueries()); + }) + ); + + if (config.ModelRegistrations.Any(r => r.Attribute.Filtering)) + graphqlBuilder.AddFiltering(); + + if (config.ModelRegistrations.Any(r => r.Attribute.Sorting)) + graphqlBuilder.AddSorting(); + + if (config.ModelRegistrations.Any(r => r.Attribute.Projection)) + graphqlBuilder.AddProjections(); + } + // If a broadcaster receiver is registered (via UseBroadcaster()), // wire up the GraphQL handler so remote lifecycle events are forwarded // to HotChocolate subscriptions. @@ -63,6 +108,14 @@ public static IServiceCollection AddTraxGraphQL(this IServiceCollection services return services; } + /// + /// Registers the Trax GraphQL schema on a named HotChocolate server ("trax"). + /// This avoids conflicts with a consumer's own default GraphQL schema. + /// Only trains annotated with [TraxQuery] or [TraxMutation] get typed operations generated. + /// + public static IServiceCollection AddTraxGraphQL(this IServiceCollection services) => + services.AddTraxGraphQL(builder => builder); + /// /// Maps the Trax GraphQL endpoint at the specified route prefix. /// Uses a named schema so it coexists with other HotChocolate schemas diff --git a/src/Trax.Api.GraphQL/Trax.Api.GraphQL.csproj b/src/Trax.Api.GraphQL/Trax.Api.GraphQL.csproj index 1785326..0ebf091 100644 --- a/src/Trax.Api.GraphQL/Trax.Api.GraphQL.csproj +++ b/src/Trax.Api.GraphQL/Trax.Api.GraphQL.csproj @@ -18,6 +18,8 @@ - + + + diff --git a/src/Trax.Api.GraphQL/TypeModules/QueryModelTypeModule.cs b/src/Trax.Api.GraphQL/TypeModules/QueryModelTypeModule.cs new file mode 100644 index 0000000..a768f6a --- /dev/null +++ b/src/Trax.Api.GraphQL/TypeModules/QueryModelTypeModule.cs @@ -0,0 +1,198 @@ +using HotChocolate.Data; +using HotChocolate.Execution.Configuration; +using HotChocolate.Language; +using HotChocolate.Types; +using HotChocolate.Types.Descriptors; +using HotChocolate.Types.Pagination; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Trax.Api.GraphQL.Configuration; +using Trax.Api.GraphQL.Queries; +using Trax.Effect.Attributes; + +namespace Trax.Api.GraphQL.TypeModules; + +/// +/// A HotChocolate TypeModule that dynamically generates GraphQL query fields +/// for entities marked with [TraxQueryModel]. Each entity gets a query +/// field under discover with optional cursor pagination, filtering, +/// sorting, and projection based on the attribute configuration. +/// +public class QueryModelTypeModule(GraphQLConfiguration configuration) : TypeModule +{ + /// + /// Discovers all registered query model entities and generates the GraphQL schema types: + /// - ObjectType for each unique entity type + /// - ObjectTypeExtension on "DiscoverQueries" to add query fields + /// + public override ValueTask> CreateTypesAsync( + IDescriptorContext context, + CancellationToken cancellationToken + ) + { + var types = new List(); + var registrations = configuration.ModelRegistrations; + + if (registrations.Count == 0) + return new(types); + + var usedEntityTypes = new HashSet(); + foreach (var reg in registrations) + { + if (usedEntityTypes.Add(reg.EntityType)) + { + var objectType = (ITypeSystemMember) + Activator.CreateInstance(typeof(ObjectType<>).MakeGenericType(reg.EntityType))!; + types.Add(objectType); + } + } + + // Group model registrations by namespace + var byNamespace = registrations.GroupBy(r => r.Attribute.Namespace); + + foreach (var group in byNamespace) + { + if (group.Key is null) + { + // No namespace — add fields directly to DiscoverQueries + types.Add( + new ObjectTypeExtension(d => + { + d.Name("DiscoverQueries"); + foreach (var reg in group) + AddModelQueryField(d, reg); + }) + ); + } + else + { + // Namespace — create/extend intermediate type + var nsTypeName = TrainTypeModule.NamespaceTypeName(group.Key, "DiscoverQueries"); + var nsFieldName = TrainTypeModule.CamelCase(group.Key); + + // Register the base ObjectType for this namespace (only once across modules) + if (configuration.RegisteredNamespaceTypes.Add(nsTypeName)) + { + types.Add(new ObjectType(d => d.Name(nsTypeName))); + } + + // Add fields to the namespace type + types.Add( + new ObjectTypeExtension(d => + { + d.Name(nsTypeName); + foreach (var reg in group) + AddModelQueryField(d, reg); + }) + ); + + // Add the namespace field to DiscoverQueries (only once across modules) + var nsFieldKey = $"DiscoverQueries.{nsFieldName}"; + if (configuration.RegisteredNamespaceTypes.Add(nsFieldKey)) + { + var capturedNsTypeName = nsTypeName; + types.Add( + new ObjectTypeExtension(d => + { + d.Name("DiscoverQueries"); + d.Field(nsFieldName) + .Type(new NamedTypeNode(capturedNsTypeName)) + .Resolve(_ => new object()); + }) + ); + } + } + } + + return new(types); + } + + private static readonly System.Reflection.MethodInfo ConfigureFieldMethod = + typeof(QueryModelTypeModule).GetMethod( + nameof(ConfigureField), + System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Static + )!; + + private static void AddModelQueryField( + IObjectTypeDescriptor descriptor, + QueryModelRegistration reg + ) + { + var fieldName = reg.Attribute.Name ?? DeriveModelName(reg.EntityType.Name); + + var field = descriptor.Field(fieldName); + + if (reg.Attribute.Description is not null) + field.Description(reg.Attribute.Description); + + if (reg.Attribute.DeprecationReason is not null) + field.Deprecated(reg.Attribute.DeprecationReason); + + // Delegate to a generic method so HotChocolate gets properly typed + // delegates for projection, filtering, sorting, and the resolver. + ConfigureFieldMethod + .MakeGenericMethod(reg.EntityType) + .Invoke(null, [field, reg.DbContextType, reg.Attribute]); + } + + private static void ConfigureField( + IObjectFieldDescriptor field, + Type dbContextType, + TraxQueryModelAttribute attr + ) + where TEntity : class + { + // Apply features in the correct middleware pipeline order: + // Paging > Projection > Filtering > Sorting + if (attr.Paging) + { + field.UsePaging>( + options: new PagingOptions { IncludeTotalCount = true } + ); + } + + if (attr.Projection) + field.UseProjection(); + + if (attr.Filtering) + field.UseFiltering(); + + if (attr.Sorting) + field.UseSorting(); + + field.Resolve(ctx => + { + var dbContext = (DbContext)ctx.Services.GetRequiredService(dbContextType); + return dbContext.Set(); + }); + } + + /// + /// Derives a pluralized camelCase GraphQL field name from a class name. + /// e.g. "Player" → "players", "Match" → "matches", "Category" → "categories" + /// + internal static string DeriveModelName(string typeName) + { + var plural = Pluralize(typeName); + return char.ToLowerInvariant(plural[0]) + plural[1..]; + } + + internal static string Pluralize(string name) + { + if ( + name.EndsWith("s", StringComparison.Ordinal) + || name.EndsWith("x", StringComparison.Ordinal) + || name.EndsWith("z", StringComparison.Ordinal) + || name.EndsWith("ch", StringComparison.Ordinal) + || name.EndsWith("sh", StringComparison.Ordinal) + ) + return name + "es"; + + if (name.EndsWith("y", StringComparison.Ordinal) && name.Length > 1 && !IsVowel(name[^2])) + return name[..^1] + "ies"; + + return name + "s"; + } + + private static bool IsVowel(char c) => "aeiouAEIOU".Contains(c); +} diff --git a/src/Trax.Api.GraphQL/TypeModules/TrainTypeModule.cs b/src/Trax.Api.GraphQL/TypeModules/TrainTypeModule.cs index bad026f..6e3fa1e 100644 --- a/src/Trax.Api.GraphQL/TypeModules/TrainTypeModule.cs +++ b/src/Trax.Api.GraphQL/TypeModules/TrainTypeModule.cs @@ -1,6 +1,8 @@ using HotChocolate.Execution.Configuration; +using HotChocolate.Language; using HotChocolate.Types; using HotChocolate.Types.Descriptors; +using Trax.Api.GraphQL.Configuration; using Trax.Api.GraphQL.Mutations; using Trax.Api.GraphQL.Queries; using Trax.Effect.Attributes; @@ -14,7 +16,10 @@ namespace Trax.Api.GraphQL.TypeModules; /// from discovered train registrations. Each train marked with [TraxMutation]/[TraxQuery] attributes /// gets a corresponding mutation/query field wired to the TrainExecutionService. /// -public partial class TrainTypeModule(ITrainDiscoveryService discoveryService) : TypeModule +public partial class TrainTypeModule( + ITrainDiscoveryService discoveryService, + GraphQLConfiguration? graphQLConfiguration = null +) : TypeModule { /// /// Discovers all registered trains and generates the GraphQL schema types: @@ -96,14 +101,6 @@ CancellationToken cancellationToken if (mutationFields.Count > 0) { types.Add(new ObjectType()); - types.Add( - new ObjectTypeExtension(d => - { - d.Name("DispatchMutations"); - foreach (var (reg, name) in mutationFields) - AddMutationField(d, reg, name); - }) - ); types.Add( new ObjectTypeExtension(d => { @@ -113,35 +110,107 @@ CancellationToken cancellationToken .Resolve(_ => new DispatchMutations()); }) ); + + AddGroupedFields(types, mutationFields, "DispatchMutations", AddMutationField); } - // Register DiscoverQueries type + extend RootQuery with a "discover" field, - // but only when there are query trains. + // Register DiscoverQueries type + extend RootQuery with a "discover" field. + // When query model registrations exist, the base type and discover field are + // already registered in GraphQLServiceExtensions — only add the field extension. + var discoverBaseRegisteredExternally = graphQLConfiguration?.ModelRegistrations.Count > 0; + if (queryFields.Count > 0) { - types.Add(new ObjectType()); - types.Add( - new ObjectTypeExtension(d => - { - d.Name("DiscoverQueries"); - foreach (var (reg, name) in queryFields) - AddQueryField(d, reg, name); - }) - ); - types.Add( - new ObjectTypeExtension(d => - { - d.Name("RootQuery"); - d.Field("discover") - .Type>() - .Resolve(_ => new DiscoverQueries()); - }) - ); + if (!discoverBaseRegisteredExternally) + { + types.Add(new ObjectType()); + types.Add( + new ObjectTypeExtension(d => + { + d.Name("RootQuery"); + d.Field("discover") + .Type>() + .Resolve(_ => new DiscoverQueries()); + }) + ); + } + + AddGroupedFields(types, queryFields, "DiscoverQueries", AddQueryField); } return new ValueTask>(types); } + /// + /// Groups fields by namespace and creates the appropriate type extensions. + /// Fields with no namespace go directly on the parent type. Fields with a namespace + /// get an intermediate ObjectType (e.g. "AlertsDiscoverQueries") and a field on the + /// parent type pointing to it. + /// + private void AddGroupedFields( + List types, + List<(TrainRegistration Registration, string TrainName)> fields, + string parentTypeName, + Action addField + ) + { + var byNamespace = fields.GroupBy(f => f.Registration.GraphQLNamespace); + + foreach (var group in byNamespace) + { + if (group.Key is null) + { + // No namespace — add fields directly to the parent type + types.Add( + new ObjectTypeExtension(d => + { + d.Name(parentTypeName); + foreach (var (reg, name) in group) + addField(d, reg, name); + }) + ); + } + else + { + // Namespace — create intermediate type and add fields to it + var nsTypeName = NamespaceTypeName(group.Key, parentTypeName); + var nsFieldName = CamelCase(group.Key); + + // Register the base ObjectType for this namespace (only once across modules) + if (graphQLConfiguration?.RegisteredNamespaceTypes.Add(nsTypeName) ?? true) + { + types.Add(new ObjectType(d => d.Name(nsTypeName))); + } + + // Add fields to the namespace type + types.Add( + new ObjectTypeExtension(d => + { + d.Name(nsTypeName); + foreach (var (reg, name) in group) + addField(d, reg, name); + }) + ); + + // Add the namespace field to the parent type (only once across modules) + var nsFieldKey = $"{parentTypeName}.{nsFieldName}"; + if (graphQLConfiguration?.RegisteredNamespaceTypes.Add(nsFieldKey) ?? true) + { + var capturedNsTypeName = nsTypeName; + types.Add( + new ObjectTypeExtension(d => + { + d.Name(parentTypeName); + d.Field(nsFieldName) + .Type(new NamedTypeNode(capturedNsTypeName)) + .Resolve(_ => new object()); + }) + ); + } + } + } + } + /// /// Builds the ExecutionMode enum type with RUN and QUEUE values. /// @@ -227,4 +296,23 @@ private static string DeriveTrainName(string serviceTypeName) return name; } + + /// + /// Builds the HotChocolate type name for a namespace group. + /// e.g. ("alerts", "DiscoverQueries") → "AlertsDiscoverQueries" + /// + internal static string NamespaceTypeName(string ns, string parentTypeName) => + PascalCase(ns) + parentTypeName; + + /// + /// Capitalizes the first character of a string. + /// + internal static string PascalCase(string value) => + string.IsNullOrEmpty(value) ? value : char.ToUpperInvariant(value[0]) + value[1..]; + + /// + /// Lowercases the first character of a string. + /// + internal static string CamelCase(string value) => + string.IsNullOrEmpty(value) ? value : char.ToLowerInvariant(value[0]) + value[1..]; } diff --git a/tests/Trax.Api.Tests/QueryModelTypeModuleTests.cs b/tests/Trax.Api.Tests/QueryModelTypeModuleTests.cs new file mode 100644 index 0000000..b6e5921 --- /dev/null +++ b/tests/Trax.Api.Tests/QueryModelTypeModuleTests.cs @@ -0,0 +1,571 @@ +using FluentAssertions; +using HotChocolate.Types; +using Microsoft.EntityFrameworkCore; +using Trax.Api.GraphQL.Configuration; +using Trax.Api.GraphQL.Configuration.TraxGraphQLBuilder; +using Trax.Api.GraphQL.Queries; +using Trax.Api.GraphQL.TypeModules; +using Trax.Effect.Attributes; + +namespace Trax.Api.Tests; + +[TestFixture] +public class QueryModelTypeModuleTests +{ + #region DeriveModelName — Pluralization + + [Test] + public void DeriveModelName_SimpleClass_PluralizesCorrectly() + { + QueryModelTypeModule.DeriveModelName("Player").Should().Be("players"); + } + + [Test] + public void DeriveModelName_EndsWithS_AddsEs() + { + QueryModelTypeModule.DeriveModelName("Address").Should().Be("addresses"); + } + + [Test] + public void DeriveModelName_EndsWithX_AddsEs() + { + QueryModelTypeModule.DeriveModelName("Box").Should().Be("boxes"); + } + + [Test] + public void DeriveModelName_EndsWithZ_AddsEs() + { + QueryModelTypeModule.DeriveModelName("Quiz").Should().Be("quizes"); + } + + [Test] + public void DeriveModelName_EndsWithCh_AddsEs() + { + QueryModelTypeModule.DeriveModelName("Match").Should().Be("matches"); + } + + [Test] + public void DeriveModelName_EndsWithSh_AddsEs() + { + QueryModelTypeModule.DeriveModelName("Crash").Should().Be("crashes"); + } + + [Test] + public void DeriveModelName_EndsWithConsonantY_ChangesToIes() + { + QueryModelTypeModule.DeriveModelName("Category").Should().Be("categories"); + } + + [Test] + public void DeriveModelName_EndsWithVowelY_AddsS() + { + QueryModelTypeModule.DeriveModelName("Key").Should().Be("keys"); + } + + [Test] + public void DeriveModelName_PascalCase_CamelCasesResult() + { + QueryModelTypeModule.DeriveModelName("MatchResult").Should().Be("matchResults"); + } + + [Test] + public void DeriveModelName_SingleChar_PluralizesCorrectly() + { + QueryModelTypeModule.DeriveModelName("A").Should().Be("as"); + } + + #endregion + + #region Pluralize Edge Cases + + [Test] + public void Pluralize_EndsWithS_AddsEs() + { + QueryModelTypeModule.Pluralize("Bus").Should().Be("Buses"); + } + + [Test] + public void Pluralize_SimpleWord_AddsS() + { + QueryModelTypeModule.Pluralize("Player").Should().Be("Players"); + } + + [Test] + public void Pluralize_EndsWithAy_AddsS() + { + QueryModelTypeModule.Pluralize("Day").Should().Be("Days"); + } + + #endregion + + #region Builder — DbContext Scanning + + [Test] + public void Build_NoDbContexts_ReturnsEmptyRegistrations() + { + var builder = new TraxGraphQLBuilder( + new Microsoft.Extensions.DependencyInjection.ServiceCollection() + ); + var config = builder.Build(); + + config.ModelRegistrations.Should().BeEmpty(); + } + + [Test] + public void Build_DbContextWithAttributedEntities_DiscoversCorrectly() + { + var builder = new TraxGraphQLBuilder( + new Microsoft.Extensions.DependencyInjection.ServiceCollection() + ); + builder.AddDbContext(); + var config = builder.Build(); + + config.ModelRegistrations.Should().HaveCount(1); + config.ModelRegistrations[0].EntityType.Should().Be(typeof(TestPlayer)); + config.ModelRegistrations[0].DbContextType.Should().Be(typeof(TestDbContext)); + config.ModelRegistrations[0].Attribute.Description.Should().Be("Test players"); + } + + [Test] + public void Build_DbContextWithMixedEntities_OnlyDiscoversAttributed() + { + var builder = new TraxGraphQLBuilder( + new Microsoft.Extensions.DependencyInjection.ServiceCollection() + ); + builder.AddDbContext(); + var config = builder.Build(); + + config.ModelRegistrations.Should().HaveCount(1); + config.ModelRegistrations[0].EntityType.Should().Be(typeof(TestPlayer)); + } + + [Test] + public void Build_AttributeWithFeatureToggles_PreservesSettings() + { + var builder = new TraxGraphQLBuilder( + new Microsoft.Extensions.DependencyInjection.ServiceCollection() + ); + builder.AddDbContext(); + var config = builder.Build(); + + config.ModelRegistrations.Should().HaveCount(1); + var attr = config.ModelRegistrations[0].Attribute; + attr.Paging.Should().BeTrue(); + attr.Filtering.Should().BeFalse(); + attr.Sorting.Should().BeFalse(); + attr.Projection.Should().BeTrue(); + } + + [Test] + public void Build_AttributeDefaults_AllFeaturesEnabled() + { + var builder = new TraxGraphQLBuilder( + new Microsoft.Extensions.DependencyInjection.ServiceCollection() + ); + builder.AddDbContext(); + var config = builder.Build(); + + var attr = config.ModelRegistrations[0].Attribute; + attr.Paging.Should().BeTrue(); + attr.Filtering.Should().BeTrue(); + attr.Sorting.Should().BeTrue(); + attr.Projection.Should().BeTrue(); + } + + [Test] + public void Build_MultipleDbContexts_DiscoversFromAll() + { + var builder = new TraxGraphQLBuilder( + new Microsoft.Extensions.DependencyInjection.ServiceCollection() + ); + builder.AddDbContext(); + builder.AddDbContext(); + var config = builder.Build(); + + config.ModelRegistrations.Should().HaveCount(2); + config + .ModelRegistrations.Select(r => r.EntityType) + .Should() + .Contain([typeof(TestPlayer), typeof(TestItem)]); + } + + #endregion + + #region CreateTypesAsync — Type Generation + + [Test] + public async Task CreateTypesAsync_NoRegistrations_ReturnsEmpty() + { + var config = new GraphQLConfiguration([]); + var module = new QueryModelTypeModule(config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + types.Should().BeEmpty(); + } + + [Test] + public async Task CreateTypesAsync_WithRegistrations_CreatesObjectTypeAndExtension() + { + var config = new GraphQLConfiguration([ + new QueryModelRegistration( + typeof(TestPlayer), + typeof(TestDbContext), + new TraxQueryModelAttribute { Description = "Test players" } + ), + ]); + var module = new QueryModelTypeModule(config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // ObjectType + ObjectTypeExtension on DiscoverQueries + types.Should().HaveCount(2); + + types + .Should() + .ContainSingle(t => + t.GetType().IsGenericType + && t.GetType().GetGenericTypeDefinition() == typeof(ObjectType<>) + && t.GetType().GetGenericArguments()[0] == typeof(TestPlayer) + ); + + types.OfType().Should().HaveCount(1); + } + + [Test] + public async Task CreateTypesAsync_DuplicateEntityTypes_RegistersTypeOnce() + { + var config = new GraphQLConfiguration([ + new QueryModelRegistration( + typeof(TestPlayer), + typeof(TestDbContext), + new TraxQueryModelAttribute { Name = "players1" } + ), + new QueryModelRegistration( + typeof(TestPlayer), + typeof(TestDbContext), + new TraxQueryModelAttribute { Name = "players2" } + ), + ]); + var module = new QueryModelTypeModule(config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Only 1 ObjectType despite 2 registrations + types + .Where(t => + t.GetType().IsGenericType + && t.GetType().GetGenericTypeDefinition() == typeof(ObjectType<>) + ) + .Should() + .HaveCount(1); + } + + #endregion + + #region TrainTypeModule Coordination + + [Test] + public async Task TrainTypeModule_WithModelRegistrations_SkipsDiscoverQueriesBaseType() + { + var discovery = new StubDiscoveryService([ + new Trax.Mediator.Services.TrainDiscovery.TrainRegistration + { + ServiceType = typeof(IStubTrain), + ImplementationType = typeof(StubTrain), + InputType = typeof(StubInput), + OutputType = typeof(StubOutput), + Lifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Scoped, + ServiceTypeName = "LookupTrain", + ImplementationTypeName = "LookupTrain", + InputTypeName = nameof(StubInput), + OutputTypeName = nameof(StubOutput), + RequiredPolicies = [], + RequiredRoles = [], + IsQuery = true, + IsMutation = false, + IsBroadcastEnabled = false, + GraphQLName = "Lookup", + GraphQLOperations = Trax.Effect.Attributes.GraphQLOperation.Run, + IsRemote = false, + }, + ]); + + var config = new GraphQLConfiguration([ + new QueryModelRegistration( + typeof(TestPlayer), + typeof(TestDbContext), + new TraxQueryModelAttribute() + ), + ]); + + var module = new Trax.Api.GraphQL.TypeModules.TrainTypeModule(discovery, config); + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Should NOT contain ObjectType (base type registered externally) + types + .Should() + .NotContain(t => + t.GetType().IsGenericType + && t.GetType().GetGenericTypeDefinition() == typeof(ObjectType<>) + && t.GetType().GetGenericArguments()[0] == typeof(DiscoverQueries) + ); + + // Should still have the field extension on DiscoverQueries + types.OfType().Should().HaveCount(1); + } + + [Test] + public async Task TrainTypeModule_WithoutModelRegistrations_CreatesDiscoverQueriesBaseType() + { + var discovery = new StubDiscoveryService([ + new Trax.Mediator.Services.TrainDiscovery.TrainRegistration + { + ServiceType = typeof(IStubTrain), + ImplementationType = typeof(StubTrain), + InputType = typeof(StubInput), + OutputType = typeof(StubOutput), + Lifetime = Microsoft.Extensions.DependencyInjection.ServiceLifetime.Scoped, + ServiceTypeName = "LookupTrain", + ImplementationTypeName = "LookupTrain", + InputTypeName = nameof(StubInput), + OutputTypeName = nameof(StubOutput), + RequiredPolicies = [], + RequiredRoles = [], + IsQuery = true, + IsMutation = false, + IsBroadcastEnabled = false, + GraphQLName = "Lookup", + GraphQLOperations = Trax.Effect.Attributes.GraphQLOperation.Run, + IsRemote = false, + }, + ]); + + var module = new Trax.Api.GraphQL.TypeModules.TrainTypeModule(discovery); + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // SHOULD contain ObjectType (no model registrations, so we own it) + types + .Should() + .ContainSingle(t => + t.GetType().IsGenericType + && t.GetType().GetGenericTypeDefinition() == typeof(ObjectType<>) + && t.GetType().GetGenericArguments()[0] == typeof(DiscoverQueries) + ); + + // DiscoverQueries field extension + RootQuery discover extension + types.OfType().Should().HaveCount(2); + } + + #endregion + + #region Attribute Property Tests + + [Test] + public void TraxQueryModelAttribute_DefaultProperties_FeaturesAllTrue() + { + var attr = new TraxQueryModelAttribute(); + attr.Name.Should().BeNull(); + attr.Description.Should().BeNull(); + attr.DeprecationReason.Should().BeNull(); + attr.Paging.Should().BeTrue(); + attr.Filtering.Should().BeTrue(); + attr.Sorting.Should().BeTrue(); + attr.Projection.Should().BeTrue(); + } + + [Test] + public void TraxQueryModelAttribute_WithInitProperties_SetsCorrectly() + { + var attr = new TraxQueryModelAttribute + { + Name = "allPlayers", + Description = "All players", + DeprecationReason = "Use v2", + Paging = false, + Filtering = false, + Sorting = true, + Projection = false, + }; + + attr.Name.Should().Be("allPlayers"); + attr.Description.Should().Be("All players"); + attr.DeprecationReason.Should().Be("Use v2"); + attr.Paging.Should().BeFalse(); + attr.Filtering.Should().BeFalse(); + attr.Sorting.Should().BeTrue(); + attr.Projection.Should().BeFalse(); + } + + #endregion + + #region Namespace Grouping + + [Test] + public async Task CreateTypesAsync_ModelWithNamespace_CreatesNamespaceTypeAndExtension() + { + var config = new GraphQLConfiguration([ + new QueryModelRegistration( + typeof(TestPlayer), + typeof(TestDbContext), + new TraxQueryModelAttribute { Description = "Test players", Namespace = "game" } + ), + ]); + var module = new QueryModelTypeModule(config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // ObjectType + ObjectType (namespace base) + 2x ObjectTypeExtension + // (namespace fields + namespace field on DiscoverQueries) + types.OfType().Should().HaveCount(2); + + // Namespace type should be tracked + config.RegisteredNamespaceTypes.Should().Contain("GameDiscoverQueries"); + config.RegisteredNamespaceTypes.Should().Contain("DiscoverQueries.game"); + } + + [Test] + public async Task CreateTypesAsync_ModelWithoutNamespace_NoNamespaceTypesCreated() + { + var config = new GraphQLConfiguration([ + new QueryModelRegistration( + typeof(TestPlayer), + typeof(TestDbContext), + new TraxQueryModelAttribute { Description = "Test players" } + ), + ]); + var module = new QueryModelTypeModule(config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Should still have one extension directly on DiscoverQueries + types.OfType().Should().HaveCount(1); + config.RegisteredNamespaceTypes.Should().BeEmpty(); + } + + [Test] + public void TraxQueryModelAttribute_Namespace_DefaultsToNull() + { + var attr = new TraxQueryModelAttribute(); + attr.Namespace.Should().BeNull(); + } + + [Test] + public void TraxQueryModelAttribute_Namespace_SetsCorrectly() + { + var attr = new TraxQueryModelAttribute { Namespace = "game" }; + attr.Namespace.Should().Be("game"); + } + + [Test] + public async Task CreateTypesAsync_MixedNamespacedAndRootModels_BothGenerated() + { + var config = new GraphQLConfiguration([ + new QueryModelRegistration( + typeof(TestPlayer), + typeof(TestDbContext), + new TraxQueryModelAttribute { Namespace = "game" } + ), + new QueryModelRegistration( + typeof(TestItem), + typeof(SecondDbContext), + new TraxQueryModelAttribute() + ), + ]); + var module = new QueryModelTypeModule(config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Root model extension on DiscoverQueries + namespace field extension on DiscoverQueries + // + namespace type extension (GameDiscoverQueries) + types.OfType().Should().HaveCount(3); + config.RegisteredNamespaceTypes.Should().Contain("GameDiscoverQueries"); + } + + #endregion + + #region Stubs + + [TraxQueryModel(Description = "Test players")] + public class TestPlayer + { + public int Id { get; set; } + public string Name { get; set; } = ""; + } + + public class TestIgnored + { + public int Id { get; set; } + public string Value { get; set; } = ""; + } + + [TraxQueryModel(Filtering = false, Sorting = false)] + public class ToggleEntity + { + public int Id { get; set; } + } + + [TraxQueryModel(Description = "Test items")] + public class TestItem + { + public int Id { get; set; } + public string ItemName { get; set; } = ""; + } + + public class TestDbContext : DbContext + { + public DbSet Players { get; set; } = null!; + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) => + optionsBuilder.UseInMemoryDatabase("TestDb_" + Guid.NewGuid()); + } + + public class MixedDbContext : DbContext + { + public DbSet Players { get; set; } = null!; + public DbSet Ignored { get; set; } = null!; + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) => + optionsBuilder.UseInMemoryDatabase("MixedDb_" + Guid.NewGuid()); + } + + public class ToggleDbContext : DbContext + { + public DbSet Toggles { get; set; } = null!; + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) => + optionsBuilder.UseInMemoryDatabase("ToggleDb_" + Guid.NewGuid()); + } + + public class SecondDbContext : DbContext + { + public DbSet Items { get; set; } = null!; + + protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) => + optionsBuilder.UseInMemoryDatabase("SecondDb_" + Guid.NewGuid()); + } + + private class StubDiscoveryService( + IReadOnlyList registrations + ) : Trax.Mediator.Services.TrainDiscovery.ITrainDiscoveryService + { + public IReadOnlyList DiscoverTrains() => + registrations; + } + + private interface IStubTrain; + + private class StubTrain; + + public record StubInput + { + public string Value { get; init; } = ""; + } + + public record StubOutput + { + public string Result { get; init; } = ""; + } + + #endregion +} diff --git a/tests/Trax.Api.Tests/TrainTypeModuleTests.cs b/tests/Trax.Api.Tests/TrainTypeModuleTests.cs index ddf448e..58909f8 100644 --- a/tests/Trax.Api.Tests/TrainTypeModuleTests.cs +++ b/tests/Trax.Api.Tests/TrainTypeModuleTests.cs @@ -2,6 +2,7 @@ using HotChocolate.Types; using LanguageExt; using Microsoft.Extensions.DependencyInjection; +using Trax.Api.GraphQL.Configuration; using Trax.Api.GraphQL.Mutations; using Trax.Api.GraphQL.Queries; using Trax.Api.GraphQL.TypeModules; @@ -805,6 +806,254 @@ public async Task CreateTypesAsync_EmptyRegistrations_NoNamespaceTypesRegistered #endregion + #region Namespace Grouping + + [Test] + public async Task CreateTypesAsync_MutationWithNamespace_CreatesNamespaceTypeAndExtension() + { + var config = new GraphQLConfiguration([]); + var discovery = new StubDiscoveryService([ + CreateRegistration( + "BanTrain", + typeof(Unit), + name: "Ban", + operations: GraphQLOperation.Run, + graphqlNamespace: "players" + ), + ]); + var module = new TrainTypeModule(discovery, config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Should have a non-generic ObjectType for the namespace (PlayersDispatchMutations) + var nsTypes = types.Where(t => t.GetType() == typeof(ObjectType)).ToList(); + // Per-train response type + namespace base type = 2 + nsTypes.Should().HaveCount(2); + + // Should have extensions: DispatchMutations (namespace field) + RootMutation (dispatch) + // + PlayersDispatchMutations (fields) + types.OfType().Should().HaveCount(3); + } + + [Test] + public async Task CreateTypesAsync_QueryWithNamespace_CreatesNamespaceTypeAndExtension() + { + var config = new GraphQLConfiguration([]); + var discovery = new StubDiscoveryService([ + CreateRegistration( + "LookupTrain", + typeof(TypedOutput), + name: "Lookup", + isQuery: true, + operations: GraphQLOperation.Run, + graphqlNamespace: "players" + ), + ]); + var module = new TrainTypeModule(discovery, config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Should have ObjectType + ObjectType for namespace (PlayersDiscoverQueries) + // + InputObjectType + ObjectType + // + extensions: RootQuery (discover) + DiscoverQueries (namespace field) + PlayersDiscoverQueries (fields) + types.OfType().Should().HaveCount(3); + + // Namespace base type should be registered + config.RegisteredNamespaceTypes.Should().Contain("PlayersDiscoverQueries"); + } + + [Test] + public async Task CreateTypesAsync_MultipleTrainsInSameNamespace_ShareIntermediateType() + { + var config = new GraphQLConfiguration([]); + var discovery = new StubDiscoveryService([ + CreateRegistration( + "LookupTrain", + typeof(TypedOutput), + name: "Lookup", + serviceTypeName: "LookupTrain", + isQuery: true, + operations: GraphQLOperation.Run, + graphqlNamespace: "players" + ), + CreateRegistration( + "SearchTrain", + typeof(TypedOutput2), + name: "Search", + serviceTypeName: "SearchTrain", + isQuery: true, + operations: GraphQLOperation.Run, + graphqlNamespace: "players" + ), + ]); + var module = new TrainTypeModule(discovery, config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Only one namespace base type should be registered + config.RegisteredNamespaceTypes.Count(n => n == "PlayersDiscoverQueries").Should().Be(1); + + // Only one namespace field extension on DiscoverQueries + config.RegisteredNamespaceTypes.Should().Contain("DiscoverQueries.players"); + } + + [Test] + public async Task CreateTypesAsync_DifferentNamespaces_CreatesSeparateTypes() + { + var config = new GraphQLConfiguration([]); + var discovery = new StubDiscoveryService([ + CreateRegistration( + "LookupTrain", + typeof(TypedOutput), + name: "Lookup", + serviceTypeName: "LookupTrain", + isQuery: true, + operations: GraphQLOperation.Run, + graphqlNamespace: "players" + ), + CreateRegistration( + "AlertTrain", + typeof(TypedOutput2), + name: "Alert", + serviceTypeName: "AlertTrain", + isQuery: true, + operations: GraphQLOperation.Run, + graphqlNamespace: "alerts" + ), + ]); + var module = new TrainTypeModule(discovery, config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + config.RegisteredNamespaceTypes.Should().Contain("PlayersDiscoverQueries"); + config.RegisteredNamespaceTypes.Should().Contain("AlertsDiscoverQueries"); + } + + [Test] + public async Task CreateTypesAsync_MixedNamespacedAndRoot_BothGenerated() + { + var config = new GraphQLConfiguration([]); + var discovery = new StubDiscoveryService([ + CreateRegistration( + "LookupTrain", + typeof(TypedOutput), + name: "Lookup", + serviceTypeName: "LookupTrain", + isQuery: true, + operations: GraphQLOperation.Run, + graphqlNamespace: "players" + ), + CreateRegistration( + "HealthTrain", + typeof(TypedOutput2), + name: "Health", + serviceTypeName: "HealthTrain", + isQuery: true, + operations: GraphQLOperation.Run + ), + ]); + var module = new TrainTypeModule(discovery, config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Should have extensions for: + // - RootQuery (discover) + // - DiscoverQueries (namespace field for "players" + root field "Health") + // - PlayersDiscoverQueries (fields) + types.OfType().Should().HaveCountGreaterThanOrEqualTo(3); + + // Namespace type should exist + config.RegisteredNamespaceTypes.Should().Contain("PlayersDiscoverQueries"); + } + + [Test] + public async Task CreateTypesAsync_NullNamespace_BackwardCompatible() + { + // Existing behavior: no namespace, fields go directly on DiscoverQueries + var config = new GraphQLConfiguration([]); + var discovery = new StubDiscoveryService([ + CreateRegistration( + "LookupTrain", + typeof(TypedOutput), + name: "Lookup", + isQuery: true, + operations: GraphQLOperation.Run + ), + ]); + var module = new TrainTypeModule(discovery, config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // No namespace types should be registered + config.RegisteredNamespaceTypes.Should().BeEmpty(); + + // Should still have DiscoverQueries extension + RootQuery extension + types.OfType().Should().HaveCount(2); + } + + [Test] + public async Task CreateTypesAsync_SameNamespaceOnQueryAndMutation_CreatesSeparateNamespaceTypes() + { + var config = new GraphQLConfiguration([]); + var discovery = new StubDiscoveryService([ + CreateRegistration( + "LookupTrain", + typeof(TypedOutput), + name: "Lookup", + serviceTypeName: "LookupTrain", + isQuery: true, + operations: GraphQLOperation.Run, + graphqlNamespace: "alerts" + ), + CreateRegistration( + "CreateAlertTrain", + typeof(Unit), + name: "CreateAlert", + serviceTypeName: "CreateAlertTrain", + operations: GraphQLOperation.Run, + graphqlNamespace: "alerts" + ), + ]); + var module = new TrainTypeModule(discovery, config); + + var types = await module.CreateTypesAsync(null!, CancellationToken.None); + + // Should create separate namespace types for queries vs mutations + config.RegisteredNamespaceTypes.Should().Contain("AlertsDiscoverQueries"); + config.RegisteredNamespaceTypes.Should().Contain("AlertsDispatchMutations"); + } + + [Test] + public void NamespaceTypeName_CombinesCorrectly() + { + TrainTypeModule + .NamespaceTypeName("alerts", "DiscoverQueries") + .Should() + .Be("AlertsDiscoverQueries"); + TrainTypeModule + .NamespaceTypeName("players", "DispatchMutations") + .Should() + .Be("PlayersDispatchMutations"); + } + + [Test] + public void PascalCase_CapitalizesFirstChar() + { + TrainTypeModule.PascalCase("alerts").Should().Be("Alerts"); + TrainTypeModule.PascalCase("Players").Should().Be("Players"); + TrainTypeModule.PascalCase("a").Should().Be("A"); + } + + [Test] + public void CamelCase_LowercasesFirstChar() + { + TrainTypeModule.CamelCase("Alerts").Should().Be("alerts"); + TrainTypeModule.CamelCase("players").Should().Be("players"); + TrainTypeModule.CamelCase("A").Should().Be("a"); + } + + #endregion + #region Helpers private static List GetNonGenericObjectTypes( @@ -841,7 +1090,8 @@ private static TrainRegistration CreateRegistration( string? description = null, string? deprecationReason = null, GraphQLOperation operations = GraphQLOperation.Run | GraphQLOperation.Queue, - bool isQuery = false + bool isQuery = false, + string? graphqlNamespace = null ) { return new TrainRegistration @@ -864,6 +1114,7 @@ private static TrainRegistration CreateRegistration( GraphQLDescription = description, GraphQLDeprecationReason = deprecationReason, GraphQLOperations = operations, + GraphQLNamespace = graphqlNamespace, IsRemote = false, }; } diff --git a/tests/Trax.Api.Tests/Trax.Api.Tests.csproj b/tests/Trax.Api.Tests/Trax.Api.Tests.csproj index 2cecb33..952c386 100644 --- a/tests/Trax.Api.Tests/Trax.Api.Tests.csproj +++ b/tests/Trax.Api.Tests/Trax.Api.Tests.csproj @@ -12,7 +12,8 @@ - + +