@@ -5,6 +5,12 @@ package install
5
5
6
6
import (
7
7
"fmt"
8
+ "net/url"
9
+ "path"
10
+ "path/filepath"
11
+ "runtime"
12
+ "strings"
13
+
8
14
"github.com/microsoft/go-sqlcmd/cmd/modern/sqlconfig"
9
15
"github.com/microsoft/go-sqlcmd/internal/cmdparser"
10
16
"github.com/microsoft/go-sqlcmd/internal/config"
@@ -15,10 +21,6 @@ import (
15
21
"github.com/microsoft/go-sqlcmd/internal/secret"
16
22
"github.com/microsoft/go-sqlcmd/internal/sql"
17
23
"github.com/spf13/viper"
18
- "net/url"
19
- "path/filepath"
20
- "runtime"
21
- "strings"
22
24
)
23
25
24
26
// MssqlBase provide base support for installing SQL Server and all of its
@@ -395,7 +397,8 @@ func (c *MssqlBase) createContainer(imageName string, contextName string) {
395
397
396
398
func (c * MssqlBase ) validateUsingUrlExists () {
397
399
output := c .Cmd .Output ()
398
- u , err := url .Parse (c .usingDatabaseUrl )
400
+ databaseUrl := extractUrl (c .usingDatabaseUrl )
401
+ u , err := url .Parse (databaseUrl )
399
402
c .CheckErr (err )
400
403
401
404
if u .Scheme != "http" && u .Scheme != "https" {
@@ -406,9 +409,17 @@ func (c *MssqlBase) validateUsingUrlExists() {
406
409
"%q is not a valid URL for --using flag" , c .usingDatabaseUrl )
407
410
}
408
411
412
+ if u .Path == "" {
413
+ output .FatalfWithHints (
414
+ []string {
415
+ "--using URL must have a path to .bak file" ,
416
+ },
417
+ "%q is not a valid URL for --using flag" , c .usingDatabaseUrl )
418
+ }
419
+
409
420
// At the moment we only support attaching .bak files, but we should
410
421
// support .bacpacs and .mdfs in the future
411
- if _ , file := filepath .Split (c . usingDatabaseUrl ); filepath .Ext (file ) != ".bak" {
422
+ if _ , file := filepath .Split (u . Path ); filepath .Ext (file ) != ".bak" {
412
423
output .FatalfWithHints (
413
424
[]string {
414
425
"--using file URL must be a .bak file" ,
@@ -417,7 +428,7 @@ func (c *MssqlBase) validateUsingUrlExists() {
417
428
}
418
429
419
430
// Verify the url actually exists, and early exit if it doesn't
420
- urlExists (c . usingDatabaseUrl , output )
431
+ urlExists (databaseUrl , output )
421
432
}
422
433
423
434
func (c * MssqlBase ) query (commandText string ) {
@@ -468,31 +479,73 @@ CHECK_POLICY=OFF`
468
479
}
469
480
}
470
481
482
+ func getDbNameAsIdentifier (dbName string ) string {
483
+ escapedDbNAme := strings .ReplaceAll (dbName , "'" , "''" )
484
+ return strings .ReplaceAll (escapedDbNAme , "]" , "]]" )
485
+ }
486
+
487
+ func getDbNameAsNonIdentifier (dbName string ) string {
488
+ return strings .ReplaceAll (dbName , "]" , "]]" )
489
+ }
490
+
491
+ //parseDbName returns the databaseName from --using arg
492
+ // It sets database name to the specified database name
493
+ // or in absence of it, it is set to the filename without
494
+ // extension.
495
+ func parseDbName (usingDbUrl string ) string {
496
+ u , _ := url .Parse (usingDbUrl )
497
+ dbToken := path .Base (u .Path )
498
+ if dbToken != "." && dbToken != "/" {
499
+ lastIdx := strings .LastIndex (dbToken , ".bak" )
500
+ if lastIdx != - 1 {
501
+ //Get file name without extension
502
+ fileName := dbToken [0 :lastIdx ]
503
+ lastIdx += 5
504
+ if lastIdx >= len (dbToken ) {
505
+ return fileName
506
+ }
507
+ //Return database name if it was specified
508
+ return dbToken [lastIdx :]
509
+ }
510
+ }
511
+ return ""
512
+ }
513
+
514
+ func extractUrl (usingArg string ) string {
515
+ urlEndIdx := strings .LastIndex (usingArg , ".bak" )
516
+ if urlEndIdx != - 1 {
517
+ return usingArg [0 :(urlEndIdx + 4 )]
518
+ }
519
+ return usingArg
520
+ }
521
+
471
522
func (c * MssqlBase ) downloadAndRestoreDb (
472
523
controller * container.Controller ,
473
524
containerId string ,
474
525
userName string ,
475
526
) {
476
527
output := c .Cmd .Output ()
528
+ databaseName := parseDbName (c .usingDatabaseUrl )
529
+ databaseUrl := extractUrl (c .usingDatabaseUrl )
477
530
478
- u , err := url .Parse (c .usingDatabaseUrl )
479
- c .CheckErr (err )
480
- _ , file := filepath .Split (c .usingDatabaseUrl )
481
- fileNameWithNoExt := strings .TrimSuffix (file , filepath .Ext (file ))
531
+ _ , file := filepath .Split (databaseUrl )
482
532
483
533
// Download file from URL into container
484
- output .Infof ("Downloading %s from %s " , file , u . Hostname () )
534
+ output .Infof ("Downloading %s" , file )
485
535
486
536
temporaryFolder := "/var/opt/mssql/backup"
487
537
488
538
controller .DownloadFile (
489
539
containerId ,
490
- c . usingDatabaseUrl ,
540
+ databaseUrl ,
491
541
temporaryFolder ,
492
542
)
493
543
494
544
// Restore database from file
495
- output .Infof ("Restoring database %s" , fileNameWithNoExt )
545
+ output .Infof ("Restoring database %s" , databaseName )
546
+
547
+ dbNameAsIdentifier := getDbNameAsIdentifier (databaseName )
548
+ dbNameAsNonIdentifier := getDbNameAsNonIdentifier (databaseName )
496
549
497
550
text := `SET NOCOUNT ON;
498
551
@@ -535,12 +588,12 @@ WHERE IsPresent = 1
535
588
SET @sql = SUBSTRING(@sql, 1, LEN(@sql)-1)
536
589
EXEC(@sql)`
537
590
538
- c .query (fmt .Sprintf (text , temporaryFolder , file , fileNameWithNoExt , temporaryFolder , file ))
591
+ c .query (fmt .Sprintf (text , temporaryFolder , file , dbNameAsIdentifier , temporaryFolder , file ))
539
592
540
593
alterDefaultDb := fmt .Sprintf (
541
594
"ALTER LOGIN [%s] WITH DEFAULT_DATABASE = [%s]" ,
542
595
userName ,
543
- fileNameWithNoExt )
596
+ dbNameAsNonIdentifier )
544
597
c .query (alterDefaultDb )
545
598
}
546
599
0 commit comments