Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -817,6 +817,10 @@ void jitc_eval_impl(ThreadState *ts) {

if (v->is_array())
v->scope = 0;

#ifndef NDEBUG
state.ptr_to_variable.insert({ v->data, index });
#endif
}

uint32_t dep[4], side_effect = v->side_effect;
Expand Down
9 changes: 9 additions & 0 deletions src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@ struct KernelHistory {

using UnusedPQ = std::priority_queue<uint32_t, std::vector<uint32_t>, std::greater<uint32_t>>;

using PointerMap = tsl::robin_map<const void *, uint32_t, PointerHasher>;

/// Records the full JIT compiler state (most frequently two used entries at top)
struct State {
/// Must be held to access members of this data structure
Expand Down Expand Up @@ -927,6 +929,13 @@ struct State {
uint32_t optix_default_sbt_index = 0;
#endif

#ifndef NDEBUG
/// Mapping from pointers that are managed by variables to their variable
/// indices. This is used for debugging purposes in frozen functions.
PointerMap ptr_to_variable;
#endif


State() {
variables.resize(1);
extra.resize(1);
Expand Down
78 changes: 52 additions & 26 deletions src/record_ts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,11 @@ Task *RecordThreadState::launch(Kernel kernel, KernelKey *key,
}
}
pause_scope pause(this);
// Forward the OptiX SBT and pipeline to the internal thread state.
#if defined(DRJIT_ENABLE_OPTIX)
m_internal->optix_pipeline = optix_pipeline;
m_internal->optix_sbt = optix_sbt;
#endif
return m_internal->launch(kernel, key, hash, size, params, nullptr,
kernel_history_entry);
}
Expand Down Expand Up @@ -875,13 +880,6 @@ void RecordThreadState::record_launch(
}

m_recording.operations.push_back(op);

// Re-assign optix specific variables to internal thread state since they
// might have changed
#if defined(DRJIT_ENABLE_OPTIX)
m_internal->optix_pipeline = optix_pipeline;
m_internal->optix_sbt = optix_sbt;
#endif
}

int Recording::replay_launch(Operation &op) {
Expand Down Expand Up @@ -2034,34 +2032,46 @@ void RecordThreadState::enqueue_host_func(void (*callback)(void *),
(void) callback; (void) payload;
}

void Recording::validate() {
void Recording::validate(uint32_t scope) {
for (uint32_t i = 0; i < recorded_variables.size(); i++) {
RecordedVariable &rv = recorded_variables[i];
if (rv.state == RecordedVarState::Uninitialized) {
Operation &last_op = operations[rv.last_op];
#ifndef NDEBUG
uint32_t index = 0;
const char *scope_string = "before";
auto it = state.ptr_to_variable.find(rv.ptr);
if (it != state.ptr_to_variable.end()) {
index = it->second;
Variable *var = jitc_var(index);
if (var->scope >= scope)
scope_string = "inside";
}
if (last_op.type == OpType::Aggregate) {
jitc_raise(
"validate(): The frozen function included a virtual "
"function call involving variable s%u <%p>, last used by "
"operation o%u. Dr.Jit would normally traverse a registry "
"of all relevant object instances in order to collect "
"their member variables. However, when recording this "
"frozen function, this traversal was skipped because no "
"such object instance was found in the function's inputs. "
"You can trigger traversal by including the relevant "
"objects in the function input, or by specifying them "
"using the state_fn argument. Alternatively, this error "
"might be caused by a nested "
"function call involving Variable r%u at slot s%u <%p>"
"which created %s the frozen function and was last used by "
"operation o%u. Dr.Jit would "
"normally traverse a registry of all relevant object "
"instances in order to collect their member variables. "
"However, when recording this frozen function, this "
"traversal was skipped because no such object instance was "
"found in the function's inputs. You can trigger traversal "
"by including the relevant objects in the function input, "
"or by specifying them using the state_fn argument. "
"Alternatively, this error might be caused by a nested "
"virtual function call.",
i, rv.ptr, rv.last_op);
index, i, rv.ptr, scope_string, rv.last_op);
} else
jitc_raise(
"validate(): Variable at slot s%u <%p> was used by %s operation "
"o%u but left in an uninitialized state! This indicates "
"that the associated variable was used, but not traversed "
"as part of the frozen function input.",
i, rv.ptr, op_type_name[(uint32_t) last_op.type], rv.last_op);
"validate(): Variable r%u at slot s%u <%p> which was "
"created %s the frozen function and was last used by %s "
"operation o%u but left in an uninitialized state! This "
"indicates that the associated variable was used, but not "
"traversed as part of the frozen function input.",
index, i, rv.ptr, scope_string,
op_type_name[(uint32_t) last_op.type], rv.last_op);
#else
if (last_op.type == OpType::Aggregate) {
jitc_raise(
Expand Down Expand Up @@ -2301,14 +2311,29 @@ void RecordThreadState::add_param(AccessInfo info) {
jitc_log(LogLevel::Debug, " -> param s%u", info.slot);

RecordedVariable &rv = m_recording.recorded_variables[info.slot];
if (info.test_uninit && rv.state == RecordedVarState::Uninitialized)
if (info.test_uninit && rv.state == RecordedVarState::Uninitialized){
#ifndef NDEBUG
uint32_t index = 0;
auto it = state.ptr_to_variable.find(rv.ptr);
if (it != state.ptr_to_variable.end())
index = it->second;
jitc_raise("record(): Variable r%u at slot s%u was read by "
"operation o%u, but it had not yet been initialized! "
"This can occur if the variable was not part of "
"the input but is used by a recorded operation, for "
"example if it was not specified as a member in a "
"DRJIT_STRUCT but used in the frozen function.",
index, info.slot, (uint32_t) m_recording.operations.size());
#else
jitc_raise("record(): Variable at slot s%u was read by "
"operation o%u, but it had not yet been initialized! "
"This can occur if the variable was not part of "
"the input but is used by a recorded operation, for "
"example if it was not specified as a member in a "
"DRJIT_STRUCT but used in the frozen function.",
info.slot, (uint32_t) m_recording.operations.size());
#endif
}

if (info.vtype == VarType::Void)
info.vtype = rv.type;
Expand Down Expand Up @@ -2544,6 +2569,7 @@ Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs,
dynamic_cast<RecordThreadState *>(thread_state(backend));
rts != nullptr) {
ThreadState *internal = rts->m_internal;
uint32_t scope = internal->scope;

// Perform reassignments to internal thread-state of possibly changed
// variables
Expand Down Expand Up @@ -2571,7 +2597,7 @@ Recording *jitc_freeze_stop(JitBackend backend, const uint32_t *outputs,
}
Recording *recording = new Recording(std::move(rts->m_recording));
try{
recording->validate();
recording->validate(scope);
} catch (const std::exception &) {
recording->destroy();
throw;
Expand Down
2 changes: 1 addition & 1 deletion src/record_ts.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ struct Recording {

/// This function is called after recording and checks that the recording is
/// valid i.e. that no variables where left uninitialized.
void validate();
void validate(uint32_t scope);
/// Checks if all recorded kernels are still in the kernel cache. This might
/// occur when calling dr.kernel_cache_flush between recording the function
/// and replaying it.
Expand Down
21 changes: 20 additions & 1 deletion src/var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,22 @@ JIT_NOINLINE void jitc_var_free(uint32_t index, Variable *v) noexcept {

if (v->is_evaluated()) {
// Release memory referenced by this variable
if (!v->retain_data)
if (!v->retain_data) {
jitc_free(v->data);

#ifndef NDEBUG
// This warning should never be thrown, except if we forgot to
// populate the mapping
if (!state.ptr_to_variable.contains(v->data))
jitc_log(
LogLevel::Warn,
"Pointr <%p> was mangaged by variable r%u, but this "
"was not recorded in the pointer to variable map!",
v->data, index);

state.ptr_to_variable.erase(v->data);
#endif
}
} else {
// Unevaluated variable, drop from CSE cache
jitc_lvn_drop(index, v);
Expand Down Expand Up @@ -784,6 +798,11 @@ uint32_t jitc_var_new(Variable &v, bool disable_lvn) {
jitc_sanitation_checkpoint();
#endif

#ifndef NDEBUG
if (v.is_evaluated())
state.ptr_to_variable.insert({ v.data, index });
#endif

return index;
}

Expand Down