diff --git a/pyproject.toml b/pyproject.toml index d3bdf19bf..32ac90992 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,6 +96,9 @@ browser = [ renderers = [ "renderers>=0.1.6", ] +braintrust = [ + "braintrust>=0.0.160", +] rl = [ "torch>=2.8.0,<2.9.0", "transformers>=4.56.2", @@ -230,6 +233,12 @@ invalid-method-override = "ignore" invalid-assignment = "ignore" not-iterable = "ignore" +[[tool.ty.overrides]] +include = ["verifiers/envs/experimental/braintrust_tracing/**"] + +[tool.ty.overrides.rules] +unresolved-import = "ignore" + [tool.coverage.run] source = ["verifiers"] omit = [ diff --git a/uv.lock b/uv.lock index 2fbe0a6ab..cf4dcb66c 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.10, <3.14" resolution-markers = [ "python_full_version >= '3.13' and sys_platform == 'win32'", @@ -14,13 +14,6 @@ resolution-markers = [ "python_full_version < '3.11'", ] -[options] - -[options.exclude-newer-package] -prime-tunnel = false -prime-sandboxes = false -renderers = false - [[package]] name = "accelerate" version = "1.13.0" @@ -415,6 +408,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d8/fc/923e25ac9cadfff1cd20038bcc0854d0f98061eb6bc78e42c43615f5982d/blake3-1.0.8-cp313-cp313t-win_amd64.whl", hash = "sha256:3cec94ed5676821cf371e9c9d25a41b4f3ebdb5724719b31b2749653b7cc1dfa", size = 215369, upload-time = "2025-10-14T06:46:39.054Z" }, ] +[[package]] +name = "braintrust" +version = "0.19.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "chevron" }, + { name = "exceptiongroup" }, + { name = "gitpython" }, + { name = "jsonschema" }, + { name = "packaging" }, + { name = "python-slugify" }, + { name = "requests" }, + { name = "sseclient-py" }, + { name = "tqdm" }, + { name = "typing-extensions" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4d/80/812c7dd5aad91bceb4341caf0d8fe62ac52c9fbad7f2d961d782e8be1d5d/braintrust-0.19.0.tar.gz", hash = "sha256:0180ed0088293e621aecdb17f2d7a6d17c8bdaf9fdf4557cc3cfcf8c243046f1", size = 582781, upload-time = "2026-05-04T16:45:07.779Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5d/a7/f2ca3bd4559319891de1c36d3297705817996a30e642a1834ab9e4c67a6c/braintrust-0.19.0-py3-none-any.whl", hash = "sha256:717fde09d036ab2339bf73853a33d1c97e91cee56e774e20271951745e4ee972", size = 680473, upload-time = "2026-05-04T16:45:05.444Z" }, +] + [[package]] name = "cachetools" version = "7.0.6" @@ -653,6 +668,15 @@ version = "1.11.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/93/09/7d04d7581ae3bb8b598017941781bceb7959dd1b13e3ebf7b6a2cd843bc9/chess-1.11.2.tar.gz", hash = "sha256:a8b43e5678fdb3000695bdaa573117ad683761e5ca38e591c4826eba6d25bb39", size = 6131385, upload-time = "2025-02-25T19:10:27.328Z" } +[[package]] +name = "chevron" +version = "0.14.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/15/1f/ca74b65b19798895d63a6e92874162f44233467c9e7c1ed8afd19016ebe9/chevron-0.14.0.tar.gz", hash = "sha256:87613aafdf6d77b6a90ff073165a61ae5086e21ad49057aa0e53681601800ebf", size = 11440, upload-time = "2021-01-02T22:47:59.233Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/52/93/342cc62a70ab727e093ed98e02a725d85b746345f05d2b5e5034649f4ec8/chevron-0.14.0-py3-none-any.whl", hash = "sha256:fbf996a709f8da2e745ef763f482ce2d311aa817d287593a5b990d6d6e4f0443", size = 11595, upload-time = "2021-01-02T22:47:57.847Z" }, +] + [[package]] name = "click" version = "8.3.3" @@ -4535,6 +4559,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9a/22/f1925cdda983ab66fc8ec6ec8014b959262747e58bdca26a4e3d1da29d56/python_multipart-0.0.26-py3-none-any.whl", hash = "sha256:c0b169f8c4484c13b0dcf2ef0ec3a4adb255c4b7d18d8e420477d2b1dd03f185", size = 28847, upload-time = "2026-04-10T14:09:58.131Z" }, ] +[[package]] +name = "python-slugify" +version = "8.0.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "text-unidecode" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/c7/5e1547c44e31da50a460df93af11a535ace568ef89d7a811069ead340c4a/python-slugify-8.0.4.tar.gz", hash = "sha256:59202371d1d05b54a9e7720c5e038f928f45daaffe41dd10822f3907b937c856", size = 10921, upload-time = "2024-02-08T18:32:45.488Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/62/02da182e544a51a5c3ccf4b03ab79df279f9c60c5e82d5e8bec7ca26ac11/python_slugify-8.0.4-py2.py3-none-any.whl", hash = "sha256:276540b79961052b66b7d116620b36518847f52d5fd9e3a70164fc8c50faa6b8", size = 10051, upload-time = "2024-02-08T18:32:43.911Z" }, +] + [[package]] name = "pytz" version = "2026.1.post1" @@ -5508,6 +5544,14 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f8/7f/3de5402f39890ac5660b86bcf5c03f9d855dad5c4ed764866d7b592b46fd/sse_starlette-3.3.4-py3-none-any.whl", hash = "sha256:84bb06e58939a8b38d8341f1bc9792f06c2b53f48c608dd207582b664fc8f3c1", size = 14330, upload-time = "2026-03-29T09:00:21.846Z" }, ] +[[package]] +name = "sseclient-py" +version = "1.9.0" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4d/2e/59920f7d66b7f9932a3d83dd0ec53fab001be1e058bf582606fe414a5198/sseclient_py-1.9.0-py3-none-any.whl", hash = "sha256:340062b1587fc2880892811e2ab5b176d98ef3eee98b3672ff3a3ba1e8ed0f6f", size = 8351, upload-time = "2026-01-02T23:39:30.995Z" }, +] + [[package]] name = "stack-data" version = "0.6.3" @@ -5585,6 +5629,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, ] +[[package]] +name = "text-unidecode" +version = "1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ab/e2/e9a00f0ccb71718418230718b3d900e71a5d16e701a3dae079a21e9cd8f8/text-unidecode-1.3.tar.gz", hash = "sha256:bad6603bb14d279193107714b288be206cac565dfa49aa5b105294dd5c4aab93", size = 76885, upload-time = "2019-08-30T21:36:45.405Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/a5/c0b6468d3824fe3fde30dbb5e1f687b291608f9473681bbf7dabbf5a87d7/text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8", size = 78154, upload-time = "2019-08-30T21:37:03.543Z" }, +] + [[package]] name = "textarena" version = "0.7.4" @@ -6127,6 +6180,9 @@ dependencies = [ ] [package.optional-dependencies] +braintrust = [ + { name = "braintrust" }, +] browser = [ { name = "aiohttp" }, { name = "python-dotenv" }, @@ -6185,6 +6241,7 @@ requires-dist = [ { name = "aiohttp", marker = "extra == 'browser'", specifier = ">=3.9.0" }, { name = "aiolimiter", specifier = ">=1.2.1" }, { name = "anthropic", specifier = ">=0.78.0" }, + { name = "braintrust", marker = "extra == 'braintrust'", specifier = ">=0.0.160" }, { name = "datasets", specifier = ">=3.0.0,<4.7.0" }, { name = "deepspeed", marker = "extra == 'rl'", specifier = ">=0.17.6" }, { name = "flash-attn", marker = "extra == 'rl'", specifier = ">=2.8.3" }, @@ -6226,7 +6283,7 @@ requires-dist = [ { name = "wandb", marker = "extra == 'rl'" }, { name = "wget", specifier = ">=3.2" }, ] -provides-extras = ["browser", "openenv", "renderers", "rg", "rl", "ta"] +provides-extras = ["braintrust", "browser", "openenv", "renderers", "rg", "rl", "ta"] [package.metadata.requires-dev] dev = [ @@ -6525,6 +6582,70 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/0e/fa3b193432cfc60c93b42f3be03365f5f909d2b3ea410295cf36df739e31/widgetsnbextension-4.0.15-py3-none-any.whl", hash = "sha256:8156704e4346a571d9ce73b84bee86a29906c9abfd7223b7228a28899ccf3366", size = 2196503, upload-time = "2025-11-01T21:15:53.565Z" }, ] +[[package]] +name = "wrapt" +version = "2.1.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/64/925f213fdcbb9baeb1530449ac71a4d57fc361c053d06bf78d0c5c7cd80c/wrapt-2.1.2.tar.gz", hash = "sha256:3996a67eecc2c68fd47b4e3c564405a5777367adfd9b8abb58387b63ee83b21e", size = 81678, upload-time = "2026-03-06T02:53:25.134Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/d2/387594fb592d027366645f3d7cc9b4d7ca7be93845fbaba6d835a912ef3c/wrapt-2.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4b7a86d99a14f76facb269dc148590c01aaf47584071809a70da30555228158c", size = 60669, upload-time = "2026-03-06T02:52:40.671Z" }, + { url = "https://files.pythonhosted.org/packages/c9/18/3f373935bc5509e7ac444c8026a56762e50c1183e7061797437ca96c12ce/wrapt-2.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a819e39017f95bf7aede768f75915635aa8f671f2993c036991b8d3bfe8dbb6f", size = 61603, upload-time = "2026-03-06T02:54:21.032Z" }, + { url = "https://files.pythonhosted.org/packages/c2/7a/32758ca2853b07a887a4574b74e28843919103194bb47001a304e24af62f/wrapt-2.1.2-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5681123e60aed0e64c7d44f72bbf8b4ce45f79d81467e2c4c728629f5baf06eb", size = 113632, upload-time = "2026-03-06T02:53:54.121Z" }, + { url = "https://files.pythonhosted.org/packages/1d/d5/eeaa38f670d462e97d978b3b0d9ce06d5b91e54bebac6fbed867809216e7/wrapt-2.1.2-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b8b28e97a44d21836259739ae76284e180b18abbb4dcfdff07a415cf1016c3e", size = 115644, upload-time = "2026-03-06T02:54:53.33Z" }, + { url = "https://files.pythonhosted.org/packages/e3/09/2a41506cb17affb0bdf9d5e2129c8c19e192b388c4c01d05e1b14db23c00/wrapt-2.1.2-cp310-cp310-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:cef91c95a50596fcdc31397eb6955476f82ae8a3f5a8eabdc13611b60ee380ba", size = 112016, upload-time = "2026-03-06T02:54:43.274Z" }, + { url = "https://files.pythonhosted.org/packages/64/15/0e6c3f5e87caadc43db279724ee36979246d5194fa32fed489c73643ba59/wrapt-2.1.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:dad63212b168de8569b1c512f4eac4b57f2c6934b30df32d6ee9534a79f1493f", size = 114823, upload-time = "2026-03-06T02:54:29.392Z" }, + { url = "https://files.pythonhosted.org/packages/56/b2/0ad17c8248f4e57bedf44938c26ec3ee194715f812d2dbbd9d7ff4be6c06/wrapt-2.1.2-cp310-cp310-musllinux_1_2_riscv64.whl", hash = "sha256:d307aa6888d5efab2c1cde09843d48c843990be13069003184b67d426d145394", size = 111244, upload-time = "2026-03-06T02:54:02.149Z" }, + { url = "https://files.pythonhosted.org/packages/ff/04/bcdba98c26f2c6522c7c09a726d5d9229120163493620205b2f76bd13c01/wrapt-2.1.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c87cf3f0c85e27b3ac7d9ad95da166bf8739ca215a8b171e8404a2d739897a45", size = 113307, upload-time = "2026-03-06T02:54:12.428Z" }, + { url = "https://files.pythonhosted.org/packages/0e/1b/5e2883c6bc14143924e465a6fc5a92d09eeabe35310842a481fb0581f832/wrapt-2.1.2-cp310-cp310-win32.whl", hash = "sha256:d1c5fea4f9fe3762e2b905fdd67df51e4be7a73b7674957af2d2ade71a5c075d", size = 57986, upload-time = "2026-03-06T02:54:26.823Z" }, + { url = "https://files.pythonhosted.org/packages/42/5a/4efc997bccadd3af5749c250b49412793bc41e13a83a486b2b54a33e240c/wrapt-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:d8f7740e1af13dff2684e4d56fe604a7e04d6c94e737a60568d8d4238b9a0c71", size = 60336, upload-time = "2026-03-06T02:54:18Z" }, + { url = "https://files.pythonhosted.org/packages/c1/f5/a2bb833e20181b937e87c242645ed5d5aa9c373006b0467bfe1a35c727d0/wrapt-2.1.2-cp310-cp310-win_arm64.whl", hash = "sha256:1c6cc827c00dc839350155f316f1f8b4b0c370f52b6a19e782e2bda89600c7dc", size = 58757, upload-time = "2026-03-06T02:53:51.545Z" }, + { url = "https://files.pythonhosted.org/packages/c7/81/60c4471fce95afa5922ca09b88a25f03c93343f759aae0f31fb4412a85c7/wrapt-2.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:96159a0ee2b0277d44201c3b5be479a9979cf154e8c82fa5df49586a8e7679bb", size = 60666, upload-time = "2026-03-06T02:52:58.934Z" }, + { url = "https://files.pythonhosted.org/packages/6b/be/80e80e39e7cb90b006a0eaf11c73ac3a62bbfb3068469aec15cc0bc795de/wrapt-2.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98ba61833a77b747901e9012072f038795de7fc77849f1faa965464f3f87ff2d", size = 61601, upload-time = "2026-03-06T02:53:00.487Z" }, + { url = "https://files.pythonhosted.org/packages/b0/be/d7c88cd9293c859fc74b232abdc65a229bb953997995d6912fc85af18323/wrapt-2.1.2-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:767c0dbbe76cae2a60dd2b235ac0c87c9cccf4898aef8062e57bead46b5f6894", size = 114057, upload-time = "2026-03-06T02:52:44.08Z" }, + { url = "https://files.pythonhosted.org/packages/ea/25/36c04602831a4d685d45a93b3abea61eca7fe35dab6c842d6f5d570ef94a/wrapt-2.1.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c691a6bc752c0cc4711cc0c00896fcd0f116abc253609ef64ef930032821842", size = 116099, upload-time = "2026-03-06T02:54:56.74Z" }, + { url = "https://files.pythonhosted.org/packages/5c/4e/98a6eb417ef551dc277bec1253d5246b25003cf36fdf3913b65cb7657a56/wrapt-2.1.2-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f3b7d73012ea75aee5844de58c88f44cf62d0d62711e39da5a82824a7c4626a8", size = 112457, upload-time = "2026-03-06T02:53:52.842Z" }, + { url = "https://files.pythonhosted.org/packages/cb/a6/a6f7186a5297cad8ec53fd7578533b28f795fdf5372368c74bd7e6e9841c/wrapt-2.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:577dff354e7acd9d411eaf4bfe76b724c89c89c8fc9b7e127ee28c5f7bcb25b6", size = 115351, upload-time = "2026-03-06T02:53:32.684Z" }, + { url = "https://files.pythonhosted.org/packages/97/6f/06e66189e721dbebd5cf20e138acc4d1150288ce118462f2fcbff92d38db/wrapt-2.1.2-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:3d7b6fd105f8b24e5bd23ccf41cb1d1099796524bcc6f7fbb8fe576c44befbc9", size = 111748, upload-time = "2026-03-06T02:53:08.455Z" }, + { url = "https://files.pythonhosted.org/packages/ef/43/4808b86f499a51370fbdbdfa6cb91e9b9169e762716456471b619fca7a70/wrapt-2.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:866abdbf4612e0b34764922ef8b1c5668867610a718d3053d59e24a5e5fcfc15", size = 113783, upload-time = "2026-03-06T02:53:02.02Z" }, + { url = "https://files.pythonhosted.org/packages/91/2c/a3f28b8fa7ac2cefa01cfcaca3471f9b0460608d012b693998cd61ef43df/wrapt-2.1.2-cp311-cp311-win32.whl", hash = "sha256:5a0a0a3a882393095573344075189eb2d566e0fd205a2b6414e9997b1b800a8b", size = 57977, upload-time = "2026-03-06T02:53:27.844Z" }, + { url = "https://files.pythonhosted.org/packages/3f/c3/2b1c7bd07a27b1db885a2fab469b707bdd35bddf30a113b4917a7e2139d2/wrapt-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:64a07a71d2730ba56f11d1a4b91f7817dc79bc134c11516b75d1921a7c6fcda1", size = 60336, upload-time = "2026-03-06T02:54:28.104Z" }, + { url = "https://files.pythonhosted.org/packages/ec/5c/76ece7b401b088daa6503d6264dd80f9a727df3e6042802de9a223084ea2/wrapt-2.1.2-cp311-cp311-win_arm64.whl", hash = "sha256:b89f095fe98bc12107f82a9f7d570dc83a0870291aeb6b1d7a7d35575f55d98a", size = 58756, upload-time = "2026-03-06T02:53:16.319Z" }, + { url = "https://files.pythonhosted.org/packages/4c/b6/1db817582c49c7fcbb7df6809d0f515af29d7c2fbf57eb44c36e98fb1492/wrapt-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ff2aad9c4cda28a8f0653fc2d487596458c2a3f475e56ba02909e950a9efa6a9", size = 61255, upload-time = "2026-03-06T02:52:45.663Z" }, + { url = "https://files.pythonhosted.org/packages/a2/16/9b02a6b99c09227c93cd4b73acc3678114154ec38da53043c0ddc1fba0dc/wrapt-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6433ea84e1cfacf32021d2a4ee909554ade7fd392caa6f7c13f1f4bf7b8e8748", size = 61848, upload-time = "2026-03-06T02:53:48.728Z" }, + { url = "https://files.pythonhosted.org/packages/af/aa/ead46a88f9ec3a432a4832dfedb84092fc35af2d0ba40cd04aea3889f247/wrapt-2.1.2-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c20b757c268d30d6215916a5fa8461048d023865d888e437fab451139cad6c8e", size = 121433, upload-time = "2026-03-06T02:54:40.328Z" }, + { url = "https://files.pythonhosted.org/packages/3a/9f/742c7c7cdf58b59085a1ee4b6c37b013f66ac33673a7ef4aaed5e992bc33/wrapt-2.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:79847b83eb38e70d93dc392c7c5b587efe65b3e7afcc167aa8abd5d60e8761c8", size = 123013, upload-time = "2026-03-06T02:53:26.58Z" }, + { url = "https://files.pythonhosted.org/packages/e8/44/2c3dd45d53236b7ed7c646fcf212251dc19e48e599debd3926b52310fafb/wrapt-2.1.2-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f8fba1bae256186a83d1875b2b1f4e2d1242e8fac0f58ec0d7e41b26967b965c", size = 117326, upload-time = "2026-03-06T02:53:11.547Z" }, + { url = "https://files.pythonhosted.org/packages/74/e2/b17d66abc26bd96f89dec0ecd0ef03da4a1286e6ff793839ec431b9fae57/wrapt-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e3d3b35eedcf5f7d022291ecd7533321c4775f7b9cd0050a31a68499ba45757c", size = 121444, upload-time = "2026-03-06T02:54:09.5Z" }, + { url = "https://files.pythonhosted.org/packages/3c/62/e2977843fdf9f03daf1586a0ff49060b1b2fc7ff85a7ea82b6217c1ae36e/wrapt-2.1.2-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:6f2c5390460de57fa9582bc8a1b7a6c86e1a41dfad74c5225fc07044c15cc8d1", size = 116237, upload-time = "2026-03-06T02:54:03.884Z" }, + { url = "https://files.pythonhosted.org/packages/88/dd/27fc67914e68d740bce512f11734aec08696e6b17641fef8867c00c949fc/wrapt-2.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7dfa9f2cf65d027b951d05c662cc99ee3bd01f6e4691ed39848a7a5fffc902b2", size = 120563, upload-time = "2026-03-06T02:53:20.412Z" }, + { url = "https://files.pythonhosted.org/packages/ec/9f/b750b3692ed2ef4705cb305bd68858e73010492b80e43d2a4faa5573cbe7/wrapt-2.1.2-cp312-cp312-win32.whl", hash = "sha256:eba8155747eb2cae4a0b913d9ebd12a1db4d860fc4c829d7578c7b989bd3f2f0", size = 58198, upload-time = "2026-03-06T02:53:37.732Z" }, + { url = "https://files.pythonhosted.org/packages/8e/b2/feecfe29f28483d888d76a48f03c4c4d8afea944dbee2b0cd3380f9df032/wrapt-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:1c51c738d7d9faa0b3601708e7e2eda9bf779e1b601dce6c77411f2a1b324a63", size = 60441, upload-time = "2026-03-06T02:52:47.138Z" }, + { url = "https://files.pythonhosted.org/packages/44/e1/e328f605d6e208547ea9fd120804fcdec68536ac748987a68c47c606eea8/wrapt-2.1.2-cp312-cp312-win_arm64.whl", hash = "sha256:c8e46ae8e4032792eb2f677dbd0d557170a8e5524d22acc55199f43efedd39bf", size = 58836, upload-time = "2026-03-06T02:53:22.053Z" }, + { url = "https://files.pythonhosted.org/packages/4c/7a/d936840735c828b38d26a854e85d5338894cda544cb7a85a9d5b8b9c4df7/wrapt-2.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:787fd6f4d67befa6fe2abdffcbd3de2d82dfc6fb8a6d850407c53332709d030b", size = 61259, upload-time = "2026-03-06T02:53:41.922Z" }, + { url = "https://files.pythonhosted.org/packages/5e/88/9a9b9a90ac8ca11c2fdb6a286cb3a1fc7dd774c00ed70929a6434f6bc634/wrapt-2.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:4bdf26e03e6d0da3f0e9422fd36bcebf7bc0eeb55fdf9c727a09abc6b9fe472e", size = 61851, upload-time = "2026-03-06T02:52:48.672Z" }, + { url = "https://files.pythonhosted.org/packages/03/a9/5b7d6a16fd6533fed2756900fc8fc923f678179aea62ada6d65c92718c00/wrapt-2.1.2-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:bbac24d879aa22998e87f6b3f481a5216311e7d53c7db87f189a7a0266dafffb", size = 121446, upload-time = "2026-03-06T02:54:14.013Z" }, + { url = "https://files.pythonhosted.org/packages/45/bb/34c443690c847835cfe9f892be78c533d4f32366ad2888972c094a897e39/wrapt-2.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:16997dfb9d67addc2e3f41b62a104341e80cac52f91110dece393923c0ebd5ca", size = 123056, upload-time = "2026-03-06T02:54:10.829Z" }, + { url = "https://files.pythonhosted.org/packages/93/b9/ff205f391cb708f67f41ea148545f2b53ff543a7ac293b30d178af4d2271/wrapt-2.1.2-cp313-cp313-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:162e4e2ba7542da9027821cb6e7c5e068d64f9a10b5f15512ea28e954893a267", size = 117359, upload-time = "2026-03-06T02:53:03.623Z" }, + { url = "https://files.pythonhosted.org/packages/1f/3d/1ea04d7747825119c3c9a5e0874a40b33594ada92e5649347c457d982805/wrapt-2.1.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:f29c827a8d9936ac320746747a016c4bc66ef639f5cd0d32df24f5eacbf9c69f", size = 121479, upload-time = "2026-03-06T02:53:45.844Z" }, + { url = "https://files.pythonhosted.org/packages/78/cc/ee3a011920c7a023b25e8df26f306b2484a531ab84ca5c96260a73de76c0/wrapt-2.1.2-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:a9dd9813825f7ecb018c17fd147a01845eb330254dff86d3b5816f20f4d6aaf8", size = 116271, upload-time = "2026-03-06T02:54:46.356Z" }, + { url = "https://files.pythonhosted.org/packages/98/fd/e5ff7ded41b76d802cf1191288473e850d24ba2e39a6ec540f21ae3b57cb/wrapt-2.1.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6f8dbdd3719e534860d6a78526aafc220e0241f981367018c2875178cf83a413", size = 120573, upload-time = "2026-03-06T02:52:50.163Z" }, + { url = "https://files.pythonhosted.org/packages/47/c5/242cae3b5b080cd09bacef0591691ba1879739050cc7c801ff35c8886b66/wrapt-2.1.2-cp313-cp313-win32.whl", hash = "sha256:5c35b5d82b16a3bc6e0a04349b606a0582bc29f573786aebe98e0c159bc48db6", size = 58205, upload-time = "2026-03-06T02:53:47.494Z" }, + { url = "https://files.pythonhosted.org/packages/12/69/c358c61e7a50f290958809b3c61ebe8b3838ea3e070d7aac9814f95a0528/wrapt-2.1.2-cp313-cp313-win_amd64.whl", hash = "sha256:f8bc1c264d8d1cf5b3560a87bbdd31131573eb25f9f9447bb6252b8d4c44a3a1", size = 60452, upload-time = "2026-03-06T02:53:30.038Z" }, + { url = "https://files.pythonhosted.org/packages/8e/66/c8a6fcfe321295fd8c0ab1bd685b5a01462a9b3aa2f597254462fc2bc975/wrapt-2.1.2-cp313-cp313-win_arm64.whl", hash = "sha256:3beb22f674550d5634642c645aba4c72a2c66fb185ae1aebe1e955fae5a13baf", size = 58842, upload-time = "2026-03-06T02:52:52.114Z" }, + { url = "https://files.pythonhosted.org/packages/da/55/9c7052c349106e0b3f17ae8db4b23a691a963c334de7f9dbd60f8f74a831/wrapt-2.1.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:0fc04bc8664a8bc4c8e00b37b5355cffca2535209fba1abb09ae2b7c76ddf82b", size = 63075, upload-time = "2026-03-06T02:53:19.108Z" }, + { url = "https://files.pythonhosted.org/packages/09/a8/ce7b4006f7218248dd71b7b2b732d0710845a0e49213b18faef64811ffef/wrapt-2.1.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a9b9d50c9af998875a1482a038eb05755dfd6fe303a313f6a940bb53a83c3f18", size = 63719, upload-time = "2026-03-06T02:54:33.452Z" }, + { url = "https://files.pythonhosted.org/packages/e4/e5/2ca472e80b9e2b7a17f106bb8f9df1db11e62101652ce210f66935c6af67/wrapt-2.1.2-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2d3ff4f0024dd224290c0eabf0240f1bfc1f26363431505fb1b0283d3b08f11d", size = 152643, upload-time = "2026-03-06T02:52:42.721Z" }, + { url = "https://files.pythonhosted.org/packages/36/42/30f0f2cefca9d9cbf6835f544d825064570203c3e70aa873d8ae12e23791/wrapt-2.1.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3278c471f4468ad544a691b31bb856374fbdefb7fee1a152153e64019379f015", size = 158805, upload-time = "2026-03-06T02:54:25.441Z" }, + { url = "https://files.pythonhosted.org/packages/bb/67/d08672f801f604889dcf58f1a0b424fe3808860ede9e03affc1876b295af/wrapt-2.1.2-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a8914c754d3134a3032601c6984db1c576e6abaf3fc68094bb8ab1379d75ff92", size = 145990, upload-time = "2026-03-06T02:53:57.456Z" }, + { url = "https://files.pythonhosted.org/packages/68/a7/fd371b02e73babec1de6ade596e8cd9691051058cfdadbfd62a5898f3295/wrapt-2.1.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:ff95d4264e55839be37bafe1536db2ab2de19da6b65f9244f01f332b5286cfbf", size = 155670, upload-time = "2026-03-06T02:54:55.309Z" }, + { url = "https://files.pythonhosted.org/packages/86/2d/9fe0095dfdb621009f40117dcebf41d7396c2c22dca6eac779f4c007b86c/wrapt-2.1.2-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:76405518ca4e1b76fbb1b9f686cff93aebae03920cc55ceeec48ff9f719c5f67", size = 144357, upload-time = "2026-03-06T02:54:24.092Z" }, + { url = "https://files.pythonhosted.org/packages/0e/b6/ec7b4a254abbe4cde9fa15c5d2cca4518f6b07d0f1b77d4ee9655e30280e/wrapt-2.1.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:c0be8b5a74c5824e9359b53e7e58bef71a729bacc82e16587db1c4ebc91f7c5a", size = 150269, upload-time = "2026-03-06T02:53:31.268Z" }, + { url = "https://files.pythonhosted.org/packages/6e/6b/2fabe8ebf148f4ee3c782aae86a795cc68ffe7d432ef550f234025ce0cfa/wrapt-2.1.2-cp313-cp313t-win32.whl", hash = "sha256:f01277d9a5fc1862f26f7626da9cf443bebc0abd2f303f41c5e995b15887dabd", size = 59894, upload-time = "2026-03-06T02:54:15.391Z" }, + { url = "https://files.pythonhosted.org/packages/ca/fb/9ba66fc2dedc936de5f8073c0217b5d4484e966d87723415cc8262c5d9c2/wrapt-2.1.2-cp313-cp313t-win_amd64.whl", hash = "sha256:84ce8f1c2104d2f6daa912b1b5b039f331febfeee74f8042ad4e04992bd95c8f", size = 63197, upload-time = "2026-03-06T02:54:41.943Z" }, + { url = "https://files.pythonhosted.org/packages/c0/1c/012d7423c95d0e337117723eb8ecf73c622ce15a97847e84cf3f8f26cd7e/wrapt-2.1.2-cp313-cp313t-win_arm64.whl", hash = "sha256:a93cd767e37faeddbe07d8fc4212d5cba660af59bdb0f6372c93faaa13e6e679", size = 60363, upload-time = "2026-03-06T02:54:48.093Z" }, + { url = "https://files.pythonhosted.org/packages/1a/c7/8528ac2dfa2c1e6708f647df7ae144ead13f0a31146f43c7264b4942bf12/wrapt-2.1.2-py3-none-any.whl", hash = "sha256:b8fd6fa2b2c4e7621808f8c62e8317f4aae56e59721ad933bac5239d913cf0e8", size = 43993, upload-time = "2026-03-06T02:53:12.905Z" }, +] + [[package]] name = "xformers" version = "0.0.32.post1" diff --git a/verifiers/envs/experimental/braintrust_tracing/__init__.py b/verifiers/envs/experimental/braintrust_tracing/__init__.py new file mode 100644 index 000000000..2e897fe06 --- /dev/null +++ b/verifiers/envs/experimental/braintrust_tracing/__init__.py @@ -0,0 +1,24 @@ +"""Experimental Braintrust tracing variants of the core environment classes. + +Usage:: + + from verifiers.envs.experimental.braintrust_tracing.stateful_tool_env import StatefulToolEnv + +These classes are drop-in replacements for their non-tracing counterparts. +Set ``BRAINTRUST_API_KEY`` and optionally ``VF_BRAINTRUST_PROJECT`` to enable +trace logging to Braintrust. +""" + +from verifiers.envs.experimental.braintrust_tracing.environment import Environment +from verifiers.envs.experimental.braintrust_tracing.multiturn_env import MultiTurnEnv +from verifiers.envs.experimental.braintrust_tracing.stateful_tool_env import ( + StatefulToolEnv, +) +from verifiers.envs.experimental.braintrust_tracing.tool_env import ToolEnv + +__all__ = [ + "Environment", + "MultiTurnEnv", + "ToolEnv", + "StatefulToolEnv", +] diff --git a/verifiers/envs/experimental/braintrust_tracing/braintrust_tracing.py b/verifiers/envs/experimental/braintrust_tracing/braintrust_tracing.py new file mode 100644 index 000000000..e6698a076 --- /dev/null +++ b/verifiers/envs/experimental/braintrust_tracing/braintrust_tracing.py @@ -0,0 +1,692 @@ +""" +Braintrust tracing for verifiers. + +Provides nested span traces for rollouts, model calls, tool calls, and scoring. +Each rollout becomes a trace in Braintrust with child spans showing the full +execution timeline. + +Activation: + Set BRAINTRUST_API_KEY to enable. No-op when unset. + Optionally set VF_BRAINTRUST_PROJECT to override the default project name + (default: "verifiers"). + +Span hierarchy produced per rollout: + + rollout (type=task) ← root span / trace + ├── setup_state (type=task) + ├── turn_0 (type=task) + │ ├── model_request (type=llm) + │ └── env_response (type=task) ← includes tool_call children + │ ├── tool_call:navigate (type=tool) + │ └── tool_call:computer (type=tool) + ├── turn_1 (type=task) + │ └── model_request (type=llm) + └── scoring (type=score) + +All public helpers are safe to call even when Braintrust is not configured; +they degrade to no-ops with near-zero overhead. Errors inside telemetry +code are swallowed so they never interfere with evaluation runs. +""" + +from __future__ import annotations + +import contextvars +import logging +import os +import sys +import threading +import time +import uuid +from typing import Any + +_log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Lazy singleton +# --------------------------------------------------------------------------- + +_INSTANCE: _Tracing | None = None +_LOCK = threading.Lock() + +# Coroutine-local storage for passing the rollout span from +# _run_rollout_state → rollout() across the await boundary without +# storing mutable state on the shared Environment instance. +_pending_rollout_span: contextvars.ContextVar[Any] = contextvars.ContextVar( + "_pending_rollout_span", default=None +) + +# Run-level tags: coroutine-local storage so concurrent generate() calls +# each get their own tag set without overwriting each other. +_run_tags: contextvars.ContextVar[list[str]] = contextvars.ContextVar( + "_run_tags", default=[] +) + + +def _get() -> _Tracing: + global _INSTANCE + if _INSTANCE is None: + with _LOCK: + if _INSTANCE is None: + _INSTANCE = _Tracing() + return _INSTANCE + + +# --------------------------------------------------------------------------- +# Public helpers +# --------------------------------------------------------------------------- + + +def enabled() -> bool: + """Return True when Braintrust tracing is active.""" + return _get().enabled + + +def flush() -> None: + """Flush any buffered spans to Braintrust.""" + try: + inst = _get() + if inst.enabled and inst._logger is not None: + inst._logger.flush() + except Exception: + pass + + +def set_run_tags(tags: list[str] | None = None) -> list[str]: + """Set tags for the current eval run. + + If *tags* is ``None`` a unique tag is auto-generated from the current + timestamp (``run--``). Returns the active tag list + so callers can inspect or log it. + + Uses a ``ContextVar`` so concurrent ``generate()`` calls each get their + own isolated tag set. + """ + if tags is not None: + new_tags = list(tags) + else: + short_id = uuid.uuid4().hex[:8] + new_tags = [f"run-{int(time.time())}-{short_id}"] + _run_tags.set(new_tags) + return new_tags + + +def get_run_tags() -> list[str]: + """Return the currently active run tags (empty list when unset).""" + return list(_run_tags.get()) + + +def clear_run_tags() -> None: + """Clear run tags (called after generate completes).""" + _run_tags.set([]) + + +# -- Span lifecycle helpers ------------------------------------------------ +# These return an opaque span object (or None when disabled). Callers store +# the span and later call end_span() when the phase completes. + + +def start_rollout_span( + *, + env_id: str = "", + model: str = "", + example_id: Any = "", + trajectory_id: str = "", +) -> Any: + """Start a root span for one rollout. Returns span or None.""" + try: + inst = _get() + if not inst.enabled or inst._logger is None: + return None + kwargs: dict[str, Any] = { + "name": "rollout", + "span_attributes": {"type": "task"}, + "input": {"example_id": _safe(example_id)}, + "metadata": { + "env_id": env_id, + "model": model, + "trajectory_id": trajectory_id, + }, + } + tags = _run_tags.get() + if tags: + kwargs["tags"] = list(tags) + span = inst._logger.start_span(**kwargs) + return span + except Exception: + return None + + +def start_child_span( + parent: Any, + *, + name: str, + span_type: str = "task", + input: Any = None, + metadata: dict[str, Any] | None = None, +) -> Any: + """Start a child span under *parent*. Returns span or None.""" + try: + if parent is None: + return None + kwargs: dict[str, Any] = { + "name": name, + "span_attributes": {"type": span_type}, + } + if input is not None: + kwargs["input"] = _safe(input) + if metadata: + kwargs["metadata"] = _safe(metadata) + return parent.start_span(**kwargs) + except Exception: + return None + + +def log_to_span( + span: Any, + *, + input: Any = None, + output: Any = None, + metadata: dict[str, Any] | None = None, + metrics: dict[str, Any] | None = None, + error: str | None = None, + scores: dict[str, float] | None = None, +) -> None: + """Log data to an existing span.""" + try: + if span is None: + return + kwargs: dict[str, Any] = {} + if input is not None: + kwargs["input"] = _safe(input) + if output is not None: + kwargs["output"] = _safe(output) + if metadata: + kwargs["metadata"] = _safe(metadata) + if metrics: + kwargs["metrics"] = _safe(metrics) + if error: + kwargs["error"] = error + if scores: + kwargs["scores"] = scores + if kwargs: + span.log(**kwargs) + except Exception: + pass + + +def end_span(span: Any) -> None: + """End (close) a span. Safe to call with None.""" + try: + if span is not None: + span.end() + except Exception: + pass + + +# -- Convenience: rollout lifecycle ---------------------------------------- + + +def rollout_started( + *, + env_id: str = "", + model: str = "", + example_id: Any = "", + trajectory_id: str = "", +) -> Any: + """Start a rollout root span. Returns the span.""" + return start_rollout_span( + env_id=env_id, + model=model, + example_id=example_id, + trajectory_id=trajectory_id, + ) + + +def rollout_completed( + span: Any, + *, + reward: Any = None, + num_turns: int = 0, + duration_s: float = 0.0, + stop_condition: str = "", + error: str = "", + input_tokens: float = 0.0, + output_tokens: float = 0.0, +) -> None: + """Finalize and close a rollout span.""" + try: + if span is None: + return + metrics: dict[str, Any] = { + "duration_s": duration_s, + "num_turns": num_turns, + "tokens": input_tokens + output_tokens, + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + } + meta: dict[str, Any] = {"stop_condition": stop_condition} + scores: dict[str, float] | None = None + if reward is not None: + try: + scores = {"reward": float(reward)} + except (TypeError, ValueError): + meta["reward_raw"] = repr(reward) + log_to_span( + span, + output={"stop_condition": stop_condition, "num_turns": num_turns}, + metadata=meta, + metrics=metrics, + error=error or None, + scores=scores, + ) + end_span(span) + except Exception: + pass + + +# -- Convenience: setup ---------------------------------------------------- + + +def setup_started(parent: Any, *, env_id: str = "", trajectory_id: str = "") -> Any: + """Start a setup_state child span.""" + return start_child_span( + parent, + name="setup_state", + span_type="task", + metadata={"env_id": env_id, "trajectory_id": trajectory_id}, + ) + + +def setup_completed(span: Any, *, duration_s: float = 0.0, error: str = "") -> None: + """Finalize and close a setup span.""" + log_to_span( + span, + metrics={"duration_s": duration_s}, + error=error or None, + ) + end_span(span) + + +# -- Convenience: turns ---------------------------------------------------- + + +def turn_started( + parent: Any, + *, + turn_index: int = 0, + trajectory_id: str = "", +) -> Any: + """Start a turn child span.""" + return start_child_span( + parent, + name=f"turn_{turn_index}", + span_type="task", + metadata={"turn_index": turn_index, "trajectory_id": trajectory_id}, + ) + + +def turn_completed( + span: Any, + *, + duration_s: float = 0.0, + model_duration_s: float | None = None, + env_duration_s: float | None = None, + is_truncated: bool = False, + error: str = "", +) -> None: + """Finalize and close a turn span.""" + metrics: dict[str, Any] = {"duration_s": duration_s} + if model_duration_s is not None: + metrics["model_duration_s"] = model_duration_s + if env_duration_s is not None: + metrics["env_duration_s"] = env_duration_s + log_to_span( + span, + metrics=metrics, + metadata={"is_truncated": is_truncated}, + error=error or None, + ) + end_span(span) + + +# -- Convenience: model requests ------------------------------------------- + + +def model_request_span( + parent: Any, + *, + model: str = "", + turn_index: int = 0, + messages: Any = None, +) -> Any: + """Start a model_request child span (type=llm).""" + input_val = None + if messages is not None: + input_val = _safe(messages) + return start_child_span( + parent, + name="model_request", + span_type="llm", + input=input_val, + metadata={"model": model, "turn_index": turn_index}, + ) + + +def model_request_completed( + span: Any, + *, + duration_s: float = 0.0, + input_tokens: float = 0.0, + output_tokens: float = 0.0, + response: Any = None, + error: str = "", +) -> None: + """Finalize and close a model_request span.""" + metrics: dict[str, Any] = { + "duration_s": duration_s, + "prompt_tokens": input_tokens, + "completion_tokens": output_tokens, + "tokens": input_tokens + output_tokens, + } + output_val = None + if response is not None: + output_val = _safe_response(response) + log_to_span( + span, + output=output_val, + metrics=metrics, + error=error or None, + ) + end_span(span) + + +# -- Convenience: tool calls ----------------------------------------------- + + +def tool_call_started( + parent: Any, + *, + tool_name: str = "", + tool_call_id: str = "", + tool_args: Any = None, +) -> Any: + """Start a tool_call child span (type=tool).""" + return start_child_span( + parent, + name=f"tool_call:{tool_name}", + span_type="tool", + input=_safe(tool_args) if tool_args is not None else {"tool_name": tool_name}, + metadata={"tool_name": tool_name, "tool_call_id": tool_call_id}, + ) + + +def tool_call_completed( + span: Any, + *, + duration_s: float = 0.0, + result: Any = None, + error: str = "", +) -> None: + """Finalize and close a tool_call span.""" + log_to_span( + span, + output=_safe(result) if result is not None else None, + metrics={"duration_s": duration_s}, + error=error or None, + ) + end_span(span) + + +# -- Convenience: scoring -------------------------------------------------- + + +def scoring_started(parent: Any, *, trajectory_id: str = "") -> Any: + """Start a scoring child span (type=score).""" + return start_child_span( + parent, + name="scoring", + span_type="score", + metadata={"trajectory_id": trajectory_id}, + ) + + +def scoring_completed( + span: Any, + *, + duration_s: float = 0.0, + reward: Any = None, +) -> None: + """Finalize and close a scoring span.""" + scores: dict[str, float] | None = None + if reward is not None: + try: + scores = {"reward": float(reward)} + except (TypeError, ValueError): + pass + log_to_span( + span, + metrics={"duration_s": duration_s}, + scores=scores, + ) + end_span(span) + + +# -- Convenience: groups --------------------------------------------------- + + +def group_started( + *, + env_id: str = "", + model: str = "", + example_id: Any = "", + group_size: int = 0, +) -> Any: + """Start a root span for a group of rollouts.""" + try: + inst = _get() + if not inst.enabled or inst._logger is None: + return None + kwargs: dict[str, Any] = { + "name": "group", + "span_attributes": {"type": "task"}, + "input": {"example_id": _safe(example_id), "group_size": group_size}, + "metadata": {"env_id": env_id, "model": model}, + } + tags = _run_tags.get() + if tags: + kwargs["tags"] = list(tags) + return inst._logger.start_span(**kwargs) + except Exception: + return None + + +def group_completed( + span: Any, + *, + duration_s: float = 0.0, + avg_reward: float | None = None, + group_size: int = 0, +) -> None: + """Finalize and close a group span.""" + scores: dict[str, float] | None = None + if avg_reward is not None: + scores = {"avg_reward": avg_reward} + log_to_span( + span, + output={"group_size": group_size}, + metrics={"duration_s": duration_s}, + scores=scores, + ) + end_span(span) + + +# -- Convenience: generate ------------------------------------------------- + + +def generate_started( + *, + env_id: str = "", + model: str = "", + num_inputs: int = 0, +) -> Any: + """Start a root span for a generate() call.""" + try: + inst = _get() + if not inst.enabled or inst._logger is None: + return None + kwargs: dict[str, Any] = { + "name": "generate", + "span_attributes": {"type": "eval"}, + "input": {"num_inputs": num_inputs}, + "metadata": {"env_id": env_id, "model": model}, + } + tags = _run_tags.get() + if tags: + kwargs["tags"] = list(tags) + return inst._logger.start_span(**kwargs) + except Exception: + return None + + +def generate_completed( + span: Any, + *, + duration_s: float = 0.0, + num_outputs: int = 0, + avg_reward: float | None = None, +) -> None: + """Finalize and close a generate span.""" + scores: dict[str, float] | None = None + if avg_reward is not None: + scores = {"avg_reward": avg_reward} + log_to_span( + span, + output={"num_outputs": num_outputs}, + metrics={"duration_s": duration_s}, + scores=scores, + ) + end_span(span) + flush() + + +# -- Convenience: stop condition / timeout --------------------------------- + + +def stop_condition_triggered( + parent: Any, + *, + condition: str = "", + error: str = "", +) -> None: + """Log a stop condition event on the rollout span.""" + log_to_span( + parent, + metadata={"stop_condition": condition}, + error=error or None, + ) + + +def timeout_triggered( + parent: Any, + *, + timeout_seconds: float | None = None, +) -> None: + """Log a timeout event on the rollout span.""" + log_to_span( + parent, + metadata={"timed_out": True, "timeout_seconds": timeout_seconds}, + error="timeout", + ) + + +# --------------------------------------------------------------------------- +# Serialization helpers +# --------------------------------------------------------------------------- + + +def _safe(obj: Any) -> Any: + """Best-effort JSON-safe conversion of arbitrary objects.""" + if obj is None: + return None + if isinstance(obj, (str, int, float, bool)): + return obj + if isinstance(obj, dict): + return {str(k): _safe(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_safe(v) for v in obj] + # Pydantic models + if hasattr(obj, "model_dump"): + try: + return obj.model_dump() + except Exception: + pass + # Fallback to repr + try: + return repr(obj)[:2000] + except Exception: + return "" + + +def _safe_response(response: Any) -> Any: + """Extract serializable data from a model response object.""" + try: + if hasattr(response, "model_dump"): + d = response.model_dump() + # Trim large fields to keep span data manageable + if isinstance(d, dict) and "choices" in d: + for choice in d.get("choices", []): + if isinstance(choice, dict) and "message" in choice: + msg = choice["message"] + if isinstance(msg, dict) and "content" in msg: + content = msg["content"] + if isinstance(content, str) and len(content) > 5000: + msg["content"] = content[:5000] + "...(truncated)" + return d + if hasattr(response, "message"): + return _safe(response.message) + return repr(response)[:2000] + except Exception: + return "" + + +# --------------------------------------------------------------------------- +# Singleton backend +# --------------------------------------------------------------------------- + + +class _Tracing: + """Lazy singleton that manages the Braintrust logger.""" + + def __init__(self) -> None: + self._logger: Any = None + self._enabled = False + self._api_key = os.environ.get("BRAINTRUST_API_KEY", "") + self._project = os.environ.get("VF_BRAINTRUST_PROJECT", "verifiers") + + if not self._api_key: + return + + try: + import braintrust + + self._logger = braintrust.init_logger( + project=self._project, + api_key=self._api_key, + ) + self._enabled = True + _log.info("Braintrust tracing enabled for project=%s", self._project) + except ImportError: + print( + "WARNING: BRAINTRUST_API_KEY is set but 'braintrust' package " + "is not installed. Run: pip install braintrust", + file=sys.stderr, + ) + except Exception as exc: + print( + f"WARNING: Failed to initialize Braintrust tracing: {exc}", + file=sys.stderr, + ) + + @property + def enabled(self) -> bool: + return self._enabled diff --git a/verifiers/envs/experimental/braintrust_tracing/environment.py b/verifiers/envs/experimental/braintrust_tracing/environment.py new file mode 100644 index 000000000..a8bf90ebc --- /dev/null +++ b/verifiers/envs/experimental/braintrust_tracing/environment.py @@ -0,0 +1,1505 @@ +from __future__ import annotations + +import asyncio +import atexit +import json +import logging +import multiprocessing as mp +import signal +import time +import uuid +import warnings +from abc import ABC, abstractmethod +from collections import defaultdict +from collections.abc import Mapping +from concurrent.futures import ThreadPoolExecutor +from copy import deepcopy +from multiprocessing.connection import Connection +from multiprocessing.process import BaseProcess +from pathlib import Path +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + List, + TypeVar, + cast, + final, +) + +from verifiers.clients import Client, resolve_client +from verifiers.decorators import discover_decorated +from verifiers.serve import ZMQEnvClient +from verifiers.utils.client_utils import ( + resolve_client_config, + resolve_client_configs, +) +from verifiers.utils.eval_utils import filter_inputs +from verifiers.utils.path_utils import is_valid_eval_results_path +from verifiers.utils.serve_utils import get_free_port +from verifiers.utils.thread_utils import scale_executors + +if TYPE_CHECKING: + from datasets import Dataset + +import verifiers as vf +from verifiers.parsers.parser import Parser +from verifiers.rubrics.rubric import Rubric +from verifiers.serve import EnvClient +from verifiers.types import ( + ClientConfig, + DatasetBuilder, + GenerateMetadata, + GenerateOutputs, + LogCallback, + Messages, + MessageType, + ProgressCallback, + Response, + RolloutInput, + RolloutOutput, + RolloutTiming, + SamplingArgs, + StartCallback, + State, + TokenUsage, + Tool, + flatten_task_input, +) +from verifiers.utils.async_utils import ( + maybe_call_with_named_args, + maybe_retry, + maybe_semaphore, + with_sem, +) +from verifiers.utils.error_utils import ErrorChain +from verifiers.utils.message_utils import normalize_messages +from verifiers.utils.save_utils import ( + GenerateOutputsBuilder, + load_outputs, + make_dataset, + push_results_to_hf_hub, + save_metadata, + save_new_outputs, + save_outputs, + state_to_output, + validate_resume_metadata, +) +from verifiers.utils.usage_utils import StateUsageTracker, extract_usage_tokens + +import verifiers.envs.experimental.braintrust_tracing.braintrust_tracing as _bt + +_MESSAGE_TYPE_UNSET = object() + + +class Environment(ABC): + """ + Base class for all environments. + """ + + def __init__( + self, + dataset: Dataset | DatasetBuilder | None = None, + eval_dataset: Dataset | DatasetBuilder | None = None, + system_prompt: str | None = None, + few_shot: Messages | None = None, + parser: Parser | None = None, + rubric: Rubric | None = None, + sampling_args: SamplingArgs | None = None, + message_type: MessageType | object = _MESSAGE_TYPE_UNSET, + tool_defs: list[Tool] | None = None, + max_workers: int = 512, + env_id: str | None = None, + env_args: dict | None = None, + map_kwargs: dict = {}, + max_seq_len: int | None = None, + score_rollouts: bool = True, + pass_threshold: float = 0.5, + **kwargs, + ): + if message_type is _MESSAGE_TYPE_UNSET: + resolved_message_type: MessageType = "chat" + else: + if message_type != "chat": + warnings.warn( + "message_type is deprecated and will be removed", + DeprecationWarning, + stacklevel=2, + ) + resolved_message_type = cast(MessageType, message_type) + self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + self.message_type: MessageType = resolved_message_type + if "oai_tools" in kwargs: + raise ValueError( + "`oai_tools` is no longer supported. Use `tool_defs` with provider-agnostic " + "tool definitions: [{'name': ..., 'description': ..., 'parameters': {...}}]." + ) + self.tool_defs: list[Tool] | None = self._normalize_tool_defs(tool_defs) + self.system_prompt = system_prompt + self.few_shot = few_shot + self.parser = parser or Parser() + self.rubric = rubric or Rubric() + if self.parser.__class__ != self.rubric.parser.__class__: + self.logger.warning( + "The parser and rubric parser are different. This may cause unexpected behavior." + ) + + self.env_id = env_id or "" + self.env_args = env_args or {} + self.max_seq_len = max_seq_len + self.map_kwargs = map_kwargs + + self.set_score_rollouts(score_rollouts) + self.pass_threshold = pass_threshold + + self.env_client: EnvClient | None = None + self.env_server_process: BaseProcess | None = None + self.death_pipe_writer: Connection | None = None + + # Dataset sources (builders) and built datasets + # Use get_dataset()/get_eval_dataset() for access; build_dataset() to trigger build + self.dataset: Dataset | None = None + self.eval_dataset: Dataset | None = None + + if dataset is not None: + if callable(dataset): + self.dataset_source: DatasetBuilder | None = cast( + DatasetBuilder, dataset + ) + else: + self.dataset_source = lambda ds=dataset: ds + self.build_dataset() # Eagerly build for raw datasets (backwards compat) + else: + self.dataset_source = None + + if eval_dataset is not None: + if callable(eval_dataset): + self.eval_dataset_source: DatasetBuilder | None = cast( + DatasetBuilder, eval_dataset + ) + else: + self.eval_dataset_source = lambda ds=eval_dataset: ds + self.build_eval_dataset() # Eagerly build for raw datasets (backwards compat) + else: + self.eval_dataset_source = None + + self.sampling_args = {"n": 1, "extra_body": {}} + if sampling_args is not None: + # merge extra_body if provided + cast(dict[str, Any], self.sampling_args["extra_body"]).update( + cast(dict[str, Any], sampling_args.get("extra_body", {})) + ) + # copy other keys + for key, value in sampling_args.items(): + if key != "extra_body": + self.sampling_args[key] = value + + self.max_workers = max_workers + for key, value in kwargs.items(): + setattr(self, key, value) + + if ( + self.dataset_source is None + and self.eval_dataset_source is None + and self.dataset is None + and self.eval_dataset is None + ): + raise ValueError("Either dataset or eval_dataset must be provided") + self.rollouts_per_example = None + self._stop_conditions: list[StopCondition] = [] + self._cleanup_handlers: list[RolloutCleanup] = [] + self._teardown_handlers: list[EnvironmentTeardown] = [] + + self.__post_init__() + + @property + def requires_group_rollouts(self) -> bool: + return self.rubric.has_group_rewards + + @property + def provides_advantages(self) -> bool: + return self.rubric.has_advantages + + @staticmethod + def _normalize_tool_defs( + tools: list[Tool] | list[dict[str, Any]] | None, + ) -> list[Tool] | None: + """Normalize tools to provider-agnostic vf.Tool objects. + + Accepts: + - vf.Tool objects + - vf.Tool-like dicts: {"name", "description", "parameters", "strict?"} + """ + if tools is None: + return None + + normalized: list[Tool] = [] + for raw_tool in tools: + if isinstance(raw_tool, Tool): + normalized.append(raw_tool) + continue + + if isinstance(raw_tool, dict): + if raw_tool.get("type") == "function" and isinstance( + raw_tool.get("function"), dict + ): + raise ValueError( + "Legacy OpenAI tool schema is no longer supported. " + "Use `tool_defs` entries in vf.Tool format: " + "{'name': ..., 'description': ..., 'parameters': {...}}." + ) + + normalized.append(Tool.model_validate(raw_tool)) + + return normalized + + def __post_init__(self): + self._stop_conditions = discover_decorated(self, "stop") + self._cleanup_handlers = discover_decorated(self, "cleanup") + self._teardown_handlers = discover_decorated(self, "teardown") + + def _sync_teardown(): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + asyncio.run(self._teardown()) + else: + loop.create_task(self._teardown()) + + atexit.register(_sync_teardown) + signal.signal( + signal.SIGINT, + lambda sig, frame: ( + _sync_teardown(), + signal.default_int_handler(sig, frame), + ), + ) + signal.signal(signal.SIGTERM, lambda _, __: (_sync_teardown(), exit(143))) + + def _ensure_example_id(self, dataset: Dataset) -> Dataset: + """Ensure example_id column exists and is integer type.""" + if "example_id" in dataset.column_names and not isinstance( + dataset["example_id"][0], int + ): + dataset = dataset.rename_column("example_id", "src_id") + if "example_id" not in dataset.column_names: + dataset = dataset.add_column("example_id", range(len(dataset))) + return dataset + + def _ensure_prompt( + self, + dataset: Dataset, + system_prompt: str | None = None, + few_shot: Messages | None = None, + question_key: str = "question", + answer_key: str = "answer", + map_kwargs: dict = {}, + ) -> Dataset: + """Ensure prompt column exists.""" + if "prompt" not in dataset.column_names: + + def format_prompt_fn(prompt_str: str) -> Messages: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if few_shot: + messages.extend(few_shot) + messages.append({"role": "user", "content": prompt_str}) + return messages + + if answer_key == "answer": + dataset = dataset.map( + lambda x: { + "prompt": format_prompt_fn(x[question_key]), + }, + **map_kwargs, + ) + else: + dataset = dataset.map( + lambda x: { + "prompt": format_prompt_fn(x[question_key]), + "answer": x[answer_key], + }, + **map_kwargs, + ) + + else: + if system_prompt is not None: + + def prepend_system_prompt(prompt: list[Any]) -> list[Any]: + assert isinstance(prompt, list), ( + f"prompt must be a list of messages when system_prompt is provided, got {type(prompt)}" + ) + # Check if a system message already exists (first message) + first = prompt[0] if prompt else None + first_role = ( + first.get("role") + if isinstance(first, dict) + else getattr(first, "role", None) + ) + if first_role == "system": + return prompt + # Prepend as a plain dict so Arrow/HuggingFace can serialize. + # Normalization to Pydantic happens later in init_state. + return [{"role": "system", "content": system_prompt}, *prompt] + + dataset = dataset.map( + lambda x: {"prompt": prepend_system_prompt(x["prompt"])}, + **map_kwargs, + ) + if few_shot is not None: + self.logger.warning( + "Dataset already has a 'prompt' column, so the provided few_shot examples will be ignored." + ) + return dataset + + def _format_dataset( + self, + dataset: Dataset, + system_prompt: str | None = None, + few_shot: Messages | None = None, + question_key: str = "question", + answer_key: str = "answer", + map_kwargs: dict = {}, + ) -> Dataset: + """ + Format dataset by creating example_id and prompt columns. + """ + if "env_id" in dataset.column_names: + dataset = dataset.remove_columns(["env_id"]) + dataset = self._ensure_example_id(dataset) + dataset = self._ensure_prompt( + dataset, system_prompt, few_shot, question_key, answer_key, map_kwargs + ) + return dataset + + def _format_completion_dataset( + self, dataset: Dataset, map_kwargs: dict = {} + ) -> Dataset: + """ + Format dataset by creating example_id. + """ + if "env_id" in dataset.column_names: + dataset = dataset.remove_columns(["env_id"]) + dataset = self._ensure_example_id(dataset) + return dataset + + def _format_dataset_source(self, dataset: Dataset) -> Dataset: + """Format a dataset as chat (messages); client maps to its format at request time.""" + return self._format_dataset( + dataset, + self.system_prompt, + self.few_shot, + map_kwargs=self.map_kwargs, + ) + + def build_dataset(self) -> Dataset | None: + """Build and cache the training dataset from source if needed.""" + if self.dataset is not None: + return self.dataset + if self.dataset_source is None: + return None + built = self.dataset_source() + self.dataset = self._format_dataset_source(built) + return self.dataset + + def build_eval_dataset(self) -> Dataset | None: + """Build and cache the evaluation dataset from source if needed.""" + if self.eval_dataset is not None: + return self.eval_dataset + if self.eval_dataset_source is None: + return None + built = self.eval_dataset_source() + self.eval_dataset = self._format_dataset_source(built) + return self.eval_dataset + + @final + def get_dataset(self, n: int = -1, seed: int | None = None) -> Dataset: + self.build_dataset() + if self.dataset is None: + raise ValueError("dataset is not set") + if seed is not None: + self.dataset = self.dataset.shuffle(seed=seed) + if n > 0: + n = min(n, len(self.dataset)) + return self.dataset.select(range(n)) + return self.dataset + + @final + def get_eval_dataset(self, n: int = -1, seed: int | None = None) -> Dataset: + self.build_eval_dataset() + if self.eval_dataset is None: + self.logger.warning( + "eval_dataset is not set, falling back to train dataset" + ) + return self.get_dataset(n, seed) + if seed is not None: + self.eval_dataset = self.eval_dataset.shuffle(seed=seed) + if n > 0: + n = min(n, len(self.eval_dataset)) + return self.eval_dataset.select(range(n)) + return self.eval_dataset + + @final + def _get_usage_tracker( + self, state: State, create_if_missing: bool = True + ) -> StateUsageTracker | None: + tracker = state.get("usage_tracker") + if isinstance(tracker, StateUsageTracker): + return tracker + if not create_if_missing: + return None + tracker = StateUsageTracker() + state["usage_tracker"] = tracker + # Expose read-only usage in state for live inspection. + state["usage"] = tracker.usage + return tracker + + @final + def increment_state_usage( + self, + state: State, + input_tokens: int | float = 0, + output_tokens: int | float = 0, + ) -> None: + tracker = self._get_usage_tracker(state, create_if_missing=True) + assert tracker is not None + tracker.increment(input_tokens, output_tokens) + + @final + def increment_state_usage_from_response( + self, state: State, response: object + ) -> None: + tracker = self._get_usage_tracker(state, create_if_missing=True) + assert tracker is not None + tracker.increment_from_response(response) + + @final + def get_state_usage(self, state: State) -> TokenUsage | None: + tracker = self._get_usage_tracker(state, create_if_missing=False) + if tracker is not None: + return tracker.snapshot() + usage = state.get("usage") + if isinstance(usage, Mapping): + try: + return { + "input_tokens": float(usage.get("input_tokens", 0.0)), + "output_tokens": float(usage.get("output_tokens", 0.0)), + } + except (TypeError, ValueError): + return None + return None + + async def get_model_response( + self, + state: State, + prompt: Messages, + client: Client | None = None, + model: str | None = None, + tool_defs: list[Tool] | None = None, + sampling_args: SamplingArgs | None = None, + ) -> Response: + """ + Get model response for a given prompt (chat or completion). + + Uses the client abstraction layer to handle provider-specific API calls. + The vf.Client adapter handles prompt conversion, sampling arg normalization, + overlong prompt detection, and response parsing. + """ + + def resolve_optional_args( + client: Client | None, + model: str | None, + tool_defs: list[Tool] | None, + sampling_args: SamplingArgs | None, + ) -> tuple[Client, str, list[Tool] | None, SamplingArgs]: + """Resolve optional arguments, fallback to state or class defaults.""" + client = client if client is not None else state["client"] + assert client is not None + model = model or state["model"] + assert model is not None + if tool_defs is None: + tool_defs = state.get("tool_defs") + if tool_defs is not None and not all( + isinstance(tool, Tool) for tool in tool_defs + ): + raise TypeError( + "tool_defs must be a list of vf.Tool objects at runtime. " + "Normalize tool dicts during state initialization." + ) + if isinstance(tool_defs, list) and len(tool_defs) == 0: + tool_defs = None + sampling_args = cast( + SamplingArgs, sampling_args or state["sampling_args"] or {} + ) + return client, model, tool_defs, sampling_args + + client, model, tool_defs, sampling_args = resolve_optional_args( + client, model, tool_defs, sampling_args + ) + + self._get_usage_tracker(state, create_if_missing=True) + + bt_parent = state.get("_bt_turn_span") or state.get("_bt_span") + bt_span = _bt.model_request_span( + bt_parent, + model=model, + turn_index=len(state.get("trajectory", [])), + messages=prompt, + ) + t0 = time.monotonic() + error_msg = "" + response = None + try: + response = await client.get_response( + prompt=prompt, + model=model, + tools=tool_defs, + sampling_args=sampling_args, + state=state, + ) + except Exception as exc: + error_msg = repr(exc)[:500] + raise + except BaseException as exc: + error_msg = repr(exc)[:500] + raise + finally: + dur = time.monotonic() - t0 + input_tok, output_tok = 0.0, 0.0 + if not error_msg and response is not None: + input_tok, output_tok = ( + float(v) for v in extract_usage_tokens(response) + ) + _bt.model_request_completed( + bt_span, + duration_s=dur, + input_tokens=input_tok, + output_tokens=output_tok, + response=response if not error_msg else None, + error=error_msg, + ) + self.increment_state_usage_from_response(state, response) + + return response + + @final + async def init_state( + self, + input: RolloutInput, + client: Client | ClientConfig, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + """ + Create initial state from dataset row. + Environment-agnostic - just stores the data. + + Creates State with input fields in "input" RolloutInput for structured access, + while State's forwarding behavior keeps direct access ergonomic. + """ + state_input = cast(RolloutInput, deepcopy(input)) + if "info" in state_input and isinstance(state_input["info"], str): + state_input["info"] = json.loads(state_input["info"]) + state_task = flatten_task_input(state_input) + state_input = cast(RolloutInput, state_task) + state = State(input=state_input) + state["task"] = state_task + + # Convert prompt to Pydantic messages + raw_prompt = state_input.get("prompt") + if isinstance(raw_prompt, (str, list)): + state["prompt"] = normalize_messages(raw_prompt, field_name="input.prompt") + + state["client"] = resolve_client(client) + state["model"] = model + state["sampling_args"] = sampling_args + state["is_completed"] = False + state["is_truncated"] = False + + # Resolve tool definitions + resolved_tool_defs: list[Tool] | list[dict[str, Any]] | None = None + info = state.get("info") + if isinstance(info, dict) and "oai_tools" in info: + raise ValueError( + "info['oai_tools'] is no longer supported. Use info['tool_defs'] with " + "provider-agnostic tool definitions: " + "[{'name': ..., 'description': ..., 'parameters': {...}}]." + ) + if isinstance(info, dict) and "tool_defs" in info: + resolved_tool_defs = info["tool_defs"] + elif self.tool_defs is not None: + resolved_tool_defs = self.tool_defs + else: + resolved_tool_defs = [] + state["tool_defs"] = self._normalize_tool_defs(resolved_tool_defs) or [] + + state["trajectory"] = [] + state["completion"] = None + self._get_usage_tracker(state, create_if_missing=True) + state["trajectory_id"] = uuid.uuid4().hex + state["reward"] = None + state["metrics"] = None + state["error"] = None + state["final_env_response"] = None + state["timing"] = RolloutTiming() + return state + + @abstractmethod + async def rollout( + self, + input: RolloutInput, + client: Client, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + """ + Run a rollout for a given input. + """ + pass + + async def cleanup( + self, + state: State, + task: object | None = None, + resources: object | None = None, + ): + """ + Finalize rollout state and clean up rollout-local resources. + """ + for handler in self._cleanup_handlers: + await maybe_call_with_named_args( + handler, + task=task, + state=state, + env=self, + resources=resources, + ) + + async def _teardown(self): + """ + Tear down environment resources. + """ + await self.rubric.teardown() + for handler in self._teardown_handlers: + await handler() + + async def _render_stop(self, state: State, condition, **kwargs) -> bool: + if await maybe_call_with_named_args( + condition, + state=state, + env=self, + **kwargs, + ): + state["is_completed"] = True + state["is_truncated"] = state.get("is_truncated", False) or any( + step.get("is_truncated", False) for step in state.get("trajectory", []) + ) + state["stop_condition"] = condition.__name__ + if state.get("stop_condition") == "has_error": + err = state["error"] + err_chain = ErrorChain(err) + self.logger.error(f"Aborted rollout due to {repr(err_chain)}") + return True + return False + + @final + async def is_completed(self, state: State, **kwargs) -> bool: + """Check stop conditions and render stop fields when one fires.""" + for condition in self._stop_conditions: + if await self._render_stop(state, condition, **kwargs): + return True + return False + + async def _run_rollout_state( + self, + input: RolloutInput, + client: Client, + model: str, + sampling_args: SamplingArgs, + ) -> State: + t0 = time.monotonic() + bt_span = _bt.rollout_started( + env_id=self.env_id, + model=model, + example_id=input.get("example_id", ""), + trajectory_id="", + ) + # Pass the rollout span to rollout() via a coroutine-local context var. + # We cannot thread it through `input` (deep-copied by init_state) or + # store it on `self` (shared across concurrent rollouts). + _bt._pending_rollout_span.set(bt_span) + state = await self.rollout( + input, + client, + model, + sampling_args, + ) + + bt_score = _bt.scoring_started( + bt_span, trajectory_id=state.get("trajectory_id", "") + ) + state["timing"].scoring.start = time.time() + if self.score_rollouts: + await self.rubric.score_rollout(state) + else: + await self.rubric.dummy_score_rollout(state) + state["timing"].scoring.end = time.time() + scoring_dur = state["timing"].scoring.end - state["timing"].scoring.start + _bt.scoring_completed( + bt_score, duration_s=scoring_dur, reward=state.get("reward") + ) + + await self.rubric.cleanup(state) + + usage = self.get_state_usage(state) or {} + _bt.rollout_completed( + bt_span, + reward=state.get("reward"), + num_turns=len(state.get("trajectory", [])), + duration_s=time.monotonic() - t0, + stop_condition=state.get("stop_condition", ""), + error=repr(state["error"])[:500] if state.get("error") else "", + input_tokens=float(usage.get("input_tokens", 0)), + output_tokens=float(usage.get("output_tokens", 0)), + ) + return state + + async def _run_group_states( + self, + group_inputs: list[RolloutInput], + client: Client, + model: str, + sampling_args: SamplingArgs, + ) -> list[State]: + t0 = time.monotonic() + example_id = group_inputs[0].get("example_id", "") if group_inputs else "" + bt_group = _bt.group_started( + env_id=self.env_id, + model=model, + example_id=example_id, + group_size=len(group_inputs), + ) + + bt_rollout_spans: list[object | None] = [None for _ in group_inputs] + rollout_start_times: list[float] = [0.0] * len(group_inputs) + + async def _traced_rollout(idx: int, ri: RolloutInput) -> State: + """Wrap a single rollout with its own Braintrust span.""" + r_t0 = time.monotonic() + bt_rollout = _bt.start_child_span( + bt_group, + name="rollout", + span_type="task", + input={"example_id": _bt._safe(ri.get("example_id", ""))}, + metadata={"env_id": self.env_id, "model": model}, + ) + bt_rollout_spans[idx] = bt_rollout + rollout_start_times[idx] = r_t0 + _bt._pending_rollout_span.set(bt_rollout) + return await self.rollout(ri, client, model, sampling_args) + + group_states = await asyncio.gather( + *[_traced_rollout(i, inp) for i, inp in enumerate(group_inputs)] + ) + + start_scoring = time.time() + for state in group_states: + state["timing"].scoring.start = start_scoring + if self.score_rollouts: + await self.rubric.score_group(group_states) + else: + await self.rubric.dummy_score_group(group_states) + end_scoring = time.time() + for state in group_states: + state["timing"].scoring.end = end_scoring + + for state in group_states: + await self.rubric.cleanup(state) + + now = time.monotonic() + for st, bt_rollout, r_t0 in zip( + group_states, bt_rollout_spans, rollout_start_times + ): + usage = self.get_state_usage(st) or {} + _bt.rollout_completed( + bt_rollout, + reward=st.get("reward"), + num_turns=len(st.get("trajectory", [])), + duration_s=now - r_t0, + stop_condition=st.get("stop_condition", ""), + error=repr(st["error"])[:500] if st.get("error") else "", + input_tokens=float(usage.get("input_tokens", 0)), + output_tokens=float(usage.get("output_tokens", 0)), + ) + + rewards = [s.get("reward") for s in group_states if s.get("reward") is not None] + _bt.group_completed( + bt_group, + duration_s=time.monotonic() - t0, + avg_reward=sum(rewards) / len(rewards) if rewards else None, + group_size=len(group_inputs), + ) + return group_states + + @final + async def run_rollout( + self, + input: RolloutInput, + client: Client | ClientConfig, + model: str, + sampling_args: SamplingArgs, + max_retries: int = 0, + state_columns: list[str] | None = None, + env_client: EnvClient | None = None, + ) -> RolloutOutput: + """Generate and, optionally, score a rollout.""" + + resolved_client_config: ClientConfig | None = None + if isinstance(client, ClientConfig): + resolved_client_config = resolve_client_config(client) + + env_client = env_client or self.env_client + if env_client is not None: # in server mode + if resolved_client_config is None: + raise ValueError( + f"client must be have type ClientConfig in server mode, got {type(client)}" + ) + return await env_client.run_rollout( + input, + resolved_client_config, + model, + sampling_args, + max_retries, + state_columns, + ) + + resolved_client = resolve_client(client) + + async def run_rollout_attempt() -> State: + return await self._run_rollout_state( + input, + resolved_client, + model, + sampling_args, + ) + + state = await maybe_retry(run_rollout_attempt, max_retries=max_retries)() + output = state_to_output(state, state_columns or []) + return output + + @final + async def run_group( + self, + group_inputs: list[RolloutInput], + client: Client | ClientConfig, + model: str, + sampling_args: SamplingArgs, + max_retries: int = 0, + state_columns: list[str] | None = None, + env_client: EnvClient | None = None, + **kwargs, + ) -> list[RolloutOutput]: + """Generate and, optionally, score one group.""" + + resolved_client_config: ClientConfig | None = None + if isinstance(client, ClientConfig): + resolved_client_config = resolve_client_config(client) + + env_client = env_client or self.env_client + if env_client is not None: # in server mode + if resolved_client_config is None: + raise ValueError( + f"client must be have type ClientConfig in server mode, got {type(client)}" + ) + return await env_client.run_group( + group_inputs, + resolved_client_config, + model, + sampling_args, + max_retries, + state_columns, + ) + + resolved_client = resolve_client(client) + + async def run_group_attempt() -> list[State]: + return await self._run_group_states( + group_inputs, + resolved_client, + model, + sampling_args, + ) + + group_states = await maybe_retry(run_group_attempt, max_retries=max_retries)() + outputs = [ + state_to_output(state, state_columns or []) for state in group_states + ] + return outputs + + async def generate( + self, + inputs: Dataset | List[RolloutInput], + client: Client | ClientConfig, + model: str, + sampling_args: SamplingArgs | None = None, + max_concurrent: int = -1, + results_path: Path | None = None, + state_columns: list[str] | None = None, + save_results: bool = False, + push_to_hf_hub: bool = False, + hf_hub_dataset_name: str | None = None, + independent_scoring: bool = False, + max_retries: int = 0, + on_start: StartCallback | None = None, + on_progress: ProgressCallback | list[ProgressCallback] | None = None, + on_log: LogCallback | None = None, + ) -> GenerateOutputs: + """ + Generate rollouts for a set of inputs. + + Args: + client: Can be a single AsyncOpenAI client or a ClientConfig. + on_progress: Progress callback(s). None uses the default tqdm progress bar. + A single callback replaces the default. A list of callbacks runs + alongside the default. + """ + from datasets import Dataset + from tqdm import tqdm + + pbar: tqdm | None = None + + def default_on_start( + raw_inputs: list[RolloutInput], + filtered_inputs: list[RolloutInput] | list[list[RolloutInput]], + ) -> None: + """Initializes the progress bar from the raw inputs.""" + nonlocal pbar + + total_rollouts = len(raw_inputs) + total_groups = len(set([i["example_id"] for i in raw_inputs])) + rollouts_per_example = ( + total_rollouts // total_groups if total_groups > 0 else 0 + ) + + if ( + isinstance(filtered_inputs, list) + and filtered_inputs + and isinstance(filtered_inputs[0], list) + ): + remaining_rollouts = sum(len(g) for g in filtered_inputs) + else: + remaining_rollouts = len(filtered_inputs) + saved_rollouts = total_rollouts - remaining_rollouts + + if filtered_inputs: + if isinstance(filtered_inputs[0], list): + pbar_total = total_groups + pbar_initial = saved_rollouts // rollouts_per_example + pbar_desc = f"Processing {total_groups} groups ({total_rollouts} total rollouts)" + else: + pbar_total = total_rollouts + pbar_initial = saved_rollouts + pbar_desc = f"Processing {total_rollouts} rollouts" + + pbar = tqdm( + total=pbar_total, + initial=pbar_initial, + desc=pbar_desc, + postfix=dict(reward="?"), + ) + + def default_on_progress( + all_outputs: list[RolloutOutput], + new_outputs: list[RolloutOutput], + new_metadata: GenerateMetadata, + ) -> None: + """Updates the progress bar from the new outputs.""" + nonlocal pbar + if pbar is not None: + pbar.update(1) + pbar.set_postfix(reward=new_metadata.get("avg_reward")) + + def default_on_log(message: str) -> None: + """Logs using the environment logger.""" + self.logger.info(message) + + on_start = on_start or cast(StartCallback, default_on_start) + extra_on_progress: list[ProgressCallback] = [] + if isinstance(on_progress, list): + extra_on_progress = cast(list[ProgressCallback], on_progress) + elif on_progress is not None: + extra_on_progress = [on_progress] + + def default_on_progress(*a, **kw): + None + + on_log = on_log or cast(LogCallback, default_on_log) + + if isinstance(inputs, Dataset): + raw_inputs = cast(list[RolloutInput], inputs.to_list()) + elif isinstance(inputs, list): + raw_inputs = inputs + + # set up semaphores + sem = await maybe_semaphore(max_concurrent) + + # set up sampling args + default_sampling_args = deepcopy(self.sampling_args) + if sampling_args is not None: + default_sampling_args.update(sampling_args) + sampling_args = default_sampling_args + + # initialize outputs builder + total_rollouts = len(raw_inputs) + num_examples = len(set([i["example_id"] for i in raw_inputs])) + rollouts_per_example = total_rollouts // num_examples if num_examples > 0 else 0 + builder = GenerateOutputsBuilder( + env_id=self.env_id, + env_args=self.env_args, + model=model, + client=client, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + state_columns=state_columns, + sampling_args=sampling_args, + results_path=results_path, + pass_threshold=self.pass_threshold, + ) + + single_client: Client | None = None + endpoint_client_configs: list[ClientConfig] = [] + endpoint_client_idx = 0 + if isinstance(client, ClientConfig): + endpoint_client_configs = resolve_client_configs(client) + else: + # Raw async-client path + single_client = client + + local_endpoint_clients: list[Client] = [] + + def get_client_for_group() -> Client | ClientConfig: + """Get next client in round-robin order or return the single client.""" + nonlocal endpoint_client_idx + if self.env_client is not None and endpoint_client_configs: + config = endpoint_client_configs[ + endpoint_client_idx % len(endpoint_client_configs) + ] + endpoint_client_idx += 1 + return config + if local_endpoint_clients: + local_client = local_endpoint_clients[ + endpoint_client_idx % len(local_endpoint_clients) + ] + endpoint_client_idx += 1 + return local_client + assert single_client is not None + return single_client + + if _bt.enabled(): + run_tags = _bt.set_run_tags() + if run_tags: + logging.getLogger(__name__).info("Braintrust run tag: %s", run_tags[0]) + + try: + if self.env_client is None and endpoint_client_configs: + for endpoint_config in endpoint_client_configs: + local_endpoint_clients.append(resolve_client(endpoint_config)) + + # load existing results if available + if results_path is not None and is_valid_eval_results_path(results_path): + validate_resume_metadata( + results_path=results_path, + env_id=self.env_id, + model=model, + num_examples=num_examples, + rollouts_per_example=rollouts_per_example, + ) + on_log(f"Resuming evaluation from {results_path}") + outputs = load_outputs(results_path) + builder.add_outputs(outputs) + filtered_inputs = filter_inputs( + raw_inputs, outputs, rollouts_per_example + ) + if not filtered_inputs: + on_log( + "No remaining rollouts to evaluate, returning completed outputs" + ) + return builder.build(sort_by_example_id=True) + on_log( + f"Found {len(outputs)} completed rollout(s), {len(filtered_inputs)} remaining rollout(s)" + ) + else: + filtered_inputs = raw_inputs + + if save_results: + on_log(f"Saving results to {builder.results_path}") + + tasks: dict[asyncio.Task, int] = {} + try: + # create tasks based on mode + if independent_scoring: + on_start(raw_inputs, filtered_inputs) + for i, rollout_input in enumerate(filtered_inputs): + task = asyncio.create_task( + with_sem( + sem, + self.run_rollout( + rollout_input, + get_client_for_group(), + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ), + ), + ) + tasks[task] = i + else: + group_inputs: dict[int, list[RolloutInput]] = defaultdict(list) + for rollout_input in filtered_inputs: + example_id = rollout_input["example_id"] + group_inputs[example_id].append(rollout_input) + filtered_group_inputs = list(group_inputs.values()) + on_start(raw_inputs, filtered_group_inputs) + + for i, group_input in enumerate(filtered_group_inputs): + # For grouped scoring, keep each group on one endpoint so + # rollouts in the same group can benefit from shared KV cache. + group_client = get_client_for_group() + task = asyncio.create_task( + with_sem( + sem, + self.run_group( + group_input, + group_client, + model, + sampling_args, + max_retries=max_retries, + state_columns=state_columns, + ), + ), + ) + tasks[task] = i + + for coro in asyncio.as_completed(tasks.keys()): + result = await coro + + # normalize: independent_scoring returns RolloutOutput, group returns list[RolloutOutput] + new_outputs = [result] if independent_scoring else result + builder.add_outputs(new_outputs) + metadata = builder.build_metadata() + + default_on_progress(builder.outputs, new_outputs, metadata) + for cb in extra_on_progress: + cb(builder.outputs, new_outputs, metadata) + + # incrementally save outputs (offloaded to thread to avoid blocking the event loop) + if save_results: + await asyncio.to_thread( + save_new_outputs, new_outputs, builder.results_path + ) + await asyncio.to_thread( + save_metadata, metadata, builder.results_path + ) + finally: + # cancel all outstanding tasks and await their completion + pending = [task for task in tasks.keys() if not task.done()] + if pending: + for task in pending: + task.cancel() + await asyncio.gather(*pending, return_exceptions=True) + + # build final results (sorted by example_id for deterministic ordering) + results = builder.build(sort_by_example_id=True) + + # save if requested + if save_results: + await asyncio.to_thread( + save_outputs, results["outputs"], builder.results_path + ) + await asyncio.to_thread( + save_metadata, results["metadata"], builder.results_path + ) + if push_to_hf_hub: + push_results_to_hf_hub(results, hf_hub_dataset_name) + if on_log is not None: + on_log( + f"Saved final results to {results['metadata']['path_to_save']}" + ) + + return results + finally: + _bt.clear_run_tags() + _bt.flush() + if pbar is not None: + pbar.close() + if local_endpoint_clients: + await asyncio.gather( + *(client.close() for client in local_endpoint_clients), + return_exceptions=True, + ) + + def generate_sync( + self, + inputs: Dataset | List[RolloutInput], + client: Client | ClientConfig, + **kwargs, + ) -> GenerateOutputs: + coro = self.generate( + inputs, + client=client, + **kwargs, + ) + # check if we're in existing event loop (e.g. Jupyter) + try: + loop = asyncio.get_running_loop() + import nest_asyncio + + nest_asyncio.apply() + return loop.run_until_complete(coro) + except RuntimeError: + pass + + # script case: create new loop and executor + executor = ThreadPoolExecutor(max_workers=self.max_workers) + loop = asyncio.new_event_loop() + try: + loop.set_default_executor(executor) + asyncio.set_event_loop(loop) + return loop.run_until_complete(coro) + finally: + loop.close() + asyncio.set_event_loop(None) + # shutdown the executor to prevent thread leaks + executor.shutdown(wait=False) + + # evaluation + def _get_eval_inputs( + self, num_examples: int = -1, rollouts_per_example: int = 1 + ) -> List[RolloutInput]: + # get_eval_dataset handles fallback to train dataset if no eval source exists + inputs = self.get_eval_dataset(n=num_examples) + assert inputs is not None, "No dataset found" + if rollouts_per_example > 1: + inputs = inputs.repeat(rollouts_per_example) + return inputs.to_list() + + async def evaluate( + self, + client: Client | ClientConfig, + model: str, + sampling_args: SamplingArgs | None = None, + num_examples: int = -1, + rollouts_per_example: int = 1, + max_concurrent: int = -1, + results_path: Path | None = None, + state_columns: list[str] | None = None, + save_results: bool = False, + push_to_hf_hub: bool = False, + hf_hub_dataset_name: str | None = None, + independent_scoring: bool = False, + max_retries: int = 0, + on_start: StartCallback | None = None, + on_progress: ProgressCallback | list[ProgressCallback] | None = None, + on_log: LogCallback | None = None, + **kwargs, + ) -> GenerateOutputs: + """ + Evaluate model on the Environment evaluation dataset. + + Args: + on_progress: Progress callback(s). None uses the default tqdm progress bar. + A single callback replaces the default. A list of callbacks runs + alongside the default. + """ + inputs = self._get_eval_inputs(num_examples, rollouts_per_example) + return await self.generate( + inputs, + client=client, + model=model, + sampling_args=sampling_args, + max_concurrent=max_concurrent, + results_path=results_path, + state_columns=state_columns, + save_results=save_results, + push_to_hf_hub=push_to_hf_hub, + hf_hub_dataset_name=hf_hub_dataset_name, + independent_scoring=independent_scoring, + max_retries=max_retries, + on_start=on_start, + on_progress=on_progress, + on_log=on_log, + **kwargs, + ) + + def evaluate_sync( + self, + client: Client | ClientConfig, + model: str, + sampling_args: SamplingArgs | None = None, + num_examples: int = -1, + rollouts_per_example: int = 1, + max_concurrent: int = -1, + results_path: Path | None = None, + state_columns: list[str] | None = None, + save_results: bool = False, + push_to_hf_hub: bool = False, + hf_hub_dataset_name: str | None = None, + independent_scoring: bool = False, + max_retries: int = 0, + ) -> GenerateOutputs: + """ + Evaluate model on the Environment evaluation dataset synchronously. + """ + inputs = self._get_eval_inputs(num_examples, rollouts_per_example) + return self.generate_sync( + inputs, + client=client, + model=model, + sampling_args=sampling_args, + max_concurrent=max_concurrent, + results_path=results_path, + state_columns=state_columns, + save_results=save_results, + push_to_hf_hub=push_to_hf_hub, + hf_hub_dataset_name=hf_hub_dataset_name, + independent_scoring=independent_scoring, + max_retries=max_retries, + ) + + # setters for use by trainers + def set_kwargs(self, **kwargs) -> None: + """ + Set environment attributes, using setter methods when available. + + For each kwarg, checks if a `set_{key}` method exists and calls it, + otherwise falls back to setattr. This ensures proper propagation for + attributes like `score_rollouts` in EnvGroup. + """ + for key, value in kwargs.items(): + setter_name = f"set_{key}" + setter = getattr(self, setter_name, None) + if setter is not None and callable(setter): + setter(value) + else: + setattr(self, key, value) + + def add_rubric(self, rubric: Rubric) -> None: + if self.rubric is None: + self.rubric = rubric + elif isinstance(self.rubric, vf.RubricGroup): + self.rubric.rubrics.append(rubric) + else: + self.rubric = vf.RubricGroup(rubrics=[self.rubric, rubric]) + + def set_concurrency(self, concurrency: int) -> None: + """Set concurrency and scale all registered thread-pool executors. + + Each executor applies its own scaling function to map concurrency + to max_workers (default 1:1). + """ + self.concurrency = concurrency + scale_executors(concurrency=concurrency) + + def set_max_seq_len(self, max_seq_len: int | None) -> None: + """Set the maximum sequence length for this environment.""" + self.max_seq_len = max_seq_len + + def set_score_rollouts(self, score_rollouts: bool) -> None: + """Set the score rollouts flag for this environment.""" + self.score_rollouts = score_rollouts + + async def start_server( + self, + address: str | None = None, + extra_env_kwargs: dict[str, Any] | None = None, + num_workers: int = 1, + # logging configs + log_level: str | None = None, + log_dir: str | None = None, + console_logging: bool = True, + # health check configs + health_check_interval: float = 1.0, # 1s + startup_timeout: float = 600.0, # 10m + recovery_timeout: float = 600.0, # 10m + ) -> None: + """Start a ZMQ server process for this environment. + + Spawns a :class:`ZMQEnvServer` (router + *num_workers* worker + processes, default 1). + + .. warning:: + This method is subject to change. External users should avoid + depending on it directly. + """ + from verifiers.serve import ZMQEnvServer + + address = address or f"tcp://127.0.0.1:{get_free_port()}" + extra_env_kwargs = extra_env_kwargs or {} + + # Death pipe: parent keeps writer, children monitor reader. + # When the parent dies (even SIGKILL), the OS closes the writer end + # and children get EOF → clean shutdown. + death_pipe_reader, self.death_pipe_writer = mp.Pipe(duplex=False) + + # Use spawn to avoid inheriting file descriptors (e.g. sockets) from + # the parent process, which has caused hangs when multiple env server + # subprocesses share the same fds. + ctx = mp.get_context("spawn") + self.env_server_process = ctx.Process( + target=ZMQEnvServer.run_server, + args=( + self.env_id, + self.env_args, + extra_env_kwargs, + log_level, + log_dir, + console_logging, + ), + kwargs=dict( + address=address, + num_workers=num_workers, + death_pipe=death_pipe_reader, + ), + daemon=False, + ) + self.env_server_process.start() + # Close the reader in the parent — only children should hold it. + death_pipe_reader.close() + self.env_client = ZMQEnvClient( + address=address, + health_check_interval=health_check_interval, + startup_timeout=startup_timeout, + recovery_timeout=recovery_timeout, + name=self.env_id, + ) + await self.env_client.wait_for_server_startup() + + async def stop_server(self) -> None: + """Stop the ZMQ server process for this environment. + + .. warning:: + This method is subject to change. External users should avoid + depending on it directly. + """ + if self.env_client is not None: + await self.env_client.close() + self.env_client = None + if self.death_pipe_writer is not None: + self.death_pipe_writer.close() + self.death_pipe_writer = None + if self.env_server_process is not None: + from verifiers.utils.process_utils import terminate_process + + terminate_process(self.env_server_process) + self.env_server_process = None + + make_dataset = staticmethod(make_dataset) + + +_EnvT = TypeVar("_EnvT", bound=Environment) +StopCondition = Callable[[State], Awaitable[bool]] +RolloutCleanup = Callable[[State], Awaitable[None]] +EnvironmentTeardown = Callable[[], Awaitable[None]] diff --git a/verifiers/envs/experimental/braintrust_tracing/multiturn_env.py b/verifiers/envs/experimental/braintrust_tracing/multiturn_env.py new file mode 100644 index 000000000..fe1b708a8 --- /dev/null +++ b/verifiers/envs/experimental/braintrust_tracing/multiturn_env.py @@ -0,0 +1,273 @@ +import asyncio +import logging +import time +from abc import abstractmethod +from typing import final + +import verifiers as vf +import verifiers.envs.experimental.braintrust_tracing.braintrust_tracing as _bt +from verifiers.envs.experimental.braintrust_tracing.environment import Environment +from verifiers.clients import Client +from verifiers.types import ( + Messages, + Response, + RolloutInput, + SamplingArgs, + State, + TimeSpan, + TrajectoryStep, +) +from verifiers.utils.message_utils import ( + concat_messages, + maybe_normalize_messages, +) +from verifiers.utils.response_utils import ( + parse_response_message, + parse_response_tokens, +) + +logger = logging.getLogger(__name__) + + +class MultiTurnMonitorRubric(vf.Rubric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_metric(self.num_turns) + + async def num_turns(self, state: State) -> int: + return len(state["trajectory"]) + + +class MultiTurnEnv(Environment): + def __init__( + self, + max_turns: int = -1, + timeout_seconds: float | None = None, + **kwargs, + ): + super().__init__(**kwargs) + self.max_turns = max_turns + self.timeout_seconds = timeout_seconds + self.max_total_completion_tokens: int = -1 + + self.add_rubric(MultiTurnMonitorRubric()) + + def set_max_total_completion_tokens(self, max_total_completion_tokens: int) -> None: + """Set the maximum total completion tokens for this environment.""" + self.max_total_completion_tokens = max_total_completion_tokens + + @abstractmethod + async def env_response( + self, messages: Messages, state: State, **kwargs + ) -> Messages: + """ + Generate a response from the environment. + """ + pass + + @vf.stop(priority=100) # always check for errors first + async def has_error(self, state: State, **kwargs) -> bool: + return state.get("error") is not None + + @vf.stop + async def prompt_too_long(self, state: State) -> bool: + return state.get("prompt_too_long", False) + + @vf.stop + async def max_turns_reached(self, state: State) -> bool: + return len(state["trajectory"]) >= self.max_turns and self.max_turns > 0 + + def mark_timed_out(self, state: State) -> None: + state["timed_out"] = True + state["is_completed"] = True + state["stop_condition"] = "timeout_reached" + + @vf.stop + async def max_total_completion_tokens_reached(self, state: State) -> bool: + if self.max_total_completion_tokens <= 0: + return False + usage = self.get_state_usage(state) + if usage is None: + return False + return usage["output_tokens"] >= self.max_total_completion_tokens + + @vf.stop + async def has_final_env_response(self, state: State) -> bool: + """Check if env_response signaled termination via final_env_response.""" + return state.get("final_env_response") is not None + + async def setup_state(self, state: State) -> State | None: + """Override to add environment-specific state fields. Mutate state in place.""" + return state + + async def get_prompt_messages(self, state: State) -> Messages: + """Override for rollouts with non-linear message sequences.""" + if len(state["trajectory"]) == 0: + return state["prompt"] + prev_turn_prompt = state["trajectory"][-1]["prompt"] + prev_turn_completion = state["trajectory"][-1]["completion"] + messages = concat_messages([prev_turn_prompt, prev_turn_completion]) + env_response = await self.env_response(messages, state) + env_response = maybe_normalize_messages(env_response, field_name="env_response") + return concat_messages([messages, env_response]) + + async def render_completion(self, state: State): + """Override for rollouts with non-linear message sequences.""" + if len(state["trajectory"]) == 0: + state["completion"] = [] + return + last_prompt = state["trajectory"][-1]["prompt"] + last_completion = state["trajectory"][-1]["completion"] + full_conversation = concat_messages([last_prompt, last_completion]) + if state.get("final_env_response"): + final_resp = state["final_env_response"] + final_resp = maybe_normalize_messages( + final_resp, field_name="final_env_response" + ) + full_conversation = concat_messages([full_conversation, final_resp]) + prompt_messages = state["prompt"] + state["completion"] = full_conversation[len(prompt_messages) :] + + @vf.cleanup(priority=100) + async def render_state(self, state: State) -> None: + """Render core rollout fields before user cleanup handlers run.""" + state["timing"].generation.end = time.time() + await self.render_completion(state) + + async def add_trajectory_step(self, state: State, trajectory_step: TrajectoryStep): + """Override to set intermediate rewards, advantages, or extra metadata.""" + state["trajectory"].append(trajectory_step) + + async def _finalize_rollout(self, state: State) -> None: + """Finalize rollout state and run cleanup handlers exactly once.""" + await self.cleanup(state) + + async def add_model_response( + self, + state: State, + prompt_messages: Messages, + response: Response, + ): + completion_messages = await parse_response_message(response) + tokens = await parse_response_tokens(response, self.max_seq_len) + response_is_truncated = response.message.is_truncated or False + is_truncated = response_is_truncated or ( + tokens is not None and bool(tokens.get("is_truncated")) + ) + trajectory_step = TrajectoryStep( + prompt=prompt_messages, + completion=completion_messages, + response=response, + tokens=tokens, + reward=None, + advantage=None, + is_truncated=is_truncated, + trajectory_id=state["trajectory_id"], + extras={}, + ) + await self.add_trajectory_step(state, trajectory_step) + + @final + async def rollout( + self, + input: RolloutInput, + client: Client, + model: str, + sampling_args: SamplingArgs | None = None, + ) -> State: + state = await self.init_state(input, client, model, sampling_args) + _env_id = getattr(self, "env_id", "") + + # Pick up the rollout span from the coroutine-local context var + # (set by _run_rollout_state) and attach it to state so all child + # spans (setup, turns, model requests) nest under the rollout root. + bt_rollout = _bt._pending_rollout_span.get(None) + if bt_rollout is not None: + state["_bt_span"] = bt_rollout + _bt._pending_rollout_span.set(None) + + async def rollout_loop() -> None: + nonlocal state, bt_rollout + state["timing"].generation.start = time.time() + state["timing"].setup.start = time.time() + bt_setup = _bt.setup_started( + bt_rollout, + env_id=_env_id, + trajectory_id=state.get("trajectory_id", ""), + ) + try: + setup_state = await self.setup_state(state) + if setup_state is not None: + state = setup_state + except vf.Error as e: + state["error"] = e + finally: + state["timing"].setup.end = time.time() + _bt.setup_completed( + bt_setup, + duration_s=state["timing"].setup.end - state["timing"].setup.start, + error=repr(state["error"])[:500] if state.get("error") else "", + ) + while not await self.is_completed(state): + turn_t0 = time.monotonic() + turn_idx = len(state["trajectory"]) + bt_turn = _bt.turn_started( + bt_rollout, + turn_index=turn_idx, + trajectory_id=state.get("trajectory_id", ""), + ) + # Store on state so get_model_response can nest under this turn + state["_bt_turn_span"] = bt_turn + turn_err = "" + env_dur: float | None = None + model_dur: float | None = None + try: + timing = state["timing"] + start_time = time.time() + prompt_messages = await self.get_prompt_messages(state) + end_time = time.time() + # First iteration has no preceding env_response; skip recording. + if state["trajectory"]: + timing.env.spans.append( + TimeSpan(start=start_time, end=end_time) + ) + env_dur = end_time - start_time + + prompt_messages = maybe_normalize_messages( + prompt_messages, field_name="prompt_messages" + ) + if state.get("final_env_response") is not None: + continue + + start_time = time.time() + response = await self.get_model_response(state, prompt_messages) + end_time = time.time() + model_dur = end_time - start_time + timing.model.spans.append(TimeSpan(start=start_time, end=end_time)) + await self.add_model_response(state, prompt_messages, response) + except vf.Error as e: + turn_err = repr(e)[:500] + if isinstance(e, vf.OverlongPromptError): + state["prompt_too_long"] = True + state["is_truncated"] = True + else: + state["error"] = e + finally: + _bt.turn_completed( + bt_turn, + duration_s=time.monotonic() - turn_t0, + model_duration_s=model_dur, + env_duration_s=env_dur, + is_truncated=state.get("is_truncated", False), + error=turn_err, + ) + state.pop("_bt_turn_span", None) + + try: + await asyncio.wait_for(rollout_loop(), timeout=self.timeout_seconds) + except asyncio.TimeoutError: + self.mark_timed_out(state) + _bt.timeout_triggered(bt_rollout, timeout_seconds=self.timeout_seconds) + finally: + await self._finalize_rollout(state) + return state diff --git a/verifiers/envs/experimental/braintrust_tracing/stateful_tool_env.py b/verifiers/envs/experimental/braintrust_tracing/stateful_tool_env.py new file mode 100644 index 000000000..7c3f514c6 --- /dev/null +++ b/verifiers/envs/experimental/braintrust_tracing/stateful_tool_env.py @@ -0,0 +1,180 @@ +import inspect +import json +from abc import abstractmethod +from typing import Callable, cast + +import verifiers as vf +from verifiers.envs.experimental.braintrust_tracing.tool_env import ToolEnv +from verifiers.types import Tool, ToolMessage +from verifiers.utils.tool_utils import convert_func_to_tool_def + + +def filter_signature(func, args_to_skip): + """Return a wrapper with filtered signature for schema generation. + + Does not mutate the original function. + """ + if not args_to_skip: + return func + sig = inspect.signature(func) + filtered_sig = sig.replace( + parameters=[ + p + for n, p in sig.parameters.items() + if n not in args_to_skip and n != "self" + ] + ) + filtered_annotations = { + k: v + for k, v in getattr(func, "__annotations__", {}).items() + if k not in args_to_skip + } + + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + setattr(wrapper, "__name__", getattr(func, "__name__", "unknown")) + setattr(wrapper, "__doc__", getattr(func, "__doc__", None)) + setattr(wrapper, "__signature__", filtered_sig) + setattr(wrapper, "__annotations__", filtered_annotations) + return wrapper + + +class StatefulToolEnv(ToolEnv): + def __init__( + self, + tools: list[Callable] | None = None, + max_turns: int = 10, + error_formatter: Callable[[Exception], str] = lambda e: f"{e}", + stop_errors: list[type[Exception]] | None = None, + **kwargs, + ): + super().__init__( + tools=tools, + max_turns=max_turns, + error_formatter=error_formatter, + stop_errors=stop_errors, + **kwargs, + ) + self.tools: list[Callable] = tools or [] + self.tool_defs: list[Tool] = [ + convert_func_to_tool_def(tool) for tool in self.tools + ] + self.tool_map: dict[str, Callable] = { + getattr(tool, "__name__", tool.__class__.__name__): tool + for tool in self.tools + } + self.skipped_args: dict[str, list[str]] = {} + self.max_turns: int = max_turns + + def add_tool(self, tool: Callable, args_to_skip: list[str] = []): + """Add a tool, optionally hiding arguments from the agent's view. + + Skipped args are removed from the schema shown to the agent but can be + injected at call time via update_tool_args. If a skipped arg uses a $ref + to a type in $defs, that definition is also removed to keep the schema clean. + + Assumes all non-skipped args use standard JSON types (no remaining $ref/$defs). + """ + self.tools.append(tool) + tool_def = convert_func_to_tool_def(filter_signature(tool, args_to_skip)) + params = tool_def.parameters + for arg in args_to_skip: + if ( + "properties" in params + and isinstance(params["properties"], dict) + and arg in params["properties"] + ): + arg_properties = cast(dict[str, dict], params["properties"]).pop(arg) + if "$ref" in arg_properties: + refs = arg_properties["$ref"] + ref_type = refs.split("/")[-1] + if "$defs" in params and ref_type in cast(dict, params["$defs"]): + params["$defs"].pop(ref_type) # type: ignore + if ( + "required" in params + and isinstance(params["required"], list) + and arg in params["required"] + ): + cast(list[str], params["required"]).remove(arg) + if "$defs" in params and not params["$defs"]: + params.pop("$defs") + if self.tool_defs is None: + self.tool_defs = [] + self.tool_defs.append(tool_def) + tool_name = getattr(tool, "__name__", tool.__class__.__name__) + self.tool_map[tool_name] = tool + self.skipped_args[tool_name] = args_to_skip + self.tool_monitor_rubric.add_tool_metric(tool_name) + + def remove_tool(self, tool: Callable): + self.tools.remove(tool) + tool_name = getattr(tool, "__name__", tool.__class__.__name__) + self.tool_defs = [ + tool_def for tool_def in self.tool_defs if tool_def.name != tool_name + ] + self.tool_map.pop(tool_name) + self.skipped_args.pop(tool_name) + self.tool_monitor_rubric.remove_tool_metric(tool_name) + + @abstractmethod + def update_tool_args( + self, + tool_name: str, + tool_args: dict, + messages: vf.Messages, + state: vf.State, + **kwargs, + ) -> dict: + """Update tool arguments and/or state (in-place) based on messages and state.""" + pass + + async def env_response( + self, messages: vf.Messages, state: vf.State, **kwargs + ) -> vf.Messages: + assert isinstance(messages, list) + last_msg = cast(vf.AssistantMessage, messages[-1]) + assert last_msg.tool_calls is not None + tool_messages = [] + for tool_call in last_msg.tool_calls: + tool_call_id = tool_call.id + try: + tool_name: str = tool_call.name + parsed_args = json.loads(tool_call.arguments) + if not isinstance(parsed_args, dict): + raise ValueError( + f"Expected tool arguments to be a dict, got {type(parsed_args).__name__}: {parsed_args}" + ) + tool_args: dict = parsed_args + except Exception as e: + if self._should_stop_for_error(e): + raise vf.ToolParseError from e + tool_messages.append( + ToolMessage( + role="tool", + content=self.error_formatter(e), + tool_call_id=tool_call_id, + ) + ) + continue + + tool_args = self.update_tool_args( + tool_name, tool_args, messages, state, **kwargs + ) + try: + tool_message = await self.call_tool( + tool_name, tool_args, tool_call_id, state=state + ) + tool_messages.append(tool_message) + except Exception as e: + if self._should_stop_for_error(e): + raise vf.ToolCallError from e + tool_messages.append( + ToolMessage( + role="tool", + content=self.error_formatter(e), + tool_call_id=tool_call_id, + ) + ) + + return tool_messages diff --git a/verifiers/envs/experimental/braintrust_tracing/tool_env.py b/verifiers/envs/experimental/braintrust_tracing/tool_env.py new file mode 100644 index 000000000..2fc572ae2 --- /dev/null +++ b/verifiers/envs/experimental/braintrust_tracing/tool_env.py @@ -0,0 +1,212 @@ +import json +import time +from typing import Callable, cast + +import verifiers as vf +import verifiers.envs.experimental.braintrust_tracing.braintrust_tracing as _bt +from verifiers.envs.experimental.braintrust_tracing.multiturn_env import MultiTurnEnv +from verifiers.types import AssistantMessage, Messages, ToolCall, ToolMessage +from verifiers.utils.async_utils import maybe_await +from verifiers.utils.tool_utils import ( + convert_func_to_tool_def, + is_valid_tool_content_parts, +) + + +class ToolMonitorRubric(vf.Rubric): + def __init__(self, tool_names: list[str] | None = None, **kwargs): + super().__init__(**kwargs) + + self.tool_names = list(tool_names) if tool_names else [] + + # add tool metrics + self.add_metric(self.total_tool_calls) + for tool_name in self.tool_names: + self.add_metric(self.get_tool_call_count_func(tool_name)) + + def add_tool_metric(self, tool_name: str): + if tool_name not in self.tool_names: + self.tool_names.append(tool_name) + self.add_metric(self.get_tool_call_count_func(tool_name)) + + def remove_tool_metric(self, tool_name: str): + if tool_name in self.tool_names: + self.tool_names.remove(tool_name) + metric_name = f"{tool_name}_calls" + for i, func in enumerate(self.funcs): + if func.__name__ == metric_name: + self.funcs.pop(i) + self.weights.pop(i) + break + + async def total_tool_calls(self, completion: Messages) -> float: + """Count the total number of tool calls.""" + total = 0 + assert isinstance(completion, list) + for msg in completion: + if msg.role != "assistant" or not hasattr(msg, "tool_calls"): + continue + tool_calls = msg.tool_calls + if isinstance(tool_calls, list): + total += len(tool_calls) + return float(total) + + def get_tool_call_count_func(self, tool_name: str) -> Callable: + """Create a metric that counts calls to a specific tool.""" + + async def tool_call_count_func(completion: Messages) -> int: + """Count calls to {tool_name} tool.""" + count = 0 + assert isinstance(completion, list) + for msg in completion: + if not isinstance(msg, AssistantMessage): + continue + tool_calls = msg.tool_calls + if not isinstance(tool_calls, list): + continue + for tool_call in tool_calls: + if isinstance(tool_call, ToolCall) and tool_call.name == tool_name: + count += 1 + + return count + + tool_call_count_func.__name__ = f"{tool_name}_calls" + return tool_call_count_func + + +class ToolEnv(MultiTurnEnv): + def __init__( + self, + tools: list[Callable] | None = None, + max_turns: int = 10, + error_formatter: Callable[[Exception], str] = lambda e: f"{e}", + stop_errors: list[type[Exception]] | None = None, + **kwargs, + ): + self.tools = tools or [] + self.max_turns = max_turns + self.error_formatter = error_formatter + self.stop_errors: list[type[Exception]] = stop_errors or [] + self.tool_defs = [convert_func_to_tool_def(tool) for tool in self.tools] + self.tool_map = { + getattr(tool, "__name__", tool.__class__.__name__): tool + for tool in self.tools + } + super().__init__(tool_defs=self.tool_defs, max_turns=max_turns, **kwargs) + + self.tool_monitor_rubric = ToolMonitorRubric( + tool_names=list(self.tool_map.keys()) + ) + self.add_rubric(self.tool_monitor_rubric) + + def _should_stop_for_error(self, err: Exception) -> bool: + """Check if error is in stop_errors.""" + return any(isinstance(err, err_type) for err_type in self.stop_errors) + + def add_tool(self, tool: Callable): + self.tools.append(tool) + if self.tool_defs is None: + self.tool_defs = [] + self.tool_defs.append(convert_func_to_tool_def(tool)) + tool_name = getattr(tool, "__name__", tool.__class__.__name__) + self.tool_map[tool_name] = tool + self.tool_monitor_rubric.add_tool_metric(tool_name) + + def remove_tool(self, tool: Callable): + self.tools.remove(tool) + if self.tool_defs is None: + self.tool_defs = [] + self.tool_defs.remove(convert_func_to_tool_def(tool)) + tool_name = getattr(tool, "__name__", tool.__class__.__name__) + self.tool_map.pop(tool_name) + self.tool_monitor_rubric.remove_tool_metric(tool_name) + + @vf.stop + async def no_tools_called(self, state: vf.State) -> bool: + if len(state["trajectory"]) == 0: + return False + last_message = state["trajectory"][-1]["completion"][-1] + is_assistant_message = last_message.role == "assistant" + no_tool_calls = ( + not hasattr(last_message, "tool_calls") or not last_message.tool_calls + ) + return is_assistant_message and no_tool_calls + + async def call_tool( + self, tool_name: str, tool_args: dict, tool_call_id: str, **kwargs + ) -> ToolMessage: + """Call a tool based on JSON command.""" + state = kwargs.get("state") + bt_parent = ( + (state.get("_bt_turn_span") or state.get("_bt_span")) if state else None + ) + bt_span = _bt.tool_call_started( + bt_parent, + tool_name=tool_name, + tool_call_id=tool_call_id, + tool_args=tool_args, + ) + t0 = time.monotonic() + err_msg = "" + result_content = None + try: + tool_func = self.tool_map[tool_name] + result = await maybe_await(tool_func, **tool_args) + content = result if is_valid_tool_content_parts(result) else str(result) + result_content = content + return ToolMessage( + role="tool", + content=content, + tool_call_id=tool_call_id, + ) + except Exception as exc: + err_msg = repr(exc)[:500] + raise + finally: + _bt.tool_call_completed( + bt_span, + duration_s=time.monotonic() - t0, + result=result_content if not err_msg else None, + error=err_msg, + ) + + async def env_response( + self, messages: vf.Messages, state: vf.State, **kwargs + ) -> vf.Messages: + last_msg = cast(vf.AssistantMessage, messages[-1]) + assert last_msg.tool_calls is not None + tool_messages = [] + for tool_call in last_msg.tool_calls: + tool_call_id: str = tool_call.id + try: + tool_name: str = tool_call.name + tool_args: dict = json.loads(tool_call.arguments) + except Exception as e: + if self._should_stop_for_error(e): + raise vf.ToolParseError from e + tool_messages.append( + ToolMessage( + role="tool", + content=self.error_formatter(e), + tool_call_id=tool_call_id, + ) + ) + continue # skip tool call below + + try: + tool_message = await self.call_tool( + tool_name, tool_args, tool_call_id, state=state + ) + tool_messages.append(tool_message) + except Exception as e: + if self._should_stop_for_error(e): + raise vf.ToolCallError from e + tool_messages.append( + ToolMessage( + role="tool", + content=self.error_formatter(e), + tool_call_id=tool_call_id, + ) + ) + + return tool_messages