Skip to content

Commit 5e00a32

Browse files
committed
[feat][prompt] prompt support go template (#269)
* [feat][prompt] prompt support go template * [feat][prompt] prompt support go template * [feat][prompt] prompt support go template ut * [feat][prompt] prompt support go template ut
1 parent e23c22e commit 5e00a32

File tree

10 files changed

+698
-2
lines changed

10 files changed

+698
-2
lines changed

backend/kitex_gen/coze/loop/prompt/domain/prompt/prompt.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/kitex_gen/coze/loop/prompt/openapi/coze.loop.prompt.openapi.go

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

backend/modules/prompt/application/convertor/prompt.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ func TemplateTypeDTO2DO(dto prompt.TemplateType) entity.TemplateType {
119119
return entity.TemplateTypeNormal
120120
case prompt.TemplateTypeJinja2:
121121
return entity.TemplateTypeJinja2
122+
case prompt.TemplateTypeGoTemplate:
123+
return entity.TemplateTypeGoTemplate
122124
default:
123125
return entity.TemplateTypeNormal
124126
}

backend/modules/prompt/application/convertor/prompt_test.go

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,3 +665,122 @@ func TestModelConfigExtraConversion(t *testing.T) {
665665
assert.NotNil(t, dtoBack)
666666
assert.Equal(t, extra, dtoBack.Extra)
667667
}
668+
669+
func TestTemplateTypeDTO2DO(t *testing.T) {
670+
tests := []struct {
671+
name string
672+
dto prompt.TemplateType
673+
want entity.TemplateType
674+
}{
675+
{
676+
name: "normal template type",
677+
dto: prompt.TemplateTypeNormal,
678+
want: entity.TemplateTypeNormal,
679+
},
680+
{
681+
name: "jinja2 template type",
682+
dto: prompt.TemplateTypeJinja2,
683+
want: entity.TemplateTypeJinja2,
684+
},
685+
{
686+
name: "go template type",
687+
dto: prompt.TemplateTypeGoTemplate,
688+
want: entity.TemplateTypeGoTemplate,
689+
},
690+
{
691+
name: "unknown template type defaults to normal",
692+
dto: prompt.TemplateType("unknown"),
693+
want: entity.TemplateTypeNormal,
694+
},
695+
}
696+
697+
for _, tt := range tests {
698+
t.Run(tt.name, func(t *testing.T) {
699+
t.Parallel()
700+
got := TemplateTypeDTO2DO(tt.dto)
701+
assert.Equal(t, tt.want, got)
702+
})
703+
}
704+
}
705+
706+
func TestPromptTemplateWithDifferentTypes(t *testing.T) {
707+
t.Parallel()
708+
709+
tests := []struct {
710+
name string
711+
dto *prompt.PromptTemplate
712+
want *entity.PromptTemplate
713+
}{
714+
{
715+
name: "normal template",
716+
dto: &prompt.PromptTemplate{
717+
TemplateType: ptr.Of(prompt.TemplateTypeNormal),
718+
Messages: []*prompt.Message{
719+
{
720+
Role: ptr.Of(prompt.RoleUser),
721+
Content: ptr.Of("Hello {{name}}"),
722+
},
723+
},
724+
},
725+
want: &entity.PromptTemplate{
726+
TemplateType: entity.TemplateTypeNormal,
727+
Messages: []*entity.Message{
728+
{
729+
Role: entity.RoleUser,
730+
Content: ptr.Of("Hello {{name}}"),
731+
},
732+
},
733+
},
734+
},
735+
{
736+
name: "jinja2 template",
737+
dto: &prompt.PromptTemplate{
738+
TemplateType: ptr.Of(prompt.TemplateTypeJinja2),
739+
Messages: []*prompt.Message{
740+
{
741+
Role: ptr.Of(prompt.RoleUser),
742+
Content: ptr.Of("Hello {{ name }}"),
743+
},
744+
},
745+
},
746+
want: &entity.PromptTemplate{
747+
TemplateType: entity.TemplateTypeJinja2,
748+
Messages: []*entity.Message{
749+
{
750+
Role: entity.RoleUser,
751+
Content: ptr.Of("Hello {{ name }}"),
752+
},
753+
},
754+
},
755+
},
756+
{
757+
name: "go template",
758+
dto: &prompt.PromptTemplate{
759+
TemplateType: ptr.Of(prompt.TemplateTypeGoTemplate),
760+
Messages: []*prompt.Message{
761+
{
762+
Role: ptr.Of(prompt.RoleUser),
763+
Content: ptr.Of("Hello {{.name}}"),
764+
},
765+
},
766+
},
767+
want: &entity.PromptTemplate{
768+
TemplateType: entity.TemplateTypeGoTemplate,
769+
Messages: []*entity.Message{
770+
{
771+
Role: entity.RoleUser,
772+
Content: ptr.Of("Hello {{.name}}"),
773+
},
774+
},
775+
},
776+
},
777+
}
778+
779+
for _, tt := range tests {
780+
t.Run(tt.name, func(t *testing.T) {
781+
t.Parallel()
782+
got := PromptTemplateDTO2DO(tt.dto)
783+
assert.Equal(t, tt.want, got)
784+
})
785+
}
786+
}

backend/modules/prompt/domain/entity/prompt_detail.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,9 @@ type PromptTemplate struct {
4343
type TemplateType string
4444

4545
const (
46-
TemplateTypeNormal TemplateType = "normal"
47-
TemplateTypeJinja2 TemplateType = "jinja2"
46+
TemplateTypeNormal TemplateType = "normal"
47+
TemplateTypeJinja2 TemplateType = "jinja2"
48+
TemplateTypeGoTemplate TemplateType = "go_template"
4849
)
4950

5051
type Message struct {
@@ -300,6 +301,8 @@ func formatText(templateType TemplateType, templateStr string, defMap map[string
300301
}), nil
301302
case TemplateTypeJinja2:
302303
return renderJinja2Template(templateStr, defMap, valMap)
304+
case TemplateTypeGoTemplate:
305+
return renderGoTemplate(templateStr, defMap, valMap)
303306
default:
304307
return "", errorx.NewByCode(prompterr.UnsupportedTemplateTypeCode, errorx.WithExtraMsg("unknown template type: "+string(templateType)))
305308
}
@@ -316,6 +319,17 @@ func renderJinja2Template(templateStr string, defMap map[string]*VariableDef, va
316319
return template.InterpolateJinja2(templateStr, variables)
317320
}
318321

322+
// renderGoTemplate 渲染 Go Template 模板
323+
func renderGoTemplate(templateStr string, defMap map[string]*VariableDef, valMap map[string]*VariableVal) (string, error) {
324+
// 转换变量为 map[string]any 格式
325+
variables, err := convertVariablesToMap(defMap, valMap)
326+
if err != nil {
327+
return "", err
328+
}
329+
330+
return template.InterpolateGoTemplate(templateStr, variables)
331+
}
332+
319333
// convertVariablesToMap 将变量定义和变量值转换为模板引擎可用的 map
320334
func convertVariablesToMap(defMap map[string]*VariableDef, valMap map[string]*VariableVal) (map[string]any, error) {
321335
if len(defMap) == 0 || len(valMap) == 0 {

0 commit comments

Comments
 (0)