Skip to content

Commit

Permalink
feat: computed fields support for functions and one to many relations…
Browse files Browse the repository at this point in the history
…hip lookups (#1711)
  • Loading branch information
davenewza authored Jan 30, 2025
1 parent 47c627d commit 2f8644b
Show file tree
Hide file tree
Showing 41 changed files with 1,339 additions and 135 deletions.
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
"-d",
"${input:directory}",
"--pattern",
"${input:pattern}"
"${input:pattern}",
]
},
{
Expand Down
9 changes: 9 additions & 0 deletions expressions/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ var messageConverters = []errorConverter{
undeclaredVariableReference,
unrecognisedToken,
mismatchedInput,
noFunctionOverload,
}

type errorConverter struct {
Expand Down Expand Up @@ -81,6 +82,14 @@ var mismatchedInput = errorConverter{
},
}

var noFunctionOverload = errorConverter{
Regex: `found no matching overload for '(.+)' applied to '\((.+)\)'`,
Construct: func(expectedReturnType *types.Type, values []string) string {
// We should provide more context here for each function (i.e. arguments and supported types)
return fmt.Sprintf("%s not supported as an argument for the function '%s'", mapOperator(values[1]), values[0])
},
}

func mapOperator(op string) string {
switch op {
case operators.In:
Expand Down
24 changes: 24 additions & 0 deletions expressions/options/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,30 @@ func WithArithmeticOperators() expressions.Option {
}
}

func WithFunctions() expressions.Option {
return func(p *expressions.Parser) error {
typeParamA := cel.TypeParamType("A")
var err error
p.CelEnv, err = p.CelEnv.Extend(
cel.Function(typing.FunctionCount, cel.Overload("count", []*types.Type{typeParamA}, typing.Number)),
cel.Function(typing.FunctionSum, cel.Overload("sum_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)),
cel.Function(typing.FunctionSum, cel.Overload("sum_number", []*types.Type{typing.NumberArray}, typing.Number)),
cel.Function(typing.FunctionAvg, cel.Overload("avg_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)),
cel.Function(typing.FunctionAvg, cel.Overload("avg_number", []*types.Type{typing.NumberArray}, typing.Number)),
cel.Function(typing.FunctionMin, cel.Overload("min_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)),
cel.Function(typing.FunctionMin, cel.Overload("min_number", []*types.Type{typing.NumberArray}, typing.Number)),
cel.Function(typing.FunctionMax, cel.Overload("max_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)),
cel.Function(typing.FunctionMax, cel.Overload("max_number", []*types.Type{typing.NumberArray}, typing.Number)),
cel.Function(typing.FunctionMedian, cel.Overload("median_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)),
cel.Function(typing.FunctionMedian, cel.Overload("median_number", []*types.Type{typing.NumberArray}, typing.Number)))
if err != nil {
return err
}

return err
}
}

// WithReturnTypeAssertion will check that the expression evaluates to a specific type
func WithReturnTypeAssertion(returnType string, asArray bool) expressions.Option {
return func(p *expressions.Parser) error {
Expand Down
8 changes: 8 additions & 0 deletions expressions/resolve/field_lookups.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ func (v *fieldLookupsGen) EndTerm(parenthesis bool) error {
return nil
}

func (v *fieldLookupsGen) StartFunction(name string) error {
return nil
}

func (v *fieldLookupsGen) EndFunction() error {
return nil
}

func (v *fieldLookupsGen) VisitAnd() error {
return nil
}
Expand Down
8 changes: 8 additions & 0 deletions expressions/resolve/ident.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ func (v *identGen) EndTerm(parenthesis bool) error {
return nil
}

func (v *identGen) StartFunction(name string) error {
return nil
}

func (v *identGen) EndFunction() error {
return nil
}

func (v *identGen) VisitAnd() error {
return ErrExpressionNotValidIdent
}
Expand Down
8 changes: 8 additions & 0 deletions expressions/resolve/ident_array.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ func (v *identArrayGen) EndTerm(parenthesis bool) error {
return nil
}

func (v *identArrayGen) StartFunction(name string) error {
return nil
}

func (v *identArrayGen) EndFunction() error {
return nil
}

func (v *identArrayGen) VisitAnd() error {
return ErrExpressionNotValidIdentArray
}
Expand Down
8 changes: 8 additions & 0 deletions expressions/resolve/operands.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ func (v *operandsResolver) EndTerm(parenthesis bool) error {
return nil
}

func (v *operandsResolver) StartFunction(name string) error {
return nil
}

func (v *operandsResolver) EndFunction() error {
return nil
}

func (v *operandsResolver) VisitAnd() error {
return nil
}
Expand Down
38 changes: 37 additions & 1 deletion expressions/resolve/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ type Visitor[T any] interface {
StartTerm(nested bool) error
// EndTerm is called when a term is finished
EndTerm(nested bool) error
// StartFunction is called when a function is started
StartFunction(name string) error
// EndFunction is called when a function is finished
EndFunction() error
// VisitAnd is called when an 'and' operator is visited between conditions
VisitAnd() error
// VisitAnd is called when an 'or' operator is visited between conditions
Expand Down Expand Up @@ -172,12 +176,44 @@ func (w *CelVisitor[T]) callExpr(expr *exprpb.Expr) error {

err = w.binaryCall(expr)
default:
return errors.New("function calls not supported yet")
err = w.functionCall(expr)
}

return err
}

func (w *CelVisitor[T]) functionCall(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
fun := c.GetFunction()
target := c.GetTarget()
args := c.GetArgs()

err := w.visitor.StartFunction(fun)
if err != nil {
return err
}

if target != nil {
err := w.eval(target, false, false)
if err != nil {
return err
}
}
for _, arg := range args {
err := w.eval(arg, false, false)
if err != nil {
return err
}
}

err = w.visitor.EndFunction()
if err != nil {
return err
}

return nil
}

func (w *CelVisitor[T]) binaryCall(expr *exprpb.Expr) error {
c := expr.GetCallExpr()
op := c.GetFunction()
Expand Down
9 changes: 9 additions & 0 deletions expressions/typing/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ var (
DateArray = cel.OpaqueType(fmt.Sprintf("%s[]", parser.FieldTypeDate))
)

var (
FunctionSum = "SUM"
FunctionCount = "COUNT"
FunctionAvg = "AVG"
FunctionMedian = "MEDIAN"
FunctionMin = "MIN"
FunctionMax = "MAX"
)

var (
Role = cel.OpaqueType("_Role")
)
Expand Down
40 changes: 40 additions & 0 deletions integration/testdata/computed_fields_circular_ref/schema.keel
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
model Account {
fields {
balance Decimal @computed(SUM(account.transactions.total))
transactions Transaction[]
standardTransactionFee Decimal
}
actions {
create createAccount() with (standardTransactionFee) {
@permission(expression: true)
}
list listAccounts() {
@permission(expression: true)
}
get getAccount(id) {
@permission(expression: true)
}
}
}

model Transaction {
fields {
account Account
amount Decimal
fee Decimal? @computed(transaction.account.standardTransactionFee)
total Decimal? @computed(transaction.amount + transaction.fee)
}
actions {
create createTransaction() with (account.id, amount) {
@permission(expression: true)
}
get getTransaction(id) {
@permission(expression: true)
}
list listTransactions() {
@permission(expression: true)
}
}
}


19 changes: 19 additions & 0 deletions integration/testdata/computed_fields_circular_ref/tests.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { test, expect, beforeEach } from "vitest";
import { models, resetDatabase, actions } from "@teamkeel/testing";

beforeEach(resetDatabase);

test("computed fields - circular reference", async () => {
const account = await models.account.create({ standardTransactionFee: 10 });
const transaction1 = await models.transaction.create({
accountId: account.id,
amount: 100,
});
const transaction2 = await models.transaction.create({
accountId: account.id,
amount: 200,
});

const getAccount = await models.account.findOne({ id: account.id });
expect(getAccount!.balance).toBe(320);
});
50 changes: 50 additions & 0 deletions integration/testdata/computed_fields_one_to_many/schema.keel
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model Invoice {
fields {
items Item[]
shipping Number
total Decimal @computed(SUM(invoice.items.product.price) + invoice.shipping)
}
actions {
get getInvoice(id) {
@permission(expression: true)
}
create createInvoice() with (shipping, items.product.id?) {
@permission(expression: true)
}
list listInvoices() {
@permission(expression: true)
}
}
}

model Item {
fields {
invoice Invoice
product Product
}
actions {
get getItem(id) {
@permission(expression: true)
}
create createItem() with (product.id, invoice.id) {
@permission(expression: true)
}
}
}

model Product {
fields {
price Decimal
}
actions {
get getProduct(id) {
@permission(expression: true)
}
create createProduct() with (price) {
@permission(expression: true)
}
list listProducts() {
@permission(expression: true)
}
}
}
Loading

0 comments on commit 2f8644b

Please sign in to comment.