diff --git a/conn_test.go b/conn_test.go index 4854410..2ad20a5 100644 --- a/conn_test.go +++ b/conn_test.go @@ -19,6 +19,7 @@ import ( ) type pipeConn struct { + label string closed bool r *bytes.Buffer w *bytes.Buffer @@ -30,6 +31,9 @@ func pipe() (client *pipeConn, server *pipeConn) { client = new(pipeConn) server = new(pipeConn) + client.label = "client" + server.label = "server" + c2s := bytes.NewBuffer(nil) server.r = c2s client.w = c2s @@ -64,6 +68,8 @@ func (p *pipeConn) Read(data []byte) (n int, err error) { } func (p *pipeConn) Write(data []byte) (n int, err error) { + logf(logTypePipe, "[%s] write: %d %x\n", p.label, len(data), data) + p.wLock.Lock() defer p.wLock.Unlock() if p.closed { diff --git a/handshake-layer.go b/handshake-layer.go index ae11cb8..8c30991 100644 --- a/handshake-layer.go +++ b/handshake-layer.go @@ -401,22 +401,70 @@ func (h *HandshakeLayer) ReadMessage() (*HandshakeMessage, error) { } func (h *HandshakeLayer) QueueMessage(hm *HandshakeMessage) error { - if h.datagram { - hm.cipher = h.conn.(*DefaultRecordLayer).cipher - h.queued = append(h.queued, hm) - return nil + hm.cipher = h.conn.(*DefaultRecordLayer).cipher + h.queued = append(h.queued, hm) + return nil +} + +func (h *HandshakeLayer) packAndWrite(hms []*HandshakeMessage) (int, error) { + buffer := []byte{} + for _, hm := range hms { + buffer = append(buffer, hm.Marshal()...) + } + + record := &TLSPlaintext{ + contentType: RecordTypeHandshake, + fragment: buffer, + } + cipher := hms[0].cipher + err := h.conn.(*DefaultRecordLayer).writeRecordWithPadding(record, cipher, 0) + if err != nil { + return 0, err } - _, err := h.WriteMessages([]*HandshakeMessage{hm}) - return err + + return len(hms), nil +} + +func (h *HandshakeLayer) packAndWriteQueue() (int, error) { + if len(h.queued) == 0 { + return 0, nil + } + + buffer := []*HandshakeMessage{} + written := 0 + for i, hm := range h.queued { + buffer = append(buffer, hm) + + if i < len(h.queued)-1 && h.queued[i+1].cipher == hm.cipher { + continue + } + + count, err := h.packAndWrite(buffer) + if err != nil { + return 0, err + } + + written += count + buffer = []*HandshakeMessage{} + } + + return written, nil } func (h *HandshakeLayer) SendQueuedMessages() (int, error) { logf(logTypeHandshake, "Sending outgoing messages") - count, err := h.WriteMessages(h.queued) + if !h.datagram { + count, err := h.packAndWriteQueue() + if err != nil { + return count, err + } + h.ClearQueuedMessages() + return count, nil } - return count, err + + return h.WriteMessages(h.queued) } func (h *HandshakeLayer) ClearQueuedMessages() { diff --git a/log.go b/log.go index 2fba90d..a2097a5 100644 --- a/log.go +++ b/log.go @@ -18,6 +18,7 @@ const ( logTypeNegotiation = "negotiation" logTypeIO = "io" logTypeFrameReader = "frame" + logTypePipe = "pipe" logTypeVerbose = "verbose" )