Skip to content

Commit 75ee48f

Browse files
authored
Add SourcesToFilter support for network-blackhole-port fault (#4408)
1 parent 1e6b153 commit 75ee48f

File tree

9 files changed

+261
-26
lines changed

9 files changed

+261
-26
lines changed

agent/handlers/task_server_setup_test.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3806,8 +3806,6 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
38063806
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
38073807
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
38083808
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
3809-
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
3810-
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
38113809
)
38123810
}
38133811
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go

Lines changed: 20 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types/types.go

Lines changed: 22 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,19 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
120120
if err != nil {
121121
return
122122
}
123+
123124
// Validate the fault request
124125
err = validateRequest(w, request, requestType)
125126
if err != nil {
126127
return
127128
}
128129

130+
if aws.StringValue(request.TrafficType) == types.TrafficTypeEgress &&
131+
aws.Uint16Value(request.Port) == tmds.PortForTasks {
132+
// Add TMDS IP to SouresToFilter so that access to TMDS is not blocked for the task
133+
request.AddSourceToFilterIfNotAlready(tmds.IPForTasks)
134+
}
135+
129136
// Obtain the task metadata via the endpoint container ID
130137
taskMetadata, err := validateTaskMetadata(w, h.AgentState, requestType, r)
131138
if err != nil {
@@ -154,7 +161,8 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
154161
insertTable = "OUTPUT"
155162
}
156163

