Skip to content

Commit

Permalink
Add a precompilation workload
Browse files Browse the repository at this point in the history
Other notable changes:
- Moved `create_profile()` into init.jl so we can use it in `init()` and the
  precompilation workload.
- Added an option to not capture stdin, because during precompilation it's not
  possible to redirect stdin.
  • Loading branch information
JamesWrigley committed Feb 23, 2025
1 parent 9d66d46 commit 47a2230
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 43 deletions.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
MbedTLS = "739be429-bea8-5141-9913-cc70e7f3736d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
SoftGlobalScope = "b85f4697-e234-5449-a836-ec8e2f98b302"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
ZMQ = "c2297ded-f4af-51ae-bb23-16f91089e4e1"
Expand All @@ -29,9 +31,11 @@ Logging = "1"
Markdown = "1"
MbedTLS = "0.5,0.6,0.7,1"
Pkg = "1"
PrecompileTools = "1.2.1"
Printf = "1"
REPL = "1"
Random = "1"
Sockets = "1"
SoftGlobalScope = "1"
UUIDs = "1"
ZMQ = "1.3"
Expand Down
11 changes: 9 additions & 2 deletions src/IJulia.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ end
n::Int = 0
capture_stdout::Bool = true
capture_stderr::Bool = !IJULIA_DEBUG
capture_stdin::Bool = true

postexecute_hooks::Vector{Function} = Function[]
preexecute_hooks::Vector{Function} = Function[]
Expand Down Expand Up @@ -168,10 +169,15 @@ function Base.close(kernel::Kernel)
close(kernel.read_stderr[])
wait(kernel.watch_stderr_task[])
end
redirect_stdin(orig_stdin[])
if kernel.capture_stdin
redirect_stdin(orig_stdin[])
end

# Reset the logger so that @log statements work and pop the InlineDisplay
Logging.global_logger(orig_logger[])
if isassigned(orig_logger)
# orig_logger seems to not be set during precompilation
Logging.global_logger(orig_logger[])
end
popdisplay()

# Close all sockets
Expand Down Expand Up @@ -456,5 +462,6 @@ include("execute_request.jl")
include("handlers.jl")
include("heartbeat.jl")
include("inline.jl")
include("precompile.jl")

end # IJulia
46 changes: 35 additions & 11 deletions src/init.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Random: seed!
import Sockets
import Logging
import Logging: AbstractLogger, ConsoleLogger

Expand Down Expand Up @@ -30,6 +31,35 @@ end
const minirepl = Ref{MiniREPL}()
end

function getports(port_hint, n)
ports = Int[]

for i in 1:n
port, server = Sockets.listenany(Sockets.localhost, port_hint)
close(server)
push!(ports, port)
port_hint = port + 1
end

return ports
end

function create_profile(port_hint=8080; key=uuid4())
ports = getports(port_hint, 5)

Dict(
"transport" => "tcp",
"ip" => "127.0.0.1",
"control_port" => ports[1],
"shell_port" => ports[2],
"stdin_port" => ports[3],
"hb_port" => ports[4],
"iopub_port" => ports[5],
"signature_scheme" => "hmac-sha256",
"key" => key
)
end

