@@ -1858,71 +1858,6 @@ bool Recording::check_kernel_cache() {
1858
1858
return true ;
1859
1859
}
1860
1860
1861
- void jitc_freeze_start (JitBackend backend, const uint32_t *inputs,
1862
- uint32_t n_inputs) {
1863
-
1864
- if (jitc_flags () & (uint32_t ) JitFlag::FreezingScope)
1865
- jitc_fail (" Tried to record a thread_state while inside another "
1866
- " FreezingScope!" );
1867
-
1868
- // Increment scope, can be used to track missing inputs
1869
- jitc_new_scope (backend);
1870
-
1871
- ThreadState *ts_ = thread_state (backend);
1872
- RecordThreadState *record_ts = new RecordThreadState (ts_);
1873
-
1874
- if (backend == JitBackend::CUDA)
1875
- thread_state_cuda = record_ts;
1876
- else
1877
- thread_state_llvm = record_ts;
1878
-
1879
- for (uint32_t i = 0 ; i < n_inputs; ++i)
1880
- record_ts->add_input (inputs[i]);
1881
-
1882
- jitc_set_flag (JitFlag::FreezingScope, true );
1883
- }
1884
- Recording *jitc_freeze_stop (JitBackend backend, const uint32_t *outputs,
1885
- uint32_t n_outputs) {
1886
- if (RecordThreadState *rts =
1887
- dynamic_cast <RecordThreadState *>(thread_state (backend));
1888
- rts != nullptr ) {
1889
- ThreadState *internal = rts->m_internal ;
1890
-
1891
- // Perform reassignments to internal thread-state of possibly changed
1892
- // variables
1893
- internal->scope = rts->scope ;
1894
-
1895
- jitc_assert (rts->record_stack .empty (),
1896
- " Kernel recording ended while still recording loop!" );
1897
-
1898
- jitc_set_flag (JitFlag::FreezingScope, false );
1899
- if (rts->m_exception ) {
1900
- std::rethrow_exception (rts->m_exception );
1901
- }
1902
-
1903
- for (uint32_t i = 0 ; i < n_outputs; ++i) {
1904
- rts->add_output (outputs[i]);
1905
- }
1906
-
1907
- if (backend == JitBackend::CUDA) {
1908
- thread_state_cuda = internal;
1909
- } else {
1910
- thread_state_llvm = internal;
1911
- }
1912
- Recording *recording = new Recording (std::move (rts->m_recording ));
1913
- recording->validate ();
1914
- delete rts;
1915
-
1916
- return recording;
1917
- } else {
1918
- jitc_fail (
1919
- " jit_record_stop(): Tried to stop recording a thread state "
1920
- " for backend %u, while no recording was started for this backend. "
1921
- " Try to start the recording with jit_record_start." ,
1922
- (uint32_t ) backend);
1923
- }
1924
- }
1925
-
1926
1861
/* *
1927
1862
* \brief
1928
1863
* This captures the offset buffer of a vcall in a kernel.
@@ -2126,6 +2061,207 @@ void RecordThreadState::add_out_param(uint32_t slot, uint32_t vtype) {
2126
2061
add_out_param (slot, (VarType) vtype);
2127
2062
}
2128
2063
2064
+ /* *
2065
+ * \brief A simple wrapper around a ThreadState, that disables all its
2066
+ * operations. This is used to prevent operations to one ThreadState being
2067
+ * executed while a different thread state is recorded.
2068
+ */
2069
+ struct DisabledThreadState : ThreadState {
2070
+ ThreadState *m_internal;
2071
+ JitBackend m_recording_backend;
2072
+ DisabledThreadState (ThreadState *internal,
2073
+ JitBackend recording_backend)
2074
+ : m_internal(internal), m_recording_backend(recording_backend){
2075
+ this ->context = internal->context ;
2076
+ this ->stream = internal->stream ;
2077
+ this ->event = internal->event ;
2078
+ this ->sync_stream_event = internal->sync_stream_event ;
2079
+ this ->device = internal->device ;
2080
+ this ->compute_capability = internal->compute_capability ;
2081
+ this ->ptx_version = internal->ptx_version ;
2082
+ this ->memory_pool = internal->memory_pool ;
2083
+
2084
+ this ->backend = internal->backend ;
2085
+ this ->scope = internal->scope ;
2086
+ this ->call_self_value = internal->call_self_value ;
2087
+ this ->call_self_index = internal->call_self_index ;
2088
+
2089
+ #if defined(DRJIT_ENABLE_OPTIX)
2090
+ this ->optix_pipeline = internal->optix_pipeline ;
2091
+ this ->optix_sbt = internal->optix_sbt ;
2092
+ #endif
2093
+
2094
+ this ->m_internal = internal;
2095
+
2096
+ this ->scope = internal->scope ;
2097
+ }
2098
+
2099
+ void raise_exception () {
2100
+ const char *backend =
2101
+ m_internal->backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2102
+ const char *recording_backend =
2103
+ m_recording_backend == JitBackend::CUDA ? " CUDA" : " LLVM" ;
2104
+ jitc_raise (
2105
+ " The frozen function is being recorded for the %s backend, but "
2106
+ " you tried to execute an operation for the %s backend, this is "
2107
+ " not permitted. It might indicate that you specified the wrong "
2108
+ " backend or the wrong backend was inferred from the inputs." ,
2109
+ recording_backend, backend);
2110
+ };
2111
+
2112
+ void barrier () override { raise_exception (); };
2113
+ Task *launch (Kernel /* kernel*/ , KernelKey * /* key*/ , XXH128_hash_t /* hash*/ ,
2114
+ uint32_t /* size*/ , std::vector<void *> * /* kernel_params*/ ,
2115
+ const std::vector<uint32_t > * /* kernel_param_ids*/ ) override {
2116
+ raise_exception ();
2117
+ return nullptr ;
2118
+ };
2119
+ void memset_async (void * /* ptr*/ , uint32_t /* size*/ , uint32_t /* isize*/ ,
2120
+ const void * /* src*/ ) override {
2121
+ raise_exception ();
2122
+ };
2123
+ uint32_t compress (const uint8_t * /* in*/ , uint32_t /* size*/ ,
2124
+ uint32_t * /* out*/ ) override {
2125
+ raise_exception ();
2126
+ return 0 ;
2127
+ };
2128
+ uint32_t mkperm (const uint32_t * /* values*/ , uint32_t /* size*/ ,
2129
+ uint32_t /* bucket_count*/ , uint32_t * /* perm*/ ,
2130
+ uint32_t * /* offsets*/ ) override {
2131
+ raise_exception ();
2132
+ return 0 ;
2133
+ };
2134
+ void memcpy (void * /* dst*/ , const void * /* src*/ ,
2135
+ size_t /* size*/ ) override {
2136
+ raise_exception ();
2137
+ };
2138
+ void memcpy_async (void * /* dst*/ , const void * /* src*/ ,
2139
+ size_t /* size*/ ) override {
2140
+ raise_exception ();
2141
+ };
2142
+ void block_reduce (VarType /* vt*/ , ReduceOp /* op*/ , uint32_t /* size*/ ,
2143
+ uint32_t /* block_size*/ , const void * /* in*/ ,
2144
+ void * /* out*/ ) override {
2145
+ raise_exception ();
2146
+ };
2147
+ void block_prefix_reduce (VarType /* vt*/ , ReduceOp /* op*/ , uint32_t /* size*/ ,
2148
+ uint32_t /* block_size*/ , bool /* exclusive*/ ,
2149
+ bool /* reverse*/ , const void * /* in*/ ,
2150
+ void * /* out*/ ) override {
2151
+ raise_exception ();
2152
+ };
2153
+ void reduce_dot (VarType /* type*/ , const void * /* ptr_1*/ ,
2154
+ const void * /* ptr_2*/ , uint32_t /* size*/ ,
2155
+ void * /* out*/ ) override {
2156
+ raise_exception ();
2157
+ };
2158
+ void poke (void * /* dst*/ , const void * /* src*/ ,
2159
+ uint32_t /* size*/ ) override {
2160
+ raise_exception ();
2161
+ };
2162
+ void aggregate (void * /* dst*/ , AggregationEntry * /* agg*/ ,
2163
+ uint32_t /* size*/ ) override {
2164
+ raise_exception ();
2165
+ };
2166
+ void enqueue_host_func (void (* /* callback*/ )(void *),
2167
+ void * /* payload*/ ) override {
2168
+ raise_exception ();
2169
+ };
2170
+ void notify_expand (uint32_t /* index*/ ) override {};
2171
+ void reduce_expanded (VarType /* vt*/ , ReduceOp /* reduce_op*/ ,
2172
+ void * /* data*/ , uint32_t /* exp*/ ,
2173
+ uint32_t /* size*/ ) override {};
2174
+ void notify_free (const void * /* ptr*/ ) override {};
2175
+ };
2176
+
2177
+ void disable_thread_state (ThreadState **ts, JitBackend recording_backend) {
2178
+ if (!*ts)
2179
+ return ;
2180
+ *ts = new DisabledThreadState (*ts, recording_backend);
2181
+ }
2182
+
2183
+ void enable_thread_state (ThreadState **ts) {
2184
+ if (!*ts)
2185
+ return ;
2186
+ if (DisabledThreadState *dst = dynamic_cast <DisabledThreadState *>(*ts);
2187
+ dst != nullptr ) {
2188
+ *ts = dst->m_internal ;
2189
+ delete dst;
2190
+ }else {
2191
+ jitc_fail (" Tried to enable a ThreadState that was not disabled." );
2192
+ }
2193
+ }
2194
+
2195
+ void jitc_freeze_start (JitBackend backend, const uint32_t *inputs,
2196
+ uint32_t n_inputs) {
2197
+
2198
+ if (jitc_flags () & (uint32_t ) JitFlag::FreezingScope)
2199
+ jitc_fail (" Tried to record a thread_state while inside another "
2200
+ " FreezingScope!" );
2201
+
2202
+ // Increment scope, can be used to track missing inputs
2203
+ jitc_new_scope (backend);
2204
+
2205
+ ThreadState *ts_ = thread_state (backend);
2206
+ RecordThreadState *record_ts = new RecordThreadState (ts_);
2207
+
2208
+ if (backend == JitBackend::CUDA) {
2209
+ thread_state_cuda = record_ts;
2210
+ disable_thread_state (&thread_state_llvm, backend);
2211
+ } else {
2212
+ thread_state_llvm = record_ts;
2213
+ disable_thread_state (&thread_state_cuda, backend);
2214
+ }
2215
+
2216
+ for (uint32_t i = 0 ; i < n_inputs; ++i)
2217
+ record_ts->add_input (inputs[i]);
2218
+
2219
+ jitc_set_flag (JitFlag::FreezingScope, true );
2220
+ }
2221
+ Recording *jitc_freeze_stop (JitBackend backend, const uint32_t *outputs,
2222
+ uint32_t n_outputs) {
2223
+ if (RecordThreadState *rts =
2224
+ dynamic_cast <RecordThreadState *>(thread_state (backend));
2225
+ rts != nullptr ) {
2226
+ ThreadState *internal = rts->m_internal ;
2227
+
2228
+ // Perform reassignments to internal thread-state of possibly changed
2229
+ // variables
2230
+ internal->scope = rts->scope ;
2231
+
2232
+ jitc_assert (rts->record_stack .empty (),
2233
+ " Kernel recording ended while still recording loop!" );
2234
+
2235
+ jitc_set_flag (JitFlag::FreezingScope, false );
2236
+ if (rts->m_exception ) {
2237
+ std::rethrow_exception (rts->m_exception );
2238
+ }
2239
+
2240
+ for (uint32_t i = 0 ; i < n_outputs; ++i) {
2241
+ rts->add_output (outputs[i]);
2242
+ }
2243
+
2244
+ if (backend == JitBackend::CUDA) {
2245
+ thread_state_cuda = internal;
2246
+ enable_thread_state (&thread_state_llvm);
2247
+ } else {
2248
+ thread_state_llvm = internal;
2249
+ enable_thread_state (&thread_state_cuda);
2250
+ }
2251
+ Recording *recording = new Recording (std::move (rts->m_recording ));
2252
+ recording->validate ();
2253
+ delete rts;
2254
+
2255
+ return recording;
2256
+ } else {
2257
+ jitc_fail (
2258
+ " jit_record_stop(): Tried to stop recording a thread state "
2259
+ " for backend %u, while no recording was started for this backend. "
2260
+ " Try to start the recording with jit_record_start." ,
2261
+ (uint32_t ) backend);
2262
+ }
2263
+ }
2264
+
2129
2265
void jitc_freeze_abort (JitBackend backend) {
2130
2266
if (RecordThreadState *rts =
2131
2267
dynamic_cast <RecordThreadState *>(thread_state (backend));
@@ -2139,8 +2275,10 @@ void jitc_freeze_abort(JitBackend backend) {
2139
2275
2140
2276
if (backend == JitBackend::CUDA) {
2141
2277
thread_state_cuda = internal;
2278
+ enable_thread_state (&thread_state_llvm);
2142
2279
} else {
2143
2280
thread_state_llvm = internal;
2281
+ enable_thread_state (&thread_state_cuda);
2144
2282
}
2145
2283
2146
2284
delete rts;
0 commit comments