Skip to content

Commit 0c7bcb0

Browse files
Deferred exception handling for disabled thread state
1 parent 2a5431b commit 0c7bcb0

File tree

1 file changed

+43
-28
lines changed

1 file changed

+43
-28
lines changed

src/record_ts.cpp

+43-28
Original file line numberDiff line numberDiff line change
@@ -2069,6 +2069,7 @@ void RecordThreadState::add_out_param(uint32_t slot, uint32_t vtype) {
20692069
struct DisabledThreadState : ThreadState {
20702070
ThreadState *m_internal;
20712071
JitBackend m_recording_backend;
2072+
bool m_raised = false;
20722073
DisabledThreadState(ThreadState *internal,
20732074
JitBackend recording_backend)
20742075
: m_internal(internal), m_recording_backend(recording_backend){
@@ -2096,76 +2097,89 @@ struct DisabledThreadState : ThreadState {
20962097
this->scope = internal->scope;
20972098
}
20982099

2099-
void raise_exception() {
2100-
const char *backend =
2101-
m_internal->backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2102-
const char *recording_backend =
2103-
m_recording_backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2104-
jitc_raise(
2105-
"The frozen function is being recorded for the %s backend, but "
2106-
"you tried to execute an operation for the %s backend, this is "
2107-
"not permitted. It might indicate that you specified the wrong "
2108-
"backend or the wrong backend was inferred from the inputs.",
2109-
recording_backend, backend);
2100+
/**
2101+
* Record that an exception has been thrown, similar to the function in
2102+
* ``RecordThreadState``.
2103+
*/
2104+
void record_exception() {
2105+
m_raised = true;
21102106
};
21112107

2112-
void barrier() override { raise_exception(); };
2108+
/**
2109+
* Actually throws the exception, if any was thrown during recording.
2110+
*/
2111+
void rethrow_exception() {
2112+
if (m_raised) {
2113+
const char *backend =
2114+
m_internal->backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2115+
const char *recording_backend =
2116+
m_recording_backend == JitBackend::CUDA ? "CUDA" : "LLVM";
2117+
jitc_raise(
2118+
"The frozen function is being recorded for the %s backend, but "
2119+
"you tried to execute an operation for the %s backend, this is "
2120+
"not permitted. It might indicate that you specified the wrong "
2121+
"backend or the wrong backend was inferred from the inputs.",
2122+
recording_backend, backend);
2123+
}
2124+
}
2125+
2126+
void barrier() override { record_exception(); };
21132127
Task *launch(Kernel /*kernel*/, KernelKey * /*key*/, XXH128_hash_t /*hash*/,
21142128
uint32_t /*size*/, std::vector<void *> * /*kernel_params*/,
21152129
const std::vector<uint32_t> * /*kernel_param_ids*/) override {
2116-
raise_exception();
2130+
record_exception();
21172131
return nullptr;
21182132
};
21192133
void memset_async(void * /*ptr*/, uint32_t /*size*/, uint32_t /*isize*/,
21202134
const void * /*src*/) override {
2121-
raise_exception();
2135+
record_exception();
21222136
};
21232137
uint32_t compress(const uint8_t * /*in*/, uint32_t /*size*/,
21242138
uint32_t * /*out*/) override {
2125-
raise_exception();
2139+
record_exception();
21262140
return 0;
21272141
};
21282142
uint32_t mkperm(const uint32_t * /*values*/, uint32_t /*size*/,
21292143
uint32_t /*bucket_count*/, uint32_t * /*perm*/,
21302144
uint32_t * /*offsets*/) override {
2131-
raise_exception();
2145+
record_exception();
21322146
return 0;
21332147
};
21342148
void memcpy(void * /*dst*/, const void * /*src*/,
21352149
size_t /*size*/) override {
2136-
raise_exception();
2150+
record_exception();
21372151
};
21382152
void memcpy_async(void * /*dst*/, const void * /*src*/,
21392153
size_t /*size*/) override {
2140-
raise_exception();
2154+
record_exception();
21412155
};
21422156
void block_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/,
21432157
uint32_t /*block_size*/, const void * /*in*/,
21442158
void * /*out*/) override {
2145-
raise_exception();
2159+
record_exception();
21462160
};
21472161
void block_prefix_reduce(VarType /*vt*/, ReduceOp /*op*/, uint32_t /*size*/,
21482162
uint32_t /*block_size*/, bool /*exclusive*/,
21492163
bool /*reverse*/, const void * /*in*/,
21502164
void * /*out*/) override {
2151-
raise_exception();
2165+
record_exception();
21522166
};
21532167
void reduce_dot(VarType /*type*/, const void * /*ptr_1*/,
21542168
const void * /*ptr_2*/, uint32_t /*size*/,
21552169
void * /*out*/) override {
2156-
raise_exception();
2170+
record_exception();
21572171
};
21582172
void poke(void * /*dst*/, const void * /*src*/,
21592173
uint32_t /*size*/) override {
2160-
raise_exception();
2174+
record_exception();
21612175
};
21622176
void aggregate(void * /*dst*/, AggregationEntry * /*agg*/,
21632177
uint32_t /*size*/) override {
2164-
raise_exception();
2178+
record_exception();
21652179
};
21662180
void enqueue_host_func(void (* /*callback*/)(void *),
21672181
void * /*payload*/) override {
2168-
raise_exception();
2182+
record_exception();
21692183
};
21702184
void notify_expand(uint32_t /*index*/) override {};
21712185
void reduce_expanded(VarType /*vt*/, ReduceOp /*reduce_op*/,
@@ -2183,10 +2197,11 @@ void disable_thread_state(ThreadState **ts, JitBackend recording_backend) {
21832197
void enable_thread_state(ThreadState **ts) {
21842198
if (!*ts)
21852199
return;
2186-
if (DisabledThreadState *dst = dynamic_cast<DisabledThreadState *>(*ts);
2187-
dst != nullptr) {
2188-
*ts = dst->m_internal;
2189-
delete dst;
2200+
if (DisabledThreadState *dts = dynamic_cast<DisabledThreadState *>(*ts);
2201+
dts != nullptr) {
2202+
*ts = dts->m_internal;
2203+
dts->rethrow_exception();
2204+
delete dts;
21902205
}else{
21912206
jitc_fail("Tried to enable a ThreadState that was not disabled.");
21922207
}

0 commit comments

Comments
 (0)