Skip to content

Commit 537f5e6

Browse files
author
Miguel Molina
authored
Merge pull request #177 from erizocosmico/fix/implicit-fks
add implicit inverse foreign keys when necessary
2 parents 627e5c8 + 25d907f commit 537f5e6

File tree

5 files changed

+91
-6
lines changed

5 files changed

+91
-6
lines changed

generator/processor.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ func (p *Processor) processPackage() (*Package, error) {
158158
}
159159

160160
pkg.SetModels(models)
161+
if err := pkg.addMissingRelationships(); err != nil {
162+
return nil, err
163+
}
161164
for _, ctor := range ctors {
162165
p.tryMatchConstructor(pkg, ctor)
163166
}

generator/template.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ func (td *TemplateData) genFieldsTimeTruncations(buf *bytes.Buffer, fields []*Fi
8080
func (td *TemplateData) GenColumnAddresses(model *Model) string {
8181
var buf bytes.Buffer
8282
td.genFieldsColumnAddresses(&buf, model.Fields)
83+
for _, fk := range model.ImplicitFKs {
84+
buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name))
85+
buf.WriteString(fmt.Sprintf("return types.Nullable(kallax.VirtualColumn(\"%s\", r, new(%s))), nil\n", fk.Name, fk.Type))
86+
}
8387
return buf.String()
8488
}
8589

