Skip to content

Commit d1f349e

Browse files
committed
Evolog Modules: Add parsing for list collection aggregate
1 parent 8bf77b3 commit d1f349e

File tree

5 files changed

+93
-3
lines changed

5 files changed

+93
-3
lines changed

alpha-api/src/main/java/at/ac/tuwien/kr/alpha/api/programs/atoms/AggregateAtom.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ enum AggregateFunctionSymbol {
2424
COUNT,
2525
MAX,
2626
MIN,
27-
SUM
27+
SUM,
28+
LIST
2829
}
2930

3031
ComparisonOperator getLowerBoundOperator();
@@ -44,6 +45,11 @@ enum AggregateFunctionSymbol {
4445
@Override
4546
AggregateLiteral toLiteral(boolean positive);
4647

48+
@Override
49+
default AggregateLiteral toLiteral() {
50+
return toLiteral(true);
51+
}
52+
4753
interface AggregateElement {
4854

4955
List<Term> getElementTerms();

alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPCore2.g4

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ choice_elements : choice_element (SEMICOLON choice_elements)?;
3434

3535
choice_element : classical_literal (COLON naf_literals?)?;
3636

37-
aggregate : NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;
37+
aggregate : (classic_aggregate | list_aggregate);
38+
39+
list_aggregate: term EQUAL AGGREGATE_LIST CURLY_OPEN list_comprehension CURLY_CLOSE;
40+
41+
list_comprehension: term COLON naf_literals; // Note: Term is expected to be a function term or basic_term
42+
43+
classic_aggregate: NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;
3844

3945
aggregate_elements : aggregate_element (SEMICOLON aggregate_elements)?;
4046

alpha-core/src/main/antlr/at/ac/tuwien/kr/alpha/core/antlr/ASPLexer.g4

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ AGGREGATE_COUNT : '#count';
4040
AGGREGATE_MAX : '#max';
4141
AGGREGATE_MIN : '#min';
4242
AGGREGATE_SUM : '#sum';
43+
AGGREGATE_LIST : '#list';
4344

4445
DIRECTIVE_ENUM : 'enumeration_predicate_is';
4546

alpha-core/src/main/java/at/ac/tuwien/kr/alpha/core/parser/ParseTreeVisitor.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,19 @@ public Set<Literal> visitBody(ASPCore2Parser.BodyContext ctx) {
388388

389389
@Override
390390
public AggregateLiteral visitAggregate(ASPCore2Parser.AggregateContext ctx) {
391-
// aggregate : NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;
391+
// aggregate : (classic_aggregate | list_aggregate);
392+
if (ctx.classic_aggregate() != null) {
393+
return visitClassic_aggregate(ctx.classic_aggregate());
394+
} else if (ctx.list_aggregate() != null) {
395+
return visitList_aggregate(ctx.list_aggregate());
396+
} else {
397+
throw notSupported(ctx);
398+
}
399+
}
400+
401+
@Override
402+
public AggregateLiteral visitClassic_aggregate(ASPCore2Parser.Classic_aggregateContext ctx) {
403+
// classic_aggregate: NAF? (lt=term lop=binop)? aggregate_function CURLY_OPEN aggregate_elements CURLY_CLOSE (uop=binop ut=term)?;
392404
boolean isPositive = ctx.NAF() == null;
393405
Term lt = null;
394406
ComparisonOperator lop = null;
@@ -407,6 +419,23 @@ public AggregateLiteral visitAggregate(ASPCore2Parser.AggregateContext ctx) {
407419
return Atoms.newAggregateAtom(lop, lt, uop, ut, aggregateFunction, aggregateElements).toLiteral(isPositive);
408420
}
409421

422+
@Override
423+
public AggregateLiteral visitList_aggregate(ASPCore2Parser.List_aggregateContext ctx) {
424+
// list_aggregate: term EQUAL AGGREGATE_LIST CURLY_OPEN list_comprehension CURLY_CLOSE;
425+
Term listResultTerm = (Term) visit(ctx.term());
426+
ImmutablePair<Term, List<Literal>> listComprehension = visitList_comprehension(ctx.list_comprehension());
427+
return Atoms.newAggregateAtom(ComparisonOperators.EQ, listResultTerm, AggregateAtom.AggregateFunctionSymbol.LIST,
428+
List.of(Atoms.newAggregateElement(List.of(listComprehension.left), listComprehension.right))).toLiteral();
429+
}
430+
431+
@Override
432+
public ImmutablePair<Term, List<Literal>> visitList_comprehension(ASPCore2Parser.List_comprehensionContext ctx) {
433+
// list_comprehension: term COLON naf_literals;
434+
Term elementTerm = (Term) visit(ctx.term());
435+
List<Literal> elementSelectors = visitNaf_literals(ctx.naf_literals());
436+
return ImmutablePair.of(elementTerm, elementSelectors);
437+
}
438+
410439
@Override
411440
public List<AggregateAtom.AggregateElement> visitAggregate_elements(ASPCore2Parser.Aggregate_elementsContext ctx) {
412441
// aggregate_elements : aggregate_element (SEMICOLON aggregate_elements)?;

alpha-core/src/test/java/at/ac/tuwien/kr/alpha/core/parser/ParserTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ public class ParserTest {
106106

107107
private static final String MODULE_LITERAL_NO_OUTPUT_WITH_NUM_ANSWER_SETS = "a(X) :- #something{4}[X].";
108108

109+
private static final String LIST_AGGREGATE = "stuff_list(LST) :- LST = #list{X : stuff(X)}.";
110+
111+
private static final String LIST_AGGREGATE_TUPLE = "stuff_list(LST) :- LST = #list{stuff_tuple(X,Y) : stuff(X,Y)}.";
112+
109113
private final ProgramParserImpl parser = new ProgramParserImpl();
110114

111115
@Test
@@ -499,4 +503,48 @@ public void moduleLiteralNoOutputWithNumAnswerSets() {
499503
assertEquals(4, moduleLiteral.getAtom().getInstantiationMode().requestedAnswerSets().get());
500504
}
501505

506+
@Test
507+
public void listAggregate() {
508+
InputProgram prog = parser.parse(LIST_AGGREGATE);
509+
assertEquals(1, prog.getRules().size());
510+
Rule<?> rule = prog.getRules().get(0);
511+
assertEquals(1, rule.getBody().size());
512+
assertEquals(1, rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).count());
513+
AggregateLiteral aggregateLiteral = (AggregateLiteral) rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).findFirst().get();
514+
AggregateAtom aggregateAtom = aggregateLiteral.getAtom();
515+
assertEquals(ComparisonOperators.EQ, aggregateAtom.getLowerBoundOperator());
516+
assertEquals(Terms.newVariable("LST"), aggregateAtom.getLowerBoundTerm());
517+
assertEquals(AggregateAtom.AggregateFunctionSymbol.LIST, aggregateAtom.getAggregateFunction());
518+
assertEquals(1, aggregateAtom.getAggregateElements().size());
519+
AggregateAtom.AggregateElement aggregateElement = aggregateAtom.getAggregateElements().get(0);
520+
assertEquals(1, aggregateElement.getElementTerms().size());
521+
Term elementTerm = aggregateElement.getElementTerms().get(0);
522+
assertEquals(Terms.newVariable("X"), elementTerm);
523+
assertEquals(1, aggregateElement.getElementLiterals().size());
524+
Literal elementLiteral = aggregateElement.getElementLiterals().get(0);
525+
assertEquals(Atoms.newBasicAtom(Predicates.getPredicate("stuff", 1), Terms.newVariable("X")).toLiteral(), elementLiteral);
526+
}
527+
528+
@Test
529+
public void listAggregateWithTuples() {
530+
InputProgram prog = parser.parse(LIST_AGGREGATE_TUPLE);
531+
assertEquals(1, prog.getRules().size());
532+
Rule<?> rule = prog.getRules().get(0);
533+
assertEquals(1, rule.getBody().size());
534+
assertEquals(1, rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).count());
535+
AggregateLiteral aggregateLiteral = (AggregateLiteral) rule.getBody().stream().filter(lit -> lit instanceof AggregateLiteral).findFirst().get();
536+
AggregateAtom aggregateAtom = aggregateLiteral.getAtom();
537+
assertEquals(ComparisonOperators.EQ, aggregateAtom.getLowerBoundOperator());
538+
assertEquals(Terms.newVariable("LST"), aggregateAtom.getLowerBoundTerm());
539+
assertEquals(AggregateAtom.AggregateFunctionSymbol.LIST, aggregateAtom.getAggregateFunction());
540+
assertEquals(1, aggregateAtom.getAggregateElements().size());
541+
AggregateAtom.AggregateElement aggregateElement = aggregateAtom.getAggregateElements().get(0);
542+
assertEquals(1, aggregateElement.getElementTerms().size());
543+
Term elementTerm = aggregateElement.getElementTerms().get(0);
544+
assertEquals(Terms.newFunctionTerm("stuff_tuple", Terms.newVariable("X"), Terms.newVariable("Y")), elementTerm);
545+
assertEquals(1, aggregateElement.getElementLiterals().size());
546+
Literal elementLiteral = aggregateElement.getElementLiterals().get(0);
547+
assertEquals(Atoms.newBasicAtom(Predicates.getPredicate("stuff", 2), Terms.newVariable("X"), Terms.newVariable("Y")).toLiteral(), elementLiteral);
548+
}
549+
502550
}

0 commit comments

Comments
 (0)