@@ -2069,6 +2069,7 @@ void RecordThreadState::add_out_param(uint32_t slot, uint32_t vtype) {
2069
2069
struct DisabledThreadState : ThreadState {
2070
2070
ThreadState *m_internal;
2071
2071
JitBackend m_recording_backend;
2072
+ bool m_raised = false ;
2072
2073
DisabledThreadState (ThreadState *internal,
2073
2074
JitBackend recording_backend)
2074
2075
: m_internal(internal), m_recording_backend(recording_backend){
@@ -2096,76 +2097,89 @@ struct DisabledThreadState : ThreadState {
2096
2097
this ->scope = internal->scope ;
2097
2098
}
2098
2099
2099
- void raise_exception () {
2100
- const char *backend =
2101
- m_internal->backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2102
- const char *recording_backend =
2103
- m_recording_backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2104
- jitc_raise (
2105
- " The frozen function is being recorded for the %s backend, but "
2106
- " you tried to execute an operation for the %s backend, this is "
2107
- " not permitted. It might indicate that you specified the wrong "
2108
- " backend or the wrong backend was inferred from the inputs." ,
2109
- recording_backend, backend);
2100
+ /* *
2101
+ * Record that an exception has been thrown, similar to the function in
2102
+ * ``RecordThreadState``.
2103
+ */
2104
+ void record_exception () {
2105
+ m_raised = true ;
2110
2106
};
2111
2107
2112
- void barrier () override { raise_exception (); };
2108
+ /* *
2109
+ * Actually throws the exception, if any was thrown during recording.
2110
+ */
2111
+ void rethrow_exception () {
2112
+ if (m_raised) {
2113
+ const char *backend =
2114
+ m_internal->backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2115
+ const char *recording_backend =
2116
+ m_recording_backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2117
+ jitc_raise (
2118
+ " The frozen function is being recorded for the %s backend, but "
2119
+ " you tried to execute an operation for the %s backend, this is "
2120
+ " not permitted. It might indicate that you specified the wrong "
2121
+ " backend or the wrong backend was inferred from the inputs." ,
2122
+ recording_backend, backend);
2123
+ }
2124
+ }
2125
+
2126
+ void barrier () override { record_exception (); };
2113
2127
Task *launch (Kernel /* kernel*/ , KernelKey * /* key*/ , XXH128_hash_t /* hash*/ ,
2114
2128
uint32_t /* size*/ , std::vector<void *> * /* kernel_params*/ ,
2115
2129
const std::vector<uint32_t > * /* kernel_param_ids*/ ) override {
2116
- raise_exception ();
2130
+ record_exception ();
2117
2131
return nullptr ;
2118
2132
};
2119
2133
void memset_async (void * /* ptr*/ , uint32_t /* size*/ , uint32_t /* isize*/ ,
2120
2134
const void * /* src*/ ) override {
2121
- raise_exception ();
2135
+ record_exception ();
2122
2136
};
2123
2137
uint32_t compress (const uint8_t * /* in*/ , uint32_t /* size*/ ,
2124
2138
uint32_t * /* out*/ ) override {
2125
- raise_exception ();
2139
+ record_exception ();
2126
2140
return 0 ;
2127
2141
};
2128
2142
uint32_t mkperm (const uint32_t * /* values*/ , uint32_t /* size*/ ,
2129
2143
uint32_t /* bucket_count*/ , uint32_t * /* perm*/ ,
2130
2144
uint32_t * /* offsets*/ ) override {
2131
- raise_exception ();
2145
+ record_exception ();
2132
2146
return 0 ;
2133
2147
};
2134
2148
void memcpy (void * /* dst*/ , const void * /* src*/ ,
2135
2149
size_t /* size*/ ) override {
2136
- raise_exception ();
2150
+ record_exception ();
2137
2151
};
2138
2152
void memcpy_async (void * /* dst*/ , const void * /* src*/ ,
2139
2153
size_t /* size*/ ) override {
2140
- raise_exception ();
2154
+ record_exception ();
2141
2155
};
2142
2156
void block_reduce (VarType /* vt*/ , ReduceOp /* op*/ , uint32_t /* size*/ ,
2143
2157
uint32_t /* block_size*/ , const void * /* in*/ ,
2144
2158
void * /* out*/ ) override {
2145
- raise_exception ();
2159
+ record_exception ();
2146
2160
};
2147
2161
void block_prefix_reduce (VarType /* vt*/ , ReduceOp /* op*/ , uint32_t /* size*/ ,
2148
2162
uint32_t /* block_size*/ , bool /* exclusive*/ ,
2149
2163
bool /* reverse*/ , const void * /* in*/ ,
2150
2164
void * /* out*/ ) override {
2151
- raise_exception ();
2165
+ record_exception ();
2152
2166
};
2153
2167
void reduce_dot (VarType /* type*/ , const void * /* ptr_1*/ ,
2154
2168
const void * /* ptr_2*/ , uint32_t /* size*/ ,
2155
2169
void * /* out*/ ) override {
2156
- raise_exception ();
2170
+ record_exception ();
2157
2171
};
2158
2172
void poke (void * /* dst*/ , const void * /* src*/ ,
2159
2173
uint32_t /* size*/ ) override {
2160
- raise_exception ();
2174
+ record_exception ();
2161
2175
};
2162
2176
void aggregate (void * /* dst*/ , AggregationEntry * /* agg*/ ,
2163
2177
uint32_t /* size*/ ) override {
2164
- raise_exception ();
2178
+ record_exception ();
2165
2179
};
2166
2180
void enqueue_host_func (void (* /* callback*/ )(void *),
2167
2181
void * /* payload*/ ) override {
2168
- raise_exception ();
2182
+ record_exception ();
2169
2183
};
2170
2184
void notify_expand (uint32_t /* index*/ ) override {};
2171
2185
void reduce_expanded (VarType /* vt*/ , ReduceOp /* reduce_op*/ ,
@@ -2183,10 +2197,11 @@ void disable_thread_state(ThreadState **ts, JitBackend recording_backend) {
2183
2197
void enable_thread_state (ThreadState **ts) {
2184
2198
if (!*ts)
2185
2199
return ;
2186
- if (DisabledThreadState *dst = dynamic_cast <DisabledThreadState *>(*ts);
2187
- dst != nullptr ) {
2188
- *ts = dst->m_internal ;
2189
- delete dst;
2200
+ if (DisabledThreadState *dts = dynamic_cast <DisabledThreadState *>(*ts);
2201
+ dts != nullptr ) {
2202
+ *ts = dts->m_internal ;
2203
+ dts->rethrow_exception ();
2204
+ delete dts;
2190
2205
}else {
2191
2206
jitc_fail (" Tried to enable a ThreadState that was not disabled." );
2192
2207
}
0 commit comments