Skip to content

Commit

Permalink
CB-28369: Add fallback to HTTP if target saltboot is not listening on…
Browse files Browse the repository at this point in the history
… the HTTPS port
  • Loading branch information
szabolcs-horvath committed Jan 30, 2025
1 parent e7a9c29 commit ea5ae18
Show file tree
Hide file tree
Showing 4 changed files with 454 additions and 54 deletions.
69 changes: 43 additions & 26 deletions saltboot/distributor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"bytes"
"crypto/tls"
"encoding/json"
"errors"
"io/ioutil"
"log"
"net/http"
"strconv"
"strings"
"sync"
"syscall"

"fmt"
"io"
Expand All @@ -18,16 +20,16 @@ import (
"github.com/hortonworks/salt-bootstrap/saltboot/model"
)

func determineProtocol() string {
if HttpsEnabled() {
func determineProtocol(httpsEnabled bool) string {
if httpsEnabled {
return "https://"
} else {
return "http://"
}
}

func getHttpClient() *http.Client {
if HttpsEnabled() {
func getHttpClient(httpsEnabled bool) *http.Client {
if httpsEnabled {
return &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
Expand All @@ -40,57 +42,71 @@ func getHttpClient() *http.Client {
}
}

func sendRequestWithFallback(httpClient *http.Client, request *http.Request, httpsEnabled bool) (string, *http.Response, error) {
resp, err := httpClient.Do(request)
if httpsEnabled && err != nil && errors.Is(err, syscall.ECONNREFUSED) {
log.Printf("[sendRequestWithFallback] Could not reach the target using HTTPS. Falling back to HTTP.")
newRequest := request.Clone(request.Context())
newRequest.URL.Scheme = "http"
newRequest.URL.Host = newRequest.URL.Hostname() + ":" + strconv.Itoa(DetermineHttpPort())
resp, err = httpClient.Do(newRequest)
return newRequest.URL.Host, resp, err
}
return request.URL.Host, resp, err
}

func DistributeRequest(clients []string, endpoint, user, pass string, requestBody RequestBody) <-chan model.Response {
protocol := determineProtocol()
httpsEnabled := HttpsEnabled()
protocol := determineProtocol(httpsEnabled)
var wg sync.WaitGroup
wg.Add(len(clients))
c := make(chan model.Response, len(clients))

for idx, client := range clients {
go func(client string, index int) {
defer wg.Done()
log.Printf("[distribute] Send request to client: %s", client)
log.Printf("[DistributeRequest] Send request to client: %s", client)

var clientAddr string
if strings.Contains(client, ":") {
clientAddr = client
} else {
clientAddr = client + ":" + strconv.Itoa(DetermineBootstrapPort())
clientAddr = client + ":" + strconv.Itoa(DetermineBootstrapPort(httpsEnabled))
}

var req *http.Request
if len(requestBody.Signature) > 0 {
indexString := strconv.Itoa(index)
log.Printf("[distribute] Send signed request to client: %s with index: %s", client, indexString)
log.Printf("[DistributeRequest] Send signed request to client: %s with index: %s", client, indexString)
req, _ = http.NewRequest("POST", protocol+clientAddr+endpoint+"?index="+indexString, bytes.NewBufferString(requestBody.SignedPayload))
req.Header.Set(SIGNATURE, requestBody.Signature)
} else {
log.Printf("[distribute] Send plain request to client: %s", client)
log.Printf("[DistributeRequest] Send plain request to client: %s", client)
req, _ = http.NewRequest("POST", protocol+clientAddr+endpoint, bytes.NewBuffer(requestBody.PlainPayload))
}
req.Header.Set("Content-Type", "application/json")
req.SetBasicAuth(user, pass)

httpClient := getHttpClient()
resp, err := httpClient.Do(req)
httpClient := getHttpClient(httpsEnabled)
respHost, resp, err := sendRequestWithFallback(httpClient, req, httpsEnabled)
if err != nil {
log.Printf("[distribute] [ERROR] Failed to send request to: %s, error: %s", client, err.Error())
c <- model.Response{StatusCode: http.StatusInternalServerError, ErrorText: err.Error(), Address: client}
log.Printf("[DistributeRequest] [ERROR] Failed to send request to: %s, error: %s", client, err.Error())
c <- model.Response{StatusCode: http.StatusInternalServerError, ErrorText: err.Error(), Address: respHost}
return
}

body, _ := ioutil.ReadAll(resp.Body)
decoder := json.NewDecoder(strings.NewReader(string(body)))
var response model.Response
if err := decoder.Decode(&response); err != nil {
log.Printf("[distribute] [ERROR] Failed to decode response, error: %s", err.Error())
log.Printf("[DistributeRequest] [ERROR] Failed to decode response, error: %s", err.Error())
}
response.Address = client
response.Address = respHost

if response.StatusCode == 0 {
response.StatusCode = resp.StatusCode
}
log.Printf("[distribute] Request to: %s result: %s", client, response.String())
log.Printf("[DistributeRequest] Request to: %s result: %s", client, response.String())
c <- response
defer closeIt(resp.Body)
}(client, idx)
Expand All @@ -105,7 +121,8 @@ func DistributeRequest(clients []string, endpoint, user, pass string, requestBod
func DistributeFileUploadRequest(endpoint string, user string, pass string, targets []string, path string,
permissions string, file multipart.File, header *multipart.FileHeader, signature string) <-chan model.Response {

protocol := determineProtocol()
httpsEnabled := HttpsEnabled()
protocol := determineProtocol(httpsEnabled)
var wg sync.WaitGroup
wg.Add(len(targets))
c := make(chan model.Response, len(targets))
Expand Down Expand Up @@ -145,31 +162,31 @@ func DistributeFileUploadRequest(endpoint string, user string, pass string, targ
if strings.Contains(target, ":") {
targetAddress = target
} else {
targetAddress = target + ":" + strconv.Itoa(DetermineBootstrapPort())
targetAddress = target + ":" + strconv.Itoa(DetermineBootstrapPort(httpsEnabled))
}

req, err := http.NewRequest("POST", protocol+targetAddress+endpoint, bytes.NewReader(fileContent))
req.Header.Set(SIGNATURE, signature)
req.Header.Set("Content-Type", bodyWriter.FormDataContentType())
req.SetBasicAuth(user, pass)

httpClient := getHttpClient()
resp, err := httpClient.Do(req)
httpClient := getHttpClient(httpsEnabled)
respHost, resp, err := sendRequestWithFallback(httpClient, req, httpsEnabled)
if err != nil {
log.Printf("[DistributeFileUploadRequest] [ERROR] Failed to send request to: %s, error: %s", target, err.Error())
c <- model.Response{StatusCode: http.StatusInternalServerError, ErrorText: err.Error(), Address: target}
log.Printf("[DistributeFileUploadRequest] [ERROR] Failed to send request to: %s, error: %s", respHost, err.Error())
c <- model.Response{StatusCode: http.StatusInternalServerError, ErrorText: err.Error(), Address: respHost}
return
}

body, _ := ioutil.ReadAll(resp.Body)
defer closeIt(resp.Body)
if resp.StatusCode != http.StatusCreated {
log.Printf("[DistributeFileUploadRequest] Error response from: %s, error: %s", target, body)
c <- model.Response{StatusCode: resp.StatusCode, ErrorText: string(body), Address: target}
log.Printf("[DistributeFileUploadRequest] Error response from: %s, error: %s", respHost, body)
c <- model.Response{StatusCode: resp.StatusCode, ErrorText: string(body), Address: respHost}
return
} else {
log.Printf("[DistributeFileUploadRequest] Request to: %s result: %s", target, body)
c <- model.Response{StatusCode: http.StatusCreated, Status: string(body), Address: target}
log.Printf("[DistributeFileUploadRequest] Request to: %s result: %s", respHost, body)
c <- model.Response{StatusCode: http.StatusCreated, Status: string(body), Address: respHost}
}
}(target, i)
}
Expand Down
Loading

0 comments on commit ea5ae18

Please sign in to comment.