diff --git a/.vscode/launch.json b/.vscode/launch.json index 77aed6e66..210497ea5 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -43,7 +43,7 @@ "-d", "${input:directory}", "--pattern", - "${input:pattern}" + "${input:pattern}", ] }, { diff --git a/expressions/errors.go b/expressions/errors.go index 7e2b8f672..18f47a82a 100644 --- a/expressions/errors.go +++ b/expressions/errors.go @@ -18,6 +18,7 @@ var messageConverters = []errorConverter{ undeclaredVariableReference, unrecognisedToken, mismatchedInput, + noFunctionOverload, } type errorConverter struct { @@ -81,6 +82,14 @@ var mismatchedInput = errorConverter{ }, } +var noFunctionOverload = errorConverter{ + Regex: `found no matching overload for '(.+)' applied to '\((.+)\)'`, + Construct: func(expectedReturnType *types.Type, values []string) string { + // We should provide more context here for each function (i.e. arguments and supported types) + return fmt.Sprintf("%s not supported as an argument for the function '%s'", mapOperator(values[1]), values[0]) + }, +} + func mapOperator(op string) string { switch op { case operators.In: diff --git a/expressions/options/options.go b/expressions/options/options.go index e82fb769e..670074dc1 100644 --- a/expressions/options/options.go +++ b/expressions/options/options.go @@ -434,6 +434,30 @@ func WithArithmeticOperators() expressions.Option { } } +func WithFunctions() expressions.Option { + return func(p *expressions.Parser) error { + typeParamA := cel.TypeParamType("A") + var err error + p.CelEnv, err = p.CelEnv.Extend( + cel.Function(typing.FunctionCount, cel.Overload("count", []*types.Type{typeParamA}, typing.Number)), + cel.Function(typing.FunctionSum, cel.Overload("sum_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)), + cel.Function(typing.FunctionSum, cel.Overload("sum_number", []*types.Type{typing.NumberArray}, typing.Number)), + cel.Function(typing.FunctionAvg, cel.Overload("avg_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)), + cel.Function(typing.FunctionAvg, cel.Overload("avg_number", []*types.Type{typing.NumberArray}, typing.Number)), + cel.Function(typing.FunctionMin, cel.Overload("min_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)), + cel.Function(typing.FunctionMin, cel.Overload("min_number", []*types.Type{typing.NumberArray}, typing.Number)), + cel.Function(typing.FunctionMax, cel.Overload("max_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)), + cel.Function(typing.FunctionMax, cel.Overload("max_number", []*types.Type{typing.NumberArray}, typing.Number)), + cel.Function(typing.FunctionMedian, cel.Overload("median_decimal", []*types.Type{typing.DecimalArray}, typing.Decimal)), + cel.Function(typing.FunctionMedian, cel.Overload("median_number", []*types.Type{typing.NumberArray}, typing.Number))) + if err != nil { + return err + } + + return err + } +} + // WithReturnTypeAssertion will check that the expression evaluates to a specific type func WithReturnTypeAssertion(returnType string, asArray bool) expressions.Option { return func(p *expressions.Parser) error { diff --git a/expressions/resolve/field_lookups.go b/expressions/resolve/field_lookups.go index bef1541d0..b679646be 100644 --- a/expressions/resolve/field_lookups.go +++ b/expressions/resolve/field_lookups.go @@ -58,6 +58,14 @@ func (v *fieldLookupsGen) EndTerm(parenthesis bool) error { return nil } +func (v *fieldLookupsGen) StartFunction(name string) error { + return nil +} + +func (v *fieldLookupsGen) EndFunction() error { + return nil +} + func (v *fieldLookupsGen) VisitAnd() error { return nil } diff --git a/expressions/resolve/ident.go b/expressions/resolve/ident.go index 120b6528f..dde4d2d66 100644 --- a/expressions/resolve/ident.go +++ b/expressions/resolve/ident.go @@ -36,6 +36,14 @@ func (v *identGen) EndTerm(parenthesis bool) error { return nil } +func (v *identGen) StartFunction(name string) error { + return nil +} + +func (v *identGen) EndFunction() error { + return nil +} + func (v *identGen) VisitAnd() error { return ErrExpressionNotValidIdent } diff --git a/expressions/resolve/ident_array.go b/expressions/resolve/ident_array.go index 93f8e089d..bfd654ff3 100644 --- a/expressions/resolve/ident_array.go +++ b/expressions/resolve/ident_array.go @@ -37,6 +37,14 @@ func (v *identArrayGen) EndTerm(parenthesis bool) error { return nil } +func (v *identArrayGen) StartFunction(name string) error { + return nil +} + +func (v *identArrayGen) EndFunction() error { + return nil +} + func (v *identArrayGen) VisitAnd() error { return ErrExpressionNotValidIdentArray } diff --git a/expressions/resolve/operands.go b/expressions/resolve/operands.go index 1175ac104..09940e545 100644 --- a/expressions/resolve/operands.go +++ b/expressions/resolve/operands.go @@ -32,6 +32,14 @@ func (v *operandsResolver) EndTerm(parenthesis bool) error { return nil } +func (v *operandsResolver) StartFunction(name string) error { + return nil +} + +func (v *operandsResolver) EndFunction() error { + return nil +} + func (v *operandsResolver) VisitAnd() error { return nil } diff --git a/expressions/resolve/visitor.go b/expressions/resolve/visitor.go index ad2b5408a..5c00264f4 100644 --- a/expressions/resolve/visitor.go +++ b/expressions/resolve/visitor.go @@ -21,6 +21,10 @@ type Visitor[T any] interface { StartTerm(nested bool) error // EndTerm is called when a term is finished EndTerm(nested bool) error + // StartFunction is called when a function is started + StartFunction(name string) error + // EndFunction is called when a function is finished + EndFunction() error // VisitAnd is called when an 'and' operator is visited between conditions VisitAnd() error // VisitAnd is called when an 'or' operator is visited between conditions @@ -172,12 +176,44 @@ func (w *CelVisitor[T]) callExpr(expr *exprpb.Expr) error { err = w.binaryCall(expr) default: - return errors.New("function calls not supported yet") + err = w.functionCall(expr) } return err } +func (w *CelVisitor[T]) functionCall(expr *exprpb.Expr) error { + c := expr.GetCallExpr() + fun := c.GetFunction() + target := c.GetTarget() + args := c.GetArgs() + + err := w.visitor.StartFunction(fun) + if err != nil { + return err + } + + if target != nil { + err := w.eval(target, false, false) + if err != nil { + return err + } + } + for _, arg := range args { + err := w.eval(arg, false, false) + if err != nil { + return err + } + } + + err = w.visitor.EndFunction() + if err != nil { + return err + } + + return nil +} + func (w *CelVisitor[T]) binaryCall(expr *exprpb.Expr) error { c := expr.GetCallExpr() op := c.GetFunction() diff --git a/expressions/typing/types.go b/expressions/typing/types.go index 36ad93afe..22727b720 100644 --- a/expressions/typing/types.go +++ b/expressions/typing/types.go @@ -33,6 +33,15 @@ var ( DateArray = cel.OpaqueType(fmt.Sprintf("%s[]", parser.FieldTypeDate)) ) +var ( + FunctionSum = "SUM" + FunctionCount = "COUNT" + FunctionAvg = "AVG" + FunctionMedian = "MEDIAN" + FunctionMin = "MIN" + FunctionMax = "MAX" +) + var ( Role = cel.OpaqueType("_Role") ) diff --git a/integration/testdata/computed_fields_circular_ref/schema.keel b/integration/testdata/computed_fields_circular_ref/schema.keel new file mode 100644 index 000000000..954d3cfad --- /dev/null +++ b/integration/testdata/computed_fields_circular_ref/schema.keel @@ -0,0 +1,40 @@ +model Account { + fields { + balance Decimal @computed(SUM(account.transactions.total)) + transactions Transaction[] + standardTransactionFee Decimal + } + actions { + create createAccount() with (standardTransactionFee) { + @permission(expression: true) + } + list listAccounts() { + @permission(expression: true) + } + get getAccount(id) { + @permission(expression: true) + } + } +} + +model Transaction { + fields { + account Account + amount Decimal + fee Decimal? @computed(transaction.account.standardTransactionFee) + total Decimal? @computed(transaction.amount + transaction.fee) + } + actions { + create createTransaction() with (account.id, amount) { + @permission(expression: true) + } + get getTransaction(id) { + @permission(expression: true) + } + list listTransactions() { + @permission(expression: true) + } + } +} + + diff --git a/integration/testdata/computed_fields_circular_ref/tests.test.ts b/integration/testdata/computed_fields_circular_ref/tests.test.ts new file mode 100644 index 000000000..33beeec48 --- /dev/null +++ b/integration/testdata/computed_fields_circular_ref/tests.test.ts @@ -0,0 +1,19 @@ +import { test, expect, beforeEach } from "vitest"; +import { models, resetDatabase, actions } from "@teamkeel/testing"; + +beforeEach(resetDatabase); + +test("computed fields - circular reference", async () => { + const account = await models.account.create({ standardTransactionFee: 10 }); + const transaction1 = await models.transaction.create({ + accountId: account.id, + amount: 100, + }); + const transaction2 = await models.transaction.create({ + accountId: account.id, + amount: 200, + }); + + const getAccount = await models.account.findOne({ id: account.id }); + expect(getAccount!.balance).toBe(320); +}); diff --git a/integration/testdata/computed_fields_one_to_many/schema.keel b/integration/testdata/computed_fields_one_to_many/schema.keel new file mode 100644 index 000000000..d5223e58a --- /dev/null +++ b/integration/testdata/computed_fields_one_to_many/schema.keel @@ -0,0 +1,50 @@ +model Invoice { + fields { + items Item[] + shipping Number + total Decimal @computed(SUM(invoice.items.product.price) + invoice.shipping) + } + actions { + get getInvoice(id) { + @permission(expression: true) + } + create createInvoice() with (shipping, items.product.id?) { + @permission(expression: true) + } + list listInvoices() { + @permission(expression: true) + } + } +} + +model Item { + fields { + invoice Invoice + product Product + } + actions { + get getItem(id) { + @permission(expression: true) + } + create createItem() with (product.id, invoice.id) { + @permission(expression: true) + } + } +} + +model Product { + fields { + price Decimal + } + actions { + get getProduct(id) { + @permission(expression: true) + } + create createProduct() with (price) { + @permission(expression: true) + } + list listProducts() { + @permission(expression: true) + } + } +} diff --git a/integration/testdata/computed_fields_one_to_many/tests.test.ts b/integration/testdata/computed_fields_one_to_many/tests.test.ts new file mode 100644 index 000000000..62d0f13eb --- /dev/null +++ b/integration/testdata/computed_fields_one_to_many/tests.test.ts @@ -0,0 +1,113 @@ +import { test, expect, beforeEach } from "vitest"; +import { models, resetDatabase, actions } from "@teamkeel/testing"; + +beforeEach(resetDatabase); + +test("computed fields - one to many", async () => { + const product1 = await models.product.create({ + price: 100, + }); + + const product2 = await models.product.create({ + price: 200, + }); + + const invoiceA = await actions.createInvoice({ shipping: 5 }); + expect(invoiceA.total).toBe(5); + + const item1 = await models.item.create({ + invoiceId: invoiceA.id, + productId: product1.id, + }); + + const item2 = await models.item.create({ + invoiceId: invoiceA.id, + productId: product2.id, + }); + + const invoiceB = await actions.createInvoice({ shipping: 5 }); + expect(invoiceB.total).toBe(5); + + const item3 = await models.item.create({ + invoiceId: invoiceB.id, + productId: product2.id, + }); + + const inv1A = await models.invoice.findOne({ id: invoiceA.id }); + expect(inv1A?.total).toBe(305); + + const inv1B = await models.invoice.findOne({ id: invoiceB.id }); + expect(inv1B?.total).toBe(205); + + await models.product.update({ id: product1.id }, { price: 150 }); + + const inv2A = await models.invoice.findOne({ id: invoiceA.id }); + expect(inv2A?.total).toBe(355); + + const inv2B = await models.invoice.findOne({ id: invoiceB.id }); + expect(inv2B?.total).toBe(205); + + await models.item.delete({ id: item2.id }); + + const inv3A = await models.invoice.findOne({ id: invoiceA.id }); + expect(inv3A?.total).toBe(155); + + const inv3B = await models.invoice.findOne({ id: invoiceB.id }); + expect(inv3B?.total).toBe(205); + + const item4 = await models.item.create({ + invoiceId: invoiceA.id, + productId: product2.id, + }); + + const inv4A = await models.invoice.findOne({ id: invoiceA.id }); + expect(inv4A?.total).toBe(355); + + const inv4B = await models.invoice.findOne({ id: invoiceB.id }); + expect(inv4B?.total).toBe(205); + + await models.item.update({ id: item4.id }, { invoiceId: invoiceB.id }); + + const inv5A = await models.invoice.findOne({ id: invoiceA.id }); + expect(inv5A?.total).toBe(155); + + const inv5B = await models.invoice.findOne({ id: invoiceB.id }); + expect(inv5B?.total).toBe(405); + + await models.product.delete({ id: product2.id }); + + const inv6A = await models.invoice.findOne({ id: invoiceA.id }); + expect(inv6A?.total).toBe(155); + + const inv6B = await models.invoice.findOne({ id: invoiceB.id }); + expect(inv6B?.total).toBe(5); + + await models.invoice.update({ id: invoiceA.id }, { shipping: 10 }); + + const inv7A = await models.invoice.findOne({ id: invoiceA.id }); + expect(inv7A?.total).toBe(160); +}); + +test("computed fields - one to many - nested create", async () => { + const product1 = await models.product.create({ + price: 100, + }); + + const product2 = await models.product.create({ + price: 200, + }); + + const invoice = await actions.createInvoice({ + shipping: 5, + items: [ + { + product: { id: product1.id }, + }, + { + product: { id: product2.id }, + }, + ], + }); + + expect(invoice.total).toBe(305); +}); diff --git a/integration/testdata/computed_fields_one_to_one/schema.keel b/integration/testdata/computed_fields_one_to_one/schema.keel new file mode 100644 index 000000000..b979f2b9f --- /dev/null +++ b/integration/testdata/computed_fields_one_to_one/schema.keel @@ -0,0 +1,37 @@ +model Company { + fields { + companyProfile CompanyProfile @unique + activeEmployees Number @computed(company.companyProfile.employeeCount - company.retrenchments) + retrenchments Number @default(0) + taxNumber Number @computed(company.companyProfile.taxProfile.taxNumber) + } + actions { + create createCompany() with (companyProfile.id, retrenchments) { + @permission(expression: true) + } + } +} + +model CompanyProfile { + fields { + employeeCount Number + taxProfile TaxProfile? @unique + company Company? + } + actions { + create createCompanyProfile() with (employeeCount,taxProfile.id) { + @permission(expression: true) + } + } +} + +model TaxProfile { + fields { + taxNumber Number @unique + } + actions { + create createTaxProfile() with (taxNumber) { + @permission(expression: true) + } + } +} \ No newline at end of file diff --git a/integration/testdata/computed_fields_one_to_one/tests.test.ts b/integration/testdata/computed_fields_one_to_one/tests.test.ts new file mode 100644 index 000000000..428703169 --- /dev/null +++ b/integration/testdata/computed_fields_one_to_one/tests.test.ts @@ -0,0 +1,18 @@ +import { test, expect, beforeEach } from "vitest"; +import { models, resetDatabase, actions } from "@teamkeel/testing"; + +beforeEach(resetDatabase); + +test("computed fields - one to one", async () => { + const taxProfile = await actions.createTaxProfile({ taxNumber: 1234567890 }); + const companyProfile = await actions.createCompanyProfile({ + employeeCount: 100, + taxProfile: { id: taxProfile.id }, + }); + const company = await actions.createCompany({ + companyProfile: { id: companyProfile.id }, + retrenchments: 8, + }); + expect(company.activeEmployees).toEqual(92); + expect(company.taxNumber).toEqual(1234567890); +}); diff --git a/integration/testdata/computed_fields/schema.keel b/integration/testdata/computed_fields_same_model/schema.keel similarity index 100% rename from integration/testdata/computed_fields/schema.keel rename to integration/testdata/computed_fields_same_model/schema.keel diff --git a/integration/testdata/computed_fields/tests.test.ts b/integration/testdata/computed_fields_same_model/tests.test.ts similarity index 100% rename from integration/testdata/computed_fields/tests.test.ts rename to integration/testdata/computed_fields_same_model/tests.test.ts diff --git a/integration/testdata/real_world_invoice_system/main.test.ts b/integration/testdata/real_world_invoice_system/main.test.ts new file mode 100644 index 000000000..a0a7c9f0d --- /dev/null +++ b/integration/testdata/real_world_invoice_system/main.test.ts @@ -0,0 +1,473 @@ +import { actions, models, resetDatabase } from "@teamkeel/testing"; +import { Product, Customer, Order } from "@teamkeel/sdk"; +import { test, describe, expect, beforeEach, beforeAll } from "vitest"; + +let productLaptop: Product | null; +let productMouse: Product | null; +let productKeyboard: Product | null; +let productMonitor: Product | null; + +let johnDoe: Customer | null; +let pamSmith: Customer | null; + +let order: Order | null; + +test("purchase new products", async () => { + productLaptop = await actions.createProduct({ + name: "Laptop", + costPrice: 100, + markup: 0.2, + }); + + productMouse = await actions.createProduct({ + name: "Mouse", + costPrice: 12, + markup: 0.4, + }); + + productKeyboard = await actions.createProduct({ + name: "Keyboard", + costPrice: 18, + markup: 0.4, + }); + + productMonitor = await actions.createProduct({ + name: "Monitor", + costPrice: 50, + markup: 0.4, + }); + + expect(productLaptop.price).toBe(120); + expect(productMouse.price).toBe(16.8); + expect(productKeyboard.price).toBe(25.2); + expect(productMonitor.price).toBe(70); + + await actions.createPurchaseOrder({ + product: { id: productLaptop.id }, + quantity: 10, + }); + + await actions.createPurchaseOrder({ + product: { id: productMouse.id }, + quantity: 20, + }); + + await actions.createPurchaseOrder({ + product: { id: productKeyboard.id }, + quantity: 25, + }); + + await actions.createPurchaseOrder({ + product: { id: productMonitor.id }, + quantity: 10, + }); +}); + +test("check stock levels after purchase order", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(10); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(20); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(25); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(10); +}); + +test("create customer", async () => { + johnDoe = await actions.createCustomer({ + name: "John Doe", + }); +}); + +test("check customer statistics before order", async () => { + expect(johnDoe?.totalOrders).toBe(0); + expect(johnDoe?.totalSpent).toBe(0); + expect(johnDoe?.averageOrderValue).toBe(0); + expect(johnDoe?.smallestOrder).toBe(0); + expect(johnDoe?.largestOrder).toBe(0); +}); + +test("create order for new products", async () => { + order = await actions.createOrder({ + customer: { id: johnDoe!.id }, + orderItems: [ + { product: { id: productLaptop!.id }, quantity: 2 }, + { product: { id: productMouse!.id }, quantity: 2 }, + { product: { id: productKeyboard!.id }, quantity: 1 }, + ], + }); + + expect(order.shipping).toBe(10); + expect(order.total).toBe(240 + 33.6 + 25.2 + order.shipping!); +}); + +test("check stock levels after order", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(18); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(10); +}); + +test("check customer statistics after order", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.totalSpent).toBe(308.8); + expect(johnDoe?.averageOrderValue).toBe(308.8); + expect(johnDoe?.smallestOrder).toBe(308.8); + expect(johnDoe?.largestOrder).toBe(308.8); +}); + +test("adjust quantity in order", async () => { + const items = await actions.listOrderItems({ + where: { order: { id: { equals: order?.id } } }, + }); + + for (const item of items.results) { + if (item.productId === productMouse!.id) { + await actions.updateOrderItem({ + where: { id: item.id }, + values: { quantity: 1 }, + }); + } + } + + order = await actions.getOrder({ id: order!.id }); + + expect(order?.shipping).toBe(8); + expect(order?.total).toBe(240 + 16.8 + 25.2 + order!.shipping!); +}); + +test("check customer statistics after adjusting quantity", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(290); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(290); + expect(johnDoe?.smallestOrder).toBe(290); + expect(johnDoe?.largestOrder).toBe(290); +}); + +test("check stock levels after adjusting quantity", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(19); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(10); +}); + +test("change product in order item", async () => { + const items = await actions.listOrderItems({ + where: { order: { id: { equals: order?.id } } }, + }); + + for (const item of items.results) { + if (item.productId === productMouse!.id) { + await actions.updateOrderItem({ + where: { id: item.id }, + values: { product: { id: productMonitor!.id } }, + }); + } + } + + order = await actions.getOrder({ id: order!.id }); + + expect(order?.shipping).toBe(8); + expect(order?.total).toBe(240 + 70 + 25.2 + order!.shipping!); +}); + +test("check stock levels after adjusting product", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(20); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(9); +}); + +test("check customer statistics after adjusting product", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(343.2); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(343.2); + expect(johnDoe?.smallestOrder).toBe(343.2); + expect(johnDoe?.largestOrder).toBe(343.2); +}); + +test("create another order", async () => { + order = await actions.createOrder({ + customer: { id: johnDoe!.id }, + orderItems: [{ product: { id: productMouse!.id }, quantity: 4 }], + }); + + expect(order.shipping).toBe(8); + expect(order.total).toBe(67.2 + order.shipping!); +}); + +test("check customer statistics after adjusting product", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(418.4); + expect(johnDoe?.totalOrders).toBe(2); + expect(johnDoe?.averageOrderValue).toBe(209.2); + expect(johnDoe?.smallestOrder).toBe(75.2); + expect(johnDoe?.largestOrder).toBe(343.2); +}); + +test("check stock levels after adjusting quantity", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(16); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(9); +}); + +test("change order's customer", async () => { + pamSmith = await actions.createCustomer({ + name: "Pam Smith", + }); + + order = await actions.updateOrder({ + where: { id: order!.id }, + values: { + customer: { id: pamSmith!.id }, + }, + }); + + expect(order.shipping).toBe(8); + expect(order.total).toBe(67.2 + order.shipping!); +}); + +test("check that stock levels are the same", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(16); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(9); +}); + +test("check customer statistics after adjusting product", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(343.2); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(343.2); + + pamSmith = await actions.getCustomer({ id: pamSmith!.id }); + expect(pamSmith?.totalSpent).toBe(75.2); + expect(pamSmith?.totalOrders).toBe(1); + expect(pamSmith?.averageOrderValue).toBe(75.2); +}); + +test("fix product markup", async () => { + productLaptop = await actions.updateProduct({ + where: { id: productLaptop!.id }, + values: { markup: 0.4 }, + }); + + expect(productLaptop?.price).toBe(140); +}); + +test("check customer statistics after fixing product markup", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(383.2); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(383.2); + + pamSmith = await actions.getCustomer({ id: pamSmith!.id }); + expect(pamSmith?.totalSpent).toBe(75.2); + expect(pamSmith?.totalOrders).toBe(1); + expect(pamSmith?.averageOrderValue).toBe(75.2); +}); + +test("delete order item", async () => { + const items = await actions.listOrderItems({ + where: { order: { id: { equals: order?.id } } }, + }); + console.log(items.results); + for (const item of items.results) { + if (item.productId === productMouse!.id) { + await actions.deleteOrderItem({ id: item!.id }); + } + } + + order = await actions.getOrder({ id: order!.id }); + + expect(order?.shipping).toBe(0); + expect(order?.total).toBe(0); +}); + +test("check customer statistics after fixing product markup", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(383.2); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(383.2); + + pamSmith = await actions.getCustomer({ id: pamSmith!.id }); + expect(pamSmith?.totalSpent).toBe(0); + expect(pamSmith?.totalOrders).toBe(1); + expect(pamSmith?.averageOrderValue).toBe(0); +}); + +test("check that stock levels have increased", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(20); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(9); +}); + +test("readd order item", async () => { + await actions.addOrderItem({ + order: { id: order!.id }, + product: { id: productMouse!.id }, + quantity: 4, + }); +}); + +test("check customer statistics after readding order item", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(383.2); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(383.2); + + pamSmith = await actions.getCustomer({ id: pamSmith!.id }); + expect(pamSmith?.totalSpent).toBe(75.2); + expect(pamSmith?.totalOrders).toBe(1); + expect(pamSmith?.averageOrderValue).toBe(75.2); +}); + +test("check that stock levels after readding order item", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(16); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(9); +}); + +test("delete order", async () => { + await actions.deleteOrder({ id: order!.id }); +}); + +test("check customer statistics after deleting order", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(383.2); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(383.2); + + pamSmith = await actions.getCustomer({ id: pamSmith!.id }); + expect(pamSmith?.totalSpent).toBe(0); + expect(pamSmith?.totalOrders).toBe(0); + expect(pamSmith?.averageOrderValue).toBe(0); + expect(pamSmith?.smallestOrder).toBe(0); + expect(pamSmith?.largestOrder).toBe(0); +}); + +test("check that stock levels after deleting order", async () => { + productLaptop = await actions.getProduct({ id: productLaptop!.id }); + expect(productLaptop?.stockQuantity).toBe(8); + + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(20); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(24); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(9); +}); + +test("delete product", async () => { + await actions.deleteProduct({ id: productLaptop!.id }); +}); + +test("check customer statistics after deleting product", async () => { + johnDoe = await actions.getCustomer({ id: johnDoe!.id }); + expect(johnDoe?.totalSpent).toBe(99.2); + expect(johnDoe?.totalOrders).toBe(1); + expect(johnDoe?.averageOrderValue).toBe(99.2); + expect(pamSmith?.smallestOrder).toBe(0); // this is because the order actually still exists + expect(pamSmith?.largestOrder).toBe(0); + + pamSmith = await actions.getCustomer({ id: pamSmith!.id }); + expect(pamSmith?.totalSpent).toBe(0); + expect(pamSmith?.totalOrders).toBe(0); + expect(pamSmith?.averageOrderValue).toBe(0); +}); + +test("delete customers", async () => { + await actions.deleteCustomer({ id: johnDoe!.id }); + await actions.deleteCustomer({ id: pamSmith!.id }); +}); + +test("check that stock levels after deleting customers", async () => { + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(20); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(25); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(10); +}); + +test("delete purchase orders", async () => { + const result = await actions.listPurchaseOrders(); + + for (const purchaseOrder of result.results) { + await actions.deletePurchaseOrder({ id: purchaseOrder!.id }); + } +}); + +test("check that stock levels after deleting purchase orders", async () => { + productMouse = await actions.getProduct({ id: productMouse!.id }); + expect(productMouse?.stockQuantity).toBe(0); + + productKeyboard = await actions.getProduct({ id: productKeyboard!.id }); + expect(productKeyboard?.stockQuantity).toBe(0); + + productMonitor = await actions.getProduct({ id: productMonitor!.id }); + expect(productMonitor?.stockQuantity).toBe(0); +}); diff --git a/integration/testdata/real_world_invoice_system/schema.keel b/integration/testdata/real_world_invoice_system/schema.keel new file mode 100644 index 000000000..7ff240117 --- /dev/null +++ b/integration/testdata/real_world_invoice_system/schema.keel @@ -0,0 +1,91 @@ +model Customer { + fields { + name Text + orders Order[] + totalOrders Number? @computed(COUNT(customer.orders)) + totalSpent Decimal? @computed(SUM(customer.orders.total)) + averageOrderValue Decimal? @computed(AVG(customer.orders.total)) + largestOrder Decimal? @computed(MAX(customer.orders.total)) + smallestOrder Decimal? @computed(MIN(customer.orders.total)) + } + actions { + create createCustomer() with (name) + get getCustomer(id) + list listCustomers() + delete deleteCustomer(id) + } + @permission(expression: true, actions: [get, list, update, delete, create]) +} + +model Order { + fields { + customer Customer + orderItems OrderItem[] + shipping Decimal? @computed(SUM(order.orderItems.quantity) * 2) + total Decimal? @computed(SUM(order.orderItems.price) - (SUM(order.orderItems.price) / 100 * order.discountPercentage) + order.shipping) + totalExcludingVat Decimal? @computed(order.total - order.vat) + vat Decimal? @computed(order.total * 0.2) + discountPercentage Number @default(0) + } + actions { + create createOrder() with (customer.id, orderItems.product.id, orderItems.quantity) + update updateOrder(id) with (customer.id?, discountPercentage?) + get getOrder(id) + list listOrders(customer.id?) + delete deleteOrder(id) + } + @permission(expression: true, actions: [get, list, update, delete, create]) +} + +model OrderItem { + fields { + order Order + product Product + quantity Number + price Decimal? @computed(orderItem.product.price * orderItem.quantity) + } + actions { + get getOrderItem(id) + create addOrderItem() with (order.id, product.id, quantity) + list listOrderItems(order.id?) + update updateOrderItem(id) with (product.id?, quantity?) + delete deleteOrderItem(id) + } + @permission(expression: true, actions: [get, list, update, delete, create]) +} + +model Product { + fields { + name Text + price Decimal? @computed(product.costPrice + product.costPrice * product.markup) + costPrice Decimal + markup Decimal + purchases PurchaseOrder[] + orderItems OrderItem[] + stockQuantity Number? @computed(SUM(product.purchases.quantity) - SUM(product.orderItems.quantity)) + unitsSold Number? @computed(SUM(product.orderItems.quantity)) + } + + actions { + create createProduct() with (name, costPrice, markup) + update updateProduct(id) with (name?, costPrice?, markup?) + get getProduct(id) + list listProducts() + delete deleteProduct(id) + } + @permission(expression: true, actions: [get, list, update, delete, create]) +} + +model PurchaseOrder { + fields { + product Product + quantity Number + } + actions { + create createPurchaseOrder() with (product.id, quantity) + get getPurchaseOrder(id) + list listPurchaseOrders(product.id?) + delete deletePurchaseOrder(id) + } + @permission(expression: true, actions: [get, list, update, delete, create]) +} \ No newline at end of file diff --git a/migrations/computed_functions.sql b/migrations/computed_functions.sql index 662729b32..a4aaa6266 100644 --- a/migrations/computed_functions.sql +++ b/migrations/computed_functions.sql @@ -5,4 +5,4 @@ FROM WHERE routine_type = 'FUNCTION' AND - routine_schema = 'public' AND routine_name LIKE '%__comp' OR routine_name LIKE '%__comp_dep'; \ No newline at end of file + routine_schema = 'public' AND routine_name LIKE '%__comp' OR routine_name LIKE '%__exec_comp_fns' OR routine_name LIKE '%__comp_dep' OR routine_name LIKE '%__comp_dep_update'; \ No newline at end of file diff --git a/migrations/migrations.go b/migrations/migrations.go index 87fbe063d..b0fcef633 100644 --- a/migrations/migrations.go +++ b/migrations/migrations.go @@ -486,7 +486,6 @@ func computedFieldDependencies(schema *proto.Schema) (map[*proto.Field][]*depPai if i < len(ident.Fragments)-2 { currModel = schema.FindModel(currField.Type.ModelName.Value) - continue } @@ -495,7 +494,13 @@ func computedFieldDependencies(schema *proto.Schema) (map[*proto.Field][]*depPai field: currField, } - dependencies[field] = append(dependencies[field], &dep) + hasDep := lo.ContainsBy(dependencies[field], func(d *depPair) bool { + return d.field.Name == currField.Name && d.ident.String() == ident.String() + }) + + if !hasDep { + dependencies[field] = append(dependencies[field], &dep) + } } } } @@ -540,7 +545,7 @@ func computedFieldsStmts(schema *proto.Schema, existingComputedFns []*FunctionRo slices.Sort(newFns) slices.Sort(retiredFns) - // Functions to be created + // Computed functions to be created for each computed field for _, fn := range newFns { statements = append(statements, modelFns[fn]) @@ -559,7 +564,7 @@ func computedFieldsStmts(schema *proto.Schema, existingComputedFns []*FunctionRo // Functions to be dropped for _, fn := range retiredFns { - statements = append(statements, fmt.Sprintf("DROP FUNCTION %s;", fn)) + statements = append(statements, fmt.Sprintf("DROP FUNCTION \"%s\";", fn)) f := fieldFromComputedFnName(schema, fn) if f != nil { @@ -594,6 +599,8 @@ func computedFieldsStmts(schema *proto.Schema, existingComputedFns []*FunctionRo return nil, nil, err } + // For each model, we need to create a function which calls all the computed functions for fields on this model + // Order is important because computed fields can depend on each other - this is catered for for _, model := range schema.Models { modelhasChanged := false for k, v := range changedFields { @@ -637,25 +644,27 @@ func computedFieldsStmts(schema *proto.Schema, existingComputedFns []*FunctionRo // Generate SQL statements in dependency order stmts := []string{} for _, field := range sorted { - s := fmt.Sprintf("\tNEW.%s := %s(NEW);\n", strcase.ToSnake(field.Name), fieldsFns[field]) + s := fmt.Sprintf("NEW.%s := %s(NEW);\n", strcase.ToSnake(field.Name), fieldsFns[field]) stmts = append(stmts, s) } // Generate the trigger function which executes all the computed field functions for the model. execFnName := computedExecFuncName(model) - sql := fmt.Sprintf("CREATE OR REPLACE FUNCTION %s() RETURNS TRIGGER AS $$ BEGIN\n%s\tRETURN NEW;\nEND; $$ LANGUAGE plpgsql;", execFnName, strings.Join(stmts, "")) + sql := fmt.Sprintf("CREATE OR REPLACE FUNCTION \"%s\"() RETURNS TRIGGER AS $$ BEGIN\n\t%s\tRETURN NEW;\nEND; $$ LANGUAGE plpgsql;", execFnName, strings.Join(stmts, "")) // Generate the table trigger which executed the trigger function. + // This must be a BEFORE trigger because we want to return the row with its computed fields being computed. triggerName := computedTriggerName(model) - trigger := fmt.Sprintf("CREATE OR REPLACE TRIGGER %s BEFORE INSERT OR UPDATE ON \"%s\" FOR EACH ROW EXECUTE PROCEDURE %s();", triggerName, strcase.ToSnake(model.Name), execFnName) + trigger := fmt.Sprintf("CREATE OR REPLACE TRIGGER \"%s\" BEFORE INSERT OR UPDATE ON \"%s\" FOR EACH ROW EXECUTE PROCEDURE \"%s\"();", triggerName, strcase.ToSnake(model.Name), execFnName) statements = append(statements, sql) statements = append(statements, trigger) } + // For computed fields which depend on fields in other models, we need to create triggers which start from the source model and cascades + // down through the relationship until it reaches the target model (where the computed field is defined). We then perform a fake update on the + // specific rows in the target model which will then trigger the computed fns. depFns := map[string]string{} - - // For computed fields which depend on fields in other models, we perform a fake update in order to trigger the computed fns. for field, deps := range dependencies { for _, dep := range deps { // Skip this because the triggers call on the exec functions themselves @@ -663,67 +672,64 @@ func computedFieldsStmts(schema *proto.Schema, existingComputedFns []*FunctionRo continue } - fieldModel := schema.FindModel(field.ModelName) - fragments, err := actions.NormalisedFragments(schema, dep.ident.Fragments) if err != nil { return nil, nil, err } - baseQuery := actions.NewQuery(schema.FindModel(field.ModelName)) - baseQuery.Select(actions.IdField()) - - model := casing.ToCamel(fragments[0]) + currentModel := casing.ToCamel(fragments[0]) for i := 1; i < len(fragments)-1; i++ { - currentFragment := fragments[i] + baseQuery := actions.NewQuery(schema.FindModel(currentModel)) + baseQuery.Select(actions.IdField()) + + expr := strings.Join(fragments, ".") + // Get the fragment pair from the previous model to the current model + // We need to reset the first fragment to the model name and not the previous model's field name + subFragments := slices.Clone(fragments[i-1 : i+1]) + subFragments[0] = strcase.ToLowerCamel(currentModel) - if !proto.ModelHasField(schema, model, currentFragment) { - return nil, nil, fmt.Errorf("this model: %s, does not have a field of name: %s", model, currentFragment) + if !proto.ModelHasField(schema, currentModel, fragments[i]) { + return nil, nil, fmt.Errorf("this model: %s, does not have a field of name: %s", currentModel, subFragments[0]) } // We know that the current fragment is a related model because it's not the last fragment - relatedModelField := proto.FindField(schema.Models, model, currentFragment) - relatedModel := relatedModelField.Type.ModelName.Value + relatedModelField := proto.FindField(schema.Models, currentModel, fragments[i]) foreignKeyField := proto.GetForeignKeyFieldName(schema.Models, relatedModelField) - primaryKey := "id" - var leftOperand *actions.QueryOperand - var rightOperand *actions.QueryOperand + previousModel := currentModel + currentModel = relatedModelField.Type.ModelName.Value + stmt := "" + // If the relationship is a belongs to or has many, we need to update the id field on the previous model switch { case relatedModelField.IsBelongsTo(): - // In a "belongs to" the foreign key is on _this_ model - leftOperand = actions.ExpressionField(fragments[:i+1], primaryKey, false) - rightOperand = actions.ExpressionField(fragments[:i], foreignKeyField, false) + stmt += "UPDATE \"" + strcase.ToSnake(previousModel) + "\" SET id = id WHERE " + strcase.ToSnake(foreignKeyField) + " IN (NEW.id, OLD.id);" default: - // In all others the foreign key is on the _other_ model - leftOperand = actions.ExpressionField(fragments[:i+1], foreignKeyField, false) - rightOperand = actions.ExpressionField(fragments[:i], primaryKey, false) + stmt += "UPDATE \"" + strcase.ToSnake(previousModel) + "\" SET id = id WHERE id IN (NEW." + strcase.ToSnake(foreignKeyField) + ", OLD." + strcase.ToSnake(foreignKeyField) + ");" } - model = relatedModelField.Type.ModelName.Value - baseQuery.Join(relatedModel, leftOperand, rightOperand) - subQuery := baseQuery.Copy() - err := subQuery.Where(leftOperand, actions.Equals, actions.Raw("NEW.id")) - if err != nil { - return nil, nil, err - } + // Trigger function which will perform a fake update on the earlier model in the expression chain + fnName := computedDependencyFuncName(schema.FindModel(strcase.ToCamel(previousModel)), schema.FindModel(currentModel), strings.Split(expr, ".")) + sql := fmt.Sprintf("CREATE OR REPLACE FUNCTION \"%s\"() RETURNS TRIGGER AS $$\nBEGIN\n\t%s\n\tRETURN NULL;\nEND; $$ LANGUAGE plpgsql;\n", fnName, stmt) - query := actions.NewQuery(schema.FindModel(field.ModelName)) - query.AddWriteValue(actions.IdField(), actions.IdField()) - err = query.Where(actions.IdField(), actions.OneOf, actions.InlineQuery(subQuery, actions.ExpressionField(fragments[:i+1], "id", false))) - if err != nil { - return nil, nil, err - } - stmt := query.UpdateStatement(context.Background()).SqlTemplate() + // For the comp_dep function on the target field's model, we include a filter on the UPDATE trigger to only trigger if the target field has changed + whenCondition := "TRUE" + if i == len(fragments)-2 { + f := strcase.ToSnake(fragments[len(fragments)-1]) + whenCondition = fmt.Sprintf("NEW.%s <> OLD.%s", f, f) - dependencyFnName := computedDependencyFuncName(fieldModel, schema.FindModel(model), fragments[:i+1]) - sql := fmt.Sprintf("CREATE OR REPLACE FUNCTION %s() RETURNS TRIGGER AS $$ BEGIN\n\t%s;\n\tRETURN NULL;\nEND; $$ LANGUAGE plpgsql;", dependencyFnName, stmt) + if !relatedModelField.IsBelongsTo() { + updatingField := strcase.ToSnake(foreignKeyField) + whenCondition += fmt.Sprintf(" OR NEW.%s <> OLD.%s", updatingField, updatingField) + } + } - triggerName := dependencyFnName - sql += fmt.Sprintf("\nCREATE OR REPLACE TRIGGER %s AFTER INSERT OR DELETE OR UPDATE ON \"%s\" FOR EACH ROW EXECUTE PROCEDURE %s();", triggerName, strcase.ToSnake(model), dependencyFnName) + // Must be an AFTER trigger as we need the data to be written in order to perform the joins and for the computation to take into account the updated data + triggerName := fnName + sql += fmt.Sprintf("CREATE OR REPLACE TRIGGER \"%s\" AFTER INSERT OR DELETE ON \"%s\" FOR EACH ROW EXECUTE PROCEDURE \"%s\"();", triggerName, strcase.ToSnake(currentModel), fnName) + sql += fmt.Sprintf("\nCREATE OR REPLACE TRIGGER \"%s_update\" AFTER UPDATE ON \"%s\" FOR EACH ROW WHEN(%s) EXECUTE PROCEDURE \"%s\"();", triggerName, strcase.ToSnake(currentModel), whenCondition, fnName) - depFns[dependencyFnName] = sql + depFns[fnName] = sql } } } @@ -743,14 +749,15 @@ func computedFieldsStmts(schema *proto.Schema, existingComputedFns []*FunctionRo // Dependency functions and triggers to be dropped for _, fn := range retiredFns { - statements = append(statements, fmt.Sprintf("DROP TRIGGER %s ON %s;", fn, strings.Split(fn, "__")[0])) - statements = append(statements, fmt.Sprintf("DROP FUNCTION %s;", fn)) + statements = append(statements, fmt.Sprintf("DROP TRIGGER IF EXISTS \"%s\" ON \"%s\";", fn, strings.Split(fn, "__")[0])) + statements = append(statements, fmt.Sprintf("DROP TRIGGER IF EXISTS \"%s_update\" ON \"%s\";", fn, strings.Split(fn, "__")[0])) + statements = append(statements, fmt.Sprintf("DROP FUNCTION IF EXISTS \"%s\";", fn)) } // If a computed field has been added or changed, we need to recompute all existing data. // This is done by fake updating each row on the table which will cause the triggers to run. for _, model := range recompute { - sql := fmt.Sprintf("UPDATE %s SET id = id;", strcase.ToSnake(model.Name)) + sql := fmt.Sprintf("UPDATE \"%s\" SET id = id;", strcase.ToSnake(model.Name)) statements = append(statements, sql) } diff --git a/migrations/sql.go b/migrations/sql.go index e2411c804..555882ce2 100644 --- a/migrations/sql.go +++ b/migrations/sql.go @@ -253,7 +253,6 @@ func computedTriggerName(model *proto.Model) string { } func computedDependencyFuncName(model *proto.Model, dependentModel *proto.Model, fragments []string) string { - // shortened alphanumeric hash from the operand idents hash := hashOfExpression(strings.Join(fragments, ".")) return fmt.Sprintf("%s__to__%s__%s__comp_dep", strcase.ToSnake(dependentModel.Name), strcase.ToSnake(model.Name), hash) } @@ -292,7 +291,7 @@ func addComputedFieldFuncStmt(schema *proto.Schema, model *proto.Model, field *p } fn := computedFieldFuncName(field) - sql := fmt.Sprintf("CREATE FUNCTION %s(r %s) RETURNS %s AS $$ BEGIN\n\tRETURN %s;\nEND; $$ LANGUAGE plpgsql;", + sql := fmt.Sprintf("CREATE FUNCTION \"%s\"(r \"%s\") RETURNS %s AS $$ BEGIN\n\tRETURN %s;\nEND; $$ LANGUAGE plpgsql;", fn, strcase.ToSnake(model.Name), sqlType, @@ -302,11 +301,11 @@ func addComputedFieldFuncStmt(schema *proto.Schema, model *proto.Model, field *p } func dropComputedExecFunctionStmt(model *proto.Model) string { - return fmt.Sprintf("DROP FUNCTION %s__exec_comp_fns;", strcase.ToSnake(model.Name)) + return fmt.Sprintf("DROP FUNCTION \"%s__exec_comp_fns\";", strcase.ToSnake(model.Name)) } func dropComputedTriggerStmt(model *proto.Model) string { - return fmt.Sprintf("DROP TRIGGER %s__comp ON %s;", strcase.ToSnake(model.Name), strcase.ToSnake(model.Name)) + return fmt.Sprintf("DROP TRIGGER \"%s__comp\" ON \"%s\";", strcase.ToSnake(model.Name), strcase.ToSnake(model.Name)) } func fieldDefinition(field *proto.Field) (string, error) { @@ -510,7 +509,7 @@ func toSqlLiteral(value any, field *proto.Field) (string, error) { func dropColumnStmt(modelName string, fieldName string) string { output := fmt.Sprintf("ALTER TABLE %s ", Identifier(modelName)) - output += fmt.Sprintf("DROP COLUMN %s;", Identifier(fieldName)) + output += fmt.Sprintf("DROP COLUMN %s CASCADE;", Identifier(fieldName)) return output } diff --git a/migrations/testdata/computed_field_changed_expression.txt b/migrations/testdata/computed_field_changed_expression.txt index c092fe3fd..2ff84ca12 100644 --- a/migrations/testdata/computed_field_changed_expression.txt +++ b/migrations/testdata/computed_field_changed_expression.txt @@ -18,16 +18,16 @@ model Item { === -CREATE FUNCTION item__total__863346d0__comp(r item) RETURNS NUMERIC AS $$ BEGIN +CREATE FUNCTION "item__total__863346d0__comp"(r "item") RETURNS NUMERIC AS $$ BEGIN RETURN r."price" + 5; END; $$ LANGUAGE plpgsql; -DROP FUNCTION item__total__0614a79a__comp; -CREATE OR REPLACE FUNCTION item__exec_comp_fns() RETURNS TRIGGER AS $$ BEGIN +DROP FUNCTION "item__total__0614a79a__comp"; +CREATE OR REPLACE FUNCTION "item__exec_comp_fns"() RETURNS TRIGGER AS $$ BEGIN NEW.total := item__total__863346d0__comp(NEW); RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER item__comp BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_comp_fns(); -UPDATE item SET id = id; +CREATE OR REPLACE TRIGGER "item__comp" BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE "item__exec_comp_fns"(); +UPDATE "item" SET id = id; === diff --git a/migrations/testdata/computed_field_initial.txt b/migrations/testdata/computed_field_initial.txt index 584338b8d..5d38bd5bf 100644 --- a/migrations/testdata/computed_field_initial.txt +++ b/migrations/testdata/computed_field_initial.txt @@ -61,15 +61,15 @@ CREATE TRIGGER identity_create AFTER INSERT ON "identity" REFERENCING NEW TABLE CREATE TRIGGER identity_update AFTER UPDATE ON "identity" REFERENCING NEW TABLE AS new_table OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); CREATE TRIGGER identity_delete AFTER DELETE ON "identity" REFERENCING OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); CREATE TRIGGER identity_updated_at BEFORE UPDATE ON "identity" FOR EACH ROW EXECUTE PROCEDURE set_updated_at(); -CREATE FUNCTION item__total__0614a79a__comp(r item) RETURNS NUMERIC AS $$ BEGIN +CREATE FUNCTION "item__total__0614a79a__comp"(r "item") RETURNS NUMERIC AS $$ BEGIN RETURN r."quantity" * r."price"; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE FUNCTION item__exec_comp_fns() RETURNS TRIGGER AS $$ BEGIN +CREATE OR REPLACE FUNCTION "item__exec_comp_fns"() RETURNS TRIGGER AS $$ BEGIN NEW.total := item__total__0614a79a__comp(NEW); RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER item__comp BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_comp_fns(); -UPDATE item SET id = id; +CREATE OR REPLACE TRIGGER "item__comp" BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE "item__exec_comp_fns"(); +UPDATE "item" SET id = id; === diff --git a/migrations/testdata/computed_field_many_to_one.txt b/migrations/testdata/computed_field_many_to_one.txt index b05db229d..a65a8db1b 100644 --- a/migrations/testdata/computed_field_many_to_one.txt +++ b/migrations/testdata/computed_field_many_to_one.txt @@ -22,6 +22,7 @@ model Agent { } === + CREATE TABLE "agent" ( "commission" NUMERIC NOT NULL, "id" TEXT NOT NULL DEFAULT ksuid(), @@ -98,25 +99,36 @@ CREATE TRIGGER identity_create AFTER INSERT ON "identity" REFERENCING NEW TABLE CREATE TRIGGER identity_update AFTER UPDATE ON "identity" REFERENCING NEW TABLE AS new_table OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); CREATE TRIGGER identity_delete AFTER DELETE ON "identity" REFERENCING OLD TABLE AS old_table FOR EACH STATEMENT EXECUTE PROCEDURE process_audit(); CREATE TRIGGER identity_updated_at BEFORE UPDATE ON "identity" FOR EACH ROW EXECUTE PROCEDURE set_updated_at(); -CREATE FUNCTION item__total__8f543d38__comp(r item) RETURNS NUMERIC AS $$ BEGIN +CREATE FUNCTION "item__total__8f543d38__comp"(r "item") RETURNS NUMERIC AS $$ BEGIN RETURN r."quantity" * (SELECT "product"."price" FROM "product" WHERE "product"."id" IS NOT DISTINCT FROM r."product_id") + (SELECT "product$agent"."commission" FROM "product" LEFT JOIN "agent" AS "product$agent" ON "product$agent"."id" = "product"."agent_id" WHERE "product"."id" IS NOT DISTINCT FROM r."product_id"); END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE FUNCTION item__exec_comp_fns() RETURNS TRIGGER AS $$ BEGIN +CREATE OR REPLACE FUNCTION "item__exec_comp_fns"() RETURNS TRIGGER AS $$ BEGIN NEW.total := item__total__8f543d38__comp(NEW); RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER item__comp BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_comp_fns(); -CREATE OR REPLACE FUNCTION agent__to__item__26590ccb__comp_dep() RETURNS TRIGGER AS $$ BEGIN - UPDATE "item" SET "id" = "item"."id" WHERE "item"."id" IN (SELECT "item"."id" FROM "item" JOIN "product" AS "item$product" ON "item$product"."id" = "item"."product_id" JOIN "agent" AS "item$product$agent" ON "item$product$agent"."id" = "item$product"."agent_id" WHERE "item$product$agent"."id" IS NOT DISTINCT FROM NEW.id); +CREATE OR REPLACE TRIGGER "item__comp" BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE "item__exec_comp_fns"(); +CREATE OR REPLACE FUNCTION "agent__to__product__2eb4dbe9__comp_dep"() RETURNS TRIGGER AS $$ +BEGIN + UPDATE "product" SET id = id WHERE agent_id IN (NEW.id, OLD.id); + RETURN NULL; +END; $$ LANGUAGE plpgsql; +CREATE OR REPLACE TRIGGER "agent__to__product__2eb4dbe9__comp_dep" AFTER INSERT OR DELETE ON "agent" FOR EACH ROW EXECUTE PROCEDURE "agent__to__product__2eb4dbe9__comp_dep"(); +CREATE OR REPLACE TRIGGER "agent__to__product__2eb4dbe9__comp_dep_update" AFTER UPDATE ON "agent" FOR EACH ROW WHEN(NEW.commission <> OLD.commission) EXECUTE PROCEDURE "agent__to__product__2eb4dbe9__comp_dep"(); +CREATE OR REPLACE FUNCTION "product__to__item__037dbf3a__comp_dep"() RETURNS TRIGGER AS $$ +BEGIN + UPDATE "item" SET id = id WHERE product_id IN (NEW.id, OLD.id); RETURN NULL; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER agent__to__item__26590ccb__comp_dep AFTER INSERT OR DELETE OR UPDATE ON "agent" FOR EACH ROW EXECUTE PROCEDURE agent__to__item__26590ccb__comp_dep(); -CREATE OR REPLACE FUNCTION product__to__item__9454cd1f__comp_dep() RETURNS TRIGGER AS $$ BEGIN - UPDATE "item" SET "id" = "item"."id" WHERE "item"."id" IN (SELECT "item"."id" FROM "item" JOIN "product" AS "item$product" ON "item$product"."id" = "item"."product_id" WHERE "item$product"."id" IS NOT DISTINCT FROM NEW.id); +CREATE OR REPLACE TRIGGER "product__to__item__037dbf3a__comp_dep" AFTER INSERT OR DELETE ON "product" FOR EACH ROW EXECUTE PROCEDURE "product__to__item__037dbf3a__comp_dep"(); +CREATE OR REPLACE TRIGGER "product__to__item__037dbf3a__comp_dep_update" AFTER UPDATE ON "product" FOR EACH ROW WHEN(NEW.price <> OLD.price) EXECUTE PROCEDURE "product__to__item__037dbf3a__comp_dep"(); +CREATE OR REPLACE FUNCTION "product__to__item__2eb4dbe9__comp_dep"() RETURNS TRIGGER AS $$ +BEGIN + UPDATE "item" SET id = id WHERE product_id IN (NEW.id, OLD.id); RETURN NULL; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER product__to__item__9454cd1f__comp_dep AFTER INSERT OR DELETE OR UPDATE ON "product" FOR EACH ROW EXECUTE PROCEDURE product__to__item__9454cd1f__comp_dep(); -UPDATE item SET id = id; +CREATE OR REPLACE TRIGGER "product__to__item__2eb4dbe9__comp_dep" AFTER INSERT OR DELETE ON "product" FOR EACH ROW EXECUTE PROCEDURE "product__to__item__2eb4dbe9__comp_dep"(); +CREATE OR REPLACE TRIGGER "product__to__item__2eb4dbe9__comp_dep_update" AFTER UPDATE ON "product" FOR EACH ROW WHEN(TRUE) EXECUTE PROCEDURE "product__to__item__2eb4dbe9__comp_dep"(); +UPDATE "item" SET id = id; === diff --git a/migrations/testdata/computed_field_many_to_one_changed.txt b/migrations/testdata/computed_field_many_to_one_changed.txt index c560169df..d21a46bb1 100644 --- a/migrations/testdata/computed_field_many_to_one_changed.txt +++ b/migrations/testdata/computed_field_many_to_one_changed.txt @@ -44,18 +44,22 @@ model Agent { === -CREATE FUNCTION item__total__5474c2e0__comp(r item) RETURNS NUMERIC AS $$ BEGIN +CREATE FUNCTION "item__total__5474c2e0__comp"(r "item") RETURNS NUMERIC AS $$ BEGIN RETURN r."quantity" * (SELECT "product"."price" FROM "product" WHERE "product"."id" IS NOT DISTINCT FROM r."product_id"); END; $$ LANGUAGE plpgsql; -DROP FUNCTION item__total__8f543d38__comp; -CREATE OR REPLACE FUNCTION item__exec_comp_fns() RETURNS TRIGGER AS $$ BEGIN +DROP FUNCTION "item__total__8f543d38__comp"; +CREATE OR REPLACE FUNCTION "item__exec_comp_fns"() RETURNS TRIGGER AS $$ BEGIN NEW.total := item__total__5474c2e0__comp(NEW); RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER item__comp BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_comp_fns(); -DROP TRIGGER agent__to__item__26590ccb__comp_dep ON agent; -DROP FUNCTION agent__to__item__26590ccb__comp_dep; -UPDATE item SET id = id; +CREATE OR REPLACE TRIGGER "item__comp" BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE "item__exec_comp_fns"(); +DROP TRIGGER IF EXISTS "agent__to__product__2eb4dbe9__comp_dep" ON "agent"; +DROP TRIGGER IF EXISTS "agent__to__product__2eb4dbe9__comp_dep_update" ON "agent"; +DROP FUNCTION IF EXISTS "agent__to__product__2eb4dbe9__comp_dep"; +DROP TRIGGER IF EXISTS "product__to__item__2eb4dbe9__comp_dep" ON "product"; +DROP TRIGGER IF EXISTS "product__to__item__2eb4dbe9__comp_dep_update" ON "product"; +DROP FUNCTION IF EXISTS "product__to__item__2eb4dbe9__comp_dep"; +UPDATE "item" SET id = id; === diff --git a/migrations/testdata/computed_field_multiple_depend.txt b/migrations/testdata/computed_field_multiple_depend.txt index a666399ef..81cb09397 100644 --- a/migrations/testdata/computed_field_multiple_depend.txt +++ b/migrations/testdata/computed_field_multiple_depend.txt @@ -21,19 +21,19 @@ model Item { === -CREATE FUNCTION item__total__0614a79a__comp(r item) RETURNS NUMERIC AS $$ BEGIN +CREATE FUNCTION "item__total__0614a79a__comp"(r "item") RETURNS NUMERIC AS $$ BEGIN RETURN r."quantity" * r."price"; END; $$ LANGUAGE plpgsql; -CREATE FUNCTION item__total_with_shipping__53d0d09b__comp(r item) RETURNS NUMERIC AS $$ BEGIN +CREATE FUNCTION "item__total_with_shipping__53d0d09b__comp"(r "item") RETURNS NUMERIC AS $$ BEGIN RETURN r."total" + 5; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE FUNCTION item__exec_comp_fns() RETURNS TRIGGER AS $$ BEGIN +CREATE OR REPLACE FUNCTION "item__exec_comp_fns"() RETURNS TRIGGER AS $$ BEGIN NEW.total := item__total__0614a79a__comp(NEW); - NEW.total_with_shipping := item__total_with_shipping__53d0d09b__comp(NEW); +NEW.total_with_shipping := item__total_with_shipping__53d0d09b__comp(NEW); RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER item__comp BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_comp_fns(); -UPDATE item SET id = id; +CREATE OR REPLACE TRIGGER "item__comp" BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE "item__exec_comp_fns"(); +UPDATE "item" SET id = id; === diff --git a/migrations/testdata/computed_field_removed_attr.txt b/migrations/testdata/computed_field_removed_attr.txt index b6a79ef31..94bbd43e9 100644 --- a/migrations/testdata/computed_field_removed_attr.txt +++ b/migrations/testdata/computed_field_removed_attr.txt @@ -18,9 +18,9 @@ model Item { === -DROP FUNCTION item__total__0614a79a__comp; -DROP TRIGGER item__comp ON item; -DROP FUNCTION item__exec_comp_fns; +DROP FUNCTION "item__total__0614a79a__comp"; +DROP TRIGGER "item__comp" ON "item"; +DROP FUNCTION "item__exec_comp_fns"; === diff --git a/migrations/testdata/computed_field_removed_field.txt b/migrations/testdata/computed_field_removed_field.txt index 293fdc828..c0eb1cc7b 100644 --- a/migrations/testdata/computed_field_removed_field.txt +++ b/migrations/testdata/computed_field_removed_field.txt @@ -17,10 +17,10 @@ model Item { === -ALTER TABLE "item" DROP COLUMN "total"; -DROP FUNCTION item__total__0614a79a__comp; -DROP TRIGGER item__comp ON item; -DROP FUNCTION item__exec_comp_fns; +ALTER TABLE "item" DROP COLUMN "total" CASCADE; +DROP FUNCTION "item__total__0614a79a__comp"; +DROP TRIGGER "item__comp" ON "item"; +DROP FUNCTION "item__exec_comp_fns"; === diff --git a/migrations/testdata/computed_field_renamed_field.txt b/migrations/testdata/computed_field_renamed_field.txt index d60503fe3..5b6cdc4ca 100644 --- a/migrations/testdata/computed_field_renamed_field.txt +++ b/migrations/testdata/computed_field_renamed_field.txt @@ -19,17 +19,17 @@ model Item { === ALTER TABLE "item" ADD COLUMN "new_total" NUMERIC NOT NULL; -ALTER TABLE "item" DROP COLUMN "total"; -CREATE FUNCTION item__new_total__0614a79a__comp(r item) RETURNS NUMERIC AS $$ BEGIN +ALTER TABLE "item" DROP COLUMN "total" CASCADE; +CREATE FUNCTION "item__new_total__0614a79a__comp"(r "item") RETURNS NUMERIC AS $$ BEGIN RETURN r."quantity" * r."price"; END; $$ LANGUAGE plpgsql; -DROP FUNCTION item__total__0614a79a__comp; -CREATE OR REPLACE FUNCTION item__exec_comp_fns() RETURNS TRIGGER AS $$ BEGIN +DROP FUNCTION "item__total__0614a79a__comp"; +CREATE OR REPLACE FUNCTION "item__exec_comp_fns"() RETURNS TRIGGER AS $$ BEGIN NEW.new_total := item__new_total__0614a79a__comp(NEW); RETURN NEW; END; $$ LANGUAGE plpgsql; -CREATE OR REPLACE TRIGGER item__comp BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE item__exec_comp_fns(); -UPDATE item SET id = id; +CREATE OR REPLACE TRIGGER "item__comp" BEFORE INSERT OR UPDATE ON "item" FOR EACH ROW EXECUTE PROCEDURE "item__exec_comp_fns"(); +UPDATE "item" SET id = id; === diff --git a/migrations/testdata/field_removed.txt b/migrations/testdata/field_removed.txt index 493aa1a14..ea4be7cd1 100644 --- a/migrations/testdata/field_removed.txt +++ b/migrations/testdata/field_removed.txt @@ -15,7 +15,7 @@ model Person { === -ALTER TABLE "person" DROP COLUMN "age"; +ALTER TABLE "person" DROP COLUMN "age" CASCADE; === diff --git a/migrations/testdata/field_removed_reln_hasone.txt b/migrations/testdata/field_removed_reln_hasone.txt index bbcebfb10..79c418c18 100644 --- a/migrations/testdata/field_removed_reln_hasone.txt +++ b/migrations/testdata/field_removed_reln_hasone.txt @@ -29,7 +29,7 @@ model Thing { === -ALTER TABLE "person" DROP COLUMN "favourite_thing_id"; +ALTER TABLE "person" DROP COLUMN "favourite_thing_id" CASCADE; === diff --git a/permissions/permissions.go b/permissions/permissions.go index 1586fcdbc..4ed3300fe 100644 --- a/permissions/permissions.go +++ b/permissions/permissions.go @@ -64,6 +64,15 @@ func (v *permissionGen) EndTerm(nested bool) error { return nil } + +func (v *permissionGen) StartFunction(name string) error { + return nil +} + +func (v *permissionGen) EndFunction() error { + return nil +} + func (v *permissionGen) VisitAnd() error { v.stmt.expression += " and " return nil diff --git a/runtime/actions/create.go b/runtime/actions/create.go index 0f8219a94..9f1dc7cd1 100644 --- a/runtime/actions/create.go +++ b/runtime/actions/create.go @@ -76,6 +76,19 @@ func Create(scope *Scope, input map[string]any) (res map[string]any, err error) return nil, err } + // Because of computed fields and nested creates, we need to fetch the row again to get the computed fields + query = NewQuery(scope.Model, opts...) + err = query.Where(IdField(), Equals, Value(res["id"])) + if err != nil { + return nil, err + } + query.Select(AllFields()) + statement = query.SelectStatement() + res, err = statement.ExecuteToSingle(scope.Context) + if err != nil { + return nil, err + } + // if we have any files in our results we need to transform them to the object structure required if scope.Model.HasFiles() { res, err = transformModelFileResponses(scope.Context, scope.Model, res) diff --git a/runtime/actions/generate_computed.go b/runtime/actions/generate_computed.go index 5e91a05ea..fd8a40bd2 100644 --- a/runtime/actions/generate_computed.go +++ b/runtime/actions/generate_computed.go @@ -3,10 +3,14 @@ package actions import ( "errors" "fmt" + "strings" + "github.com/emirpasic/gods/stacks/arraystack" "github.com/google/cel-go/common/operators" "github.com/iancoleman/strcase" + "github.com/teamkeel/keel/casing" "github.com/teamkeel/keel/expressions/resolve" + "github.com/teamkeel/keel/expressions/typing" "github.com/teamkeel/keel/proto" "github.com/teamkeel/keel/schema/parser" @@ -15,20 +19,22 @@ import ( // GenerateComputedFunction visits the expression and generates a SQL expression func GenerateComputedFunction(schema *proto.Schema, model *proto.Model, field *proto.Field) resolve.Visitor[string] { return &computedQueryGen{ - schema: schema, - model: model, - field: field, - sql: "", + schema: schema, + model: model, + field: field, + sql: "", + functions: arraystack.New(), } } var _ resolve.Visitor[string] = new(computedQueryGen) type computedQueryGen struct { - schema *proto.Schema - model *proto.Model - field *proto.Field - sql string + schema *proto.Schema + model *proto.Model + field *proto.Field + sql string + functions *arraystack.Stack } func (v *computedQueryGen) StartTerm(nested bool) error { @@ -45,6 +51,15 @@ func (v *computedQueryGen) EndTerm(nested bool) error { return nil } +func (v *computedQueryGen) StartFunction(name string) error { + v.functions.Push(name) + return nil +} + +func (v *computedQueryGen) EndFunction() error { + return nil +} + func (v *computedQueryGen) VisitAnd() error { v.sql += " AND " return nil @@ -105,37 +120,123 @@ func (v *computedQueryGen) VisitIdent(ident *parser.ExpressionIdent) error { model := v.schema.FindModel(strcase.ToCamel(ident.Fragments[0])) field := proto.FindField(v.schema.Models, model.Name, ident.Fragments[1]) - if len(ident.Fragments) == 2 { + normalised, err := NormalisedFragments(v.schema, ident.Fragments) + if err != nil { + return err + } + + if len(normalised) == 2 { v.sql += "r." + sqlQuote(strcase.ToSnake(field.Name)) - } else if len(ident.Fragments) > 2 { - // Join together all the tables based on the ident fragments - model = v.schema.FindModel(field.Type.ModelName.Value) - query := NewQuery(model) - err := query.AddJoinFromFragments(v.schema, ident.Fragments[1:]) + } else if len(normalised) > 2 { + isToMany, err := v.isToManyLookup(ident) if err != nil { return err } - // Select the column as specified in the last ident fragment - fieldName := ident.Fragments[len(ident.Fragments)-1] - fragments := ident.Fragments[1 : len(ident.Fragments)-1] - query.Select(ExpressionField(fragments, fieldName, false)) - - // Filter by this model's row's ID - relatedModelField := proto.FindField(v.schema.Models, v.model.Name, ident.Fragments[1]) - foreignKeyField := proto.GetForeignKeyFieldName(v.schema.Models, relatedModelField) - fk := fmt.Sprintf("r.\"%s\"", strcase.ToSnake(foreignKeyField)) - err = query.Where(IdField(), Equals, Raw(fk)) - if err != nil { - return err + if isToMany { + model = v.schema.FindModel(field.Type.ModelName.Value) + query := NewQuery(model) + + relatedModelField := proto.FindField(v.schema.Models, v.model.Name, normalised[1]) + foreignKeyField := proto.GetForeignKeyFieldName(v.schema.Models, relatedModelField) + + r := proto.FindField(v.schema.Models, v.model.Name, normalised[1]) + subFragments := normalised[1:] + subFragments[0] = strcase.ToLowerCamel(r.Type.ModelName.Value) + + err := query.AddJoinFromFragments(v.schema, subFragments) + if err != nil { + return err + } + + funcBegin, has := v.functions.Pop() + if !has { + return errors.New("no function found for 1:M lookup") + } + + fieldName := normalised[len(normalised)-1] + fragments := normalised[1 : len(normalised)-1] + + raw := "" + selectField := sqlQuote(casing.ToSnake(strings.Join(fragments, "$"))) + "." + sqlQuote(casing.ToSnake(fieldName)) + switch funcBegin { + case typing.FunctionSum: + raw += fmt.Sprintf("COALESCE(SUM(%s), 0)", selectField) + case typing.FunctionCount: + raw += fmt.Sprintf("COALESCE(COUNT(%s), 0)", selectField) + case typing.FunctionAvg: + raw += fmt.Sprintf("COALESCE(AVG(%s), 0)", selectField) + case typing.FunctionMedian: + raw += fmt.Sprintf("COALESCE(percentile_cont(0.5) WITHIN GROUP (ORDER BY %s), 0)", selectField) + case typing.FunctionMin: + raw += fmt.Sprintf("COALESCE(MIN(%s), 0)", selectField) + case typing.FunctionMax: + raw += fmt.Sprintf("COALESCE(MAX(%s), 0)", selectField) + } + + query.Select(Raw(raw)) + + // Filter by this model's row's ID + fk := fmt.Sprintf("r.\"%s\"", parser.FieldNameId) + err = query.Where(Field(foreignKeyField), Equals, Raw(fk)) + if err != nil { + return err + } + + stmt := query.SelectStatement() + v.sql += fmt.Sprintf("(%s)", stmt.SqlTemplate()) + } else { + // Join together all the tables based on the ident fragments + model = v.schema.FindModel(field.Type.ModelName.Value) + query := NewQuery(model) + err := query.AddJoinFromFragments(v.schema, normalised[1:]) + if err != nil { + return err + } + + // Select the column as specified in the last ident fragment + fieldName := normalised[len(normalised)-1] + fragments := normalised[1 : len(normalised)-1] + query.Select(ExpressionField(fragments, fieldName, false)) + + // Filter by this model's row's ID + relatedModelField := proto.FindField(v.schema.Models, v.model.Name, normalised[1]) + foreignKeyField := proto.GetForeignKeyFieldName(v.schema.Models, relatedModelField) + + fk := fmt.Sprintf("r.\"%s\"", strcase.ToSnake(foreignKeyField)) + err = query.Where(IdField(), Equals, Raw(fk)) + if err != nil { + return err + } + + stmt := query.SelectStatement() + v.sql += fmt.Sprintf("(%s)", stmt.SqlTemplate()) } - - stmt := query.SelectStatement() - v.sql += fmt.Sprintf("(%s)", stmt.SqlTemplate()) } + return nil } +func (v *computedQueryGen) isToManyLookup(idents *parser.ExpressionIdent) (bool, error) { + model := v.schema.FindModel(strcase.ToCamel(idents.Fragments[0])) + + fragments, err := NormalisedFragments(v.schema, idents.Fragments) + if err != nil { + return false, err + } + + for i := 1; i < len(fragments)-1; i++ { + currentFragment := fragments[i] + field := proto.FindField(v.schema.Models, model.Name, currentFragment) + if field.Type.Type == proto.Type_TYPE_MODEL && field.Type.Repeated { + return true, nil + } + model = v.schema.FindModel(field.Type.ModelName.Value) + } + + return false, nil +} + func (v *computedQueryGen) VisitIdentArray(idents []*parser.ExpressionIdent) error { return errors.New("ident arrays not supported in computed expressions") } diff --git a/runtime/actions/generate_computed_test.go b/runtime/actions/generate_computed_test.go index 338c3dca5..9db6c0b5d 100644 --- a/runtime/actions/generate_computed_test.go +++ b/runtime/actions/generate_computed_test.go @@ -170,6 +170,30 @@ var computedTestCases = []computedTestCase{ field: "total Decimal @computed(item.product.standardPrice * item.quantity + item.product.agent.commission)", expectedSql: `(SELECT "product"."standard_price" FROM "product" WHERE "product"."id" IS NOT DISTINCT FROM r."product_id") * r."quantity" + (SELECT "product$agent"."commission" FROM "product" LEFT JOIN "agent" AS "product$agent" ON "product$agent"."id" = "product"."agent_id" WHERE "product"."id" IS NOT DISTINCT FROM r."product_id")`, }, + { + name: "sum function", + keelSchema: ` + model Invoice { + fields { + item Item[] + #placeholder# + } + } + model Item { + fields { + invoice Invoice + product Product + } + } + model Product { + fields { + name Text + price Decimal + } + }`, + field: "total Decimal @computed(SUM(invoice.item.product.price))", + expectedSql: `(SELECT COALESCE(SUM("item$product"."price"), 0) FROM "item" LEFT JOIN "product" AS "item$product" ON "item$product"."id" = "item"."product_id" WHERE "item"."invoice_id" IS NOT DISTINCT FROM r."id")`, + }, } func TestGeneratedComputed(t *testing.T) { diff --git a/runtime/actions/generate_filter.go b/runtime/actions/generate_filter.go index 2881e86f7..91b08642a 100644 --- a/runtime/actions/generate_filter.go +++ b/runtime/actions/generate_filter.go @@ -92,6 +92,14 @@ func (v *whereQueryGen) EndTerm(nested bool) error { return nil } +func (v *whereQueryGen) StartFunction(name string) error { + return nil +} + +func (v *whereQueryGen) EndFunction() error { + return nil +} + func (v *whereQueryGen) VisitAnd() error { v.query.And() return nil diff --git a/runtime/actions/generate_select.go b/runtime/actions/generate_select.go index 809f1d6cc..c75c4901a 100644 --- a/runtime/actions/generate_select.go +++ b/runtime/actions/generate_select.go @@ -42,6 +42,14 @@ func (v *setQueryGen) EndTerm(parenthesis bool) error { return nil } +func (v *setQueryGen) StartFunction(name string) error { + return nil +} + +func (v *setQueryGen) EndFunction() error { + return nil +} + func (v *setQueryGen) VisitAnd() error { return errors.New("and operator not supported with set") } diff --git a/schema/attributes/computed.go b/schema/attributes/computed.go index 562def2d2..6db4e330c 100644 --- a/schema/attributes/computed.go +++ b/schema/attributes/computed.go @@ -15,6 +15,7 @@ func ValidateComputedExpression(schema []*parser.AST, model *parser.ModelNode, f options.WithComparisonOperators(), options.WithLogicalOperators(), options.WithArithmeticOperators(), + options.WithFunctions(), options.WithReturnTypeAssertion(field.Type.Value, field.Repeated), } diff --git a/schema/attributes/computed_test.go b/schema/attributes/computed_test.go new file mode 100644 index 000000000..8bbdfbf0c --- /dev/null +++ b/schema/attributes/computed_test.go @@ -0,0 +1,41 @@ +package attributes_test + +import ( + "testing" + + "github.com/stretchr/testify/require" + "github.com/teamkeel/keel/schema/attributes" + "github.com/teamkeel/keel/schema/query" + "github.com/teamkeel/keel/schema/reader" +) + +func TestComputed_SumFunction(t *testing.T) { + schema := parse(t, &reader.SchemaFile{FileName: "test.keel", Contents: ` + model Invoice { + fields { + items Item[] + total Decimal @computed(SUM(invoice.items.total)) + } + } + model Item { + fields { + invoice Invoice + product Product + quantity Number + total Decimal? @computed(item.quantity * item.product.price) + } + } + model Product { + fields { + price Decimal + items Item[] + } + }`}) + + model := query.Model(schema, "Invoice") + expression := model.Sections[0].Fields[1].Attributes[0].Arguments[0].Expression + + issues, err := attributes.ValidateComputedExpression(schema, model, model.Sections[0].Fields[1], expression) + require.NoError(t, err) + require.Len(t, issues, 0) +} diff --git a/schema/testdata/errors/attribute_computed_functions.keel b/schema/testdata/errors/attribute_computed_functions.keel new file mode 100644 index 000000000..5a1ccf77a --- /dev/null +++ b/schema/testdata/errors/attribute_computed_functions.keel @@ -0,0 +1,26 @@ +model Invoice { + fields { + price Decimal + items Item[] + //expect-error:37:38:AttributeExpressionError:Text[] not supported as an argument for the function 'SUM' + total1 Decimal @computed(SUM(invoice.items.product.price)) + //expect-error:37:38:AttributeExpressionError:Product[] not supported as an argument for the function 'SUM' + total2 Decimal @computed(SUM(invoice.items.product)) + //expect-error:37:38:AttributeExpressionError:Decimal not supported as an argument for the function 'SUM' + total3 Decimal @computed(SUM(invoice.price)) + } +} + +model Item { + fields { + invoice Invoice + product Product + } +} + +model Product { + fields { + name Text + price Text + } +} \ No newline at end of file