Skip to content

Commit 3dbd62b

Browse files
Implemented DisabledThreadState to detect frozen function backend missmatches
1 parent aa64245 commit 3dbd62b

File tree

1 file changed

+218
-65
lines changed

1 file changed

+218
-65
lines changed

src/record_ts.cpp

+218-65
Original file line numberDiff line numberDiff line change
@@ -1858,71 +1858,6 @@ bool Recording::check_kernel_cache() {
18581858
return true;
18591859
}
18601860

1861-
void jitc_freeze_start(JitBackend backend, const uint32_t *inputs,
1862-
uint32_t n_inputs) {
1863-
1864-
if (jitc_flags() & (uint32_t) JitFlag::FreezingScope)
1865-
jitc_fail("Tried to record a thread_state while inside another "
1866-
"FreezingScope!");
1867-
1868-
// Increment scope, can be used to track missing inputs
1869-
jitc_new_scope(backend);
1870-
1871-
ThreadState *ts_ = thread_state(backend);
1872-
RecordThreadState *record_ts = new RecordThreadState(ts_);
1873-
1874-
if (backend == JitBackend::CUDA)
1875-
thread_state_cuda = record_ts;
1876-
else
1877-
thread_state_llvm = record_ts;
1878-
1879-
for (uint32_t i = 0; i < n_inputs; ++i)
1880-
record_ts->add_input(inputs[i]);
1881-
1882-
jitc_set_flag(JitFlag::FreezingScope, true);
1883-
}
1884-
Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs,
1885-
uint32_t n_outputs) {
1886-
if (RecordThreadState *rts =
1887-
dynamic_cast<RecordThreadState *>(thread_state(backend));
1888-
rts != nullptr) {
1889-
ThreadState *internal = rts->m_internal;
1890-
1891-
// Perform reassignments to internal thread-state of possibly changed
1892-
// variables
1893-
internal->scope = rts->scope;
1894-
1895-
jitc_assert(rts->record_stack.empty(),
1896-
"Kernel recording ended while still recording loop!");
1897-
1898-
jitc_set_flag(JitFlag::FreezingScope, false);
1899-
if (rts->m_exception) {
1900-
std::rethrow_exception(rts->m_exception);
1901-
}
1902-
1903-
for (uint32_t i = 0; i < n_outputs; ++i) {
1904-
rts->add_output(outputs[i]);
1905-
}
1906-
1907-
if (backend == JitBackend::CUDA) {
1908-
thread_state_cuda = internal;
1909-
} else {
1910-
thread_state_llvm = internal;
1911-
}
1912-
Recording *recording = new Recording(std::move(rts->m_recording));
1913-
recording->validate();
1914-
delete rts;
1915-
1916-
return recording;
1917-
} else {
1918-
jitc_fail(
1919-
"jit_record_stop(): Tried to stop recording a thread state "
1920-
"for backend %u, while no recording was started for this backend. "
1921-
"Try to start the recording with jit_record_start.",
1922-
(uint32_t) backend);
1923-
}
1924-
}
1925-
19261861
/**
19271862
* \brief
19281863
* This captures the offset buffer of a vcall in a kernel.
@@ -2126,6 +2061,222 @@ void RecordThreadState::add_out_param(uint32_t slot, uint32_t vtype) {
21262061
add_out_param(slot, (VarType) vtype);
21272062
}
21282063

2064+
/**
2065+
* \brief A simple wrapper around a ThreadState, that disables all its
2066+
* operations. This is used to prevent operations to one ThreadState being
2067+
* executed while a different thread state is recorded.
2068+
*/
2069+
struct DisabledThreadState : ThreadState {
2070+
ThreadState *m_internal;
2071+
JitBackend m_recording_backend;
2072+
bool m_raised = false;
2073+
DisabledThreadState(ThreadState *internal,
2074+
JitBackend recording_backend)
2075+
: m_internal(internal), m_recording_backend(recording_backend){
2076+
this->context = internal->context;
2077+
this->stream = internal->stream;
2078+
this->event = internal->event;
2079+
this->sync_stream_event = internal->sync_stream_event;
2080+
this->device = internal->device;
2081+
this->compute_capability = internal->compute_capability;
2082+
this->ptx_version = internal->ptx_version;
2083+
this->memory_pool = internal->memory_pool;
2084+
2085+
this->backend = internal->backend;
2086+
this->scope = internal->scope;
2087+
this->call_self_value = internal->call_self_value;
2088+
this->call_self_index = internal->call_self_index;
2089+
2090+
#if defined(DRJIT_ENABLE_OPTIX)
2091+
this->optix_pipeline = internal->optix_pipeline;
2092+
this->optix_sbt = internal->optix_sbt;
2093+
#endif
2094+
2095+
this->m_internal = internal;
2096+
2097+
this->scope = internal->scope;
2098+
}
2099+
2100+
/**
2101+
* Record that an exception has been thrown, similar to the function in
2102+
* ``RecordThreadState``.
2103+
*/
2104+
void record_exception() {
2105+
m_raised = true;
2106+
};
2107+
2108+
/**
2109+
* Actually throws the exception, if any was thrown during recording.
2110+
*/
2111+
void rethrow_exception() {
2112+
if (m_raised) {
2113+
const char *backend =
2114+
m_internal->backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2115+
const char *recording_backend =
2116+
m_recording_backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2117+
jitc_raise(
2118+
"The frozen function is being recorded for the %s backend, but "
2119+
"you tried to execute an operation for the %s backend, this is "
2120+
"not permitted. It might indicate that you specified the wrong "
2121+
"backend or the wrong backend was inferred from the inputs.",
2122+
recording_backend, backend);
2123+
}
2124+
}
2125+
2126+
void barrier() override { record_exception(); };
2127+
Task *launch(Kernel /*kernel*/, KernelKey * /*key*/, XXH128_hash_t /*hash*/,
2128+
uint32_t /*size*/, std::vector<void *> * /*kernel_params*/,
2129+
const std::vector<uint32_t> * /*kernel_param_ids*/) override {
2130+
record_exception();
2131+
return nullptr;
2132+
};
2133+
void memset_async(void * /*ptr*/, uint32_t /*size*/, uint32_t /*isize*/,
2134+
const void * /*src*/) override {
2135+
record_exception();
2136+
};
2137+
uint32_t compress(const uint8_t * /*in*/, uint32_t /*size*/,
2138+
uint32_t * /*out*/) override {
2139+
record_exception();
2140+
return 0;
2141+
};
2142+
uint32_t mkperm(const uint32_t * /*values*/, uint32_t /*size*/,
2143+
uint32_t /*bucket_count*/, uint32_t * /*perm*/,
2144+
uint32_t * /*offsets*/) override {
2145+
record_exception();
2146+
return 0;
2147+
};
2148+
void memcpy(void * /*dst*/, const void * /*src*/,
2149+
size_t /*size*/) override {
2150+
record_exception();
2151+
};
2152+
void memcpy_async(void * /*dst*/, const void * /*src*/,
2153+
size_t /*size*/) override {
2154+
record_exception();
2155+
};
2156+
void block_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/,
2157+
uint32_t /*block_size*/, const void * /*in*/,
2158+
void * /*out*/) override {
2159+
record_exception();
2160+
};
2161+
void block_prefix_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/,
2162+
uint32_t /*block_size*/, bool /*exclusive*/,
2163+
bool /*reverse*/, const void * /*in*/,
2164+
void * /*out*/) override {
2165+
record_exception();
2166+
};
2167+
void reduce_dot(VarType /*type*/, const void * /*ptr_1*/,
2168+
const void * /*ptr_2*/, uint32_t /*size*/,
2169+
void * /*out*/) override {
2170+
record_exception();
2171+
};
2172+
void poke(void * /*dst*/, const void * /*src*/,
2173+
uint32_t /*size*/) override {
2174+
record_exception();
2175+
};
2176+
void aggregate(void * /*dst*/, AggregationEntry * /*agg*/,
2177+
uint32_t /*size*/) override {
2178+
record_exception();
2179+
};
2180+
void enqueue_host_func(void (* /*callback*/)(void *),
2181+
void * /*payload*/) override {
2182+
record_exception();
2183+
};
2184+
void notify_expand(uint32_t /*index*/) override {};
2185+
void reduce_expanded(VarType /*vt*/, ReduceOp /*reduce_op*/,
2186+
void * /*data*/, uint32_t /*exp*/,
2187+
uint32_t /*size*/) override {};
2188+
void notify_free(const void * /*ptr*/) override {};
2189+
};
2190+
2191+
void set_disabled_thread_state(ThreadState **ts, JitBackend recording_backend) {
2192+
if (!*ts)
2193+
return;
2194+
*ts = new DisabledThreadState(*ts, recording_backend);
2195+
}
2196+
2197+
void unset_disabled_thread_state(ThreadState **ts) {
2198+
if (!*ts)
2199+
return;
2200+
if (DisabledThreadState *dts = dynamic_cast<DisabledThreadState *>(*ts);
2201+
dts != nullptr) {
2202+
*ts = dts->m_internal;
2203+
dts->rethrow_exception();
2204+
delete dts;
2205+
}else{
2206+
jitc_fail("Tried to enable a ThreadState that was not disabled.");
2207+
}
2208+
}
2209+
2210+
void jitc_freeze_start(JitBackend backend, const uint32_t *inputs,
2211+
uint32_t n_inputs) {
2212+
2213+
if (jitc_flags() & (uint32_t) JitFlag::FreezingScope)
2214+
jitc_fail("Tried to record a thread_state while inside another "
2215+
"FreezingScope!");
2216+
2217+
// Increment scope, can be used to track missing inputs
2218+
jitc_new_scope(backend);
2219+
2220+
ThreadState *ts_ = thread_state(backend);
2221+
RecordThreadState *record_ts = new RecordThreadState(ts_);
2222+
2223+
if (backend == JitBackend::CUDA) {
2224+
thread_state_cuda = record_ts;
2225+
set_disabled_thread_state(&thread_state_llvm, backend);
2226+
} else {
2227+
thread_state_llvm = record_ts;
2228+
set_disabled_thread_state(&thread_state_cuda, backend);
2229+
}
2230+
2231+
for (uint32_t i = 0; i < n_inputs; ++i)
2232+
record_ts->add_input(inputs[i]);
2233+
2234+
jitc_set_flag(JitFlag::FreezingScope, true);
2235+
}
2236+
Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs,
2237+
uint32_t n_outputs) {
2238+
if (RecordThreadState *rts =
2239+
dynamic_cast<RecordThreadState *>(thread_state(backend));
2240+
rts != nullptr) {
2241+
ThreadState *internal = rts->m_internal;
2242+
2243+
// Perform reassignments to internal thread-state of possibly changed
2244+
// variables
2245+
internal->scope = rts->scope;
2246+
2247+
jitc_assert(rts->record_stack.empty(),
2248+
"Kernel recording ended while still recording loop!");
2249+
2250+
jitc_set_flag(JitFlag::FreezingScope, false);
2251+
if (rts->m_exception) {
2252+
std::rethrow_exception(rts->m_exception);
2253+
}
2254+
2255+
for (uint32_t i = 0; i < n_outputs; ++i) {
2256+
rts->add_output(outputs[i]);
2257+
}
2258+
2259+
if (backend == JitBackend::CUDA) {
2260+
thread_state_cuda = internal;
2261+
unset_disabled_thread_state(&thread_state_llvm);
2262+
} else {
2263+
thread_state_llvm = internal;
2264+
unset_disabled_thread_state(&thread_state_cuda);
2265+
}
2266+
Recording *recording = new Recording(std::move(rts->m_recording));
2267+
recording->validate();
2268+
delete rts;
2269+
2270+
return recording;
2271+
} else {
2272+
jitc_fail(
2273+
"jit_record_stop(): Tried to stop recording a thread state "
2274+
"for backend %u, while no recording was started for this backend. "
2275+
"Try to start the recording with jit_record_start.",
2276+
(uint32_t) backend);
2277+
}
2278+
}
2279+
21292280
void jitc_freeze_abort(JitBackend backend) {
21302281
if (RecordThreadState *rts =
21312282
dynamic_cast<RecordThreadState *>(thread_state(backend));
@@ -2139,8 +2290,10 @@ void jitc_freeze_abort(JitBackend backend) {
21392290

21402291
if (backend == JitBackend::CUDA) {
21412292
thread_state_cuda = internal;
2293+
unset_disabled_thread_state(&thread_state_llvm);
21422294
} else {
21432295
thread_state_llvm = internal;
2296+
unset_disabled_thread_state(&thread_state_cuda);
21442297
}
21452298

21462299
delete rts;

0 commit comments

Comments
 (0)