157-
_, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol), port, chainName,
164+
_, cmdErr := h.startNetworkBlackholePort(ctxWithTimeout, aws.StringValue(request.Protocol),
165+
port, aws.StringValueSlice(request.SourcesToFilter), chainName,
158166
networkMode, networkNSPath, insertTable, taskArn)
159167
if err := ctxWithTimeout.Err(); errors.Is(err, context.DeadlineExceeded) {
160168
statusCode = http.StatusInternalServerError
@@ -187,7 +195,10 @@ func (h *FaultHandler) StartNetworkBlackholePort() func(http.ResponseWriter, *ht
187195
// 2. Creates a new chain via `iptables -N <chain>` (the chain name is in the form of "<trafficType>-<protocol>-<port>")
188196
// 3. Appends a new rule to the newly created chain via `iptables -A <chain> -p <protocol> --dport <port> -j DROP`
189197
// 4. Inserts the newly created chain into the built-in INPUT/OUTPUT table
190-
func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, port, chain, networkMode, netNs, insertTable, taskArn string) (string, error) {
198+
func (h *FaultHandler) startNetworkBlackholePort(
199+
ctx context.Context, protocol, port string, sourcesToFilter []string,
200+
chain, networkMode, netNs, insertTable, taskArn string,
201+
) (string, error) {
191202
running, cmdOutput, err := h.checkNetworkBlackHolePort(ctx, protocol, port, chain, networkMode, netNs, taskArn)
192203
if err != nil {
193204
return cmdOutput, err
@@ -246,12 +257,13 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol,
246257
return "", nil
247258
}
248259

249-
// Add a rule to accept all traffic to TMDS
250-
protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
251-
requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks,
252-
acceptTarget)
253-
if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil {
254-
return out, err
260+
for _, sourceToFilter := range sourcesToFilter {
261+
filterRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
262+
requestTimeoutSeconds, chain, protocol, sourceToFilter, port,
263+
acceptTarget)
264+
if out, err := execRuleChangeCommand(filterRuleCmdString); err != nil {
265+
return out, err
266+
}
255267
}
256268

257269
// Add a rule to drop all traffic to the port that the fault targets

ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go

Lines changed: 132 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ import (
3535
mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks"
3636
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig"
3737
mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks"
38+
"github.com/aws/aws-sdk-go/aws"
3839

3940
"github.com/golang/mock/gomock"
4041
"github.com/gorilla/mux"
@@ -521,8 +522,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
521522
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
522523
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
523524
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
524-
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
525-
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
526525
)
527526
},
528527
},
@@ -556,8 +555,6 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
556555
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
557556
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
558557
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
559-
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
560-
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
561558
)
562559
},
563560
},
@@ -663,6 +660,137 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
663660
)
664661
},
665662
},
663+
{
664+
name: "SourcesToFilter validation failure",
665+
expectedStatusCode: 400,
666+
requestBody: map[string]interface{}{
667+
"Port": port,
668+
"Protocol": protocol,
669+
"TrafficType": trafficType,
670+
"SourcesToFilter": aws.StringSlice([]string{"1.2.3.4", "bad"}),
671+
},
672+
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse("invalid value bad for parameter SourcesToFilter"),
673+
},
674+
{
675+
name: "TMDS IP is added to SourcesToFilter if needed",
676+
requestBody: map[string]interface{}{
677+
"Port": 80,
678+
"Protocol": protocol,
679+
"TrafficType": "egress",
680+
},
681+
expectedStatusCode: 200,
682+
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
683+
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
684+
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
685+
Return(happyTaskResponse, nil).
686+
Times(1)
687+
},
688+
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
689+
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
690+
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
691+
gomock.InOrder(
692+
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
693+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
694+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
695+
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
696+
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
697+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
698+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
699+
exec.EXPECT().CommandContext(gomock.Any(),
700+
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80",
701+
"-p", "tcp", "-d", "169.254.170.2", "--dport", "80", "-j", "ACCEPT",
702+
).Times(1).Return(cmdExec),
703+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
704+
exec.EXPECT().CommandContext(gomock.Any(),
705+
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "egress-tcp-80",
706+
"-p", "tcp", "-d", "0.0.0.0/0", "--dport", "80", "-j", "DROP",
707+
).Times(1).Return(cmdExec),
708+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
709+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
710+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
711+
)
712+
},
713+
},
714+
{
715+
name: "Sources to filter are filtered",
716+
requestBody: map[string]interface{}{
717+
"Port": 443,
718+
"Protocol": "udp",
719+
"TrafficType": "ingress",
720+
"SourcesToFilter": []string{"1.2.3.4/20", "8.8.8.8"},
721+
},
722+
expectedStatusCode: 200,
723+
expectedResponseBody: types.NewNetworkFaultInjectionSuccessResponse("running"),
724+
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
725+
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
726+
Return(happyTaskResponse, nil).
727+
Times(1)
728+
},
729+
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
730+
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
731+
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
732+
gomock.InOrder(
733+
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
734+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
735+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
736+
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
737+
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
738+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
739+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
740+
exec.EXPECT().CommandContext(gomock.Any(),
741+
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
742+
"-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT",
743+
).Times(1).Return(cmdExec),
744+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
745+
exec.EXPECT().CommandContext(gomock.Any(),
746+
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
747+
"-p", "udp", "-d", "8.8.8.8", "--dport", "443", "-j", "ACCEPT",
748+
).Times(1).Return(cmdExec),
749+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
750+
exec.EXPECT().CommandContext(gomock.Any(),
751+
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
752+
"-p", "udp", "-d", "0.0.0.0/0", "--dport", "443", "-j", "DROP",
753+
).Times(1).Return(cmdExec),
754+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
755+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
756+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
757+
)
758+
},
759+
},
760+
{
761+
name: "Error when filtering a source",
762+
expectedStatusCode: 500,
763+
requestBody: map[string]interface{}{
764+
"Port": 443,
765+
"Protocol": "udp",
766+
"TrafficType": "ingress",
767+
"SourcesToFilter": []string{"1.2.3.4/20"},
768+
},
769+
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
770+
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
771+
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
772+
Return(happyTaskResponse, nil).
773+
Times(1)
774+
},
775+
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
776+
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
777+
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
778+
gomock.InOrder(
779+
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
780+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
781+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
782+
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
783+
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
784+
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
785+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
786+
exec.EXPECT().CommandContext(gomock.Any(),
787+
"nsenter", "--net=/some/path", "iptables", "-w", "5", "-A", "ingress-udp-443",
788+
"-p", "udp", "-d", "1.2.3.4/20", "--dport", "443", "-j", "ACCEPT",
789+
).Times(1).Return(cmdExec),
790+
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")),
791+
)
792+
},
793+
},
666794
}
667795

668796
return append(tcs, commonTcs...)

0 commit comments

Comments
 (0)