@@ -129,6 +133,10 @@ func (td *TemplateData) IdentifierType(f *Field) string {
129133
func (td *TemplateData) GenColumnValues(model *Model) string {
130134
var buf bytes.Buffer
131135
td.genFieldsValues(&buf, model.Fields)
136+
for _, fk := range model.ImplicitFKs {
137+
buf.WriteString(fmt.Sprintf("case \"%s\":\n", fk.Name))
138+
buf.WriteString(fmt.Sprintf("return r.Model.VirtualColumn(col), nil\n"))
139+
}
132140
return buf.String()
133141
}
134142

@@ -159,6 +167,9 @@ func (td *TemplateData) genFieldsValues(buf *bytes.Buffer, fields []*Field) {
159167
func (td *TemplateData) GenModelColumns(model *Model) string {
160168
var buf bytes.Buffer
161169
td.genFieldsColumns(&buf, model.Fields)
170+
for _, fk := range model.ImplicitFKs {
171+
buf.WriteString(fmt.Sprintf("kallax.NewSchemaField(\"%s\"),\n", fk.Name))
172+
}
162173
return buf.String()
163174
}
164175

generator/types.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,59 @@ func (p *Package) FindModel(name string) *Model {
153153
return p.indexedModels[name]
154154
}
155155

156+
func (p *Package) addMissingRelationships() error {
157+
for _, m := range p.Models {
158+
for _, f := range m.Fields {
159+
if f.Kind == Relationship && !f.IsInverse() {
160+
if err := p.trySetFK(f.TypeSchemaName(), f); err != nil {
161+
return err
162+
}
163+
}
164+
}
165+
}
166+
167+
return nil
168+
}
169+
170+
func (p *Package) trySetFK(model string, fk *Field) error {
171+
m := p.FindModel(model)
172+
if m == nil {
173+
return fmt.Errorf("kallax: cannot assign implicit foreign key to non-existent model %s", model)
174+
}
175+
176+
var found bool
177+
for _, f := range m.Fields {
178+
if f.Kind == Relationship {
179+
if f.ForeignKey() == fk.ForeignKey() {
180+
found = true
181+
break
182+
}
183+
} else {
184+
if f.ColumnName() == fk.ForeignKey() {
185+
found = true
186+
break
187+
}
188+
}
189+
}
190+
191+
if !found {
192+
for _, ifk := range m.ImplicitFKs {
193+
if ifk.Name == fk.ForeignKey() {
194+
found = true
195+
break
196+
}
197+
}
198+
}
199+
200+
if !found {
201+
m.ImplicitFKs = append(m.ImplicitFKs, ImplicitFK{
202+
Name: fk.ForeignKey(),
203+
Type: identifierType(fk.Model.ID),
204+
})
205+
}
206+
return nil
207+
}
208+
156209
const (
157210
// StoreNamePattern is the pattern used to name stores.
158211
StoreNamePattern = "%sStore"
@@ -182,6 +235,10 @@ type Model struct {
182235
Type string
183236
// Fields contains the list of fields in the model.
184237
Fields []*Field
238+
// ImplicitFKs contains the list of fks that are implicit based on
239+
// other models' definitions, such as foreign keys with no explicit inverse
240+
// on the related model.
241+
ImplicitFKs []ImplicitFK
185242
// ID contains the identifier field of the model.
186243
ID *Field
187244
// Events contains the list of events implemented by the model.
@@ -499,6 +556,11 @@ func relationshipsOnFields(fields []*Field) []*Field {
499556
return result
500557
}
501558

559+
type ImplicitFK struct {
560+
Name string
561+
Type string
562+
}
563+
502564
// Field is the representation of a model field.
503565
type Field struct {
504566
// Name is the field name.

tests/kallax.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,8 @@ func (r *Child) ColumnAddress(col string) (interface{}, error) {
611611
return (*kallax.NumericID)(&r.ID), nil
612612
case "name":
613613
return &r.Name, nil
614+
case "parent_id":
615+
return types.Nullable(kallax.VirtualColumn("parent_id", r, new(kallax.NumericID))), nil
614616

615617
default:
616618
return nil, fmt.Errorf("kallax: invalid column in Child: %s", col)
@@ -624,6 +626,8 @@ func (r *Child) Value(col string) (interface{}, error) {
624626
return r.ID, nil
625627
case "name":
626628
return r.Name, nil
629+
case "parent_id":
630+
return r.Model.VirtualColumn(col), nil
627631

628632
default:
629633
return nil, fmt.Errorf("kallax: invalid column in Child: %s", col)
@@ -10773,6 +10777,7 @@ var Schema = &schema{
1077310777
true,
1077410778
kallax.NewSchemaField("id"),
1077510779
kallax.NewSchemaField("name"),
10780+
kallax.NewSchemaField("parent_id"),
1077610781
),
1077710782
ID: kallax.NewSchemaField("id"),
1077810783
Name: kallax.NewSchemaField("name"),

tests/store_test.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -278,10 +278,12 @@ func (s *StoreSuite) TestInsert_RelWithNoInverse() {
278278
s.NoError(store.Insert(p))
279279
s.NotEqual(0, p.ID)
280280

281-
var count int
282-
err := s.db.QueryRow("SELECT COUNT(*) FROM children WHERE parent_id = $1", p.ID).Scan(&count)
281+
p, err := store.FindOne(NewParentQuery().WithChildren(nil))
283282
s.NoError(err)
284-
s.Equal(3, count)
283+
s.Len(p.Children, 3)
284+
for _, c := range p.Children {
285+
s.NotEqual(int64(0), c.ID)
286+
}
285287
}
286288

287289
func (s *StoreSuite) TestInsert_RelWithNoInverseNoPtr() {
@@ -298,8 +300,10 @@ func (s *StoreSuite) TestInsert_RelWithNoInverseNoPtr() {
298300
s.NoError(store.Insert(p))
299301
s.NotEqual(0, p.ID)
300302

301-
var count int
302-
err := s.db.QueryRow("SELECT COUNT(*) FROM children WHERE parent_id = $1", p.ID).Scan(&count)
303+
p, err := store.FindOne(NewParentNoPtrQuery().WithChildren(nil))
303304
s.NoError(err)
304-
s.Equal(3, count)
305+
s.Len(p.Children, 3)
306+
for _, c := range p.Children {
307+
s.NotEqual(int64(0), c.ID)
308+
}
305309
}

0 commit comments

Comments
 (0)