Skip to content

Commit

Permalink
Support --let (#49)
Browse files Browse the repository at this point in the history
  • Loading branch information
dveeden authored Nov 15, 2023
1 parent 6fe2160 commit fc3c1e6
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
38 changes: 38 additions & 0 deletions src/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,11 @@ func (t *tester) Run() error {
case Q_ERROR:
t.expectedErrs = strings.Split(strings.TrimSpace(s), ",")
case Q_ECHO:
varSearch := regexp.MustCompile("\\$([A-Za-z0-9_]+)( |$)")
s := varSearch.ReplaceAllStringFunc(s, func(s string) string {
return os.Getenv(varSearch.FindStringSubmatch(s)[1])
})

t.buf.WriteString(s)
t.buf.WriteString("\n")
case Q_QUERY:
Expand Down Expand Up @@ -431,6 +436,30 @@ func (t *tester) Run() error {
q.Query = q.Query[:len(q.Query)-1]
}
t.disconnect(q.Query)
case Q_LET:
q.Query = strings.TrimSpace(q.Query)
eqIdx := strings.Index(q.Query, "=")
if eqIdx > 1 {
start := 0
if q.Query[0] == '$' {
start = 1
}
varName := strings.TrimSpace(q.Query[start:eqIdx])
varValue := strings.TrimSpace(q.Query[eqIdx+1:])
varSearch := regexp.MustCompile("`(.*)`")
varValue = varSearch.ReplaceAllStringFunc(varValue, func(s string) string {
s = strings.Trim(s, "`")
r, err := t.executeStmtString(s)
if err != nil {
log.WithFields(log.Fields{
"query": s, "line": q.Line},
).Error("failed to perform let query")
return ""
}
return r
})
os.Setenv(varName, varValue)
}
case Q_REMOVE_FILE:
err = os.Remove(strings.TrimSpace(q.Query))
if err != nil {
Expand Down Expand Up @@ -818,6 +847,15 @@ func (t *tester) executeStmt(query string) error {
return nil
}

func (t *tester) executeStmtString(query string) (string, error) {
var result string
err := t.mdb.QueryRow(query).Scan(&result)
if err != nil {
return "", err
}
return result, nil
}

func (t *tester) openResult() error {
if record {
return nil
Expand Down
4 changes: 4 additions & 0 deletions src/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,8 @@ func TestParseQueryies(t *testing.T) {
_, err := ParseQueries(query{Query: sql, Line: 1})
assertEqual(t, err, ErrInvalidCommand, fmt.Sprintf("Expected: %v, got %v", ErrInvalidCommand, err))

sql = "--let $foo=`SELECT 1`"
if q, err := ParseQueries(query{Query: sql, Line: 1}); err == nil {
assertEqual(t, q[0].tp, Q_LET, fmt.Sprintf("Expected: %d, got: %d", Q_LET, q[0].tp))
}
}

0 comments on commit fc3c1e6

Please sign in to comment.