Skip to content

Commit 70936f0

Browse files
fix distribution trace bug + move into interpreter
1 parent 14594de commit 70936f0

File tree

5 files changed

+11
-20
lines changed

5 files changed

+11
-20
lines changed

src/lmql/runtime/interpreter.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import namedtuple
44
from dataclasses import dataclass
55
from typing import Any, Dict, Optional, List, Union, NamedTuple, Tuple, Set
6+
from lmql.runtime.postprocessing.conditional_prob import ConditionalDistributionPostprocessor
67
import numpy as np
78
import warnings
89

@@ -1126,6 +1127,9 @@ async def debug_out(decoder_step):
11261127
# set decoder step +1, for all stats logging that happens in postprocessing
11271128
self.decoder_step += 1
11281129

1130+
# applies distribution postprocessor if required
1131+
results = await (ConditionalDistributionPostprocessor(self).process(results))
1132+
11291133
# check if a certificate was requested
11301134
if self.certificate != False:
11311135
active_tracer().event("lmql.LMQLResult", results, skip_none=True)

src/lmql/runtime/lmql_runtime.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77
from typing import Any, Dict, Optional
88

99
from lmql.ops.ops import *
10+
from lmql.runtime.context import Context
1011
from lmql.runtime.langchain import chain, call_sync
1112
from lmql.runtime.output_writer import silent
12-
from lmql.runtime.postprocessing.conditional_prob import \
13-
ConditionalDistributionPostprocessor
1413
from lmql.runtime.postprocessing.group_by import GroupByPostprocessor
1514
from lmql.api.inspect import is_query
1615
from lmql.runtime.formatting import format, tag
@@ -232,14 +231,6 @@ async def __acall__(self, *args, **kwargs):
232231
finally:
233232
if PromptInterpreter.main == interpreter:
234233
PromptInterpreter.main = None
235-
236-
# applies distribution postprocessor if required
237-
results = await (ConditionalDistributionPostprocessor(interpreter).process(results))
238-
239-
# apply remaining postprocessors
240-
if self.postprocessors is not None:
241-
for postprocessor in self.postprocessors:
242-
results = await postprocessor.process(results, self.output_writer)
243234

244235
interpreter.print_stats()
245236
interpreter.dcmodel.close()

src/lmql/runtime/postprocessing/conditional_prob.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,10 @@ async def score(self, prompt: str, values, dcmodel: dc.DcModel):
2929

3030
async def process(self, results):
3131
model: dc.DcModel = self.interpreter.dcmodel
32-
# optional unpacker for singular results
33-
unpack = lambda v: v
34-
35-
# unpack singular results after processing
36-
if type(results) is not list:
37-
results = [results]
38-
unpack = lambda v: v[0]
3932

4033
# check if distribution is required
4134
if not any(r is not None and hasattr(r, "distribution_variable") and r.distribution_variable is not None for r in results):
42-
return unpack(results)
35+
return results
4336

4437
if len(results) > 1:
4538
if "top1_distribution" in self.interpreter.decoder_kwargs and self.interpreter.decoder_kwargs["top1_distribution"]:
@@ -77,4 +70,4 @@ async def process(self, results):
7770
result.variables[f"P({distribution_variable})"] = [(value, prob) for value, prob, _ in distribution]
7871
result.variables[f"log P({distribution_variable})"] = [(value, prob) for value, prob, _ in log_distribution]
7972

80-
return unpack(results[0])
73+
return results

src/lmql/runtime/tracing/tracer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def active_tracer() -> Tracer:
176176
different tracers in sub-queries.
177177
"""
178178
_ensure_tracer()
179-
assert len(_tracer.get()) > 0, "No tracer set in this context"
179+
if len(_tracer.get()) == 0:
180+
warnings.warn("An LMQL tracer was requested in a context without active tracer. This indicates that some internal LLM calls may not be traced correctly.")
181+
return NullTracer("null")
180182
return _tracer.get()[-1]
181183

182184
def set_tracer(tracer):

src/lmql/tests/test_sample_queries.py

+1
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ async def main():
7373
print(error_buffer.getvalue())
7474
print(e)
7575
print(termcolor.colored("[FAIL]", "red"), f"({time.time() - s:.2f}s)")
76+
sys.exit(1)
7677

7778

7879
if __name__ == "__main__":

0 commit comments

Comments
 (0)