Skip to content

Commit

Permalink
[MOREL-138] Type annotations in patterns, function declarations and e…
Browse files Browse the repository at this point in the history
…xpressions

Fixes hydromatic#138
  • Loading branch information
julianhyde committed Sep 25, 2022
1 parent 60bf996 commit a7194e4
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 22 deletions.
9 changes: 5 additions & 4 deletions docs/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ In Morel but not Standard ML:
In Standard ML but not in Morel:
* `word` constant
* `longid` identifier
* type annotations ("`:` *typ*") (appears in expressions, patterns, and *funmatch*)
* references (`ref` and operators `!` and `:=`)
* exceptions (`raise`, `handle`, `exception`)
* `while` loop
Expand Down Expand Up @@ -131,6 +130,7 @@ In Standard ML but not in Morel:
| '<b>(</b>' <i>exp<sub>1</sub></i> <b>;</b> ... <b>;</b> <i>exp<sub>n</sub></i> '<b>)</b>' sequence (n &ge; 2)
| <b>let</b> <i>dec</i> <b>in</b> <i>exp<sub>1</sub></i> ; ... ; <i>exp<sub>n</sub></i> <b>end</b>
local declaration (n ≥ 1)
| <i>exp</i> <b>:</b> <i>type</i> type annotation
| <i>exp<sub>1</sub></i> <b>andalso</b> <i>exp<sub>2</sub></i> conjunction
| <i>exp<sub>1</sub></i> <b>orelse</b> <i>exp<sub>2</sub></i> disjunction
| <b>if</b> <i>exp<sub>1</sub></i> <b>then</b> <i>exp<sub>2</sub></i> <b>else</b> <i>exp<sub>3</sub></i>
Expand Down Expand Up @@ -173,6 +173,7 @@ In Standard ML but not in Morel:
| '<b>(</b>' <i>pat<sub>1</sub></i> , ... , <i>pat<sub>n</sub></i> '<b>)</b>' tuple (n &ne; 1)
| <b>{</b> [ <i>patrow</i> ] <b>}</b> record
| '<b>[</b>' <i>pat<sub>1</sub></i> <b>,</b> ... <b>,</b> <i>pat<sub>n</sub></i> '<b>]</b>' list (n &ge; 0)
| <i>pat</i> <b>:</b> <i>type</i> type annotation
| <i>id</i> <b>as</b> <i>pat</i> layered
<i>patrow</i> &rarr; '<b>...</b>' wildcard
| <i>lab</i> <b>=</b> <i>pat</i> [<b>,</b> <i>patrow</i>] pattern
Expand Down Expand Up @@ -206,10 +207,10 @@ In Standard ML but not in Morel:
<i>funbind</i> &rarr; <i>funmatch</i> [ <b>and</b> <i>funmatch</i> ]*
clausal function
<i>funmatch</i> &rarr; <i>funmatchItem</i> [ '<b>|</b>' funmatchItem ]*
<i>funmatchItem</i> &rarr; [ <b>op</b> ] <i>id</i> <i>pat<sub>1</sub></i> ... <i>pat<sub>n</sub></i> <b>=</b> <i>exp</i>
<i>funmatchItem</i> &rarr; [ <b>op</b> ] <i>id</i> <i>pat<sub>1</sub></i> ... <i>pat<sub>n</sub></i> [ <b>:</b> <i>type</i> ] <b>=</b> <i>exp</i>
nonfix (n &ge; 1)
| <i>pat<sub>1</sub></i> <i>id</i> <i>pat<sub>2</sub></i> <b>=</b> <i>exp</i> infix
| '<b>(</b>' <i>pat<sub>1</sub></i> <i>id</i> <i>pat<sub>2</sub></i> '<b>)</b>' <i>pat'<sub>1</sub></i> ... <i>pat'<sub>n</sub></i> = <i>exp</i>
| <i>pat<sub>1</sub></i> <i>id</i> <i>pat<sub>2</sub></i> [ <b>:</b> <i>type</i> ] <b>=</b> <i>exp</i> infix
| '<b>(</b>' <i>pat<sub>1</sub></i> <i>id</i> <i>pat<sub>2</sub></i> '<b>)</b>' <i>pat'<sub>1</sub></i> ... <i>pat'<sub>n</sub></i> [ <b>:</b> <i>type</i> ] = <i>exp</i>
infix (n &ge; 0)
<i>datbind</i> &rarr; <i>datbindItem</i> [ <b>and</b> <i>datbindItem</i> ]*
data type
Expand Down
7 changes: 5 additions & 2 deletions src/main/java/net/hydromatic/morel/ast/Ast.java
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ public static class AnnotatedExp extends Exp {
public final Exp exp;

/** Creates a type annotation. */
AnnotatedExp(Pos pos, Type type, Exp exp) {
AnnotatedExp(Pos pos, Exp exp, Type type) {
super(pos, Op.ANNOTATED_EXP);
this.type = requireNonNull(type);
this.exp = requireNonNull(exp);
Expand Down Expand Up @@ -1087,12 +1087,15 @@ AstWriter unparse(AstWriter w, int left, int right) {
public static class FunMatch extends AstNode {
public final String name;
public final List<Pat> patList;
@Nullable public final Type returnType;
public final Exp exp;

FunMatch(Pos pos, String name, ImmutableList<Pat> patList, Exp exp) {
FunMatch(Pos pos, String name, ImmutableList<Pat> patList,
@Nullable Type returnType, Exp exp) {
super(pos, Op.FUN_MATCH);
this.name = name;
this.patList = patList;
this.returnType = returnType;
this.exp = exp;
}

Expand Down
10 changes: 6 additions & 4 deletions src/main/java/net/hydromatic/morel/ast/AstBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,10 @@ public Ast.FunBind funBind(Pos pos,
}

public Ast.FunMatch funMatch(Pos pos, String name,
Iterable<? extends Ast.Pat> patList, Ast.Exp exp) {
return new Ast.FunMatch(pos, name, ImmutableList.copyOf(patList), exp);
Iterable<? extends Ast.Pat> patList, @Nullable Ast.Type returnType,
Ast.Exp exp) {
return new Ast.FunMatch(pos, name, ImmutableList.copyOf(patList),
returnType, exp);
}

public Ast.Apply apply(Ast.Exp fn, Ast.Exp arg) {
Expand All @@ -379,8 +381,8 @@ public Ast.InfixPat infixPat(Pos pos, Op op, Ast.Pat p0, Ast.Pat p1) {
return new Ast.InfixPat(pos, op, p0, p1);
}

public Ast.Exp annotatedExp(Pos pos, Ast.Type type, Ast.Exp expression) {
return new Ast.AnnotatedExp(pos, type, expression);
public Ast.Exp annotatedExp(Pos pos, Ast.Exp expression, Ast.Type type) {
return new Ast.AnnotatedExp(pos, expression, type);
}

public Ast.Exp infixCall(Pos pos, Op op, Ast.Exp a0, Ast.Exp a1) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/net/hydromatic/morel/ast/Op.java
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public enum Op {
FORALL_TYPE,

// annotated expression "e: t"
ANNOTATED_EXP(" : "),
ANNOTATED_EXP(" : ", 0),

TIMES(" * ", 7),
DIVIDE(" / ", 7),
Expand Down
9 changes: 5 additions & 4 deletions src/main/java/net/hydromatic/morel/ast/Shuttle.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,8 @@ protected Ast.Id visit(Ast.Id id) {
}

protected Ast.Exp visit(Ast.AnnotatedExp annotatedExp) {
return ast.annotatedExp(annotatedExp.pos,
annotatedExp.type.accept(this),
annotatedExp.exp.accept(this));
return ast.annotatedExp(annotatedExp.pos, annotatedExp.exp.accept(this),
annotatedExp.type.accept(this));
}

protected Ast.Exp visit(Ast.If ifThenElse) {
Expand Down Expand Up @@ -206,7 +205,9 @@ protected Ast.FunBind visit(Ast.FunBind funBind) {

protected Ast.FunMatch visit(Ast.FunMatch funMatch) {
return ast.funMatch(funMatch.pos, funMatch.name,
visitList(funMatch.patList), funMatch.exp.accept(this));
visitList(funMatch.patList),
funMatch.returnType == null ? null : funMatch.returnType.accept(this),
funMatch.exp.accept(this));
}

protected Ast.ValDecl visit(Ast.ValDecl valDecl) {
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/net/hydromatic/morel/compile/Resolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ private Core.Exp toCore(Ast.Exp exp) {
return core.stringLiteral((String) ((Ast.Literal) exp).value);
case UNIT_LITERAL:
return core.unitLiteral();
case ANNOTATED_EXP:
return toCore(((Ast.AnnotatedExp) exp).exp);
case ID:
return toCore((Ast.Id) exp);
case ANDALSO:
Expand Down Expand Up @@ -535,6 +537,11 @@ private Core.Pat toCore(Ast.Pat pat, Type type, Type targetType) {
final Ast.AsPat asPat = (Ast.AsPat) pat;
return core.asPat(type, asPat.id.name, nameGenerator, toCore(asPat.pat));

case ANNOTATED_PAT:
// There is no annotated pat in core, because all patterns have types.
final Ast.AnnotatedPat annotatedPat = (Ast.AnnotatedPat) pat;
return toCore(annotatedPat.pat);

case CON_PAT:
final Ast.ConPat conPat = (Ast.ConPat) pat;
return core.conPat(type, conPat.tyCon.name, toCore(conPat.pat));
Expand Down
58 changes: 53 additions & 5 deletions src/main/java/net/hydromatic/morel/compile/TypeResolver.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import org.apache.calcite.util.Holder;
Expand All @@ -71,6 +72,7 @@
import java.util.stream.Collectors;

import static net.hydromatic.morel.ast.AstBuilder.ast;
import static net.hydromatic.morel.type.RecordType.ORDERING;
import static net.hydromatic.morel.util.Static.skip;
import static net.hydromatic.morel.util.Static.toImmutableList;

Expand Down Expand Up @@ -194,6 +196,12 @@ private Ast.Exp deduceType(TypeEnv env, Ast.Exp node, Unifier.Variable v) {
case UNIT_LITERAL:
return reg(node, v, toTerm(PrimitiveType.UNIT));

case ANNOTATED_EXP:
final Ast.AnnotatedExp annotatedExp = (Ast.AnnotatedExp) node;
final Type type = toType(annotatedExp.type, typeSystem);
deduceType(env, annotatedExp.exp, v);
return reg(node, v, toTerm(type, Subst.EMPTY));

case ANDALSO:
case ORELSE:
return infix(env, (Ast.InfixCall) node, v, PrimitiveType.BOOL);
Expand Down Expand Up @@ -795,14 +803,31 @@ private Type toType(Ast.Type type) {
final Ast.TupleType tupleType = (Ast.TupleType) type;
return typeSystem.tupleType(toTypes(tupleType.types));

case RECORD_TYPE:
final Ast.RecordType recordType = (Ast.RecordType) type;
final ImmutableSortedMap.Builder<String, Type> argNameTypes =
ImmutableSortedMap.orderedBy(ORDERING);
recordType.fieldTypes.forEach((name, t) ->
argNameTypes.put(name, toType(t)));
return typeSystem.recordType(argNameTypes.build());

case FUNCTION_TYPE:
final Ast.FunctionType functionType = (Ast.FunctionType) type;
final Type paramType = toType(functionType.paramType, typeSystem);
final Type resultType = toType(functionType.resultType, typeSystem);
return typeSystem.fnType(paramType, resultType);

case NAMED_TYPE:
final Ast.NamedType namedType = (Ast.NamedType) type;
final List<Type> typeList = toTypes(namedType.types);
if (namedType.name.equals(LIST_TY_CON) && typeList.size() == 1) {
// TODO: make 'list' a regular generic type
return typeSystem.listType(typeList.get(0));
}
final Type genericType = typeSystem.lookup(namedType.name);
if (namedType.types.isEmpty()) {
return genericType;
}
final List<Type> typeList = namedType.types.stream().map(this::toType)
.collect(toImmutableList());
return typeSystem.apply(genericType, typeList);

case TY_VAR:
Expand All @@ -811,7 +836,7 @@ private Type toType(Ast.Type type) {
name -> typeSystem.typeVariable(tyVarMap.size()));

default:
throw new AssertionError("cannot convert type " + type);
throw new AssertionError("cannot convert type " + type + " " + type.op);
}
}

Expand Down Expand Up @@ -850,22 +875,39 @@ private Ast.ValDecl toValDecl(TypeEnv env, Ast.FunDecl funDecl) {
private Ast.ValBind toValBind(TypeEnv env, Ast.FunBind funBind) {
final List<Ast.Pat> vars;
Ast.Exp exp;
Ast.Type returnType = null;
if (funBind.matchList.size() == 1) {
exp = funBind.matchList.get(0).exp;
vars = funBind.matchList.get(0).patList;
final Ast.FunMatch funMatch = funBind.matchList.get(0);
exp = funMatch.exp;
vars = funMatch.patList;
returnType = funMatch.returnType;
} else {
final List<String> varNames =
MapList.of(funBind.matchList.get(0).patList.size(),
index -> "v" + index);
vars = Lists.transform(varNames, v -> ast.idPat(Pos.ZERO, v));
final List<Ast.Match> matchList = new ArrayList<>();
Pos prevReturnTypePos = null;
for (Ast.FunMatch funMatch : funBind.matchList) {
matchList.add(
ast.match(funMatch.pos, patTuple(env, funMatch.patList),
funMatch.exp));
if (funMatch.returnType != null) {
if (returnType != null
&& !returnType.equals(funMatch.returnType)) {
throw new CompileException("parameter or result constraints of "
+ "clauses don't agree [tycon mismatch]", false,
prevReturnTypePos.plus(funMatch.pos));
}
returnType = funMatch.returnType;
prevReturnTypePos = funMatch.pos;
}
}
exp = ast.caseOf(Pos.ZERO, idTuple(varNames), matchList);
}
if (returnType != null) {
exp = ast.annotatedExp(exp.pos, exp, returnType);
}
final Pos pos = funBind.pos;
for (Ast.Pat var : Lists.reverse(vars)) {
exp = ast.fn(pos, ast.match(pos, var, exp));
Expand Down Expand Up @@ -963,6 +1005,12 @@ private Ast.Pat deducePatType(TypeEnv env, Ast.Pat pat,
deducePatType(env, asPat.pat, termMap, null, v);
return reg(pat, null, v);

case ANNOTATED_PAT:
final Ast.AnnotatedPat annotatedPat = (Ast.AnnotatedPat) pat;
final Type type = toType(annotatedPat.type, typeSystem);
deducePatType(env, annotatedPat.pat, termMap, null, v);
return reg(pat, v, toTerm(type, Subst.EMPTY));

case TUPLE_PAT:
final List<Unifier.Term> typeTerms = new ArrayList<>();
final Ast.TuplePat tuple = (Ast.TuplePat) pat;
Expand Down
14 changes: 12 additions & 2 deletions src/main/javacc/MorelParser.jj
Original file line number Diff line number Diff line change
Expand Up @@ -761,9 +761,16 @@ Exp expression1() :
Exp expression() :
{
Exp e;
Type t;
}
{
e = expression1() { return e; }
e = expression1()
(
<COLON> t = type() {
e = ast.annotatedExp(e.pos.plus(t.pos), e, t);
}
)*
{ return e; }
}

/** List of expressions "e1 as id1, e2 as id2, e3 as id3". */
Expand Down Expand Up @@ -1077,13 +1084,16 @@ void funMatch(List<FunMatch> list) :
Ast.Pat pat;
final List<Ast.Pat> patList = new ArrayList<>();
final Ast.Exp expression;
Ast.Type returnType = null;
}
{
id = identifier()
( pat = atomPat() { patList.add(pat); } )+
[ <COLON> returnType = type() ]
<EQ> expression = expression() {
list.add(
ast.funMatch(id.pos.plus(expression.pos), id.name, patList, expression));
ast.funMatch(id.pos.plus(expression.pos), id.name, patList, returnType,
expression));
}
}

Expand Down
11 changes: 11 additions & 0 deletions src/test/java/net/hydromatic/morel/MainTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,13 @@ public class MainTest {
containsString(
"Encountered \" \"rec\" \"rec \"\" at line 1, column 19."));

// : is right-associative and low precedence
ml("1 : int : int").assertParseSame();
ml("(2 : int) + 1 : int").assertParseSame();
ml("(2 : int) + (1 : int) : int").assertParseSame();
ml("((2 : int) + (1 : int)) : int")
.assertParse("(2 : int) + (1 : int) : int");

// pattern
ml("let val (x, y) = (1, 2) in x + y end").assertParseSame();
ml("let val w as (x, y) = (1, 2) in #1 w + #2 w + x + y end")
Expand Down Expand Up @@ -456,6 +463,10 @@ public class MainTest {
ml("fn x => case x of 0 => 1 | _ => 2").assertType("int -> int");
ml("fn x => case x of 0 => \"zero\" | _ => \"nonzero\"")
.assertType("int -> string");
ml("fn x: int => true").assertType("int -> bool");
ml("fn x: int * int => true").assertType("int * int -> bool");
ml("fn x: int * string => (false, #2 x)")
.assertType("int * string -> bool * string");
}

@Test void testTypeFnTuple() {
Expand Down
46 changes: 46 additions & 0 deletions src/test/resources/script/type.sml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,52 @@
{} = ();
() = {};
(*) Expressions with type annotations
1: int;
(2, true): int * bool;
[]: int list;
(1: int) + (2: int);
1 + (2: int);
(1: int) + 2;
String.size "abc": int;
String.size ("abc": string);
String.size ("abc": string): int;
(*) Patterns with type annotations
val x: int = 1;
val y: bool = true;
val p: int * bool = (1, true);
val empty: int list = [];
(*) Function declarations with type annotations
fun f (x: int, y) = x + y;
fun f (x, y: int) = x + y;
fun f3 (e: {name: string, deptno:int}) = e.deptno;
fun hello (name: string, code: int): string = "hello!";
fun hello2 (name: string) (code : int): string = "hello!";
val hello3: string * int -> string =
fn (name, code) => "hello!";
fun l1 [] = 0 | l1 ((h: string) :: t) = 1 + (l1 t);
fun l2 [] = 0 | l2 (h :: (t: bool list)) = 1 + (l2 t);
fun countOption (NONE: string option) = 0
| countOption (SOME x) = 1;
fun countOption2 NONE: int = 0
| countOption2 (SOME x) = 1;
fun firstOrSecond (e1 :: e2 :: rest): int = e2
| firstOrSecond (e1 :: rest) = e1;
(*
sml-nj gives the following error:
stdIn:1.6-2.32 Error: parameter or result constraints of clauses don't agree [tycon mismatch]
this clause: 'Z option -> string list
previous clauses: 'Z option -> int list
in declaration:
f = (fn NONE => nil: int list
| SOME x => nil: string list)
*)
fun f NONE:int list = []
| f (SOME x):string list = [];
(*) Function with unit arg
fun one () = 1;
one ();
Expand Down
Loading

0 comments on commit a7194e4

Please sign in to comment.