Skip to content

Commit 9780801

Browse files
Add support for customized database name with --using (#290)
This commit adds support for specifying database name as part of --using argument. The database name now can be specified after the URL as ,DbName . The database from bak file will be restored as the database name specified.
1 parent db1147f commit 9780801

File tree

3 files changed

+135
-16
lines changed

3 files changed

+135
-16
lines changed

cmd/modern/root/install/mssql-base.go

Lines changed: 69 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@ package install
55

66
import (
77
"fmt"
8+
"net/url"
9+
"path"
10+
"path/filepath"
11+
"runtime"
12+
"strings"
13+
814
"github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
915
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
1016
"github.com/microsoft/go-sqlcmd/internal/config"
@@ -15,10 +21,6 @@ import (
1521
"github.com/microsoft/go-sqlcmd/internal/secret"
1622
"github.com/microsoft/go-sqlcmd/internal/sql"
1723
"github.com/spf13/viper"
18-
"net/url"
19-
"path/filepath"
20-
"runtime"
21-
"strings"
2224
)
2325

2426
// MssqlBase provide base support for installing SQL Server and all of its
@@ -395,7 +397,8 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
395397

396398
func (c *MssqlBase) validateUsingUrlExists() {
397399
output := c.Cmd.Output()
398-
u, err := url.Parse(c.usingDatabaseUrl)
400+
databaseUrl := extractUrl(c.usingDatabaseUrl)
401+
u, err := url.Parse(databaseUrl)
399402
c.CheckErr(err)
400403

401404
if u.Scheme != "http" && u.Scheme != "https" {
@@ -406,9 +409,17 @@ func (c *MssqlBase) validateUsingUrlExists() {
406409
"%q is not a valid URL for --using flag", c.usingDatabaseUrl)
407410
}
408411

412+
if u.Path == "" {
413+
output.FatalfWithHints(
414+
[]string{
415+
"--using URL must have a path to .bak file",
416+
},
417+
"%q is not a valid URL for --using flag", c.usingDatabaseUrl)
418+
}
419+
409420
// At the moment we only support attaching .bak files, but we should
410421
// support .bacpacs and .mdfs in the future
411-
if _, file := filepath.Split(c.usingDatabaseUrl); filepath.Ext(file) != ".bak" {
422+
if _, file := filepath.Split(u.Path); filepath.Ext(file) != ".bak" {
412423
output.FatalfWithHints(
413424
[]string{
414425
"--using file URL must be a .bak file",
@@ -417,7 +428,7 @@ func (c *MssqlBase) validateUsingUrlExists() {
417428
}
418429

419430
// Verify the url actually exists, and early exit if it doesn't
420-
urlExists(c.usingDatabaseUrl, output)
431+
urlExists(databaseUrl, output)
421432
}
422433

423434
func (c *MssqlBase) query(commandText string) {
@@ -468,31 +479,73 @@ CHECK_POLICY=OFF`
468479
}
469480
}
470481

482+
func getDbNameAsIdentifier(dbName string) string {
483+
escapedDbNAme := strings.ReplaceAll(dbName, "'", "''")
484+
return strings.ReplaceAll(escapedDbNAme, "]", "]]")
485+
}
486+
487+
func getDbNameAsNonIdentifier(dbName string) string {
488+
return strings.ReplaceAll(dbName, "]", "]]")
489+
}
490+
491+
//parseDbName returns the databaseName from --using arg
492+
// It sets database name to the specified database name
493+
// or in absence of it, it is set to the filename without
494+
// extension.
495+
func parseDbName(usingDbUrl string) string {
496+
u, _ := url.Parse(usingDbUrl)
497+
dbToken := path.Base(u.Path)
498+
if dbToken != "." && dbToken != "/" {
499+
lastIdx := strings.LastIndex(dbToken, ".bak")
500+
if lastIdx != -1 {
501+
//Get file name without extension
502+
fileName := dbToken[0:lastIdx]
503+
lastIdx += 5
504+
if lastIdx >= len(dbToken) {
505+
return fileName
506+
}
507+
//Return database name if it was specified
508+
return dbToken[lastIdx:]
509+
}
510+
}
511+
return ""
512+
}
513+
514+
func extractUrl(usingArg string) string {
515+
urlEndIdx := strings.LastIndex(usingArg, ".bak")
516+
if urlEndIdx != -1 {
517+
return usingArg[0:(urlEndIdx + 4)]
518+
}
519+
return usingArg
520+
}
521+
471522
func (c *MssqlBase) downloadAndRestoreDb(
472523
controller *container.Controller,
473524
containerId string,
474525
userName string,
475526
) {
476527
output := c.Cmd.Output()
528+
databaseName := parseDbName(c.usingDatabaseUrl)
529+
databaseUrl := extractUrl(c.usingDatabaseUrl)
477530

478-
u, err := url.Parse(c.usingDatabaseUrl)
479-
c.CheckErr(err)
480-
_, file := filepath.Split(c.usingDatabaseUrl)
481-
fileNameWithNoExt := strings.TrimSuffix(file, filepath.Ext(file))
531+
_, file := filepath.Split(databaseUrl)
482532

483533
// Download file from URL into container
484-
output.Infof("Downloading %s from %s", file, u.Hostname())
534+
output.Infof("Downloading %s", file)
485535

486536
temporaryFolder := "/var/opt/mssql/backup"
487537

488538
controller.DownloadFile(
489539
containerId,
490-
c.usingDatabaseUrl,
540+
databaseUrl,
491541
temporaryFolder,
492542
)
493543

494544
// Restore database from file
495-
output.Infof("Restoring database %s", fileNameWithNoExt)
545+
output.Infof("Restoring database %s", databaseName)
546+
547+
dbNameAsIdentifier := getDbNameAsIdentifier(databaseName)
548+
dbNameAsNonIdentifier := getDbNameAsNonIdentifier(databaseName)
496549

497550
text := `SET NOCOUNT ON;
498551
@@ -535,12 +588,12 @@ WHERE IsPresent = 1
535588
SET @sql = SUBSTRING(@sql, 1, LEN(@sql)-1)
536589
EXEC(@sql)`
537590

538-
c.query(fmt.Sprintf(text, temporaryFolder, file, fileNameWithNoExt, temporaryFolder, file))
591+
c.query(fmt.Sprintf(text, temporaryFolder, file, dbNameAsIdentifier, temporaryFolder, file))
539592

540593
alterDefaultDb := fmt.Sprintf(
541594
"ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [%s]",
542595
userName,
543-
fileNameWithNoExt)
596+
dbNameAsNonIdentifier)
544597
c.query(alterDefaultDb)
545598
}
546599

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package install
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestGetDbNameIfExists(t *testing.T) {
10+
11+
type test struct {
12+
input string
13+
expectedIdentifierOp string
14+
expectedNonIdentifierOp string
15+
}
16+
17+
tests := []test{
18+
// Positive Testcases
19+
// Database name specified
20+
{"https://example.com/my%20random%20bac%27kp%5Dack.bak,myDbName", "myDbName", "myDbName"},
21+
{"https://example.com/my%20random%20bac%27kp%5Dack.bak,myDb Name", "myDb Name", "myDb Name"},
22+
{"https://example.com/my%20random%20bac%27kp%5Dack.bak,myDb Na,me", "myDb Na,me", "myDb Na,me"},
23+
{"https://example.com/my%20random%20bac%27kp%5Dack.bak,[myDb Na,me]", "[myDb Na,me]]", "[myDb Na,me]]"},
24+
{"https://example.com/my%20random%20bac%27kp%5Dack.bak,[myDb Na'me]", "[myDb Na''me]]", "[myDb Na'me]]"},
25+
{"https://example.com/my%20random%20bac%27kp%5Dack.bak,[myDb ,Nam,e]", "[myDb ,Nam,e]]", "[myDb ,Nam,e]]"},
26+
27+
// Delimiter between filename and databaseName is part of the filename
28+
// Decoded filename: my random .bak bac'kp]ack.bak
29+
{"https://example.com/my%20random%20.bak%20bac%27kp%5Dack.bak,[myDb ,Nam,e]", "[myDb ,Nam,e]]", "[myDb ,Nam,e]]"},
30+
31+
// Database name not specified
32+
{"https://example.com/my%20random%20.bak%20bac%27kp%5Dack.bak", "my random .bak bac''kp]]ack", "my random .bak bac'kp]]ack"},
33+
{"https://example.com/my%20random%20.bak%20bac%27kp%5Dack.bak,", "my random .bak bac''kp]]ack", "my random .bak bac'kp]]ack"},
34+
35+
//Negative Testcases
36+
{"https://example.com,myDbName", "", ""},
37+
}
38+
39+
for _, testcase := range tests {
40+
dbname := parseDbName(testcase.input)
41+
dbnameAsIdentifier := getDbNameAsIdentifier(dbname)
42+
dbnameAsNonIdentifier := getDbNameAsNonIdentifier(dbname)
43+
assert.Equal(t, testcase.expectedIdentifierOp, dbnameAsIdentifier, "Unexpected database name as identifier")
44+
assert.Equal(t, testcase.expectedNonIdentifierOp, dbnameAsNonIdentifier, "Unexpected database name as non-identifier")
45+
}
46+
}
47+
48+
func TestExtractUrl(t *testing.T) {
49+
type test struct {
50+
inputURL string
51+
expectedURL string
52+
}
53+
54+
tests := []test{
55+
{"https://example.com/testdb.bak,myDbName", "https://example.com/testdb.bak"},
56+
{"https://example.com/testdb.bak", "https://example.com/testdb.bak"},
57+
{"https://example.com,", "https://example.com,"},
58+
}
59+
60+
for _, testcase := range tests {
61+
assert.Equal(t, testcase.expectedURL, extractUrl(testcase.inputURL), "Extracted URL does not match expected URL")
62+
}
63+
}

cmd/modern/root/install/mssql.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ func (c *Mssql) DefineCommand(...cmdparser.CommandOptions) {
3636
{
3737
Description: "Create SQL Server, download and attach AdventureWorks sample database",
3838
Steps: []string{"sqlcmd create mssql --using https://aka.ms/AdventureWorksLT.bak"}},
39+
{
40+
Description: "Create SQL Server, download and attach AdventureWorks sample database with different database name",
41+
Steps: []string{"sqlcmd create mssql --using https://aka.ms/AdventureWorksLT.bak,adventureworks"}},
3942
{
4043
Description: "Create SQL Server with an empty user database",
4144
Steps: []string{"sqlcmd create mssql --user-database db1"}},

0 commit comments

Comments
 (0)