Skip to content

Commit 2a5431b

Browse files
Added method to detect recording frozen functions with the wrong backend
1 parent 04698e3 commit 2a5431b

File tree

1 file changed

+203
-65
lines changed

1 file changed

+203
-65
lines changed

src/record_ts.cpp

+203-65
Original file line numberDiff line numberDiff line change
@@ -1858,71 +1858,6 @@ bool Recording::check_kernel_cache() {
18581858
return true;
18591859
}
18601860

1861-
void jitc_freeze_start(JitBackend backend, const uint32_t *inputs,
1862-
uint32_t n_inputs) {
1863-
1864-
if (jitc_flags() & (uint32_t) JitFlag::FreezingScope)
1865-
jitc_fail("Tried to record a thread_state while inside another "
1866-
"FreezingScope!");
1867-
1868-
// Increment scope, can be used to track missing inputs
1869-
jitc_new_scope(backend);
1870-
1871-
ThreadState *ts_ = thread_state(backend);
1872-
RecordThreadState *record_ts = new RecordThreadState(ts_);
1873-
1874-
if (backend == JitBackend::CUDA)
1875-
thread_state_cuda = record_ts;
1876-
else
1877-
thread_state_llvm = record_ts;
1878-
1879-
for (uint32_t i = 0; i < n_inputs; ++i)
1880-
record_ts->add_input(inputs[i]);
1881-
1882-
jitc_set_flag(JitFlag::FreezingScope, true);
1883-
}
1884-
Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs,
1885-
uint32_t n_outputs) {
1886-
if (RecordThreadState *rts =
1887-
dynamic_cast<RecordThreadState *>(thread_state(backend));
1888-
rts != nullptr) {
1889-
ThreadState *internal = rts->m_internal;
1890-
1891-
// Perform reassignments to internal thread-state of possibly changed
1892-
// variables
1893-
internal->scope = rts->scope;
1894-
1895-
jitc_assert(rts->record_stack.empty(),
1896-
"Kernel recording ended while still recording loop!");
1897-
1898-
jitc_set_flag(JitFlag::FreezingScope, false);
1899-
if (rts->m_exception) {
1900-
std::rethrow_exception(rts->m_exception);
1901-
}
1902-
1903-
for (uint32_t i = 0; i < n_outputs; ++i) {
1904-
rts->add_output(outputs[i]);
1905-
}
1906-
1907-
if (backend == JitBackend::CUDA) {
1908-
thread_state_cuda = internal;
1909-
} else {
1910-
thread_state_llvm = internal;
1911-
}
1912-
Recording *recording = new Recording(std::move(rts->m_recording));
1913-
recording->validate();
1914-
delete rts;
1915-
1916-
return recording;
1917-
} else {
1918-
jitc_fail(
1919-
"jit_record_stop(): Tried to stop recording a thread state "
1920-
"for backend %u, while no recording was started for this backend. "
1921-
"Try to start the recording with jit_record_start.",
1922-
(uint32_t) backend);
1923-
}
1924-
}
1925-
19261861
/**
19271862
* \brief
19281863
* This captures the offset buffer of a vcall in a kernel.
@@ -2126,6 +2061,207 @@ void RecordThreadState::add_out_param(uint32_t slot, uint32_t vtype) {
21262061
add_out_param(slot, (VarType) vtype);
21272062
}
21282063

2064+
/**
2065+
* \brief A simple wrapper around a ThreadState, that disables all its
2066+
* operations. This is used to prevent operations to one ThreadState being
2067+
* executed while a different thread state is recorded.
2068+
*/
2069+
struct DisabledThreadState : ThreadState {
2070+
ThreadState *m_internal;
2071+
JitBackend m_recording_backend;
2072+
DisabledThreadState(ThreadState *internal,
2073+
JitBackend recording_backend)
2074+
: m_internal(internal), m_recording_backend(recording_backend){
2075+
this->context = internal->context;
2076+
this->stream = internal->stream;
2077+
this->event = internal->event;
2078+
this->sync_stream_event = internal->sync_stream_event;
2079+
this->device = internal->device;
2080+
this->compute_capability = internal->compute_capability;
2081+
this->ptx_version = internal->ptx_version;
2082+
this->memory_pool = internal->memory_pool;
2083+
2084+
this->backend = internal->backend;
2085+
this->scope = internal->scope;
2086+
this->call_self_value = internal->call_self_value;
2087+
this->call_self_index = internal->call_self_index;
2088+
2089+
#if defined(DRJIT_ENABLE_OPTIX)
2090+
this->optix_pipeline = internal->optix_pipeline;
2091+
this->optix_sbt = internal->optix_sbt;
2092+
#endif
2093+
2094+
this->m_internal = internal;
2095+
2096+
this->scope = internal->scope;
2097+
}
2098+
2099+
void raise_exception() {
2100+
const char *backend =
2101+
m_internal->backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2102+
const char *recording_backend =
2103+
m_recording_backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2104+
jitc_raise(
2105+
"The frozen function is being recorded for the %s backend, but "
2106+
"you tried to execute an operation for the %s backend, this is "
2107+
"not permitted. It might indicate that you specified the wrong "
2108+
"backend or the wrong backend was inferred from the inputs.",
2109+
recording_backend, backend);
2110+
};
2111+
2112+
void barrier() override { raise_exception(); };
2113+
Task *launch(Kernel /*kernel*/, KernelKey * /*key*/, XXH128_hash_t /*hash*/,
2114+
uint32_t /*size*/, std::vector<void *> * /*kernel_params*/,
2115+
const std::vector<uint32_t> * /*kernel_param_ids*/) override {
2116+
raise_exception();
2117+
return nullptr;
2118+
};
2119+
void memset_async(void * /*ptr*/, uint32_t /*size*/, uint32_t /*isize*/,
2120+
const void * /*src*/) override {
2121+
raise_exception();
2122+
};
2123+
uint32_t compress(const uint8_t * /*in*/, uint32_t /*size*/,
2124+
uint32_t * /*out*/) override {
2125+
raise_exception();
2126+
return 0;
2127+
};
2128+
uint32_t mkperm(const uint32_t * /*values*/, uint32_t /*size*/,
2129+
uint32_t /*bucket_count*/, uint32_t * /*perm*/,
2130+
uint32_t * /*offsets*/) override {
2131+
raise_exception();
2132+
return 0;
2133+
};
2134+
void memcpy(void * /*dst*/, const void * /*src*/,
2135+
size_t /*size*/) override {
2136+
raise_exception();
2137+
};
2138+
void memcpy_async(void * /*dst*/, const void * /*src*/,
2139+
size_t /*size*/) override {
2140+
raise_exception();
2141+
};
2142+
void block_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/,
2143+
uint32_t /*block_size*/, const void * /*in*/,
2144+
void * /*out*/) override {
2145+
raise_exception();
2146+
};
2147+
void block_prefix_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/,
2148+
uint32_t /*block_size*/, bool /*exclusive*/,
2149+
bool /*reverse*/, const void * /*in*/,
2150+
void * /*out*/) override {
2151+
raise_exception();
2152+
};
2153+
void reduce_dot(VarType /*type*/, const void * /*ptr_1*/,
2154+
const void * /*ptr_2*/, uint32_t /*size*/,
2155+
void * /*out*/) override {
2156+
raise_exception();
2157+
};
2158+
void poke(void * /*dst*/, const void * /*src*/,
2159+
uint32_t /*size*/) override {
2160+
raise_exception();
2161+
};
2162+
void aggregate(void * /*dst*/, AggregationEntry * /*agg*/,
2163+
uint32_t /*size*/) override {
2164+
raise_exception();
2165+
};
2166+
void enqueue_host_func(void (* /*callback*/)(void *),
2167+
void * /*payload*/) override {
2168+
raise_exception();
2169+
};
2170+
void notify_expand(uint32_t /*index*/) override {};
2171+
void reduce_expanded(VarType /*vt*/, ReduceOp /*reduce_op*/,
2172+
void * /*data*/, uint32_t /*exp*/,
2173+
uint32_t /*size*/) override {};
2174+
void notify_free(const void * /*ptr*/) override {};
2175+
};
2176+
2177+
void disable_thread_state(ThreadState **ts, JitBackend recording_backend) {
2178+
if (!*ts)
2179+
return;
2180+
*ts = new DisabledThreadState(*ts, recording_backend);
2181+
}
2182+
2183+
void enable_thread_state(ThreadState **ts) {
2184+
if (!*ts)
2185+
return;
2186+
if (DisabledThreadState *dst = dynamic_cast<DisabledThreadState *>(*ts);
2187+
dst != nullptr) {
2188+
*ts = dst->m_internal;
2189+
delete dst;
2190+
}else{
2191+
jitc_fail("Tried to enable a ThreadState that was not disabled.");
2192+
}
2193+
}
2194+
2195+
void jitc_freeze_start(JitBackend backend, const uint32_t *inputs,
2196+
uint32_t n_inputs) {
2197+
2198+
if (jitc_flags() & (uint32_t) JitFlag::FreezingScope)
2199+
jitc_fail("Tried to record a thread_state while inside another "
2200+
"FreezingScope!");
2201+
2202+
// Increment scope, can be used to track missing inputs
2203+
jitc_new_scope(backend);
2204+
2205+
ThreadState *ts_ = thread_state(backend);
2206+
RecordThreadState *record_ts = new RecordThreadState(ts_);
2207+
2208+
if (backend == JitBackend::CUDA) {
2209+
thread_state_cuda = record_ts;
2210+
disable_thread_state(&thread_state_llvm, backend);
2211+
} else {
2212+
thread_state_llvm = record_ts;
2213+
disable_thread_state(&thread_state_cuda, backend);
2214+
}
2215+
2216+
for (uint32_t i = 0; i < n_inputs; ++i)
2217+
record_ts->add_input(inputs[i]);
2218+
2219+
jitc_set_flag(JitFlag::FreezingScope, true);
2220+
}
2221+
Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs,
2222+
uint32_t n_outputs) {
2223+
if (RecordThreadState *rts =
2224+
dynamic_cast<RecordThreadState *>(thread_state(backend));
2225+
rts != nullptr) {
2226+
ThreadState *internal = rts->m_internal;
2227+
2228+
// Perform reassignments to internal thread-state of possibly changed
2229+
// variables
2230+
internal->scope = rts->scope;
2231+
2232+
jitc_assert(rts->record_stack.empty(),
2233+
"Kernel recording ended while still recording loop!");
2234+
2235+
jitc_set_flag(JitFlag::FreezingScope, false);
2236+
if (rts->m_exception) {
2237+
std::rethrow_exception(rts->m_exception);
2238+
}
2239+
2240+
for (uint32_t i = 0; i < n_outputs; ++i) {
2241+
rts->add_output(outputs[i]);
2242+
}
2243+
2244+
if (backend == JitBackend::CUDA) {
2245+
thread_state_cuda = internal;
2246+
enable_thread_state(&thread_state_llvm);
2247+
} else {
2248+
thread_state_llvm = internal;
2249+
enable_thread_state(&thread_state_cuda);
2250+
}
2251+
Recording *recording = new Recording(std::move(rts->m_recording));
2252+
recording->validate();
2253+
delete rts;
2254+
2255+
return recording;
2256+
} else {
2257+
jitc_fail(
2258+
"jit_record_stop(): Tried to stop recording a thread state "
2259+
"for backend %u, while no recording was started for this backend. "
2260+
"Try to start the recording with jit_record_start.",
2261+
(uint32_t) backend);
2262+
}
2263+
}
2264+
21292265
void jitc_freeze_abort(JitBackend backend) {
21302266
if (RecordThreadState *rts =
21312267
dynamic_cast<RecordThreadState *>(thread_state(backend));
@@ -2139,8 +2275,10 @@ void jitc_freeze_abort(JitBackend backend) {
21392275

21402276
if (backend == JitBackend::CUDA) {
21412277
thread_state_cuda = internal;
2278+
enable_thread_state(&thread_state_llvm);
21422279
} else {
21432280
thread_state_llvm = internal;
2281+
enable_thread_state(&thread_state_cuda);
21442282
}
21452283

21462284
delete rts;

0 commit comments

Comments
 (0)