Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework initialization of constants & class variables #15333

Merged
merged 12 commits into from
Jan 20, 2025
6 changes: 3 additions & 3 deletions src/compiler/crystal/codegen/class_var.cr
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
initialized_flag.thread_local = true if class_var.thread_local?
end
Expand Down Expand Up @@ -61,7 +61,7 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @llvm_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @llvm_mod.globals.add(llvm_context.int1, initialized_flag_name)
initialized_flag = @llvm_mod.globals.add(llvm_context.int8, initialized_flag_name)
initialized_flag.thread_local = true if class_var.thread_local?
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/crystal/codegen/const.cr
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = const.initialized_llvm_name
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
end
initialized_flag
Expand Down
53 changes: 35 additions & 18 deletions src/compiler/crystal/codegen/once.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,48 @@ class Crystal::CodeGenVisitor
if once_init_fun = typed_fun?(@main_mod, ONCE_INIT)
once_init_fun = check_main_fun ONCE_INIT, once_init_fun

once_state_global = @main_mod.globals.add(once_init_fun.type.return_type, ONCE_STATE)
once_state_global.linkage = LLVM::Linkage::Internal if @single_module
once_state_global.initializer = once_init_fun.type.return_type.null

state = call once_init_fun
store state, once_state_global
if once_init_fun.type.return_type.void?
call once_init_fun
else
# legacy (kept for backward compatibility): the compiler must save the
# state returned by __crystal_once_init
once_state_global = @main_mod.globals.add(once_init_fun.type.return_type, ONCE_STATE)
once_state_global.linkage = LLVM::Linkage::Internal if @single_module
once_state_global.initializer = once_init_fun.type.return_type.null

state = call once_init_fun
store state, once_state_global
end
end
end

def run_once(flag, func : LLVMTypedFunction)
once_fun = main_fun(ONCE)
once_init_fun = main_fun(ONCE_INIT)

# both of these should be Void*
once_state_type = once_init_fun.type.return_type
once_initializer_type = once_fun.func.params.last.type
once_fun_params = once_fun.func.params
once_initializer_type = once_fun_params.last.type # must be Void*
initializer = pointer_cast(func.func.to_value, once_initializer_type)

once_state_global = @llvm_mod.globals[ONCE_STATE]? || begin
global = @llvm_mod.globals.add(once_state_type, ONCE_STATE)
global.linkage = LLVM::Linkage::External
global
if once_fun_params.size == 2
args = [flag, initializer]
else
# legacy (kept for backward compatibility): the compiler must pass the
# state returned by __crystal_once_init to __crystal_once as the first
# argument
once_init_fun = main_fun(ONCE_INIT)
once_state_type = once_init_fun.type.return_type # must be Void*

once_state_global = @llvm_mod.globals[ONCE_STATE]? || begin
global = @llvm_mod.globals.add(once_state_type, ONCE_STATE)
global.linkage = LLVM::Linkage::External
global
end

state = load(once_state_type, once_state_global)
# cast Int8* to Bool* (required for LLVM 14 and below)
bool = bit_cast(flag, @llvm_context.int1.pointer)
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved
args = [state, bool, initializer]
end

state = load(once_state_type, once_state_global)
initializer = pointer_cast(func.func.to_value, once_initializer_type)
call once_fun, [state, flag, initializer]
call once_fun, args
end
end
182 changes: 142 additions & 40 deletions src/crystal/once.cr
Original file line number Diff line number Diff line change
@@ -1,54 +1,156 @@
# This file defines the functions `__crystal_once_init` and `__crystal_once` expected
# by the compiler. `__crystal_once` is called each time a constant or class variable
# has to be initialized and is its responsibility to verify the initializer is executed
# only once. `__crystal_once_init` is executed only once at the beginning of the program
# and the result is passed on each call to `__crystal_once`.

# This implementation uses an array to store the initialization flag pointers for each value
# to find infinite loops and raise an error. In multithread mode a mutex is used to
# avoid race conditions between threads.

# :nodoc:
class Crystal::OnceState
@rec = [] of Bool*

def once(flag : Bool*, initializer : Void*)
unless flag.value
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
# This file defines two functions expected by the compiler:
#
# - `__crystal_once_init`: executed only once at the beginning of the program
# and, for the legacy implementation, the result is passed on each call to
# `__crystal_once`.
#
# - `__crystal_once`: called each time a constant or class variable has to be
# initialized and is its responsibility to verify the initializer is executed
# only once and to fail on recursion.

# In multithread mode a mutex is used to avoid race conditions between threads.
#
# On Win32, `Crystal::System::FileDescriptor#@@reader_thread` spawns a new
# thread even without the `preview_mt` flag, and the thread can also reference
# Crystal constants, leading to race conditions, so we always enable the mutex.

module Crystal
# :nodoc:
@[AlwaysInline]
def self.once_unreachable : NoReturn
x = uninitialized NoReturn
x
end
end
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved

{% if compare_versions(Crystal::VERSION, "1.16.0-dev") >= 0 %}
# This implementation uses an enum over the initialization flag pointer for
# each value to find infinite loops and raise an error.

module Crystal
# :nodoc:
enum OnceState : Int8
Processing = -1
Uninitialized = 0
Initialized = 1
end

Comment on lines +23 to +27
Copy link
Contributor

@BlobCodes BlobCodes Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very nit-picky, but in my v2 branch, I replaced this trinary enum with a Int8 so all comparisons are done with 0 (== Initialized becomes > 0).

Comparing with 0 results in fewer (or smaller) assembly instructions on most CPUs
The only arch I'm really familiar with is risc-v, so here's an example in risc-v:

