Skip to content

Commit 23928dd

Browse files
committed
修改template生成代码,适配gobatis
1 parent dedba6e commit 23928dd

File tree

3 files changed

+164
-150
lines changed

3 files changed

+164
-150
lines changed

common/formatter.go

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,35 @@
55

66
package common
77

8-
type KeyWordFormatter func(string) string
8+
const (
9+
MysqlKeywordEscapeChar = "`"
10+
PostgresKeywordEscapeChar = `"`
911

10-
var KwFormatter = MysqlKeyWordFormatter
11-
var kwFormatterMap = map[string]KeyWordFormatter{
12-
"mysql": MysqlKeyWordFormatter,
13-
"postgres": PostgresKeyWordFormatter,
14-
}
12+
MysqlEscapeKeywordEscapeChar = "`"
13+
PostgresEscapeKeywordEscapeChar = `\"`
14+
)
15+
16+
var KeywordEscapeChar = MysqlKeywordEscapeChar
17+
var EscapeKeywordEscapeChar = MysqlEscapeKeywordEscapeChar
18+
19+
type KeywordFormatter func(string) string
20+
21+
var KwFormatter = CommonKeywordFormatter
1522

16-
func MysqlKeyWordFormatter(src string) string {
17-
return "`" + src + "`"
23+
func CommonKeywordFormatter(src string) string {
24+
return KeywordEscapeChar + src + KeywordEscapeChar
1825
}
1926

20-
func PostgresKeyWordFormatter(src string) string {
21-
return `"` + src + `"`
27+
func CommonEscapeKeywordFormatter(src string) string {
28+
return EscapeKeywordEscapeChar + src + EscapeKeywordEscapeChar
2229
}
2330

24-
func SelectKeyWordFormatter(driver string) {
25-
v := kwFormatterMap[driver]
26-
if v != nil {
27-
KwFormatter = v
31+
func SelectKeywordFormatter(driver string) {
32+
if driver == "postgres" {
33+
KeywordEscapeChar = PostgresKeywordEscapeChar
34+
EscapeKeywordEscapeChar = PostgresEscapeKeywordEscapeChar
35+
} else {
36+
KeywordEscapeChar = MysqlKeywordEscapeChar
37+
EscapeKeywordEscapeChar = MysqlEscapeKeywordEscapeChar
2838
}
2939
}

gen_template.go

Lines changed: 139 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ import (
1313
"github.com/xfali/gobatis-cmd/common"
1414
"github.com/xfali/gobatis-cmd/io"
1515
"strings"
16-
"time"
16+
"time"
1717
)
1818

1919
func genTemplate(config Config, tableName string, model []common.ModelInfo) {
20-
common.SelectKeyWordFormatter(config.Driver)
20+
common.SelectKeywordFormatter(config.Driver)
2121
targetDir := config.Path + "template/"
2222
if !io.IsPathExists(targetDir) {
2323
io.Mkdir(targetDir)
@@ -34,20 +34,34 @@ func genTemplate(config Config, tableName string, model []common.ModelInfo) {
3434

3535
func buildTmplMapper(builder *strings.Builder, config Config, tableName string, model []common.ModelInfo) {
3636
modelName := common.TableName2ModelName(tableName)
37-
columns := formatXmlColumns(tableName, model)
38-
tableName = common.KwFormatter(tableName)
37+
columns := formatXmlColumns(tableName, model)
38+
tableName = common.KwFormatter(tableName)
3939

40-
builder.WriteString("{{/*This file was generated by xfali/gobatis-cmd at*/}}")
41-
builder.WriteString(common.Newline())
42-
builder.WriteString(fmt.Sprintf("{{/*%s*/}}", time.Now().String()))
43-
builder.WriteString(common.Newline())
44-
builder.WriteString(common.Newline())
40+
builder.WriteString("{{/*This file was generated by xfali/gobatis-cmd at*/}}")
41+
builder.WriteString(common.Newline())
42+
builder.WriteString(fmt.Sprintf("{{/*%s*/}}", time.Now().String()))
43+
builder.WriteString(common.Newline())
44+
builder.WriteString(common.Newline())
4545

46-
//select
46+
//select
4747
builder.WriteString(fmt.Sprintf(`{{define "select%s"}}`, modelName))
4848
builder.WriteString(common.Newline())
4949

50-
builder.WriteString(fmt.Sprintf(`SELECT %s FROM %s`, columns, tableName))
50+
builder.WriteString(fmt.Sprintf(`SELECT %s FROM %s`, columns, tableName))
51+
builder.WriteString(common.Newline())
52+
53+
builder.WriteString(genTmplWhere(modelName, model))
54+
builder.WriteString(common.Newline())
55+
56+
builder.WriteString(`{{end}}`)
57+
builder.WriteString(common.Newline())
58+
builder.WriteString(common.Newline())
59+
60+
//select count
61+
builder.WriteString(fmt.Sprintf(`{{define "select%sCount"}}`, modelName))
62+
builder.WriteString(common.Newline())
63+
64+
builder.WriteString(fmt.Sprintf(`SELECT COUNT(*) FROM %s`, tableName))
5165
builder.WriteString(common.Newline())
5266

5367
builder.WriteString(genTmplWhere(modelName, model))
@@ -57,164 +71,154 @@ func buildTmplMapper(builder *strings.Builder, config Config, tableName string,
5771
builder.WriteString(common.Newline())
5872
builder.WriteString(common.Newline())
5973

60-
//insert
61-
builder.WriteString(fmt.Sprintf(`{{define "insert%s"}}`, modelName))
62-
builder.WriteString(common.Newline())
74+
//insert
75+
builder.WriteString(fmt.Sprintf(`{{define "insert%s"}}`, modelName))
76+
builder.WriteString(common.Newline())
6377

64-
builder.WriteString(fmt.Sprintf(`INSERT INTO %s(%s)`, tableName, columns))
65-
builder.WriteString(common.Newline())
78+
builder.WriteString(fmt.Sprintf(`INSERT INTO %s(%s)`, tableName, columns))
79+
builder.WriteString(common.Newline())
6680

67-
builder.WriteString("VALUES(")
68-
builder.WriteString(common.Newline())
81+
builder.WriteString("VALUES(")
82+
builder.WriteString(common.Newline())
6983

70-
builder.WriteString(genTmplValues(modelName, model))
84+
builder.WriteString(genTmplValues(modelName, model))
7185

72-
builder.WriteString(")")
73-
builder.WriteString(common.Newline())
86+
builder.WriteString(")")
87+
builder.WriteString(common.Newline())
7488

75-
builder.WriteString(`{{end}}`)
76-
builder.WriteString(common.Newline())
77-
builder.WriteString(common.Newline())
89+
builder.WriteString(`{{end}}`)
90+
builder.WriteString(common.Newline())
91+
builder.WriteString(common.Newline())
7892

79-
//insertBatch
80-
builder.WriteString(fmt.Sprintf(`{{define "insertBatch%s"}}`, modelName))
81-
builder.WriteString(common.Newline())
93+
//insertBatch
94+
builder.WriteString(fmt.Sprintf(`{{define "insertBatch%s"}}`, modelName))
95+
builder.WriteString(common.Newline())
8296

83-
builder.WriteString(`{{$size := len . | add -1}}`)
84-
builder.WriteString(common.Newline())
97+
builder.WriteString(`{{$size := len . | add -1}}`)
98+
builder.WriteString(common.Newline())
8599

86-
builder.WriteString(fmt.Sprintf(`INSERT INTO %s(%s)`, tableName, columns))
87-
builder.WriteString(common.Newline())
100+
builder.WriteString(fmt.Sprintf(`INSERT INTO %s(%s)`, tableName, columns))
101+
builder.WriteString(common.Newline())
88102

89-
builder.WriteString("VALUES {{range $i, $v := .}}")
90-
builder.WriteString(common.Newline())
103+
builder.WriteString("VALUES {{range $i, $v := .}}")
104+
builder.WriteString(common.Newline())
91105

92-
builder.WriteString(genTmplRangeValues(modelName, model))
106+
builder.WriteString(genTmplRangeValues(modelName, model))
93107

94-
builder.WriteString(`{{end}}`)
95-
builder.WriteString(common.Newline())
108+
builder.WriteString(`{{end}}`)
109+
builder.WriteString(common.Newline())
96110

97-
builder.WriteString(`{{end}}`)
98-
builder.WriteString(common.Newline())
99-
builder.WriteString(common.Newline())
111+
builder.WriteString(`{{end}}`)
112+
builder.WriteString(common.Newline())
113+
builder.WriteString(common.Newline())
100114

101-
//update
102-
builder.WriteString(fmt.Sprintf(`{{define "update%s"}}`, modelName))
103-
builder.WriteString(common.Newline())
115+
//update
116+
builder.WriteString(fmt.Sprintf(`{{define "update%s"}}`, modelName))
117+
builder.WriteString(common.Newline())
104118

105-
builder.WriteString(fmt.Sprintf(`UPDATE %s`, tableName))
106-
builder.WriteString(common.Newline())
119+
builder.WriteString(fmt.Sprintf(`UPDATE %s`, tableName))
120+
builder.WriteString(common.Newline())
107121

108-
setStr, index := genTmplSet(modelName, model)
109-
builder.WriteString(setStr)
110-
builder.WriteString(common.Newline())
122+
setStr, index := genTmplSet(modelName, model)
123+
builder.WriteString(setStr)
124+
builder.WriteString(common.Newline())
111125

112-
if index != -1 {
113-
builder.WriteString(genTmplWhere(modelName, model[index:index+1]))
114-
builder.WriteString(common.Newline())
115-
}
126+
if index != -1 {
127+
builder.WriteString(genTmplWhere(modelName, model[index:index+1]))
128+
builder.WriteString(common.Newline())
129+
}
116130

117-
builder.WriteString(`{{end}}`)
118-
builder.WriteString(common.Newline())
119-
builder.WriteString(common.Newline())
131+
builder.WriteString(`{{end}}`)
132+
builder.WriteString(common.Newline())
133+
builder.WriteString(common.Newline())
120134

121-
//delete
122-
builder.WriteString(fmt.Sprintf(`{{define "delete%s"}}`, modelName))
123-
builder.WriteString(common.Newline())
135+
//delete
136+
builder.WriteString(fmt.Sprintf(`{{define "delete%s"}}`, modelName))
137+
builder.WriteString(common.Newline())
124138

125-
builder.WriteString(fmt.Sprintf(`DELETE FROM %s`, tableName))
126-
builder.WriteString(common.Newline())
139+
builder.WriteString(fmt.Sprintf(`DELETE FROM %s`, tableName))
140+
builder.WriteString(common.Newline())
127141

128-
builder.WriteString(genTmplWhere(modelName, model))
129-
builder.WriteString(common.Newline())
142+
builder.WriteString(genTmplWhere(modelName, model))
143+
builder.WriteString(common.Newline())
130144

131-
builder.WriteString(`{{end}}`)
132-
builder.WriteString(common.Newline())
133-
builder.WriteString(common.Newline())
145+
builder.WriteString(`{{end}}`)
146+
builder.WriteString(common.Newline())
147+
builder.WriteString(common.Newline())
134148
}
135149

136150
func genTmplWhere(modelName string, model []common.ModelInfo) string {
137-
builder := strings.Builder{}
138-
139-
builder.WriteString("{{")
140-
for i := range model {
141-
field := common.Column2Modelfield(model[i].ColumnName)
142-
if i == 0 {
143-
builder.WriteString(fmt.Sprintf(`where (ne .%s %s) "AND" "%s" .%s ""`, field, getTmplCond(model[i].DataType), model[i].ColumnName, field))
144-
} else {
145-
builder.WriteString(fmt.Sprintf(` | where (ne .%s %s) "AND" "%s" .%s`, field, getTmplCond(model[i].DataType), model[i].ColumnName, field))
146-
}
147-
}
148-
builder.WriteString("}}")
149-
150-
return builder.String()
151+
builder := strings.Builder{}
152+
153+
builder.WriteString("{{")
154+
for i := range model {
155+
field := common.Column2Modelfield(model[i].ColumnName)
156+
if i == 0 {
157+
builder.WriteString(fmt.Sprintf(`where .%s "AND" "%s = " (arg .%s) ""`, field, common.CommonEscapeKeywordFormatter(model[i].ColumnName), field))
158+
} else {
159+
builder.WriteString(fmt.Sprintf(` | where .%s "AND" "%s = " (arg .%s)`, field, common.CommonEscapeKeywordFormatter(model[i].ColumnName), field))
160+
}
161+
}
162+
builder.WriteString("}}")
163+
164+
return builder.String()
151165
}
152166

153167
func genTmplSet(modelName string, model []common.ModelInfo) (string, int) {
154-
builder := strings.Builder{}
155-
156-
index := -1
157-
builder.WriteString("{{")
158-
for i := range model {
159-
field := common.Column2Modelfield(model[i].ColumnName)
160-
if i == 0 {
161-
builder.WriteString(fmt.Sprintf(`set (ne .%s %s) "%s" .%s ""`, field, getTmplCond(model[i].DataType), model[i].ColumnName, field))
162-
} else {
163-
builder.WriteString(fmt.Sprintf(` | set (ne .%s %s) "%s" .%s`, field, getTmplCond(model[i].DataType), model[i].ColumnName, field))
164-
}
165-
if strings.ToUpper(model[i].ColumnKey) == "PRI" {
166-
index = i
167-
continue
168-
}
169-
}
170-
builder.WriteString("}}")
171-
172-
return builder.String(), index
168+
builder := strings.Builder{}
169+
170+
index := -1
171+
builder.WriteString("{{")
172+
for i := range model {
173+
field := common.Column2Modelfield(model[i].ColumnName)
174+
if i == 0 {
175+
builder.WriteString(fmt.Sprintf(`set .%s "%s = " (arg .%s) ""`, field, common.CommonEscapeKeywordFormatter(model[i].ColumnName), field))
176+
} else {
177+
builder.WriteString(fmt.Sprintf(` | set .%s "%s = " (arg .%s)`, field, common.CommonEscapeKeywordFormatter(model[i].ColumnName), field))
178+
}
179+
if strings.ToUpper(model[i].ColumnKey) == "PRI" {
180+
index = i
181+
continue
182+
}
183+
}
184+
builder.WriteString("}}")
185+
186+
return builder.String(), index
173187
}
174188

175189
func genTmplValues(modelName string, model []common.ModelInfo) string {
176-
builder := strings.Builder{}
177-
178-
size := len(model)
179-
for i := range model {
180-
if sqlType2GoMap[model[i].DataType] == "string" {
181-
builder.WriteString(fmt.Sprintf("'{{.%s}}'", common.Column2Modelfield(model[i].ColumnName)))
182-
} else {
183-
builder.WriteString(fmt.Sprintf("{{.%s}}", common.Column2Modelfield(model[i].ColumnName)))
184-
}
185-
186-
size--
187-
if size > 0 {
188-
builder.WriteString(", ")
189-
}
190-
}
191-
192-
return builder.String()
190+
builder := strings.Builder{}
191+
192+
size := len(model)
193+
for i := range model {
194+
builder.WriteString(fmt.Sprintf("{{arg .%s}}", common.Column2Modelfield(model[i].ColumnName)))
195+
size--
196+
if size > 0 {
197+
builder.WriteString(", ")
198+
}
199+
}
200+
201+
return builder.String()
193202
}
194203

195204
func genTmplRangeValues(modelName string, model []common.ModelInfo) string {
196-
builder := strings.Builder{}
197-
198-
builder.WriteString(`(`)
199-
size := len(model)
200-
for i := range model {
201-
if sqlType2GoMap[model[i].DataType] == "string" {
202-
builder.WriteString(fmt.Sprintf("'{{$v.%s}}'", common.Column2Modelfield(model[i].ColumnName)))
203-
} else {
204-
builder.WriteString(fmt.Sprintf("{{$v.%s}}", common.Column2Modelfield(model[i].ColumnName)))
205-
}
206-
207-
size--
208-
if size > 0 {
209-
builder.WriteString(", ")
210-
}
211-
}
212-
213-
builder.WriteString(`){{if lt $i $size}},{{end}}`)
214-
builder.WriteString(common.Newline())
215-
return builder.String()
205+
builder := strings.Builder{}
206+
207+
builder.WriteString(`(`)
208+
size := len(model)
209+
for i := range model {
210+
builder.WriteString(fmt.Sprintf("{{arg $v.%s}}", common.Column2Modelfield(model[i].ColumnName)))
211+
size--
212+
if size > 0 {
213+
builder.WriteString(", ")
214+
}
215+
}
216+
217+
builder.WriteString(`){{if lt $i $size}},{{end}}`)
218+
builder.WriteString(common.Newline())
219+
return builder.String()
216220
}
217221

218222
func getTmplCond(ctype string) string {
219-
return sqlType2IfCondMap[ctype]
223+
return sqlType2IfCondMap[ctype]
220224
}

gen_xml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
)
1818

1919
func genXml(config Config, tableName string, model []common.ModelInfo) {
20-
common.SelectKeyWordFormatter(config.Driver)
20+
common.SelectKeywordFormatter(config.Driver)
2121
if config.MapperFile == "xml" {
2222
xmlDir := config.Path + "xml/"
2323
if !io.IsPathExists(xmlDir) {

0 commit comments

Comments
 (0)