"""
init(args, kernel)
Expand All @@ -48,16 +78,7 @@ function init(args, kernel, profile=nothing)
else
# generate profile and save
let port0 = 5678
merge!(kernel.profile, Dict{String,Any}(
"ip" => "127.0.0.1",
"transport" => "tcp",
"stdin_port" => port0,
"control_port" => port0+1,
"hb_port" => port0+2,
"shell_port" => port0+3,
"iopub_port" => port0+4,
"key" => uuid4()
))
merge!(kernel.profile, create_profile(port0))

Check warning on line 81 in src/init.jl

View check run for this annotation

Codecov / codecov/patch

src/init.jl#L81

Added line #L81 was not covered by tests
fname = "profile-$(getpid()).json"
kernel.connection_file = "$(pwd())/$fname"
println("connect ipython with --existing $(kernel.connection_file)")

Check warning on line 84 in src/init.jl

View check run for this annotation

Codecov / codecov/patch

src/init.jl#L83-L84

Added lines #L83 - L84 were not covered by tests
Expand Down Expand Up @@ -108,7 +129,10 @@ function init(args, kernel, profile=nothing)
kernel.read_stderr[], = redirect_stderr()
redirect_stderr(IJuliaStdio(stderr, kernel, "stderr"))
end
redirect_stdin(IJuliaStdio(stdin, kernel, "stdin"))
if kernel.capture_stdin
redirect_stdin(IJuliaStdio(stdin, kernel, "stdin"))
end

@static if VERSION < v"1.11"
minirepl[] = MiniREPL(TextDisplay(stdout))
end
Expand Down
10 changes: 10 additions & 0 deletions src/msg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ function recv_ipython(socket, kernel)
if signature != hmac(header, parent_header, metadata, content, kernel)
error("Invalid HMAC signature") # What should we do here?
end

# Note: don't remove these lines, they're useful for creating a
# precompilation workload.
# @show idents
# @show signature
# @show header
# @show parent_header
# @show metadata
# @show content

m = Msg(idents, JSON.parse(header), JSON.parse(content), JSON.parse(parent_header), JSON.parse(metadata))
@vprintln("RECEIVED $m")
return m
Expand Down
69 changes: 69 additions & 0 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import PrecompileTools: @compile_workload

# This key is used by the tests and precompilation workload to keep some
# consistency in the message signatures.
const _TEST_KEY = "a0436f6c-1916-498b-8eb9-e81ab9368e84"

# How to update the precompilation workload:
# 1. Uncomment the `@show` expressions in `recv_ipython()` in msg.jl.
# 2. Copy this workload into tests/kernel.jl and update as desired:
#
# Kernel(profile; capture_stdout=false, capture_stderr=false) do kernel
# jupyter_client(profile) do client
# kernel_info(client)
# execute(client, "42")
# execute(client, "error(42)")
# end
# end
#
# 3. When the above runs it will print out the contents of the received messages
# as strings. You can copy these verbatim into the precompilation workload
# below. Note that if you modify any step of the workload you will need to
# update *all* the messages to ensure they have the right parent
# headers/signatures.
@compile_workload begin
profile = create_profile(45_000; key=_TEST_KEY)

Kernel(profile; capture_stdout=false, capture_stderr=false, capture_stdin=false) do kernel
# Connect as a client to the kernel
requests_socket = ZMQ.Socket(ZMQ.DEALER)
ip = profile["ip"]
port = profile["shell_port"]
ZMQ.connect(requests_socket, "tcp://$(ip):$(port)")

# kernel_info
idents = ["d2bd8e47-b2c9cd130d2967a19f52c1a3"]
signature = "3c4f523a0e8b80e5b3e35756d75f62d12b851e1fd67c609a9119872e911f83d2"
header = "{\"msg_id\": \"d2bd8e47-b2c9cd130d2967a19f52c1a3_3534705_0\", \"msg_type\": \"kernel_info_request\", \"username\": \"james\", \"session\": \"d2bd8e47-b2c9cd130d2967a19f52c1a3\", \"date\": \"2025-02-20T22:29:47.616834Z\", \"version\": \"5.4\"}"
parent_header = "{}"
metadata = "{}"
content = "{}"

ZMQ.send_multipart(requests_socket, [only(idents), "<IDS|MSG>", signature, header, parent_header, metadata, content])
ZMQ.recv_multipart(requests_socket, String)

# Execute `42`
idents = ["d2bd8e47-b2c9cd130d2967a19f52c1a3"]
signature = "758c034ba5efb4fd7fd5a5600f913bc634739bf6a2c1e1d87e88b008706337bc"
header = "{\"msg_id\": \"d2bd8e47-b2c9cd130d2967a19f52c1a3_3534705_1\", \"msg_type\": \"execute_request\", \"username\": \"james\", \"session\": \"d2bd8e47-b2c9cd130d2967a19f52c1a3\", \"date\": \"2025-02-20T22:29:49.835131Z\", \"version\": \"5.4\"}"
parent_header = "{}"
metadata = "{}"
content = "{\"code\": \"42\", \"silent\": false, \"store_history\": true, \"user_expressions\": {}, \"allow_stdin\": true, \"stop_on_error\": true}"

ZMQ.send_multipart(requests_socket, [only(idents), "<IDS|MSG>", signature, header, parent_header, metadata, content])
ZMQ.recv_multipart(requests_socket, String)

# Execute `error(42)`
idents = ["d2bd8e47-b2c9cd130d2967a19f52c1a3"]
signature = "953702763b65d9b0505f34ae0eb195574b9c2c65eebedbfa8476150133649801"
header = "{\"msg_id\": \"d2bd8e47-b2c9cd130d2967a19f52c1a3_3534705_2\", \"msg_type\": \"execute_request\", \"username\": \"james\", \"session\": \"d2bd8e47-b2c9cd130d2967a19f52c1a3\", \"date\": \"2025-02-20T22:29:50.320836Z\", \"version\": \"5.4\"}"
parent_header = "{}"
metadata = "{}"
content = "{\"code\": \"error(42)\", \"silent\": false, \"store_history\": true, \"user_expressions\": {}, \"allow_stdin\": true, \"stop_on_error\": true}"

ZMQ.send_multipart(requests_socket, [only(idents), "<IDS|MSG>", signature, header, parent_header, metadata, content])
ZMQ.recv_multipart(requests_socket, String)

close(requests_socket)
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Sockets = "6462fe0b-24de-5631-8697-dd941f90decc"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZMQ = "c2297ded-f4af-51ae-bb23-16f91089e4e1"
32 changes: 2 additions & 30 deletions test/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Test
import Sockets
import Sockets: listenany

import ZMQ
import PythonCall
import PythonCall: Py, pyimport, pyconvert, pytype, pystr

Expand All @@ -31,35 +32,6 @@ import IJulia: Kernel
import IJulia: ans, In, Out


function getports(port_hint, n)
ports = Int[]

for i in 1:n
port, server = listenany(Sockets.localhost, port_hint)
close(server)
push!(ports, port)
port_hint = port + 1
end

return ports
end

function create_profile(port_hint=8080)
ports = getports(port_hint, 5)

Dict(
"transport" => "tcp",
"ip" => "127.0.0.1",
"control_port" => ports[1],
"shell_port" => ports[2],
"stdin_port" => ports[3],
"hb_port" => ports[4],
"iopub_port" => ports[5],
"signature_scheme" => "hmac-sha256",
"key" => "a0436f6c-1916-498b-8eb9-e81ab9368e84"
)
end

function test_py_get!(get_func, result)
try
result[] = get_func(timeout=0)
Expand Down Expand Up @@ -156,7 +128,7 @@ function jupyter_client(f, profile)
end

@testset "Kernel" begin
profile = create_profile()
profile = IJulia.create_profile(; key=IJulia._TEST_KEY)
profile_kwargs = Dict([Symbol(key) => value for (key, value) in profile])
profile_kwargs[:key] = pystr(profile_kwargs[:key]).encode()

Expand Down

0 comments on commit 47a2230

Please sign in to comment.