# enum comparison `return if state == Initialized`
  li t0, 1         # load 1 into t0
  bne t0, a0, init # initialize if needed
  ret              # early return (or normal program flow if inlined)
init:
  # stuff

# comparison `return if state <= 0`
  blez a0, init # initialize if needed
  ret           # early-return (or normal program flow if inlined)
init:
  # stuff

Of course this is micro-optimization, but if this is inlined into every const access, it could be noticable.

Copy link
Contributor Author

@ysbaddaden ysbaddaden Jan 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EDIT: I fixed the examples below... I stupidly used a UInt8 🤦

On x86_64 the only difference is in the jump instruction:

; flag.value == 1
cmpb   $0x1,(%rdi)
jne    39 <foo+0x9>

; flag.value > 0
cmpb   $0x0,(%rdi)
jle    39 <foo+0x9>

Same for ARM32:

; flag.value == 1
ldrb    r0, [r0]
cmp     r0, #1
moveq   pc, lr

; flag.value > 0
ldrb    r0, [r0]
cmp     r0, #1
movge   pc, lr

But AArch64 indeed only needs one instruction instead of two for the equality check: And same for AArch64:

; flag.value == 1
ldrb    w8, [x0]
cmp     w8, #0x1
b.ne    40 <foo+0x10>

; flag.value > 0
ldrb    w8, [x0]
cmp     w8, #0x1
b.lt    40 <*foo+0x10>

NOTE: I'm not fluent in the assembly of each arch. I compiled a tiny program with --cross-compile --target=... then used objdump --disassemble from a crosschain build of binutils to compare the LLVM generated assembly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

; flag.value > 0
cmpb   $0x0,(%rdi)
je     39 <foo+0x9>

To me that looks like an unsigned comparison. So it only checks x != 0.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wrote a commit, but didn't push it (yet): I'm wondering if we'd really benefit from the change in practice 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested again, it's actually a jle!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BlobCodes I run my tests again and updated my comment above.

We can probably do better in manually written assembly, but LLVM actually generates the same assembly for all 3 architectures. The only difference stands in the jump instruction 🤷

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's see in a follow up if we can squeeze even more performance with this idea.

{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex = uninitialized Mutex

ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved
# :nodoc:
def self.once_mutex : Mutex
@@once_mutex
end

# :nodoc:
def self.once_mutex=(@@once_mutex : Mutex)
end
@rec << flag
{% end %}

Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = true
# :nodoc:
# Using @[NoInline] so LLVM optimizes for the hot path (var already
# initialized).
@[NoInline]
def self.once(flag : OnceState*, initializer : Void*) : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
Crystal.once_mutex.synchronize { once_exec(flag, initializer) }
{% else %}
once_exec(flag, initializer)
{% end %}

@rec.pop
# safety check, and allows to safely call `#once_unreachable` in
# `__crystal_once`
unless flag.value.initialized?
System.print_error "BUG: failed to initialize constant or class variable\n"
LibC._exit(1)
end
end

private def self.once_exec(flag : OnceState*, initializer : Void*) : Nil
case flag.value
in .initialized?
return
in .uninitialized?
flag.value = :processing
Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = :initialized
in .processing?
raise "Recursion while initializing class variables and/or constants"
end
end
end

# on Win32, `Crystal::System::FileDescriptor#@@reader_thread` spawns a new
# thread even without the `preview_mt` flag, and the thread can also reference
# Crystal constants, leading to race conditions, so we always enable the mutex
# TODO: can this be improved?
{% if flag?(:preview_mt) || flag?(:win32) %}
@mutex = Mutex.new(:reentrant)
# :nodoc:
fun __crystal_once_init : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
Crystal.once_mutex = Mutex.new(:reentrant)
{% end %}
end

# :nodoc:
#
# Using `@[AlwaysInline]` allows LLVM to optimize const accesses. Since this
# is a `fun` the function will still appear in the symbol table, though it
# will never be called.
@[AlwaysInline]
fun __crystal_once(flag : Crystal::OnceState*, initializer : Void*) : Nil
return if flag.value.initialized?

Crystal.once(flag, initializer)

# tell LLVM that it can optimize away repeated `__crystal_once` calls for
# this global (e.g. repeated access to constant in a single funtion);
# this is truly unreachable otherwise `Crystal.once` would have panicked
Crystal.once_unreachable unless flag.value.initialized?
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

praise: This is pretty darn clever 👏 Kudos to @BlobCodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I second that: thanks a lot @BlobCodes 🙇

{% else %}
# This implementation uses a global array to store the initialization flag
# pointers for each value to find infinite loops and raise an error.

# :nodoc:
class Crystal::OnceState
@rec = [] of Bool*

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
@mutex.synchronize do
previous_def
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
end
@rec << flag

Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = true

@rec.pop
end
end
{% end %}
end

# :nodoc:
fun __crystal_once_init : Void*
Crystal::OnceState.new.as(Void*)
end
{% if flag?(:preview_mt) || flag?(:win32) %}
@mutex = Mutex.new(:reentrant)

# :nodoc:
fun __crystal_once(state : Void*, flag : Bool*, initializer : Void*)
state.as(Crystal::OnceState).once(flag, initializer)
end
@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
@mutex.synchronize do
previous_def
end
end
end
{% end %}
ysbaddaden marked this conversation as resolved.
Show resolved Hide resolved
end

# :nodoc:
fun __crystal_once_init : Void*
Crystal::OnceState.new.as(Void*)
end

# :nodoc:
@[AlwaysInline]
fun __crystal_once(state : Void*, flag : Bool*, initializer : Void*)
return if flag.value
state.as(Crystal::OnceState).once(flag, initializer)
Crystal.once_unreachable unless flag.value
end
{% end %}
Loading