diff --git a/client-state-machine.go b/client-state-machine.go index d8464a1..afaa46e 100644 --- a/client-state-machine.go +++ b/client-state-machine.go @@ -102,6 +102,27 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} } + // Raw public keys + var cct *ClientCertTypeExtension + var sct *ServerCertTypeExtension + if state.Config.AllowRawPublicKeys || !state.Config.ForbidCertificates { + cct = &ClientCertTypeExtension{HandshakeType: HandshakeTypeClientHello} + sct = &ServerCertTypeExtension{HandshakeType: HandshakeTypeClientHello} + + if state.Config.AllowRawPublicKeys { + cct.CertificateTypes = append(cct.CertificateTypes, CertificateTypeRawPublicKey) + sct.CertificateTypes = append(sct.CertificateTypes, CertificateTypeRawPublicKey) + } + + if !state.Config.ForbidCertificates { + cct.CertificateTypes = append(cct.CertificateTypes, CertificateTypeX509) + sct.CertificateTypes = append(sct.CertificateTypes, CertificateTypeX509) + } + } + if state.Config.AllowRawPublicKeys || state.Config.ForbidCertificates { + logf(logTypeHandshake, "[ClientStateStart] AllowRawPublicKeys only") + } + // Construct base ClientHello ch := &ClientHelloBody{ LegacyVersion: wireVersion(state.hsCtx.hIn), @@ -113,7 +134,7 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertInternalError } for _, ext := range []ExtensionBody{&sv, &sni, &ks, &sg, &sa} { - err := ch.Extensions.Add(ext) + err = ch.Extensions.Add(ext) if err != nil { logf(logTypeHandshake, "[ClientStateStart] Error adding extension type=[%v] [%v]", ext.Type(), err) return nil, nil, AlertInternalError @@ -128,6 +149,18 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertInternalError } } + if cct != nil { + err = ch.Extensions.Add(cct) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ClientCertType extension [%v]", err) + return nil, nil, AlertInternalError + } + err = ch.Extensions.Add(sct) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ServerCertType extension [%v]", err) + return nil, nil, AlertInternalError + } + } if state.cookie != nil { err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) if err != nil { @@ -157,6 +190,7 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ // Handle PSK and EarlyData just before transmitting, so that we can // calculate the PSK binder value var psk *PreSharedKeyExtension + var certWithExternPSK *CertWithExternPSKExtension var ed *EarlyDataExtension var offeredPSK PreSharedKey var earlyHash crypto.Hash @@ -215,6 +249,12 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ {Binder: bytes.Repeat([]byte{0x00}, params.Hash.Size())}, }, } + + if state.Config.CertWithExternPSK == true { + logf(logTypeHandshake, "Adding CertWithExternPSKExtension") + certWithExternPSK = &CertWithExternPSKExtension{} + ch.Extensions.Add(certWithExternPSK) + } ch.Extensions.Add(psk) // Compute the binder key @@ -443,11 +483,13 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, // Do PSK or key agreement depending on extensions serverPSK := PreSharedKeyExtension{HandshakeType: HandshakeTypeServerHello} serverKeyShare := KeyShareExtension{HandshakeType: HandshakeTypeServerHello} + certWithExternPSK := CertWithExternPSKExtension{} foundExts, err := sh.Extensions.Parse( []ExtensionBody{ &serverPSK, &serverKeyShare, + &certWithExternPSK, }) if err != nil { logf(logTypeHandshake, "[ClientWaitSH] Error processing extensions [%v]", err) @@ -456,6 +498,12 @@ func (state clientStateWaitSH) Next(hr handshakeMessageReader) (HandshakeState, if foundExts[ExtensionTypePreSharedKey] && (serverPSK.SelectedIdentity == 0) { state.Params.UsingPSK = true + logf(logTypeHandshake, "[ClientStateWaitSH] UsingPSK") + } + + if foundExts[ExtensionTypeCertWithExternPSK] { + state.Params.UsingCertWithExternPSK = true + logf(logTypeHandshake, "[ClientStateWaitSH] Found ExtensionTypeCertWithExternPSK") } var dhSecret []byte @@ -590,11 +638,15 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, serverALPN := &ALPNExtension{} serverEarlyData := &EarlyDataExtension{} + serverClientCertType := &ClientCertTypeExtension{HandshakeType: HandshakeTypeEncryptedExtensions} + serverServerCertType := &ServerCertTypeExtension{HandshakeType: HandshakeTypeEncryptedExtensions} foundExts, err := ee.Extensions.Parse( []ExtensionBody{ serverALPN, serverEarlyData, + serverClientCertType, + serverServerCertType, }) if err != nil { logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err) @@ -607,6 +659,15 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, state.Params.NextProto = serverALPN.Protocols[0] } + if foundExts[ExtensionTypeClientCertType] && + foundExts[ExtensionTypeServerCertType] && + state.Config.AllowRawPublicKeys { + foundClient := serverClientCertType.CertificateTypes[0] == CertificateTypeRawPublicKey + foundServer := serverServerCertType.CertificateTypes[0] == CertificateTypeRawPublicKey + state.Params.UsingRawPublicKeys = foundClient && foundServer + logf(logTypeHandshake, "[ClientStateWaitEE] UsingRawPublicKeys") + } + state.handshakeHash.Write(hm.Marshal()) toSend := []HandshakeAction{} @@ -617,7 +678,8 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, KeySet: makeTrafficKeys(state.cryptoParams, state.clientHandshakeTrafficSecret)}) } - if state.Params.UsingPSK { + if state.Params.UsingPSK && !state.Params.UsingCertWithExternPSK { + logf(logTypeHandshake, "[ClientStateWaitEE] UsingPSK, not UsingCertWithExternPSK") logf(logTypeHandshake, "[ClientStateWaitEE] -> [ClientStateWaitFinished]") nextState := clientStateWaitFinished{ Params: state.Params, @@ -818,44 +880,72 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, hcv := state.handshakeHash.Sum(nil) logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey - if err := certVerify.Verify(serverPublicKey, hcv); err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") - return nil, nil, AlertHandshakeFailure - } - - certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList)) - rawCerts := make([][]byte, len(state.serverCertificate.CertificateList)) + certs := make([][]byte, len(state.serverCertificate.CertificateList)) for i, certEntry := range state.serverCertificate.CertificateList { certs[i] = certEntry.CertData - rawCerts[i] = certEntry.CertData.Raw } - var verifiedChains [][]*x509.Certificate - if !state.Config.InsecureSkipVerify { - opts := x509.VerifyOptions{ - Roots: state.Config.RootCAs, - CurrentTime: state.Config.time(), - DNSName: state.Config.ServerName, - Intermediates: x509.NewCertPool(), + var err error + var serverPublicKey crypto.PublicKey + var eeCert *x509.Certificate + if !state.Params.UsingRawPublicKeys { + eeCert, err = x509.ParseCertificate(certs[0]) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Unable to parse client cert: %v", err) } - for i, cert := range certs { - if i == 0 { - continue - } - opts.Intermediates.AddCert(cert) + serverPublicKey = eeCert.PublicKey + } else { + serverPublicKey, err = unmarshalSigningKey(certs[0]) + if state.Params.UsingPSK { + logf(logTypeHandshake, "[ClientStateWaitCV] UsingRawPublicKeys and PSK") + } else { + logf(logTypeHandshake, "[ClientStateWaitCV] UsingRawPublicKeys") } - var err error - verifiedChains, err = certs[0].Verify(opts) if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err) - return nil, nil, AlertBadCertificate + logf(logTypeHandshake, "[ClientStateWaitCV] Unable to parse raw public key: %v", err) + } + } + + if err := certVerify.Verify(serverPublicKey, hcv); err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") + return nil, nil, AlertHandshakeFailure + } + + var verifiedChains [][]*x509.Certificate + if !state.Params.UsingRawPublicKeys { + if !state.Config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: state.Config.RootCAs, + CurrentTime: state.Config.time(), + DNSName: state.Config.ServerName, + Intermediates: x509.NewCertPool(), + } + + for i, cert := range certs { + if i == 0 { + continue + } + + caCert, err := x509.ParseCertificate(cert) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Error parsing server chain: %v", err) + return nil, nil, AlertDecodeError + } + + opts.Intermediates.AddCert(caCert) + } + var err error + verifiedChains, err = eeCert.Verify(opts) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err) + return nil, nil, AlertBadCertificate + } } } if state.Config.VerifyPeerCertificate != nil { - if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil { + if err := state.Config.VerifyPeerCertificate(certs, verifiedChains); err != nil { logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err) return nil, nil, AlertBadCertificate } @@ -888,7 +978,7 @@ type clientStateWaitFinished struct { certificates []*Certificate serverCertificateRequest *CertificateRequestBody - peerCertificates []*x509.Certificate + peerCertificates [][]byte verifiedChains [][]*x509.Certificate masterSecret []byte @@ -1000,12 +1090,23 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS state.handshakeHash.Write(certm.Marshal()) } else { // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(cert.Chain)), - } - for i, entry := range cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} + var certList []CertificateEntry + if !state.Params.UsingRawPublicKeys { + certList = make([]CertificateEntry, len(cert.Chain)) + for i, entry := range cert.Chain { + certList[i] = CertificateEntry{CertData: entry.Raw} + } + } else { + certData, err := marshalSigningKey(cert.PrivateKey.Public()) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Unable to marshal raw public key [%v]", err) + return nil, nil, AlertInternalError + } + + certList = []CertificateEntry{{CertData: certData}} } + + certificate := &CertificateBody{CertificateList: certList} certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) if err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) @@ -1015,6 +1116,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS toSend = append(toSend, QueueHandshakeMessage{certm}) state.handshakeHash.Write(certm.Marshal()) + // Create and send CertificateVerify hcv := state.handshakeHash.Sum(nil) logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) diff --git a/common.go b/common.go index 4d6477a..6242a01 100644 --- a/common.go +++ b/common.go @@ -125,6 +125,9 @@ const ( ExtensionTypeCookie ExtensionType = 44 ExtensionTypePSKKeyExchangeModes ExtensionType = 45 ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 + ExtensionTypeClientCertType ExtensionType = 19 + ExtensionTypeServerCertType ExtensionType = 20 + ExtensionTypeCertWithExternPSK ExtensionType = 33 ) // enum {...} NamedGroup @@ -164,6 +167,19 @@ const ( KeyUpdateRequested KeyUpdateRequest = 1 ) +/* +0 X.509 Y [RFC6091] +1 OpenPGP_RESERVED N [RFC6091][RFC8446] Used in TLS versions prior to 1.3. +2 Raw Public Key Y [RFC7250] +3 1609Dot2 N +*/ +type CertificateType uint8 + +const ( + CertificateTypeX509 CertificateType = 0 + CertificateTypeRawPublicKey CertificateType = 2 +) + type State uint8 const ( diff --git a/common_test.go b/common_test.go index bbf8629..7250ded 100644 --- a/common_test.go +++ b/common_test.go @@ -2,6 +2,8 @@ package mint import ( "bytes" + "crypto" + "encoding/base64" "encoding/hex" "fmt" "reflect" @@ -18,6 +20,22 @@ func unhex(h string) []byte { return b } +func computeEpskid(der []byte) []byte { + b64 := base64.StdEncoding.EncodeToString(der) + logf(logTypeVerbose, "Base64 DER of SubjectPublicKeyInfo %s", b64) + + extract := HkdfExtract(crypto.SHA256, nil, der) + b64 = base64.StdEncoding.EncodeToString(extract) + logf(logTypeVerbose, "Base64 of HKDF-Extract %s", b64) + + epskid := HkdfExpand(crypto.SHA256, extract, []byte("tls13-bspsk-identity"), 32) + b64 = base64.StdEncoding.EncodeToString(epskid) + logf(logTypeVerbose, "Base64 of HKDF-Expand (epskid) %s", b64) + + return epskid + +} + func assertTrue(t *testing.T, test bool, msg string) { t.Helper() prefix := string("") diff --git a/conn.go b/conn.go index ccf8dc1..c59cf91 100644 --- a/conn.go +++ b/conn.go @@ -16,6 +16,7 @@ import ( type Certificate struct { Chain []*x509.Certificate PrivateKey crypto.Signer + PublicKey crypto.PublicKey } type PreSharedKey struct { @@ -120,14 +121,20 @@ type Config struct { // the verifiedChains argument will always be nil. VerifyPeerCertificate func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error - CipherSuites []CipherSuite - Groups []NamedGroup - SignatureSchemes []SignatureScheme - NextProtos []string - PSKs PreSharedKeyCache - PSKModes []PSKKeyExchangeMode - NonBlocking bool - UseDTLS bool + CipherSuites []CipherSuite + Groups []NamedGroup + SignatureSchemes []SignatureScheme + NextProtos []string + PSKs PreSharedKeyCache + PSKModes []PSKKeyExchangeMode + NonBlocking bool + UseDTLS bool + CertWithExternPSK bool + + // These bools are arranged in opposite directions so that their default + // values reflect the correct default semantics (certs yes, raw keys no) + AllowRawPublicKeys bool + ForbidCertificates bool RecordLayer RecordLayerFactory @@ -250,7 +257,7 @@ var ( type ConnectionState struct { HandshakeState State CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) - PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + PeerCertificates [][]byte // certificate chain presented by remote peer VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates NextProto string // Selected ALPN proto UsingPSK bool // Are we using PSK. diff --git a/conn_test.go b/conn_test.go index 5c8314a..4bf0384 100644 --- a/conn_test.go +++ b/conn_test.go @@ -8,6 +8,7 @@ import ( "crypto/rand" "crypto/x509" "crypto/x509/pkix" + "encoding/hex" "errors" "fmt" "io" @@ -176,10 +177,12 @@ var ( certificates, clientCertificates []*Certificate clientName, serverName string - psk PreSharedKey - psks *PSKMapCache + psk PreSharedKey + hkdfPsk PreSharedKey + psks *PSKMapCache + hkdfPsks *PSKMapCache - basicConfig, dtlsConfig, nbConfig, nbDTLSConfig, hrrConfig, alpnConfig, pskConfig, pskDTLSConfig, pskECDHEConfig, pskDHEConfig, resumptionConfig, ffdhConfig, x25519Config *Config + basicConfig, dtlsConfig, nbConfig, nbDTLSConfig, hrrConfig, alpnConfig, rawConfig, pskConfig, pskDTLSConfig, pskECDHEConfig, pskDHEConfig, certWithExternPSKConfig, rawWithExternPSKConfig, rawWithExternPSKFromHKDFConfig, resumptionConfig, ffdhConfig, x25519Config *Config ) func init() { @@ -197,12 +200,26 @@ func init() { panic(err) } + derPublicKey, err := marshalSigningKey(clientKey.Public()) + if err != nil { + panic(err) + } + + epskid := computeEpskid(derPublicKey) + psk = PreSharedKey{ CipherSuite: TLS_AES_128_GCM_SHA256, IsResumption: false, Identity: []byte{0, 1, 2, 3}, Key: []byte{4, 5, 6, 7}, } + + hkdfPsk = PreSharedKey{ + CipherSuite: TLS_AES_128_GCM_SHA256, + IsResumption: false, + Identity: epskid, + Key: derPublicKey, + } certificates = []*Certificate{ { Chain: []*x509.Certificate{serverCert}, @@ -220,6 +237,11 @@ func init() { "00010203": psk, } + hkdfPsks = &PSKMapCache{ + serverName: hkdfPsk, + hex.EncodeToString(epskid): hkdfPsk, + } + basicConfig = &Config{ ServerName: serverName, Certificates: certificates, @@ -262,6 +284,13 @@ func init() { InsecureSkipVerify: true, } + rawConfig = &Config{ + ServerName: serverName, + Certificates: certificates, + AllowRawPublicKeys: true, + InsecureSkipVerify: true, + } + pskConfig = &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, @@ -297,6 +326,35 @@ func init() { InsecureSkipVerify: true, } + certWithExternPSKConfig = &Config{ + ServerName: serverName, + Certificates: certificates, + PSKs: psks, + CertWithExternPSK: true, + InsecureSkipVerify: true, + } + + rawWithExternPSKConfig = &Config{ + ServerName: serverName, + Certificates: certificates, + PSKs: psks, + AllowRawPublicKeys: true, + ForbidCertificates: true, + CertWithExternPSK: true, + InsecureSkipVerify: true, + } + + rawWithExternPSKFromHKDFConfig = &Config{ + ServerName: serverName, + Certificates: certificates, + PSKs: hkdfPsks, + AllowRawPublicKeys: true, + ForbidCertificates: true, + CertWithExternPSK: true, + InsecureSkipVerify: true, + RequireClientAuth: true, + } + resumptionConfig = &Config{ ServerName: serverName, Certificates: certificates, @@ -357,11 +415,13 @@ func checkConsistency(t *testing.T, client *Conn, server *Conn) { func testConnInner(t *testing.T, name string, p testInstanceState) { // Configs array: - configs := map[string]*Config{"basic config": basicConfig, - "HRR": hrrConfig, - "ALPN": alpnConfig, - "FFDH": ffdhConfig, - "x25519": x25519Config, + configs := map[string]*Config{ + "basic config": basicConfig, + "HRR": hrrConfig, + "ALPN": alpnConfig, + "RawPK": rawConfig, + "FFDH": ffdhConfig, + "x25519": x25519Config, } c := configs[p["config"]] @@ -400,6 +460,7 @@ func TestBasicFlows(t *testing.T) { "basic config", "HRR", "ALPN", + "RawPK", "FFDH", "x25519", }, @@ -766,6 +827,32 @@ func TestPSKFlows(t *testing.T) { assertTrue(t, client.state.Params.UsingPSK, "Session did not use the provided PSK") } } +func TestRFC8773Flows(t *testing.T) { + for _, conf := range []*Config{certWithExternPSKConfig, rawWithExternPSKConfig, rawWithExternPSKFromHKDFConfig} { + cConn, sConn := pipe() + + client := Client(cConn, conf) + server := Server(sConn, conf) + + var clientAlert, serverAlert Alert + + done := make(chan bool) + go func(t *testing.T) { + serverAlert = server.Handshake() + assertEquals(t, serverAlert, AlertNoAlert) + done <- true + }(t) + + clientAlert = client.Handshake() + assertEquals(t, clientAlert, AlertNoAlert) + + <-done + + checkConsistency(t, client, server) + + assertTrue(t, client.state.Params.UsingPSK, "Session did not use the provided PSK") + } +} func TestNonBlockingReadBeforeConnected(t *testing.T) { conn := Client(&bufferedConn{}, &Config{NonBlocking: true}) @@ -1232,9 +1319,9 @@ func TestConnectionState(t *testing.T) { serverCS := server.ConnectionState() assertEquals(t, clientCS.CipherSuite.Suite, configClient.CipherSuites[0]) assertDeepEquals(t, clientCS.VerifiedChains, [][]*x509.Certificate{{serverCert}}) - assertDeepEquals(t, clientCS.PeerCertificates, []*x509.Certificate{serverCert}) + assertDeepEquals(t, clientCS.PeerCertificates, [][]byte{serverCert.Raw}) assertEquals(t, serverCS.CipherSuite.Suite, serverConfig.CipherSuites[0]) - assertDeepEquals(t, serverCS.PeerCertificates, []*x509.Certificate{clientCert}) + assertDeepEquals(t, serverCS.PeerCertificates, [][]byte{clientCert.Raw}) } func TestDTLS(t *testing.T) { diff --git a/crypto.go b/crypto.go index 4fa59a2..4e9140a 100644 --- a/crypto.go +++ b/crypto.go @@ -328,6 +328,31 @@ func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { } } +func marshalSigningKey(pub crypto.PublicKey) ([]byte, error) { + switch pub.(type) { + case *rsa.PublicKey: + return x509.MarshalPKIXPublicKey(pub) + + case *ecdsa.PublicKey: + return x509.MarshalPKIXPublicKey(pub) + + // TODO support Ed25519 + + default: + return nil, fmt.Errorf("tls.marshalsigningkey: Unsupported public key type") + } +} + +func unmarshalSigningKey(data []byte) (interface{}, error) { + pub, err := x509.ParsePKIXPublicKey(data) + if err == nil { + return pub, err + } + + // TODO attempt to parse Ed25519 + return nil, err +} + // XXX(rlb): Copied from crypto/x509 type ecdsaSignature struct { R, S *big.Int diff --git a/extensions.go b/extensions.go index 07cb16c..fd9e79a 100644 --- a/extensions.go +++ b/extensions.go @@ -624,3 +624,116 @@ func (c CookieExtension) Marshal() ([]byte, error) { func (c *CookieExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, c) } + +// struct { +// select(ClientOrServerExtension) { +// case client: +// CertificateType client_certificate_types<1..2^8-1>; +// case server: +// CertificateType client_certificate_type; +// } +// } ClientCertTypeExtension; +type ClientCertTypeExtension struct { + HandshakeType HandshakeType `tls:"none"` + CertificateTypes []CertificateType `tls:"head=1,min=1"` +} + +func (cct ClientCertTypeExtension) Type() ExtensionType { + return ExtensionTypeClientCertType +} + +func (cct ClientCertTypeExtension) Marshal() ([]byte, error) { + switch cct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(cct) + + case HandshakeTypeEncryptedExtensions: + return syntax.Marshal(uint8(cct.CertificateTypes[0])) + } + + return nil, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} + +func (cct *ClientCertTypeExtension) Unmarshal(data []byte) (int, error) { + switch cct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Unmarshal(data, cct) + + case HandshakeTypeEncryptedExtensions: + var val uint8 + n, err := syntax.Unmarshal(data, &val) + if err != nil { + return 0, err + } + + cct.CertificateTypes = []CertificateType{CertificateType(val)} + return n, err + } + + return 0, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} + +// struct { +// select(ClientOrServerExtension) { +// case client: +// CertificateType client_certificate_types<1..2^8-1>; +// case server: +// CertificateType client_certificate_type; +// } +// } ServerCertTypeExtension; +type ServerCertTypeExtension struct { + HandshakeType HandshakeType `tls:"none"` + CertificateTypes []CertificateType `tls:"head=1,min=1"` +} + +func (sct ServerCertTypeExtension) Type() ExtensionType { + return ExtensionTypeServerCertType +} + +func (sct ServerCertTypeExtension) Marshal() ([]byte, error) { + switch sct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(sct) + + case HandshakeTypeEncryptedExtensions: + return syntax.Marshal(uint8(sct.CertificateTypes[0])) + } + + return nil, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} + +func (sct *ServerCertTypeExtension) Unmarshal(data []byte) (int, error) { + switch sct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Unmarshal(data, sct) + + case HandshakeTypeEncryptedExtensions: + var val uint8 + n, err := syntax.Unmarshal(data, &val) + if err != nil { + return 0, err + } + + sct.CertificateTypes = []CertificateType{CertificateType(val)} + return n, err + } + + return 0, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} + +// struct { +// } CertWithExternPSK; + +type CertWithExternPSKExtension struct{} + +func (certepsk CertWithExternPSKExtension) Type() ExtensionType { + return ExtensionTypeCertWithExternPSK +} + +func (certepsk CertWithExternPSKExtension) Marshal() ([]byte, error) { + return []byte{}, nil +} + +func (certepsk *CertWithExternPSKExtension) Unmarshal(data []byte) (int, error) { + return 0, nil +} diff --git a/handshake-messages.go b/handshake-messages.go index e8e3aef..a96c6d9 100644 --- a/handshake-messages.go +++ b/handshake-messages.go @@ -3,7 +3,6 @@ package mint import ( "bytes" "crypto" - "crypto/x509" "encoding/binary" "fmt" @@ -269,23 +268,13 @@ func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { // CertificateEntry certificate_list<0..2^24-1>; // } Certificate; type CertificateEntry struct { - CertData *x509.Certificate - Extensions ExtensionList -} - -type CertificateBody struct { - CertificateRequestContext []byte - CertificateList []CertificateEntry -} - -type certificateEntryInner struct { CertData []byte `tls:"head=3,min=1"` Extensions ExtensionList `tls:"head=2"` } -type certificateBodyInner struct { - CertificateRequestContext []byte `tls:"head=1"` - CertificateList []certificateEntryInner `tls:"head=3"` +type CertificateBody struct { + CertificateRequestContext []byte `tls:"head=1"` + CertificateList []CertificateEntry `tls:"head=3"` } func (c CertificateBody) Type() HandshakeType { @@ -293,41 +282,11 @@ func (c CertificateBody) Type() HandshakeType { } func (c CertificateBody) Marshal() ([]byte, error) { - inner := certificateBodyInner{ - CertificateRequestContext: c.CertificateRequestContext, - CertificateList: make([]certificateEntryInner, len(c.CertificateList)), - } - - for i, entry := range c.CertificateList { - inner.CertificateList[i] = certificateEntryInner{ - CertData: entry.CertData.Raw, - Extensions: entry.Extensions, - } - } - - return syntax.Marshal(inner) + return syntax.Marshal(c) } func (c *CertificateBody) Unmarshal(data []byte) (int, error) { - inner := certificateBodyInner{} - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return read, err - } - - c.CertificateRequestContext = inner.CertificateRequestContext - c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) - - for i, entry := range inner.CertificateList { - c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) - if err != nil { - return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) - } - - c.CertificateList[i].Extensions = entry.Extensions - } - - return read, nil + return syntax.Unmarshal(data, c) } // struct { diff --git a/handshake-messages_test.go b/handshake-messages_test.go index 0f387d6..22cd98c 100644 --- a/handshake-messages_test.go +++ b/handshake-messages_test.go @@ -153,11 +153,11 @@ var ( CertificateRequestContext: []byte{0, 0, 0, 0}, CertificateList: []CertificateEntry{ { - CertData: cert1, + CertData: cert1.Raw, Extensions: extListValidIn, }, { - CertData: cert2, + CertData: cert2.Raw, Extensions: extListValidIn, }, }, @@ -166,7 +166,7 @@ var ( CertificateRequestContext: []byte{0, 0, 0, 0}, CertificateList: []CertificateEntry{ { - CertData: cert1, + CertData: cert1.Raw, Extensions: extListSingleTooLongIn, }, }, @@ -462,11 +462,11 @@ func TestCertificateMarshalUnmarshal(t *testing.T) { certValidIn.CertificateRequestContext = originalContext // Test marshal failure on no raw certa - originalRaw := cert1.Raw - cert1.Raw = []byte{} + originalCert := certValidIn.CertificateList[0].CertData + certValidIn.CertificateList[0].CertData = nil out, err = certValidIn.Marshal() assertError(t, err, "Marshaled a Certificate with an empty cert") - cert1.Raw = originalRaw + certValidIn.CertificateList[0].CertData = originalCert // Test marshal failure on extension list marshal failure out, err = certOverflowIn.Marshal() @@ -500,12 +500,6 @@ func TestCertificateMarshalUnmarshal(t *testing.T) { _, err = cert.Unmarshal(certValid) assertError(t, err, "Unmarshaled a Certificate with truncated certificates") certValid[8] ^= 0xFF - - // Test unmarshal failure on malformed certificate - certValid[11] ^= 0xFF // Clobber first octet of first cert - _, err = cert.Unmarshal(certValid) - assertError(t, err, "Unmarshaled a Certificate with truncated certificates") - certValid[11] ^= 0xFF } func TestCertificateVerifyMarshalUnmarshal(t *testing.T) { diff --git a/log.go b/log.go index 2fba90d..50da076 100644 --- a/log.go +++ b/log.go @@ -9,6 +9,8 @@ import ( // We use this environment variable to control logging. It should be a // comma-separated list of log tags (see below) or "*" to enable all logging. +// Usage example: +// $ MINT_LOG=* go test -v const logConfigVar = "MINT_LOG" // Pre-defined log types diff --git a/server-state-machine.go b/server-state-machine.go index e392a5d..de99937 100644 --- a/server-state-machine.go +++ b/server-state-machine.go @@ -2,6 +2,7 @@ package mint import ( "bytes" + "crypto" "crypto/x509" "fmt" "hash" @@ -118,10 +119,13 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ signatureAlgorithms := new(SignatureAlgorithmsExtension) clientKeyShares := &KeyShareExtension{HandshakeType: HandshakeTypeClientHello} clientPSK := &PreSharedKeyExtension{HandshakeType: HandshakeTypeClientHello} + certWithExternPSK := &CertWithExternPSKExtension{} clientEarlyData := &EarlyDataExtension{} clientALPN := new(ALPNExtension) clientPSKModes := new(PSKKeyExchangeModesExtension) clientCookie := new(CookieExtension) + clientClientCertType := &ClientCertTypeExtension{HandshakeType: HandshakeTypeClientHello} + clientServerCertType := &ServerCertTypeExtension{HandshakeType: HandshakeTypeClientHello} // Handle external extensions. if state.Config.ExtensionHandler != nil { @@ -141,9 +145,12 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ clientEarlyData, clientKeyShares, clientPSK, + certWithExternPSK, clientALPN, clientPSKModes, clientCookie, + clientClientCertType, + clientServerCertType, }) if err != nil { @@ -206,6 +213,11 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertIllegalParameter } + if foundExts[ExtensionTypeCertWithExternPSK] { + logf(logTypeHandshake, "[ServerStateStart] Found ExtensionTypeCertWithExternPSK") + connParams.UsingCertWithExternPSK = true + } + // Figure out if we can do DH canDoDH, dhGroup, dhPublic, dhSecret := DHNegotiation(clientKeyShares.Shares, state.Config.Groups) @@ -335,6 +347,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ pskSecret = psk.Key } else { psk = nil + } + if (connParams.UsingPSK && connParams.UsingCertWithExternPSK) || !connParams.UsingPSK { // If we're not using a PSK mode, then we need to have certain extensions if !(foundExts[ExtensionTypeServerName] && @@ -344,6 +358,31 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertMissingExtension } + // Figure out if we're going to use a certificate or a raw public key + // TODO(rlb): Move this to a negotiation function? + if foundExts[ExtensionTypeClientCertType] && + foundExts[ExtensionTypeServerCertType] && + state.Config.AllowRawPublicKeys { + foundClient := false + for _, t := range clientClientCertType.CertificateTypes { + if t == CertificateTypeRawPublicKey { + foundClient = true + break + } + } + + foundServer := false + for _, t := range clientServerCertType.CertificateTypes { + if t == CertificateTypeRawPublicKey { + foundServer = true + break + } + } + + // Only support symmetric usage of raw public keys for now + connParams.UsingRawPublicKeys = foundClient && foundServer + } + // Select a certificate name := string(*serverName) var err error @@ -512,6 +551,14 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat } } + if state.Params.UsingCertWithExternPSK { + err := sh.Extensions.Add(&CertWithExternPSKExtension{}) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding CertWithExternPSK extension [%v]", err) + return nil, nil, AlertInternalError + } + } + // Run the external extension handler. if state.Config.ExtensionHandler != nil { err := state.Config.ExtensionHandler.Send(HandshakeTypeServerHello, &sh.Extensions) @@ -591,6 +638,27 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat return nil, nil, AlertInternalError } } + if state.Params.UsingRawPublicKeys { + logf(logTypeHandshake, "[server] sending ClientCertType extension") + err = eeList.Add(&ClientCertTypeExtension{ + HandshakeType: HandshakeTypeEncryptedExtensions, + CertificateTypes: []CertificateType{CertificateTypeRawPublicKey}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding CCT to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + + logf(logTypeHandshake, "[server] sending ServerCertType extension") + err = eeList.Add(&ServerCertTypeExtension{ + HandshakeType: HandshakeTypeEncryptedExtensions, + CertificateTypes: []CertificateType{CertificateTypeRawPublicKey}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding SCT to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } ee := &EncryptedExtensionsBody{eeList} // Run the external extension handler. @@ -617,7 +685,7 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat } // Authenticate with a certificate if required - if !state.Params.UsingPSK { + if (state.Params.UsingPSK && state.Params.UsingCertWithExternPSK) || !state.Params.UsingPSK { // Send a CertificateRequest message if we want client auth if state.Config.RequireClientAuth { state.Params.UsingClientAuth = true @@ -643,13 +711,24 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat handshakeHash.Write(crm.Marshal()) } - // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(state.cert.Chain)), - } - for i, entry := range state.cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} + // Create and send Certificate + var certList []CertificateEntry + if !state.Params.UsingRawPublicKeys { + certList = make([]CertificateEntry, len(state.cert.Chain)) + for i, entry := range state.cert.Chain { + certList[i] = CertificateEntry{CertData: entry.Raw} + } + } else { + certData, err := marshalSigningKey(state.cert.PrivateKey.Public()) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Unable to marshal raw public key [%v]", err) + return nil, nil, AlertInternalError + } + + certList = []CertificateEntry{{CertData: certData}} } + + certificate := &CertificateBody{CertificateList: certList} certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) @@ -659,6 +738,7 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat toSend = append(toSend, QueueHandshakeMessage{certm}) handshakeHash.Write(certm.Marshal()) + // Create and send CertificateVerify certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) @@ -1045,22 +1125,35 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, certVerify := &CertificateVerifyBody{} if err := safeUnmarshal(certVerify, hm.body); err != nil { - logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) + logf(logTypeHandshake, "[ServerStateWaitCV] Error decoding message %v", err) return nil, nil, AlertDecodeError } - rawCerts := make([][]byte, len(state.clientCertificate.CertificateList)) - certs := make([]*x509.Certificate, len(state.clientCertificate.CertificateList)) + certs := make([][]byte, len(state.clientCertificate.CertificateList)) for i, certEntry := range state.clientCertificate.CertificateList { certs[i] = certEntry.CertData - rawCerts[i] = certEntry.CertData.Raw } // Verify client signature over handshake hash + var err error + var clientPublicKey crypto.PublicKey + if !state.Params.UsingRawPublicKeys { + cert, err := x509.ParseCertificate(certs[0]) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Unable to parse client cert: %v", err) + } + + clientPublicKey = cert.PublicKey + } else { + clientPublicKey, err = unmarshalSigningKey(certs[0]) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Unable to parse raw public key: %v", err) + } + } + hcv := state.handshakeHash.Sum(nil) logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey if err := certVerify.Verify(clientPublicKey, hcv); err != nil { logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) return nil, nil, AlertHandshakeFailure @@ -1068,7 +1161,7 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, if state.Config.VerifyPeerCertificate != nil { // TODO(#171): pass in the verified chains, once we support different client auth types - if err := state.Config.VerifyPeerCertificate(rawCerts, nil); err != nil { + if err := state.Config.VerifyPeerCertificate(certs, nil); err != nil { logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate: %s", err) return nil, nil, AlertBadCertificate } @@ -1101,7 +1194,7 @@ type serverStateWaitFinished struct { masterSecret []byte clientHandshakeTrafficSecret []byte - peerCertificates []*x509.Certificate + peerCertificates [][]byte verifiedChains [][]*x509.Certificate handshakeHash hash.Hash diff --git a/state-machine.go b/state-machine.go index f0635c8..339c0ba 100644 --- a/state-machine.go +++ b/state-machine.go @@ -57,10 +57,12 @@ type ConnectionOptions struct { type ConnectionParameters struct { UsingPSK bool UsingDH bool + UsingRawPublicKeys bool ClientSendingEarlyData bool UsingEarlyData bool RejectedEarlyData bool UsingClientAuth bool + UsingCertWithExternPSK bool CipherSuite CipherSuite ServerName string @@ -97,7 +99,7 @@ type stateConnected struct { clientTrafficSecret []byte serverTrafficSecret []byte exporterSecret []byte - peerCertificates []*x509.Certificate + peerCertificates [][]byte verifiedChains [][]*x509.Certificate }