diff --git a/acorn-graphql/src/main/java/com/datasqrl/ai/converter/GraphQLSchemaConverter.java b/acorn-graphql/src/main/java/com/datasqrl/ai/converter/GraphQLSchemaConverter.java index 35a0b55..c63e2c4 100644 --- a/acorn-graphql/src/main/java/com/datasqrl/ai/converter/GraphQLSchemaConverter.java +++ b/acorn-graphql/src/main/java/com/datasqrl/ai/converter/GraphQLSchemaConverter.java @@ -9,14 +9,21 @@ import com.datasqrl.ai.tool.FunctionDefinition.Argument; import com.datasqrl.ai.tool.FunctionDefinition.Parameters; import com.datasqrl.ai.util.ErrorHandling; + +import graphql.language.BooleanValue; import graphql.language.Comment; import graphql.language.Definition; import graphql.language.Document; +import graphql.language.FloatValue; +import graphql.language.IntValue; import graphql.language.ListType; +import graphql.language.Node; +import graphql.language.NodeVisitorStub; import graphql.language.NonNullType; import graphql.language.OperationDefinition; import graphql.language.OperationDefinition.Operation; import graphql.language.SourceLocation; +import graphql.language.StringValue; import graphql.language.Type; import graphql.language.TypeName; import graphql.language.VariableDefinition; @@ -36,11 +43,15 @@ import graphql.schema.GraphQLScalarType; import graphql.schema.GraphQLSchema; import graphql.schema.GraphQLType; +import graphql.schema.InputValueWithState; import graphql.schema.idl.RuntimeWiring; import graphql.schema.idl.SchemaGenerator; import graphql.schema.idl.SchemaParser; import graphql.schema.idl.SchemaPrinter; import graphql.schema.idl.TypeDefinitionRegistry; +import graphql.util.TraversalControl; +import graphql.util.TraverserContext; + import java.lang.reflect.Field; import java.lang.reflect.Modifier; import java.util.ArrayList; @@ -53,6 +64,7 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; import java.util.stream.Stream; + import lombok.Value; import lombok.extern.slf4j.Slf4j; @@ -404,7 +416,8 @@ public boolean visit( unwrappedType, argName, nestedField.getName(), - nestedField.getDescription()); + nestedField.getDescription(), + nestedField.getInputFieldDefaultValue()); String typeString = printFieldType(nestedField); queryHeader.append(argName).append(": ").append(typeString); numArgs++; @@ -422,7 +435,8 @@ public boolean visit( unwrappedType, argName, arg.getName(), - arg.getDescription()); + arg.getDescription(), + arg.getArgumentDefaultValue()); String typeString = printArgumentType(arg); queryHeader.append(argName).append(": ").append(typeString); numArgs++; @@ -463,7 +477,7 @@ private String processField( UnwrappedType unwrappedType, String argName, String originalName, - String description) { + String description, InputValueWithState defaultValue) { Argument argDef = convert(unwrappedType.type()); argDef.setDescription(description); if (numArgs > 0) queryBody.append(", "); @@ -472,9 +486,40 @@ private String processField( params.getProperties().put(argName, argDef); argName = "$" + argName; queryBody.append(originalName).append(": ").append(argName); + if(defaultValue.getValue() instanceof graphql.language.Value value ) { + convertDefaultValue(queryBody, value); + } return argName; } +private void convertDefaultValue(StringBuilder queryBody, graphql.language.Value value) { + value.accept(null, new NodeVisitorStub() { + @Override + public TraversalControl visitBooleanValue(BooleanValue node, TraverserContext context) { + queryBody.append("=").append(node.isValue()); + return TraversalControl.CONTINUE; + } + + @Override + public TraversalControl visitIntValue(IntValue node, TraverserContext context) { + queryBody.append("=").append(node.getValue()); + return TraversalControl.CONTINUE; + } + + @Override + public TraversalControl visitStringValue(StringValue node, TraverserContext context) { + queryBody.append("=\"").append(node.getValue()).append('"'); + return TraversalControl.CONTINUE; + } + + @Override + public TraversalControl visitFloatValue(FloatValue node, TraverserContext context) { + queryBody.append("=").append(node.getValue()); + return TraversalControl.CONTINUE; + } + }); +} + private String printFieldType(GraphQLInputObjectField field) { GraphQLInputObjectType type = GraphQLInputObjectType.newInputObject().name("DummyType").field(field).build();