Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fiber safety to crystal/once #15370

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/compiler/crystal/codegen/class_var.cr
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
initialized_flag.thread_local = true if class_var.thread_local?
end
Expand Down Expand Up @@ -61,7 +61,7 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @llvm_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @llvm_mod.globals.add(llvm_context.int8, initialized_flag_name)
initialized_flag = @llvm_mod.globals.add(llvm_context.int1, initialized_flag_name)
initialized_flag.thread_local = true if class_var.thread_local?
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/crystal/codegen/const.cr
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = const.initialized_llvm_name
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
end
initialized_flag
Expand Down
3 changes: 0 additions & 3 deletions src/compiler/crystal/codegen/once.cr
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class Crystal::CodeGenVisitor
end

state = load(once_state_type, once_state_global)
{% if LibLLVM::IS_LT_150 %}
flag = bit_cast(flag, @llvm_context.int1.pointer) # cast Int8* to Bool*
{% end %}
args = [state, flag, initializer]
end

Expand Down
2 changes: 1 addition & 1 deletion src/crystal/main.cr
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ module Crystal
# so we explicitly initialize their class vars, then init crystal/once
Thread.init
Fiber.init
Crystal.once_init
Crystal::Once.init
end

# :nodoc:
Expand Down
210 changes: 110 additions & 100 deletions src/crystal/once.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,138 +7,148 @@
# with older compiler releases. It is executed only once at the beginning of the
# program and, for the legacy implementation, the result is passed on each call
# to `__crystal_once`.
#
# In multithread mode a mutex is used to avoid race conditions between threads.
#
# On Win32, `Crystal::System::FileDescriptor#@@reader_thread` spawns a new
# thread even without the `preview_mt` flag, and the thread can also reference
# Crystal constants, leading to race conditions, so we always enable the mutex.

{% if compare_versions(Crystal::VERSION, "1.16.0-dev") >= 0 %}
# This implementation uses an enum over the initialization flag pointer for
# each value to find infinite loops and raise an error.

module Crystal
# :nodoc:
enum OnceState : Int8
Processing = -1
Uninitialized = 0
Initialized = 1
require "crystal/pointer_linked_list"
require "crystal/spin_lock"

module Crystal
# :nodoc:
module Once
struct Operation
include PointerLinkedList::Node

getter fiber : Fiber
getter flag : Bool*

def initialize(@flag : Bool*, @fiber : Fiber)
@waiting = PointerLinkedList(Fiber::PointerLinkedListNode).new
end

def add_waiter(node) : Nil
@waiting.push(node)
end

def resume_all : Nil
@waiting.each(&.value.enqueue)
end
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex = uninitialized Mutex
{% end %}
@@spin = uninitialized SpinLock
@@operations = uninitialized PointerLinkedList(Operation)

# :nodoc:
def self.once_init : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex = Mutex.new(:reentrant)
{% end %}
def self.init : Nil
@@spin = SpinLock.new
@@operations = PointerLinkedList(Operation).new
end

# :nodoc:
# Using @[NoInline] so LLVM optimizes for the hot path (var already
# initialized).
@[NoInline]
def self.once(flag : OnceState*, initializer : Void*) : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex.synchronize { once_exec(flag, initializer) }
{% else %}
once_exec(flag, initializer)
{% end %}
protected def self.exec(flag : Bool*, &)
@@spin.lock

if flag.value
@@spin.unlock
elsif operation = processing?(flag)
check_reentrancy(operation)
wait_initializer(operation)
else
run_initializer(flag) { yield }
end

# safety check, and allows to safely call `Intrinsics.unreachable` in
# `__crystal_once`
unless flag.value.initialized?
System.print_error "BUG: failed to initialize constant or class variable\n"
LibC._exit(1)
return if flag.value

System.print_error "BUG: failed to initialize class variable or constant\n"
LibC._exit(1)
end

private def self.processing?(flag)
@@operations.each do |operation|
return operation if operation.value.flag == flag
end
end

private def self.once_exec(flag : OnceState*, initializer : Void*) : Nil
case flag.value
in .initialized?
return
in .uninitialized?
flag.value = :processing
Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = :initialized
in .processing?
private def self.check_reentrancy(operation)
if operation.value.fiber == Fiber.current
@@spin.unlock
raise "Recursion while initializing class variables and/or constants"
end
end

