diff --git a/src/ssl.jl b/src/ssl.jl index 3f94309..a0a9a0f 100644 --- a/src/ssl.jl +++ b/src/ssl.jl @@ -35,13 +35,22 @@ bio_clear_flags(bio::BIO) = bio_set_flags(bio, 0x00) function on_bio_stream_read(bio::BIO, out::Ptr{Cchar}, outlen::Cint) try bio_clear_flags(bio) - io = bio_get_data(bio)::TCPSocket - n = bytesavailable(io) - if n == 0 - bio_set_read_retry(bio) - return Cint(0) + io = bio_get_data(bio) + if io isa TCPSocket + n = bytesavailable(io) + if n == 0 + bio_set_read_retry(bio) + return Cint(0) + end + unsafe_read(io, out, min(UInt(n), outlen)) + else + n = bytesavailable(io) + if n == 0 + bio_set_read_retry(bio) + return Cint(0) + end + unsafe_read(io, out, min(UInt(n), outlen)) end - unsafe_read(io, out, min(UInt(n), outlen)) return Cint(min(n, outlen)) catch e # we don't want to throw a Julia exception from a C callback @@ -51,8 +60,13 @@ end function on_bio_stream_write(bio::BIO, in::Ptr{Cchar}, inlen::Cint)::Cint try - io = bio_get_data(bio)::TCPSocket - written = unsafe_write(io, in, inlen) + bio_clear_flags(bio) + io = bio_get_data(bio) + if io isa TCPSocket + written = unsafe_write(io, in, inlen) + else + written = unsafe_write(io, in, inlen) + end return Cint(written) catch e # we don't want to throw a Julia exception from a C callback @@ -373,12 +387,12 @@ end """ SSLStream. """ -mutable struct SSLStream <: IO +mutable struct SSLStream{T} <: IO ssl::SSL ssl_context::SSLContext rbio::BIO wbio::BIO - io::TCPSocket + io::T # used in `eof` where we want the call to `eof` on the underlying # socket and the SSL_peek call that processes bytes to be seen # as one "operation" @@ -395,16 +409,16 @@ mutable struct SSLStream <: IO peekbytes::Base.RefValue{Csize_t} closed::Bool - function SSLStream(ssl_context::SSLContext, io::TCPSocket) + function SSLStream(ssl_context::SSLContext, io::T) where {T <: IO} # Create a read and write BIOs. bio_read::BIO = BIO(io; finalize=false) bio_write::BIO = BIO(io; finalize=false) ssl = SSL(ssl_context, bio_read, bio_write) - return new(ssl, ssl_context, bio_read, bio_write, io, ReentrantLock(), ReentrantLock(), Ref{Csize_t}(0), Ref{Csize_t}(0), Ref{UInt8}(0x00), Ref{Csize_t}(0), false) + return new{T}(ssl, ssl_context, bio_read, bio_write, io, ReentrantLock(), ReentrantLock(), Ref{Csize_t}(0), Ref{Csize_t}(0), Ref{UInt8}(0x00), Ref{Csize_t}(0), false) end end -SSLStream(tcp::TCPSocket) = SSLStream(SSLContext(OpenSSL.TLSClientMethod()), tcp) +SSLStream(tcp::IO) = SSLStream(SSLContext(OpenSSL.TLSClientMethod()), tcp) # backwards compat Base.getproperty(ssl::SSLStream, nm::Symbol) = nm === :bio_read_stream ? ssl : getfield(ssl, nm)