diff --git a/client-state-machine.go b/client-state-machine.go index d8464a1..1fbc512 100644 --- a/client-state-machine.go +++ b/client-state-machine.go @@ -96,12 +96,6 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ state.Params.ServerName = state.Opts.ServerName - // Application Layer Protocol Negotiation - var alpn *ALPNExtension - if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { - alpn = &ALPNExtension{Protocols: state.Opts.NextProtos} - } - // Construct base ClientHello ch := &ClientHelloBody{ LegacyVersion: wireVersion(state.hsCtx.hIn), @@ -119,15 +113,42 @@ func (state clientStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertInternalError } } - // XXX: These optional extensions can't be folded into the above because Go - // interface-typed values are never reported as nil - if alpn != nil { + + // Application Layer Protocol Negotiation + if (state.Opts.NextProtos != nil) && (len(state.Opts.NextProtos) > 0) { + alpn := &ALPNExtension{Protocols: state.Opts.NextProtos} err := ch.Extensions.Add(alpn) if err != nil { logf(logTypeHandshake, "[ClientStateStart] Error adding ALPN extension [%v]", err) return nil, nil, AlertInternalError } } + + // Certificate type negotiation (to enable raw public keys) + if state.Config.AllowRawPublicKeys { + cct := &ClientCertTypeExtension{HandshakeType: HandshakeTypeClientHello} + sct := &ServerCertTypeExtension{HandshakeType: HandshakeTypeClientHello} + + cct.CertificateTypes = append(cct.CertificateTypes, CertificateTypeRawPublicKey) + sct.CertificateTypes = append(sct.CertificateTypes, CertificateTypeRawPublicKey) + + if !state.Config.ForbidX509 { + cct.CertificateTypes = append(cct.CertificateTypes, CertificateTypeX509) + sct.CertificateTypes = append(sct.CertificateTypes, CertificateTypeX509) + } + + err = ch.Extensions.Add(cct) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ClientCertType extension [%v]", err) + return nil, nil, AlertInternalError + } + err = ch.Extensions.Add(sct) + if err != nil { + logf(logTypeHandshake, "[ClientStateStart] Error adding ServerCertType extension [%v]", err) + return nil, nil, AlertInternalError + } + } + if state.cookie != nil { err := ch.Extensions.Add(&CookieExtension{Cookie: state.cookie}) if err != nil { @@ -590,11 +611,15 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, serverALPN := &ALPNExtension{} serverEarlyData := &EarlyDataExtension{} + serverClientCertType := &ClientCertTypeExtension{HandshakeType: HandshakeTypeEncryptedExtensions} + serverServerCertType := &ServerCertTypeExtension{HandshakeType: HandshakeTypeEncryptedExtensions} foundExts, err := ee.Extensions.Parse( []ExtensionBody{ serverALPN, serverEarlyData, + serverClientCertType, + serverServerCertType, }) if err != nil { logf(logTypeHandshake, "[ClientStateWaitEE] Error decoding extensions: %v", err) @@ -607,6 +632,26 @@ func (state clientStateWaitEE) Next(hr handshakeMessageReader) (HandshakeState, state.Params.NextProto = serverALPN.Protocols[0] } + if foundExts[ExtensionTypeServerCertType] { + certType := serverServerCertType.CertificateTypes[0] + if !CertificateTypeValid(certType, state.Config.AllowRawPublicKeys, state.Config.ForbidX509) { + logf(logTypeHandshake, "[ClientStateWaitEE] Server sent illegal certificate type: %v", certType) + return nil, nil, AlertHandshakeFailure + } + + state.Params.ServerCertType = certType + } + + if foundExts[ExtensionTypeClientCertType] { + certType := serverClientCertType.CertificateTypes[0] + if !CertificateTypeValid(certType, state.Config.AllowRawPublicKeys, state.Config.ForbidX509) { + logf(logTypeHandshake, "[ClientStateWaitEE] Client sent illegal certificate type: %v", certType) + return nil, nil, AlertHandshakeFailure + } + + state.Params.ClientCertType = certType + } + state.handshakeHash.Write(hm.Marshal()) toSend := []HandshakeAction{} @@ -818,44 +863,71 @@ func (state clientStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, hcv := state.handshakeHash.Sum(nil) logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - serverPublicKey := state.serverCertificate.CertificateList[0].CertData.PublicKey + certs := make([][]byte, len(state.serverCertificate.CertificateList)) + for i, certEntry := range state.serverCertificate.CertificateList { + certs[i] = certEntry.CertData + } + + var err error + var serverPublicKey crypto.PublicKey + var eeCert *x509.Certificate + switch state.Params.ServerCertType { + case CertificateTypeX509: + eeCert, err = x509.ParseCertificate(certs[0]) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Unable to parse client cert: %v", err) + return nil, nil, AlertDecodeError + } + + serverPublicKey = eeCert.PublicKey + + case CertificateTypeRawPublicKey: + serverPublicKey, err = unmarshalSigningKey(certs[0]) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Unable to parse raw public key: %v", err) + return nil, nil, AlertDecodeError + } + } + if err := certVerify.Verify(serverPublicKey, hcv); err != nil { logf(logTypeHandshake, "[ClientStateWaitCV] Server signature failed to verify") return nil, nil, AlertHandshakeFailure } - certs := make([]*x509.Certificate, len(state.serverCertificate.CertificateList)) - rawCerts := make([][]byte, len(state.serverCertificate.CertificateList)) - for i, certEntry := range state.serverCertificate.CertificateList { - certs[i] = certEntry.CertData - rawCerts[i] = certEntry.CertData.Raw - } - var verifiedChains [][]*x509.Certificate - if !state.Config.InsecureSkipVerify { - opts := x509.VerifyOptions{ - Roots: state.Config.RootCAs, - CurrentTime: state.Config.time(), - DNSName: state.Config.ServerName, - Intermediates: x509.NewCertPool(), - } + if state.Params.ServerCertType == CertificateTypeX509 { + if !state.Config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: state.Config.RootCAs, + CurrentTime: state.Config.time(), + DNSName: state.Config.ServerName, + Intermediates: x509.NewCertPool(), + } + + for i, cert := range certs { + if i == 0 { + continue + } + + caCert, err := x509.ParseCertificate(cert) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Error parsing server chain: %v", err) + return nil, nil, AlertDecodeError + } - for i, cert := range certs { - if i == 0 { - continue + opts.Intermediates.AddCert(caCert) + } + var err error + verifiedChains, err = eeCert.Verify(opts) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err) + return nil, nil, AlertBadCertificate } - opts.Intermediates.AddCert(cert) - } - var err error - verifiedChains, err = certs[0].Verify(opts) - if err != nil { - logf(logTypeHandshake, "[ClientStateWaitCV] Certificate verification failed: %s", err) - return nil, nil, AlertBadCertificate } } if state.Config.VerifyPeerCertificate != nil { - if err := state.Config.VerifyPeerCertificate(rawCerts, verifiedChains); err != nil { + if err := state.Config.VerifyPeerCertificate(certs, verifiedChains); err != nil { logf(logTypeHandshake, "[ClientStateWaitCV] Application rejected server certificate: %s", err) return nil, nil, AlertBadCertificate } @@ -888,7 +960,7 @@ type clientStateWaitFinished struct { certificates []*Certificate serverCertificateRequest *CertificateRequestBody - peerCertificates []*x509.Certificate + peerCertificates [][]byte verifiedChains [][]*x509.Certificate masterSecret []byte @@ -1000,12 +1072,24 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS state.handshakeHash.Write(certm.Marshal()) } else { // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(cert.Chain)), - } - for i, entry := range cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} + var certList []CertificateEntry + switch state.Params.ClientCertType { + case CertificateTypeX509: + certList = make([]CertificateEntry, len(cert.Chain)) + for i, entry := range cert.Chain { + certList[i] = CertificateEntry{CertData: entry.Raw} + } + case CertificateTypeRawPublicKey: + certData, err := marshalSigningKey(cert.PrivateKey.Public()) + if err != nil { + logf(logTypeHandshake, "[ClientStateWaitFinished] Unable to marshal raw public key [%v]", err) + return nil, nil, AlertInternalError + } + + certList = []CertificateEntry{{CertData: certData}} } + + certificate := &CertificateBody{CertificateList: certList} certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) if err != nil { logf(logTypeHandshake, "[ClientStateWaitFinished] Error marshaling Certificate [%v]", err) @@ -1015,6 +1099,7 @@ func (state clientStateWaitFinished) Next(hr handshakeMessageReader) (HandshakeS toSend = append(toSend, QueueHandshakeMessage{certm}) state.handshakeHash.Write(certm.Marshal()) + // Create and send CertificateVerify hcv := state.handshakeHash.Sum(nil) logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) diff --git a/common.go b/common.go index 4d6477a..eceee1a 100644 --- a/common.go +++ b/common.go @@ -125,6 +125,8 @@ const ( ExtensionTypeCookie ExtensionType = 44 ExtensionTypePSKKeyExchangeModes ExtensionType = 45 ExtensionTypeTicketEarlyDataInfo ExtensionType = 46 + ExtensionTypeClientCertType ExtensionType = 19 + ExtensionTypeServerCertType ExtensionType = 20 ) // enum {...} NamedGroup @@ -164,6 +166,19 @@ const ( KeyUpdateRequested KeyUpdateRequest = 1 ) +/* +0 X.509 Y [RFC6091] +1 OpenPGP_RESERVED N [RFC6091][RFC8446] Used in TLS versions prior to 1.3. +2 Raw Public Key Y [RFC7250] +3 1609Dot2 N +*/ +type CertificateType uint8 + +const ( + CertificateTypeX509 CertificateType = 0 + CertificateTypeRawPublicKey CertificateType = 2 +) + type State uint8 const ( diff --git a/conn.go b/conn.go index ccf8dc1..c8fa1bb 100644 --- a/conn.go +++ b/conn.go @@ -16,6 +16,7 @@ import ( type Certificate struct { Chain []*x509.Certificate PrivateKey crypto.Signer + PublicKey crypto.PublicKey } type PreSharedKey struct { @@ -129,6 +130,12 @@ type Config struct { NonBlocking bool UseDTLS bool + // These bools are arranged in opposite directions so that their default + // values reflect the correct default semantics (certs yes, raw keys no). + // ForbidX509 is only meaningful if AllowRawPublicKeys is true. + AllowRawPublicKeys bool + ForbidX509 bool + RecordLayer RecordLayerFactory // The same config object can be shared among different connections, so it @@ -198,15 +205,25 @@ func (c *Config) Init(isClient bool) error { return nil } +func (c *Config) certTypeValid() bool { + // ForbidX509 can only be set when AllowRawPublicKeys is also set + // Note that this is equivalent to: + // ForbidX509 => AllowRawPublicKeys + return !c.ForbidX509 || c.AllowRawPublicKeys +} + func (c *Config) ValidForServer() bool { - return (reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0) || - (len(c.Certificates) > 0 && - len(c.Certificates[0].Chain) > 0 && - c.Certificates[0].PrivateKey != nil) + // The server must have either PSKs or certificates + havePSK := reflect.ValueOf(c.PSKs).IsValid() && c.PSKs.Size() > 0 + haveCert := len(c.Certificates) > 0 && + len(c.Certificates[0].Chain) > 0 && + c.Certificates[0].PrivateKey != nil + + return (havePSK || haveCert) && c.certTypeValid() } func (c *Config) ValidForClient() bool { - return len(c.ServerName) > 0 + return len(c.ServerName) > 0 && c.certTypeValid() } func (c *Config) time() time.Time { @@ -250,7 +267,7 @@ var ( type ConnectionState struct { HandshakeState State CipherSuite CipherSuiteParams // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) - PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + PeerCertificates [][]byte // certificate chain presented by remote peer VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates NextProto string // Selected ALPN proto UsingPSK bool // Are we using PSK. diff --git a/conn_test.go b/conn_test.go index 5c8314a..cfdd42e 100644 --- a/conn_test.go +++ b/conn_test.go @@ -179,7 +179,7 @@ var ( psk PreSharedKey psks *PSKMapCache - basicConfig, dtlsConfig, nbConfig, nbDTLSConfig, hrrConfig, alpnConfig, pskConfig, pskDTLSConfig, pskECDHEConfig, pskDHEConfig, resumptionConfig, ffdhConfig, x25519Config *Config + basicConfig, dtlsConfig, nbConfig, nbDTLSConfig, hrrConfig, alpnConfig, rawConfig, pskConfig, pskDTLSConfig, pskECDHEConfig, pskDHEConfig, resumptionConfig, ffdhConfig, x25519Config *Config ) func init() { @@ -262,6 +262,13 @@ func init() { InsecureSkipVerify: true, } + rawConfig = &Config{ + ServerName: serverName, + Certificates: certificates, + AllowRawPublicKeys: true, + InsecureSkipVerify: true, + } + pskConfig = &Config{ ServerName: serverName, CipherSuites: []CipherSuite{TLS_AES_128_GCM_SHA256}, @@ -357,11 +364,13 @@ func checkConsistency(t *testing.T, client *Conn, server *Conn) { func testConnInner(t *testing.T, name string, p testInstanceState) { // Configs array: - configs := map[string]*Config{"basic config": basicConfig, - "HRR": hrrConfig, - "ALPN": alpnConfig, - "FFDH": ffdhConfig, - "x25519": x25519Config, + configs := map[string]*Config{ + "basic config": basicConfig, + "HRR": hrrConfig, + "ALPN": alpnConfig, + "RawPK": rawConfig, + "FFDH": ffdhConfig, + "x25519": x25519Config, } c := configs[p["config"]] @@ -400,6 +409,7 @@ func TestBasicFlows(t *testing.T) { "basic config", "HRR", "ALPN", + "RawPK", "FFDH", "x25519", }, @@ -1232,9 +1242,9 @@ func TestConnectionState(t *testing.T) { serverCS := server.ConnectionState() assertEquals(t, clientCS.CipherSuite.Suite, configClient.CipherSuites[0]) assertDeepEquals(t, clientCS.VerifiedChains, [][]*x509.Certificate{{serverCert}}) - assertDeepEquals(t, clientCS.PeerCertificates, []*x509.Certificate{serverCert}) + assertDeepEquals(t, clientCS.PeerCertificates, [][]byte{serverCert.Raw}) assertEquals(t, serverCS.CipherSuite.Suite, serverConfig.CipherSuites[0]) - assertDeepEquals(t, serverCS.PeerCertificates, []*x509.Certificate{clientCert}) + assertDeepEquals(t, serverCS.PeerCertificates, [][]byte{clientCert.Raw}) } func TestDTLS(t *testing.T) { diff --git a/crypto.go b/crypto.go index 4fa59a2..4e9140a 100644 --- a/crypto.go +++ b/crypto.go @@ -328,6 +328,31 @@ func newSigningKey(sig SignatureScheme) (crypto.Signer, error) { } } +func marshalSigningKey(pub crypto.PublicKey) ([]byte, error) { + switch pub.(type) { + case *rsa.PublicKey: + return x509.MarshalPKIXPublicKey(pub) + + case *ecdsa.PublicKey: + return x509.MarshalPKIXPublicKey(pub) + + // TODO support Ed25519 + + default: + return nil, fmt.Errorf("tls.marshalsigningkey: Unsupported public key type") + } +} + +func unmarshalSigningKey(data []byte) (interface{}, error) { + pub, err := x509.ParsePKIXPublicKey(data) + if err == nil { + return pub, err + } + + // TODO attempt to parse Ed25519 + return nil, err +} + // XXX(rlb): Copied from crypto/x509 type ecdsaSignature struct { R, S *big.Int diff --git a/extensions.go b/extensions.go index 07cb16c..629a4ba 100644 --- a/extensions.go +++ b/extensions.go @@ -624,3 +624,99 @@ func (c CookieExtension) Marshal() ([]byte, error) { func (c *CookieExtension) Unmarshal(data []byte) (int, error) { return syntax.Unmarshal(data, c) } + +// struct { +// select(ClientOrServerExtension) { +// case client: +// CertificateType client_certificate_types<1..2^8-1>; +// case server: +// CertificateType client_certificate_type; +// } +// } ClientCertTypeExtension; +type ClientCertTypeExtension struct { + HandshakeType HandshakeType `tls:"none"` + CertificateTypes []CertificateType `tls:"head=1,min=1"` +} + +func (cct ClientCertTypeExtension) Type() ExtensionType { + return ExtensionTypeClientCertType +} + +func (cct ClientCertTypeExtension) Marshal() ([]byte, error) { + switch cct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(cct) + + case HandshakeTypeEncryptedExtensions: + return syntax.Marshal(uint8(cct.CertificateTypes[0])) + } + + return nil, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} + +func (cct *ClientCertTypeExtension) Unmarshal(data []byte) (int, error) { + switch cct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Unmarshal(data, cct) + + case HandshakeTypeEncryptedExtensions: + var val uint8 + n, err := syntax.Unmarshal(data, &val) + if err != nil { + return 0, err + } + + cct.CertificateTypes = []CertificateType{CertificateType(val)} + return n, err + } + + return 0, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} + +// struct { +// select(ClientOrServerExtension) { +// case client: +// CertificateType client_certificate_types<1..2^8-1>; +// case server: +// CertificateType client_certificate_type; +// } +// } ServerCertTypeExtension; +type ServerCertTypeExtension struct { + HandshakeType HandshakeType `tls:"none"` + CertificateTypes []CertificateType `tls:"head=1,min=1"` +} + +func (sct ServerCertTypeExtension) Type() ExtensionType { + return ExtensionTypeServerCertType +} + +func (sct ServerCertTypeExtension) Marshal() ([]byte, error) { + switch sct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Marshal(sct) + + case HandshakeTypeEncryptedExtensions: + return syntax.Marshal(uint8(sct.CertificateTypes[0])) + } + + return nil, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} + +func (sct *ServerCertTypeExtension) Unmarshal(data []byte) (int, error) { + switch sct.HandshakeType { + case HandshakeTypeClientHello: + return syntax.Unmarshal(data, sct) + + case HandshakeTypeEncryptedExtensions: + var val uint8 + n, err := syntax.Unmarshal(data, &val) + if err != nil { + return 0, err + } + + sct.CertificateTypes = []CertificateType{CertificateType(val)} + return n, err + } + + return 0, fmt.Errorf("tls.certificate_types: Handshake type not allowed") +} diff --git a/handshake-messages.go b/handshake-messages.go index e8e3aef..a96c6d9 100644 --- a/handshake-messages.go +++ b/handshake-messages.go @@ -3,7 +3,6 @@ package mint import ( "bytes" "crypto" - "crypto/x509" "encoding/binary" "fmt" @@ -269,23 +268,13 @@ func (ee *EncryptedExtensionsBody) Unmarshal(data []byte) (int, error) { // CertificateEntry certificate_list<0..2^24-1>; // } Certificate; type CertificateEntry struct { - CertData *x509.Certificate - Extensions ExtensionList -} - -type CertificateBody struct { - CertificateRequestContext []byte - CertificateList []CertificateEntry -} - -type certificateEntryInner struct { CertData []byte `tls:"head=3,min=1"` Extensions ExtensionList `tls:"head=2"` } -type certificateBodyInner struct { - CertificateRequestContext []byte `tls:"head=1"` - CertificateList []certificateEntryInner `tls:"head=3"` +type CertificateBody struct { + CertificateRequestContext []byte `tls:"head=1"` + CertificateList []CertificateEntry `tls:"head=3"` } func (c CertificateBody) Type() HandshakeType { @@ -293,41 +282,11 @@ func (c CertificateBody) Type() HandshakeType { } func (c CertificateBody) Marshal() ([]byte, error) { - inner := certificateBodyInner{ - CertificateRequestContext: c.CertificateRequestContext, - CertificateList: make([]certificateEntryInner, len(c.CertificateList)), - } - - for i, entry := range c.CertificateList { - inner.CertificateList[i] = certificateEntryInner{ - CertData: entry.CertData.Raw, - Extensions: entry.Extensions, - } - } - - return syntax.Marshal(inner) + return syntax.Marshal(c) } func (c *CertificateBody) Unmarshal(data []byte) (int, error) { - inner := certificateBodyInner{} - read, err := syntax.Unmarshal(data, &inner) - if err != nil { - return read, err - } - - c.CertificateRequestContext = inner.CertificateRequestContext - c.CertificateList = make([]CertificateEntry, len(inner.CertificateList)) - - for i, entry := range inner.CertificateList { - c.CertificateList[i].CertData, err = x509.ParseCertificate(entry.CertData) - if err != nil { - return 0, fmt.Errorf("tls:certificate: Certificate failed to parse: %v", err) - } - - c.CertificateList[i].Extensions = entry.Extensions - } - - return read, nil + return syntax.Unmarshal(data, c) } // struct { diff --git a/handshake-messages_test.go b/handshake-messages_test.go index 0f387d6..22cd98c 100644 --- a/handshake-messages_test.go +++ b/handshake-messages_test.go @@ -153,11 +153,11 @@ var ( CertificateRequestContext: []byte{0, 0, 0, 0}, CertificateList: []CertificateEntry{ { - CertData: cert1, + CertData: cert1.Raw, Extensions: extListValidIn, }, { - CertData: cert2, + CertData: cert2.Raw, Extensions: extListValidIn, }, }, @@ -166,7 +166,7 @@ var ( CertificateRequestContext: []byte{0, 0, 0, 0}, CertificateList: []CertificateEntry{ { - CertData: cert1, + CertData: cert1.Raw, Extensions: extListSingleTooLongIn, }, }, @@ -462,11 +462,11 @@ func TestCertificateMarshalUnmarshal(t *testing.T) { certValidIn.CertificateRequestContext = originalContext // Test marshal failure on no raw certa - originalRaw := cert1.Raw - cert1.Raw = []byte{} + originalCert := certValidIn.CertificateList[0].CertData + certValidIn.CertificateList[0].CertData = nil out, err = certValidIn.Marshal() assertError(t, err, "Marshaled a Certificate with an empty cert") - cert1.Raw = originalRaw + certValidIn.CertificateList[0].CertData = originalCert // Test marshal failure on extension list marshal failure out, err = certOverflowIn.Marshal() @@ -500,12 +500,6 @@ func TestCertificateMarshalUnmarshal(t *testing.T) { _, err = cert.Unmarshal(certValid) assertError(t, err, "Unmarshaled a Certificate with truncated certificates") certValid[8] ^= 0xFF - - // Test unmarshal failure on malformed certificate - certValid[11] ^= 0xFF // Clobber first octet of first cert - _, err = cert.Unmarshal(certValid) - assertError(t, err, "Unmarshaled a Certificate with truncated certificates") - certValid[11] ^= 0xFF } func TestCertificateVerifyMarshalUnmarshal(t *testing.T) { diff --git a/negotiation.go b/negotiation.go index 7463953..3a131f5 100644 --- a/negotiation.go +++ b/negotiation.go @@ -216,3 +216,25 @@ func ALPNNegotiation(psk *PreSharedKey, offered, supported []string) (string, er } return "", err } + +func CertificateTypeValid(certType CertificateType, allowRawPublicKey, forbidX509 bool) bool { + switch certType { + case CertificateTypeRawPublicKey: + return allowRawPublicKey + case CertificateTypeX509: + return !forbidX509 + default: + return false + } +} + +func CertificateTypeNegotiation(offered []CertificateType, allowRawPublicKey, forbidX509 bool) *CertificateType { + for _, t := range offered { + if CertificateTypeValid(t, allowRawPublicKey, forbidX509) { + out := new(CertificateType) + *out = t + return out + } + } + return nil +} diff --git a/server-state-machine.go b/server-state-machine.go index e392a5d..4e17493 100644 --- a/server-state-machine.go +++ b/server-state-machine.go @@ -2,6 +2,7 @@ package mint import ( "bytes" + "crypto" "crypto/x509" "fmt" "hash" @@ -122,6 +123,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ clientALPN := new(ALPNExtension) clientPSKModes := new(PSKKeyExchangeModesExtension) clientCookie := new(CookieExtension) + clientClientCertType := &ClientCertTypeExtension{HandshakeType: HandshakeTypeClientHello} + clientServerCertType := &ServerCertTypeExtension{HandshakeType: HandshakeTypeClientHello} // Handle external extensions. if state.Config.ExtensionHandler != nil { @@ -144,6 +147,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ clientALPN, clientPSKModes, clientCookie, + clientClientCertType, + clientServerCertType, }) if err != nil { @@ -331,6 +336,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ var pskSecret []byte var cert *Certificate var certScheme SignatureScheme + var clientCertType *CertificateType + var serverCertType *CertificateType if connParams.UsingPSK { pskSecret = psk.Key } else { @@ -344,6 +351,35 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ return nil, nil, AlertMissingExtension } + // Negotiate certificate types for the client and server + if foundExts[ExtensionTypeClientCertType] && state.Config.RequireClientAuth { + clientCertType = CertificateTypeNegotiation(clientClientCertType.CertificateTypes, + state.Config.AllowRawPublicKeys, state.Config.ForbidX509) + + if clientCertType == nil { + logf(logTypeHandshake, "[ServerStateStart] Client certificate type negotiation failure") + return nil, nil, AlertHandshakeFailure + } + + connParams.ClientCertType = *clientCertType + } + + if foundExts[ExtensionTypeServerCertType] { + serverCertType = CertificateTypeNegotiation(clientServerCertType.CertificateTypes, + state.Config.AllowRawPublicKeys, state.Config.ForbidX509) + + if serverCertType == nil { + logf(logTypeHandshake, "[ServerStateStart] Client certificate type negotiation failure") + return nil, nil, AlertHandshakeFailure + } + + connParams.ServerCertType = *serverCertType + } + + // Since we have already failed in cases where `nil` is invalid, from here + // on out, clientCertType == nil or serverCertType == nil indicate whether + // those extensions are to be sent. + // Select a certificate name := string(*serverName) var err error @@ -396,6 +432,8 @@ func (state serverStateStart) Next(hr handshakeMessageReader) (HandshakeState, [ certScheme: certScheme, legacySessionId: ch.LegacySessionID, clientEarlyTrafficSecret: clientEarlyTrafficSecret, + clientCertType: clientCertType, + serverCertType: serverCertType, firstClientHello: firstClientHello, helloRetryRequest: helloRetryRequest, @@ -460,6 +498,8 @@ type serverStateNegotiated struct { firstClientHello *HandshakeMessage helloRetryRequest *HandshakeMessage clientHello *HandshakeMessage + clientCertType *CertificateType + serverCertType *CertificateType } var _ HandshakeState = &serverStateNegotiated{} @@ -583,6 +623,7 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat return nil, nil, AlertInternalError } } + if state.Params.UsingEarlyData { logf(logTypeHandshake, "[server] sending EDI extension") err = eeList.Add(&EarlyDataExtension{}) @@ -591,6 +632,31 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat return nil, nil, AlertInternalError } } + + if state.serverCertType != nil { + logf(logTypeHandshake, "[server] sending ServerCertType extension") + err = eeList.Add(&ServerCertTypeExtension{ + HandshakeType: HandshakeTypeEncryptedExtensions, + CertificateTypes: []CertificateType{*state.serverCertType}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding SCT to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + + if state.Config.RequireClientAuth && state.clientCertType != nil { + logf(logTypeHandshake, "[server] sending ClientCertType extension") + err = eeList.Add(&ClientCertTypeExtension{ + HandshakeType: HandshakeTypeEncryptedExtensions, + CertificateTypes: []CertificateType{*state.clientCertType}, + }) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Error adding CCT to EncryptedExtensions [%v]", err) + return nil, nil, AlertInternalError + } + } + ee := &EncryptedExtensionsBody{eeList} // Run the external extension handler. @@ -643,13 +709,26 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat handshakeHash.Write(crm.Marshal()) } - // Create and send Certificate, CertificateVerify - certificate := &CertificateBody{ - CertificateList: make([]CertificateEntry, len(state.cert.Chain)), - } - for i, entry := range state.cert.Chain { - certificate.CertificateList[i] = CertificateEntry{CertData: entry} + // Create and send Certificate + var certList []CertificateEntry + switch state.Params.ServerCertType { + case CertificateTypeX509: + certList = make([]CertificateEntry, len(state.cert.Chain)) + for i, entry := range state.cert.Chain { + certList[i] = CertificateEntry{CertData: entry.Raw} + } + + case CertificateTypeRawPublicKey: + certData, err := marshalSigningKey(state.cert.PrivateKey.Public()) + if err != nil { + logf(logTypeHandshake, "[ServerStateNegotiated] Unable to marshal raw public key [%v]", err) + return nil, nil, AlertInternalError + } + + certList = []CertificateEntry{{CertData: certData}} } + + certificate := &CertificateBody{CertificateList: certList} certm, err := state.hsCtx.hOut.HandshakeMessageFromBody(certificate) if err != nil { logf(logTypeHandshake, "[ServerStateNegotiated] Error marshaling Certificate [%v]", err) @@ -659,6 +738,7 @@ func (state serverStateNegotiated) Next(_ handshakeMessageReader) (HandshakeStat toSend = append(toSend, QueueHandshakeMessage{certm}) handshakeHash.Write(certm.Marshal()) + // Create and send CertificateVerify certificateVerify := &CertificateVerifyBody{Algorithm: state.certScheme} logf(logTypeHandshake, "Creating CertVerify: %04x %v", state.certScheme, params.Hash) @@ -1045,22 +1125,37 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, certVerify := &CertificateVerifyBody{} if err := safeUnmarshal(certVerify, hm.body); err != nil { - logf(logTypeHandshake, "[ServerStateWaitCert] Error decoding message %v", err) + logf(logTypeHandshake, "[ServerStateWaitCV] Error decoding message %v", err) return nil, nil, AlertDecodeError } - rawCerts := make([][]byte, len(state.clientCertificate.CertificateList)) - certs := make([]*x509.Certificate, len(state.clientCertificate.CertificateList)) + certs := make([][]byte, len(state.clientCertificate.CertificateList)) for i, certEntry := range state.clientCertificate.CertificateList { certs[i] = certEntry.CertData - rawCerts[i] = certEntry.CertData.Raw } // Verify client signature over handshake hash + var err error + var clientPublicKey crypto.PublicKey + switch state.Params.ClientCertType { + case CertificateTypeX509: + cert, err := x509.ParseCertificate(certs[0]) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Unable to parse client cert: %v", err) + } + + clientPublicKey = cert.PublicKey + + case CertificateTypeRawPublicKey: + clientPublicKey, err = unmarshalSigningKey(certs[0]) + if err != nil { + logf(logTypeHandshake, "[ServerStateWaitCV] Unable to parse raw public key: %v", err) + } + } + hcv := state.handshakeHash.Sum(nil) logf(logTypeHandshake, "Handshake Hash to be verified: [%d] %x", len(hcv), hcv) - clientPublicKey := state.clientCertificate.CertificateList[0].CertData.PublicKey if err := certVerify.Verify(clientPublicKey, hcv); err != nil { logf(logTypeHandshake, "[ServerStateWaitCV] Failure in client auth verification [%v]", err) return nil, nil, AlertHandshakeFailure @@ -1068,7 +1163,7 @@ func (state serverStateWaitCV) Next(hr handshakeMessageReader) (HandshakeState, if state.Config.VerifyPeerCertificate != nil { // TODO(#171): pass in the verified chains, once we support different client auth types - if err := state.Config.VerifyPeerCertificate(rawCerts, nil); err != nil { + if err := state.Config.VerifyPeerCertificate(certs, nil); err != nil { logf(logTypeHandshake, "[ServerStateWaitCV] Application rejected client certificate: %s", err) return nil, nil, AlertBadCertificate } @@ -1101,7 +1196,7 @@ type serverStateWaitFinished struct { masterSecret []byte clientHandshakeTrafficSecret []byte - peerCertificates []*x509.Certificate + peerCertificates [][]byte verifiedChains [][]*x509.Certificate handshakeHash hash.Hash diff --git a/state-machine.go b/state-machine.go index f0635c8..b5001be 100644 --- a/state-machine.go +++ b/state-machine.go @@ -62,6 +62,9 @@ type ConnectionParameters struct { RejectedEarlyData bool UsingClientAuth bool + ClientCertType CertificateType + ServerCertType CertificateType + CipherSuite CipherSuite ServerName string NextProto string @@ -97,7 +100,7 @@ type stateConnected struct { clientTrafficSecret []byte serverTrafficSecret []byte exporterSecret []byte - peerCertificates []*x509.Certificate + peerCertificates [][]byte verifiedChains [][]*x509.Certificate }