@@ -1858,71 +1858,6 @@ bool Recording::check_kernel_cache() {
1858
1858
return true ;
1859
1859
}
1860
1860
1861
- void jitc_freeze_start (JitBackend backend, const uint32_t *inputs,
1862
- uint32_t n_inputs) {
1863
-
1864
- if (jitc_flags () & (uint32_t ) JitFlag::FreezingScope)
1865
- jitc_fail (" Tried to record a thread_state while inside another "
1866
- " FreezingScope!" );
1867
-
1868
- // Increment scope, can be used to track missing inputs
1869
- jitc_new_scope (backend);
1870
-
1871
- ThreadState *ts_ = thread_state (backend);
1872
- RecordThreadState *record_ts = new RecordThreadState (ts_);
1873
-
1874
- if (backend == JitBackend::CUDA)
1875
- thread_state_cuda = record_ts;
1876
- else
1877
- thread_state_llvm = record_ts;
1878
-
1879
- for (uint32_t i = 0 ; i < n_inputs; ++i)
1880
- record_ts->add_input (inputs[i]);
1881
-
1882
- jitc_set_flag (JitFlag::FreezingScope, true );
1883
- }
1884
- Recording *jitc_freeze_stop (JitBackend backend, const uint32_t *outputs,
1885
- uint32_t n_outputs) {
1886
- if (RecordThreadState *rts =
1887
- dynamic_cast <RecordThreadState *>(thread_state (backend));
1888
- rts != nullptr ) {
1889
- ThreadState *internal = rts->m_internal ;
1890
-
1891
- // Perform reassignments to internal thread-state of possibly changed
1892
- // variables
1893
- internal->scope = rts->scope ;
1894
-
1895
- jitc_assert (rts->record_stack .empty (),
1896
- " Kernel recording ended while still recording loop!" );
1897
-
1898
- jitc_set_flag (JitFlag::FreezingScope, false );
1899
- if (rts->m_exception ) {
1900
- std::rethrow_exception (rts->m_exception );
1901
- }
1902
-
1903
- for (uint32_t i = 0 ; i < n_outputs; ++i) {
1904
- rts->add_output (outputs[i]);
1905
- }
1906
-
1907
- if (backend == JitBackend::CUDA) {
1908
- thread_state_cuda = internal;
1909
- } else {
1910
- thread_state_llvm = internal;
1911
- }
1912
- Recording *recording = new Recording (std::move (rts->m_recording ));
1913
- recording->validate ();
1914
- delete rts;
1915
-
1916
- return recording;
1917
- } else {
1918
- jitc_fail (
1919
- " jit_record_stop(): Tried to stop recording a thread state "
1920
- " for backend %u, while no recording was started for this backend. "
1921
- " Try to start the recording with jit_record_start." ,
1922
- (uint32_t ) backend);
1923
- }
1924
- }
1925
-
1926
1861
/* *
1927
1862
* \brief
1928
1863
* This captures the offset buffer of a vcall in a kernel.
@@ -2126,6 +2061,222 @@ void RecordThreadState::add_out_param(uint32_t slot, uint32_t vtype) {
2126
2061
add_out_param (slot, (VarType) vtype);
2127
2062
}
2128
2063
2064
+ /* *
2065
+ * \brief A simple wrapper around a ThreadState, that disables all its
2066
+ * operations. This is used to prevent operations to one ThreadState being
2067
+ * executed while a different thread state is recorded.
2068
+ */
2069
+ struct DisabledThreadState : ThreadState {
2070
+ ThreadState *m_internal;
2071
+ JitBackend m_recording_backend;
2072
+ bool m_raised = false ;
2073
+ DisabledThreadState (ThreadState *internal,
2074
+ JitBackend recording_backend)
2075
+ : m_internal(internal), m_recording_backend(recording_backend){
2076
+ this ->context = internal->context ;
2077
+ this ->stream = internal->stream ;
2078
+ this ->event = internal->event ;
2079
+ this ->sync_stream_event = internal->sync_stream_event ;
2080
+ this ->device = internal->device ;
2081
+ this ->compute_capability = internal->compute_capability ;
2082
+ this ->ptx_version = internal->ptx_version ;
2083
+ this ->memory_pool = internal->memory_pool ;
2084
+
2085
+ this ->backend = internal->backend ;
2086
+ this ->scope = internal->scope ;
2087
+ this ->call_self_value = internal->call_self_value ;
2088
+ this ->call_self_index = internal->call_self_index ;
2089
+
2090
+ #if defined(DRJIT_ENABLE_OPTIX)
2091
+ this ->optix_pipeline = internal->optix_pipeline ;
2092
+ this ->optix_sbt = internal->optix_sbt ;
2093
+ #endif
2094
+
2095
+ this ->m_internal = internal;
2096
+
2097
+ this ->scope = internal->scope ;
2098
+ }
2099
+
2100
+ /* *
2101
+ * Record that an exception has been thrown, similar to the function in
2102
+ * ``RecordThreadState``.
2103
+ */
2104
+ void record_exception () {
2105
+ m_raised = true ;
2106
+ };
2107
+
2108
+ /* *
2109
+ * Actually throws the exception, if any was thrown during recording.
2110
+ */
2111
+ void rethrow_exception () {
2112
+ if (m_raised) {
2113
+ const char *backend =
2114
+ m_internal->backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2115
+ const char *recording_backend =
2116
+ m_recording_backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2117
+ jitc_raise (
2118
+ " The frozen function is being recorded for the %s backend, but "
2119
+ " you tried to execute an operation for the %s backend, this is "
2120
+ " not permitted. It might indicate that you specified the wrong "
2121
+ " backend or the wrong backend was inferred from the inputs." ,
2122
+ recording_backend, backend);
2123
+ }
2124
+ }
2125
+
2126
+ void barrier () override { record_exception (); };
2127
+ Task *launch (Kernel /* kernel*/ , KernelKey * /* key*/ , XXH128_hash_t /* hash*/ ,
2128
+ uint32_t /* size*/ , std::vector<void *> * /* kernel_params*/ ,
2129
+ const std::vector<uint32_t > * /* kernel_param_ids*/ ) override {
2130
+ record_exception ();
2131
+ return nullptr ;
2132
+ };
2133
+ void memset_async (void * /* ptr*/ , uint32_t /* size*/ , uint32_t /* isize*/ ,
2134
+ const void * /* src*/ ) override {
2135
+ record_exception ();
2136
+ };
2137
+ uint32_t compress (const uint8_t * /* in*/ , uint32_t /* size*/ ,
2138
+ uint32_t * /* out*/ ) override {
2139
+ record_exception ();
2140
+ return 0 ;
2141
+ };
2142
+ uint32_t mkperm (const uint32_t * /* values*/ , uint32_t /* size*/ ,
2143
+ uint32_t /* bucket_count*/ , uint32_t * /* perm*/ ,
2144
+ uint32_t * /* offsets*/ ) override {
2145
+ record_exception ();
2146
+ return 0 ;
2147
+ };
2148
+ void memcpy (void * /* dst*/ , const void * /* src*/ ,
2149
+ size_t /* size*/ ) override {
2150
+ record_exception ();
2151
+ };
2152
+ void memcpy_async (void * /* dst*/ , const void * /* src*/ ,
2153
+ size_t /* size*/ ) override {
2154
+ record_exception ();
2155
+ };
2156
+ void block_reduce (VarType /* vt*/ , ReduceOp /* op*/ , uint32_t /* size*/ ,
2157
+ uint32_t /* block_size*/ , const void * /* in*/ ,
2158
+ void * /* out*/ ) override {
2159
+ record_exception ();
2160
+ };
2161
+ void block_prefix_reduce (VarType /* vt*/ , ReduceOp /* op*/ , uint32_t /* size*/ ,
2162
+ uint32_t /* block_size*/ , bool /* exclusive*/ ,
2163
+ bool /* reverse*/ , const void * /* in*/ ,
2164
+ void * /* out*/ ) override {
2165
+ record_exception ();
2166
+ };
2167
+ void reduce_dot (VarType /* type*/ , const void * /* ptr_1*/ ,
2168
+ const void * /* ptr_2*/ , uint32_t /* size*/ ,
2169
+ void * /* out*/ ) override {
2170
+ record_exception ();
2171
+ };
2172
+ void poke (void * /* dst*/ , const void * /* src*/ ,
2173
+ uint32_t /* size*/ ) override {
2174
+ record_exception ();
2175
+ };
2176
+ void aggregate (void * /* dst*/ , AggregationEntry * /* agg*/ ,
2177
+ uint32_t /* size*/ ) override {
2178
+ record_exception ();
2179
+ };
2180
+ void enqueue_host_func (void (* /* callback*/ )(void *),
2181
+ void * /* payload*/ ) override {
2182
+ record_exception ();
2183
+ };
2184
+ void notify_expand (uint32_t /* index*/ ) override {};
2185
+ void reduce_expanded (VarType /* vt*/ , ReduceOp /* reduce_op*/ ,
2186
+ void * /* data*/ , uint32_t /* exp*/ ,
2187
+ uint32_t /* size*/ ) override {};
2188
+ void notify_free (const void * /* ptr*/ ) override {};
2189
+ };
2190
+
2191
+ void set_disabled_thread_state (ThreadState **ts, JitBackend recording_backend) {
2192
+ if (!*ts)
2193
+ return ;
2194
+ *ts = new DisabledThreadState (*ts, recording_backend);
2195
+ }
2196
+
2197
+ void unset_disabled_thread_state (ThreadState **ts) {
2198
+ if (!*ts)
2199
+ return ;
2200
+ if (DisabledThreadState *dts = dynamic_cast <DisabledThreadState *>(*ts);
2201
+ dts != nullptr ) {
2202
+ *ts = dts->m_internal ;
2203
+ dts->rethrow_exception ();
2204
+ delete dts;
2205
+ }else {
2206
+ jitc_fail (" Tried to enable a ThreadState that was not disabled." );
2207
+ }
2208
+ }
2209
+
2210
+ void jitc_freeze_start (JitBackend backend, const uint32_t *inputs,
2211
+ uint32_t n_inputs) {
2212
+
2213
+ if (jitc_flags () & (uint32_t ) JitFlag::FreezingScope)
2214
+ jitc_fail (" Tried to record a thread_state while inside another "
2215
+ " FreezingScope!" );
2216
+
2217
+ // Increment scope, can be used to track missing inputs
2218
+ jitc_new_scope (backend);
2219
+
2220
+ ThreadState *ts_ = thread_state (backend);
2221
+ RecordThreadState *record_ts = new RecordThreadState (ts_);
2222
+
2223
+ if (backend == JitBackend::CUDA) {
2224
+ thread_state_cuda = record_ts;
2225
+ set_disabled_thread_state (&thread_state_llvm, backend);
2226
+ } else {
2227
+ thread_state_llvm = record_ts;
2228
+ set_disabled_thread_state (&thread_state_cuda, backend);
2229
+ }
2230
+
2231
+ for (uint32_t i = 0 ; i < n_inputs; ++i)
2232
+ record_ts->add_input (inputs[i]);
2233
+
2234
+ jitc_set_flag (JitFlag::FreezingScope, true );
2235
+ }
2236
+ Recording *jitc_freeze_stop (JitBackend backend, const uint32_t *outputs,
2237
+ uint32_t n_outputs) {
2238
+ if (RecordThreadState *rts =
2239
+ dynamic_cast <RecordThreadState *>(thread_state (backend));
2240
+ rts != nullptr ) {
2241
+ ThreadState *internal = rts->m_internal ;
2242
+
2243
+ // Perform reassignments to internal thread-state of possibly changed
2244
+ // variables
2245
+ internal->scope = rts->scope ;
2246
+
2247
+ jitc_assert (rts->record_stack .empty (),
2248
+ " Kernel recording ended while still recording loop!" );
2249
+
2250
+ jitc_set_flag (JitFlag::FreezingScope, false );
2251
+ if (rts->m_exception ) {
2252
+ std::rethrow_exception (rts->m_exception );
2253
+ }
2254
+
2255
+ for (uint32_t i = 0 ; i < n_outputs; ++i) {
2256
+ rts->add_output (outputs[i]);
2257
+ }
2258
+
2259
+ if (backend == JitBackend::CUDA) {
2260
+ thread_state_cuda = internal;
2261
+ unset_disabled_thread_state (&thread_state_llvm);
2262
+ } else {
2263
+ thread_state_llvm = internal;
2264
+ unset_disabled_thread_state (&thread_state_cuda);
2265
+ }
2266
+ Recording *recording = new Recording (std::move (rts->m_recording ));
2267
+ recording->validate ();
2268
+ delete rts;
2269
+
2270
+ return recording;
2271
+ } else {
2272
+ jitc_fail (
2273
+ " jit_record_stop(): Tried to stop recording a thread state "
2274
+ " for backend %u, while no recording was started for this backend. "
2275
+ " Try to start the recording with jit_record_start." ,
2276
+ (uint32_t ) backend);
2277
+ }
2278
+ }
2279
+
2129
2280
void jitc_freeze_abort (JitBackend backend) {
2130
2281
if (RecordThreadState *rts =
2131
2282
dynamic_cast <RecordThreadState *>(thread_state (backend));
@@ -2139,8 +2290,10 @@ void jitc_freeze_abort(JitBackend backend) {
2139
2290
2140
2291
if (backend == JitBackend::CUDA) {
2141
2292
thread_state_cuda = internal;
2293
+ unset_disabled_thread_state (&thread_state_llvm);
2142
2294
} else {
2143
2295
thread_state_llvm = internal;
2296
+ unset_disabled_thread_state (&thread_state_cuda);
2144
2297
}
2145
2298
2146
2299
delete rts;
0 commit comments