private def self.wait_initializer(operation)
waiting = Fiber::PointerLinkedListNode.new(Fiber.current)
operation.value.add_waiter(pointerof(waiting))
@@spin.unlock
Fiber.suspend
end

private def self.run_initializer(flag, &)
operation = Operation.new(flag, Fiber.current)
@@operations.push pointerof(operation)
@@spin.unlock

yield

@@spin.lock
flag.value = true
@@operations.delete pointerof(operation)
@@spin.unlock

operation.resume_all
end
end

# :nodoc:
#
# Using `@[AlwaysInline]` allows LLVM to optimize const accesses. Since this
# is a `fun` the function will still appear in the symbol table, though it
# will never be called.
@[AlwaysInline]
fun __crystal_once(flag : Crystal::OnceState*, initializer : Void*) : Nil
return if flag.value.initialized?

Crystal.once(flag, initializer)
# Never inlined to avoid bloating the call site with the slow-path that should
# usually not be taken.
@[NoInline]
def self.once(flag : Bool*, initializer : Void*)
Once.exec(flag, &Proc(Nil).new(initializer, Pointer(Void).null))
end

# tell LLVM that it can optimize away repeated `__crystal_once` calls for
# this global (e.g. repeated access to constant in a single funtion);
# this is truly unreachable otherwise `Crystal.once` would have panicked
Intrinsics.unreachable unless flag.value.initialized?
# :nodoc:
#
# NOTE: should also never be inlined, but that would capture the block, which
# would be a breaking change when we use this method to protect class getter
# and class property macros with lazy initialization (the block may return or
# break).
#
# TODO: consider a compile time flag to enable/disable the capture? returning
# from the block is unexpected behavior: the returned value won't be saved in
# the class variable.
def self.once(flag : Bool*, &)
Once.exec(flag) { yield } unless flag.value
end
{% else %}
# This implementation uses a global array to store the initialization flag
# pointers for each value to find infinite loops and raise an error.

module Crystal
# :nodoc:
class OnceState
@rec = [] of Bool*

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
end
@rec << flag

Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = true

@rec.pop
end
end
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@mutex = Mutex.new(:reentrant)

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
@mutex.synchronize do
previous_def
end
end
end
{% end %}
end
{% if compare_versions(Crystal::VERSION, "1.16.0-dev") >= 0 %}
# :nodoc:
#
# We always inline this accessor to optimize for the fast-path (already
# initialized).
@[AlwaysInline]
fun __crystal_once(flag : Bool*, initializer : Void*)
return if flag.value
Crystal.once(flag, initializer)

# :nodoc:
def self.once_init : Nil
end
# tells LLVM to assume that the flag is true, this avoids repeated access to
# the same constant or class variable to check the flag and try to run the
# initializer (only the first access will)
Intrinsics.unreachable unless flag.value
end

{% else %}
# :nodoc:
#
# Unused. Kept for backward compatibility with older compilers.
fun __crystal_once_init : Void*
Crystal::OnceState.new.as(Void*)
Pointer(Void).null
end

# :nodoc:
@[AlwaysInline]
fun __crystal_once(state : Void*, flag : Bool*, initializer : Void*)
return if flag.value
state.as(Crystal::OnceState).once(flag, initializer)
Crystal.once(flag, initializer)
Intrinsics.unreachable unless flag.value
end
{% end %}
6 changes: 3 additions & 3 deletions src/crystal/spin_lock.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ struct Crystal::SpinLock
private UNLOCKED = 0
private LOCKED = 1

{% if flag?(:preview_mt) %}
{% if flag?(:preview_mt) || flag?(:win32) %}
@m = Atomic(Int32).new(UNLOCKED)
{% end %}

def lock
{% if flag?(:preview_mt) %}
{% if flag?(:preview_mt) || flag?(:win32) %}
while @m.swap(LOCKED, :acquire) == LOCKED
while @m.get(:relaxed) == LOCKED
Intrinsics.pause
Expand All @@ -18,7 +18,7 @@ struct Crystal::SpinLock
end

def unlock
{% if flag?(:preview_mt) %}
{% if flag?(:preview_mt) || flag?(:win32) %}
@m.set(UNLOCKED, :release)
{% end %}
end
Expand Down
2 changes: 1 addition & 1 deletion src/prelude.cr
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
# appear in the API docs.

# This list requires ordered statements
require "crystal/once"
require "lib_c"
require "macros"
require "object"
require "crystal/once"
require "comparable"
require "exception"
require "iterable"
Expand Down