|
1 | 1 | package at.ac.tuwien.kr.alpha.core.programs.transformation;
|
2 | 2 |
|
3 | 3 | import at.ac.tuwien.kr.alpha.api.Alpha;
|
| 4 | +import at.ac.tuwien.kr.alpha.api.AnswerSet; |
| 5 | +import at.ac.tuwien.kr.alpha.api.common.fixedinterpretations.PredicateInterpretation; |
| 6 | +import at.ac.tuwien.kr.alpha.api.programs.Predicate; |
4 | 7 | import at.ac.tuwien.kr.alpha.api.programs.NormalProgram;
|
| 8 | +import at.ac.tuwien.kr.alpha.api.programs.atoms.Atom; |
| 9 | +import at.ac.tuwien.kr.alpha.api.programs.atoms.BasicAtom; |
5 | 10 | import at.ac.tuwien.kr.alpha.api.programs.atoms.ExternalAtom;
|
6 | 11 | import at.ac.tuwien.kr.alpha.api.programs.atoms.ModuleAtom;
|
7 | 12 | import at.ac.tuwien.kr.alpha.api.programs.literals.Literal;
|
8 | 13 | import at.ac.tuwien.kr.alpha.api.programs.literals.ModuleLiteral;
|
9 | 14 | import at.ac.tuwien.kr.alpha.api.programs.modules.Module;
|
10 | 15 | import at.ac.tuwien.kr.alpha.api.programs.rules.NormalRule;
|
| 16 | +import at.ac.tuwien.kr.alpha.api.programs.rules.Rule; |
11 | 17 | import at.ac.tuwien.kr.alpha.api.programs.rules.heads.NormalHead;
|
| 18 | +import at.ac.tuwien.kr.alpha.api.programs.terms.Term; |
| 19 | +import at.ac.tuwien.kr.alpha.commons.programs.Programs; |
| 20 | +import at.ac.tuwien.kr.alpha.commons.programs.atoms.Atoms; |
12 | 21 | import at.ac.tuwien.kr.alpha.commons.programs.rules.Rules;
|
| 22 | +import org.apache.commons.collections4.ListUtils; |
| 23 | +import org.apache.commons.collections4.SetUtils; |
13 | 24 |
|
| 25 | +import java.util.Collections; |
14 | 26 | import java.util.List;
|
15 | 27 | import java.util.Map;
|
16 | 28 | import java.util.Set;
|
17 | 29 | import java.util.function.Function;
|
18 | 30 | import java.util.stream.Collectors;
|
| 31 | +import java.util.stream.Stream; |
19 | 32 |
|
20 | 33 | /**
|
21 | 34 | * Program transformation that translates {@link at.ac.tuwien.kr.alpha.api.programs.literals.ModuleLiteral}s into
|
@@ -57,17 +70,57 @@ private NormalRule linkModuleAtoms(NormalRule rule, Map<String, Module> moduleTa
|
57 | 70 | return Rules.newNormalRule(newHead, newBody);
|
58 | 71 | }
|
59 | 72 |
|
60 |
| - private ExternalAtom translateModuleAtom(ModuleAtom moduleAtom, Map<String, Module> moduleTable) { |
61 |
| - if (!moduleTable.containsKey(moduleAtom.getModuleName())) { |
62 |
| - throw new IllegalArgumentException("Module " + moduleAtom.getModuleName() + " not found in module table."); |
| 73 | + private ExternalAtom translateModuleAtom(ModuleAtom atom, Map<String, Module> moduleTable) { |
| 74 | + if (!moduleTable.containsKey(atom.getModuleName())) { |
| 75 | + throw new IllegalArgumentException("Module " + atom.getModuleName() + " not found in module table."); |
63 | 76 | }
|
64 |
| - Module implementationModule = moduleTable.get(moduleAtom.getModuleName()); |
65 |
| - //implementationModule. |
66 |
| - return null; |
| 77 | + Module definition = moduleTable.get(atom.getModuleName()); |
| 78 | + // verify inputs |
| 79 | + Predicate inputSpec = definition.getInputSpec(); |
| 80 | + if (atom.getInput().size() != inputSpec.getArity()) { |
| 81 | + throw new IllegalArgumentException("Module " + atom.getModuleName() + " expects " + inputSpec.getArity() + " inputs, but " + atom.getInput().size() + " were given."); |
| 82 | + } |
| 83 | + NormalProgram normalizedImplementation = moduleRunner.normalizeProgram(definition.getImplementation()); |
| 84 | + // verify outputs |
| 85 | + Set<Predicate> outputSpec = definition.getOutputSpec(); |
| 86 | + int expectedOutputTerms; |
| 87 | + if (outputSpec.isEmpty()) { |
| 88 | + expectedOutputTerms = calculateOutputPredicates(normalizedImplementation).size(); |
| 89 | + } else { |
| 90 | + expectedOutputTerms = outputSpec.size(); |
| 91 | + } |
| 92 | + if (atom.getOutput().size() != expectedOutputTerms) { |
| 93 | + throw new IllegalArgumentException("Module " + atom.getModuleName() + " expects " + outputSpec.size() + " outputs, but " + atom.getOutput().size() + " were given."); |
| 94 | + } |
| 95 | + // create the actual interpretation |
| 96 | + PredicateInterpretation interpretation = terms -> { |
| 97 | + BasicAtom inputAtom = Atoms.newBasicAtom(inputSpec, terms); |
| 98 | + NormalProgram program = Programs.newNormalProgram(normalizedImplementation.getRules(), |
| 99 | + ListUtils.union(List.of(inputAtom), normalizedImplementation.getFacts()), normalizedImplementation.getInlineDirectives()); |
| 100 | + java.util.function.Predicate<Predicate> filter = outputSpec.isEmpty() ? p -> true : outputSpec::contains; |
| 101 | + Stream<AnswerSet> answerSets = moduleRunner.solve(program, filter); |
| 102 | + if (atom.getInstantiationMode().requestedAnswerSets().isPresent()) { |
| 103 | + answerSets = answerSets.limit(atom.getInstantiationMode().requestedAnswerSets().get()); |
| 104 | + } |
| 105 | + return answerSets.map(ModuleLinker::answerSetToTerms).collect(Collectors.toSet()); |
| 106 | + }; |
| 107 | + return Atoms.newExternalAtom(atom.getPredicate(), interpretation, atom.getInput(), atom.getOutput()); |
67 | 108 | }
|
68 | 109 |
|
69 | 110 | private static boolean containsModuleAtom(NormalRule rule) {
|
70 | 111 | return rule.getBody().stream().anyMatch(literal -> literal instanceof ModuleLiteral);
|
71 | 112 | }
|
72 | 113 |
|
| 114 | + private static Set<Predicate> calculateOutputPredicates(NormalProgram program) { |
| 115 | + return SetUtils.union(program.getFacts().stream().map(Atom::getPredicate).collect(Collectors.toSet()), |
| 116 | + program.getRules().stream() |
| 117 | + .filter(java.util.function.Predicate.not(Rule::isConstraint)) |
| 118 | + .map(Rule::getHead).map(NormalHead::getAtom).map(Atom::getPredicate) |
| 119 | + .collect(Collectors.toSet())); |
| 120 | + } |
| 121 | + |
| 122 | + private static List<Term> answerSetToTerms(AnswerSet answerSet) { |
| 123 | + return Collections.emptyList(); // TODO |
| 124 | + } |
| 125 | + |
73 | 126 | }
|
0 commit comments