Skip to content

Commit 0a5893a

Browse files
committed
fix: v1.0 version fallback checks if v1.0 is supported by client
Signed-off-by: Pierre-Henri Symoneaux <[email protected]>
1 parent cb4fafd commit 0a5893a

File tree

2 files changed

+79
-13
lines changed

2 files changed

+79
-13
lines changed

kmipclient/client.go

+6-3
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ func (c *Client) CloneCtx(ctx context.Context) (*Client, error) {
294294
dialer: c.dialer,
295295
middlewares: slices.Clone(c.middlewares),
296296
conn: stream,
297+
addr: c.addr,
297298
}, nil
298299
}
299300

@@ -379,10 +380,12 @@ func (c *Client) negotiateVersion(ctx context.Context) error {
379380
return errors.New("Unexpected batch item count")
380381
}
381382
bi := resp.BatchItem[0]
382-
if bi.ResultStatus == kmip.ResultStatusOperationFailed &&
383-
(bi.ResultReason == kmip.ResultReasonOperationNotSupported /*|| bi.ResultReason == kmip.ReasonInvalidMessage && bi.Operation == 0x00*/) {
383+
if bi.ResultStatus == kmip.ResultStatusOperationFailed && bi.ResultReason == kmip.ResultReasonOperationNotSupported {
384384
// If the discover opertion is not supported, then fallbacks to kmip v1.0
385-
// TODO: Check that v1.0 is in the client's supported version list and return an error if not
385+
// but also check that v1.0 is in the client's supported version list and return an error if not.
386+
if !slices.Contains(c.supportedVersions, kmip.V1_0) {
387+
return errors.New("Protocol version negotiation failed. No common version found")
388+
}
386389
c.version = &kmip.V1_0
387390
return nil
388391
}

kmipclient/client_test.go

+73-10
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,21 @@ package kmipclient_test
22

33
import (
44
"context"
5+
"os"
56
"sync"
67
"testing"
78
"time"
89

910
"github.com/ovh/kmip-go"
11+
"github.com/ovh/kmip-go/kmipclient"
1012
"github.com/ovh/kmip-go/kmipserver"
1113
"github.com/ovh/kmip-go/kmiptest"
1214
"github.com/ovh/kmip-go/payloads"
15+
"github.com/ovh/kmip-go/ttlv"
1316

1417
"github.com/stretchr/testify/require"
1518
)
1619

17-
// func testClientRequest[Req, Resp kmip.OperationPayload](t *testing.T, tf func(*kmipclient.Client) *kmipclient.Executor[Req, Resp], f func(*testing.T, Req) (Resp, error)) (Resp, error) {
18-
// mux := kmipserver.NewBatchExecutor()
19-
// client := kmiptest.NewClientAndServer(t, mux)
20-
// req := tf(client)
21-
// mux.Route(req.RequestPayload().Operation(), kmipserver.HandleFunc(func(ctx context.Context, pl Req) (Resp, error) {
22-
// return f(t, pl)
23-
// }))
24-
// return req.Exec()
25-
// }
26-
2720
func TestRequest_ContextTimeout(t *testing.T) {
2821
mux := kmipserver.NewBatchExecutor()
2922
client := kmiptest.NewClientAndServer(t, mux)
@@ -215,3 +208,73 @@ func TestClone(t *testing.T) {
215208
_, err = client3.Request(context.Background(), &payloads.DiscoverVersionsRequestPayload{})
216209
require.NoError(t, err)
217210
}
211+
212+
func TestVersionNegociation(t *testing.T) {
213+
router := kmipserver.NewBatchExecutor()
214+
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
215+
return &payloads.DiscoverVersionsResponsePayload{
216+
ProtocolVersion: []kmip.ProtocolVersion{
217+
kmip.V1_3, kmip.V1_2,
218+
},
219+
}, nil
220+
}))
221+
addr, ca := kmiptest.NewServer(t, router)
222+
client, err := kmipclient.Dial(
223+
addr,
224+
kmipclient.WithRootCAPem([]byte(ca)),
225+
kmipclient.WithMiddlewares(
226+
kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML),
227+
),
228+
kmipclient.WithKmipVersions(kmip.V1_2, kmip.V1_3),
229+
)
230+
require.NoError(t, err)
231+
require.NotNil(t, client)
232+
require.EqualValues(t, client.Version(), kmip.V1_3)
233+
}
234+
235+
func TestVersionNegociation_NoCommon(t *testing.T) {
236+
router := kmipserver.NewBatchExecutor()
237+
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
238+
return &payloads.DiscoverVersionsResponsePayload{
239+
ProtocolVersion: []kmip.ProtocolVersion{},
240+
}, nil
241+
}))
242+
addr, ca := kmiptest.NewServer(t, router)
243+
client, err := kmipclient.Dial(
244+
addr,
245+
kmipclient.WithRootCAPem([]byte(ca)),
246+
kmipclient.WithMiddlewares(
247+
kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML),
248+
),
249+
kmipclient.WithKmipVersions(kmip.V1_1, kmip.V1_2),
250+
)
251+
require.Error(t, err)
252+
require.Nil(t, client)
253+
}
254+
255+
func TestVersionNegociation_v1_0_Fallback(t *testing.T) {
256+
router := kmipserver.NewBatchExecutor()
257+
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
258+
return nil, kmipserver.ErrOperationNotSupported
259+
}))
260+
client := kmiptest.NewClientAndServer(t, router)
261+
require.EqualValues(t, client.Version(), kmip.V1_0)
262+
}
263+
264+
func TestVersionNegociation_v1_0_Fallback_unsupported(t *testing.T) {
265+
router := kmipserver.NewBatchExecutor()
266+
router.Route(kmip.OperationDiscoverVersions, kmipserver.HandleFunc(func(ctx context.Context, pl *payloads.DiscoverVersionsRequestPayload) (*payloads.DiscoverVersionsResponsePayload, error) {
267+
return nil, kmipserver.ErrOperationNotSupported
268+
}))
269+
addr, ca := kmiptest.NewServer(t, router)
270+
client, err := kmipclient.Dial(
271+
addr,
272+
kmipclient.WithRootCAPem([]byte(ca)),
273+
kmipclient.WithMiddlewares(
274+
kmipclient.DebugMiddleware(os.Stderr, ttlv.MarshalXML),
275+
),
276+
kmipclient.WithKmipVersions(kmip.V1_3, kmip.V1_4),
277+
)
278+
require.Error(t, err)
279+
require.Nil(t, client)
280+
}

0 commit comments

Comments
 (0)