From 75d30e608bf1dcb2931a34a2328786db93de4554 Mon Sep 17 00:00:00 2001 From: DavdaJames Date: Sun, 21 Sep 2025 23:07:32 +0530 Subject: [PATCH 1/2] added feature for using alias approximate alongside count_tokens_approximately --- libs/core/langchain_core/messages/utils.py | 85 ++++++++++++++++--- .../tests/unit_tests/messages/test_utils.py | 75 ++++++++++++++++ libs/core/uv.lock | 6 +- uv.lock | 29 ++++--- 4 files changed, 170 insertions(+), 25 deletions(-) diff --git a/libs/core/langchain_core/messages/utils.py b/libs/core/langchain_core/messages/utils.py index 6f681d7a65ea1..181df5f0bf69a 100644 --- a/libs/core/langchain_core/messages/utils.py +++ b/libs/core/langchain_core/messages/utils.py @@ -688,6 +688,7 @@ def trim_messages( *, max_tokens: int, token_counter: Union[ + Literal["approximate"], Callable[[list[BaseMessage]], int], Callable[[BaseMessage], int], BaseLanguageModel, @@ -738,11 +739,16 @@ def trim_messages( BaseMessage. If a BaseLanguageModel is passed in then BaseLanguageModel.get_num_tokens_from_messages() will be used. Set to `len` to count the number of **messages** in the chat history. + You can also use string shortcuts for convenience: + + - ``"approximate"``: Uses `count_tokens_approximately` for fast, approximate + token counts. .. note:: - Use `count_tokens_approximately` to get fast, approximate token counts. - This is recommended for using `trim_messages` on the hot path, where - exact token counting is not necessary. + Use `count_tokens_approximately` (or the shortcut ``"approximate"``) to get + fast, approximate token counts. This is recommended for using + `trim_messages` on the hot path, where exact token counting is not + necessary. strategy: Strategy for trimming. @@ -849,6 +855,35 @@ def trim_messages( HumanMessage(content="what do you call a speechless parrot"), ] + Trim chat history using approximate token counting with the "approximate" shortcut: + + .. code-block:: python + + trim_messages( + messages, + max_tokens=45, + strategy="last", + # Using the "approximate" shortcut for fast approximate token counting + token_counter="approximate", + start_on="human", + include_system=True, + ) + + This is equivalent to using `count_tokens_approximately` directly: + + .. code-block:: python + + from langchain_core.messages.utils import count_tokens_approximately + + trim_messages( + messages, + max_tokens=45, + strategy="last", + token_counter=count_tokens_approximately, + start_on="human", + include_system=True, + ) + Trim chat history based on the message count, keeping the SystemMessage if present, and ensuring that the chat history starts with a HumanMessage ( or a SystemMessage followed by a HumanMessage). @@ -977,24 +1012,43 @@ def dummy_token_counter(messages: list[BaseMessage]) -> int: raise ValueError(msg) messages = convert_to_messages(messages) - if hasattr(token_counter, "get_num_tokens_from_messages"): - list_token_counter = token_counter.get_num_tokens_from_messages - elif callable(token_counter): + + # Handle string shortcuts for token counter + if isinstance(token_counter, str): + if token_counter in _TOKEN_COUNTER_SHORTCUTS: + actual_token_counter = _TOKEN_COUNTER_SHORTCUTS[token_counter] + else: + available_shortcuts = ", ".join( + f"'{key}'" for key in _TOKEN_COUNTER_SHORTCUTS + ) + msg = ( + f"Invalid token_counter shortcut '{token_counter}'. " + f"Available shortcuts: {available_shortcuts}." + ) + raise ValueError(msg) + else: + actual_token_counter = token_counter + + if hasattr(actual_token_counter, "get_num_tokens_from_messages"): + list_token_counter = actual_token_counter.get_num_tokens_from_messages # type: ignore[assignment] + elif callable(actual_token_counter): if ( - next(iter(inspect.signature(token_counter).parameters.values())).annotation + next( + iter(inspect.signature(actual_token_counter).parameters.values()) + ).annotation is BaseMessage ): def list_token_counter(messages: Sequence[BaseMessage]) -> int: - return sum(token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] + return sum(actual_token_counter(msg) for msg in messages) # type: ignore[arg-type, misc] else: - list_token_counter = token_counter + list_token_counter = actual_token_counter # type: ignore[assignment] else: msg = ( f"'token_counter' expected to be a model that implements " f"'get_num_tokens_from_messages()' or a function. Received object of type " - f"{type(token_counter)}." + f"{type(actual_token_counter)}." ) raise ValueError(msg) @@ -1754,3 +1808,14 @@ def count_tokens_approximately( # round up once more time in case extra_tokens_per_message is a float return math.ceil(token_count) + + +# Mapping from string shortcuts to token counter functions +def _approximate_token_counter(messages: Sequence[BaseMessage]) -> int: + """Wrapper for count_tokens_approximately that matches expected signature.""" + return count_tokens_approximately(messages) + + +_TOKEN_COUNTER_SHORTCUTS = { + "approximate": _approximate_token_counter, +} diff --git a/libs/core/tests/unit_tests/messages/test_utils.py b/libs/core/tests/unit_tests/messages/test_utils.py index bedd518589ea0..ac6c3388fa37a 100644 --- a/libs/core/tests/unit_tests/messages/test_utils.py +++ b/libs/core/tests/unit_tests/messages/test_utils.py @@ -660,6 +660,81 @@ def test_trim_messages_start_on_with_allow_partial() -> None: assert messages == messages_copy +def test_trim_messages_token_counter_shortcut_approximate() -> None: + """Test that 'approximate' shortcut works for token_counter.""" + messages = [ + SystemMessage("This is a test message"), + HumanMessage("Another test message", id="first"), + AIMessage("AI response here", id="second"), + ] + messages_copy = [m.model_copy(deep=True) for m in messages] + + # Test using the "approximate" shortcut + result_shortcut = trim_messages( + messages, + max_tokens=50, + token_counter="approximate", + strategy="last", + ) + + # Test using count_tokens_approximately directly + result_direct = trim_messages( + messages, + max_tokens=50, + token_counter=count_tokens_approximately, + strategy="last", + ) + + # Both should produce the same result + assert result_shortcut == result_direct + assert messages == messages_copy + + +def test_trim_messages_token_counter_shortcut_invalid() -> None: + """Test that invalid token_counter shortcut raises ValueError.""" + messages = [ + SystemMessage("This is a test message"), + HumanMessage("Another test message"), + ] + + # Test with invalid shortcut + with pytest.raises(ValueError, match="Invalid token_counter shortcut 'invalid'"): + trim_messages( + messages, + max_tokens=50, + token_counter="invalid", + strategy="last", + ) + + +def test_trim_messages_token_counter_shortcut_with_options() -> None: + """Test that 'approximate' shortcut works with different trim options.""" + messages = [ + SystemMessage("System instructions"), + HumanMessage("First human message", id="first"), + AIMessage("First AI response", id="ai1"), + HumanMessage("Second human message", id="second"), + AIMessage("Second AI response", id="ai2"), + ] + messages_copy = [m.model_copy(deep=True) for m in messages] + + # Test with various options + result = trim_messages( + messages, + max_tokens=100, + token_counter="approximate", + strategy="last", + include_system=True, + start_on="human", + ) + + # Should include system message and start on human + assert len(result) >= 2 + assert isinstance(result[0], SystemMessage) + assert any(isinstance(msg, HumanMessage) for msg in result[1:]) + assert messages == messages_copy + + class FakeTokenCountingModel(FakeChatModel): @override def get_num_tokens_from_messages( diff --git a/libs/core/uv.lock b/libs/core/uv.lock index 5e5675550adfb..7107924ce63e9 100644 --- a/libs/core/uv.lock +++ b/libs/core/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.9.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.14' and platform_python_implementation == 'PyPy'", @@ -567,7 +567,7 @@ name = "importlib-metadata" version = "8.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "zipp", marker = "python_full_version < '3.13'" }, + { name = "zipp", marker = "python_full_version < '3.10'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" } wheels = [ @@ -1126,6 +1126,8 @@ test = [ test-integration = [ { name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" }, { name = "nltk", specifier = ">=3.9.1,<4.0.0" }, + { name = "scipy", marker = "python_full_version == '3.12.*'", specifier = ">=1.7.0,<2.0.0" }, + { name = "scipy", marker = "python_full_version >= '3.13'", specifier = ">=1.14.1,<2.0.0" }, { name = "sentence-transformers", specifier = ">=3.0.1,<4.0.0" }, { name = "spacy", specifier = ">=3.8.7,<4.0.0" }, { name = "thinc", specifier = ">=8.3.6,<9.0.0" }, diff --git a/uv.lock b/uv.lock index 82a5dbae26cf7..e77b3a2abe31e 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.9.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -1476,8 +1476,8 @@ name = "html5lib" version = "1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six" }, - { name = "webencodings" }, + { name = "six", marker = "python_full_version < '3.13'" }, + { name = "webencodings", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ac/b6/b55c3f49042f1df3dcd422b7f224f939892ee94f22abcf503a9b7339eaf2/html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f", size = 272215, upload-time = "2020-06-22T23:32:38.834Z" } wheels = [ @@ -2364,6 +2364,7 @@ test = [ { name = "onnxruntime", marker = "python_full_version >= '3.10'" }, { name = "pytest", specifier = ">=7.3.0,<8.0.0" }, { name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" }, + { name = "pytest-benchmark" }, { name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" }, { name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" }, { name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" }, @@ -2484,7 +2485,7 @@ requires-dist = [ { name = "aiohttp", specifier = ">=3.9.1,<4.0.0" }, { name = "fireworks-ai", specifier = ">=0.13.0,<1.0.0" }, { name = "langchain-core", editable = "libs/core" }, - { name = "openai", specifier = ">=1.10.0,<2.0.0" }, + { name = "openai", specifier = ">=1.0.0,<1.108.0" }, { name = "requests", specifier = ">=2.0.0,<3.0.0" }, ] @@ -2737,7 +2738,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "langchain-core", editable = "libs/core" }, - { name = "openai", specifier = ">=1.104.2,<2.0.0" }, + { name = "openai", specifier = ">=1.104.2,<1.108.0" }, { name = "tiktoken", specifier = ">=0.7.0,<1.0.0" }, ] @@ -2822,6 +2823,8 @@ test = [ test-integration = [ { name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" }, { name = "nltk", specifier = ">=3.9.1,<4.0.0" }, + { name = "scipy", marker = "python_full_version == '3.12.*'", specifier = ">=1.7.0,<2.0.0" }, + { name = "scipy", marker = "python_full_version >= '3.13'", specifier = ">=1.14.1,<2.0.0" }, { name = "sentence-transformers", specifier = ">=3.0.1,<4.0.0" }, { name = "spacy", specifier = ">=3.8.7,<4.0.0" }, { name = "thinc", specifier = ">=8.3.6,<9.0.0" }, @@ -2862,7 +2865,7 @@ name = "langdetect" version = "1.0.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six" }, + { name = "six", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0e/72/a3add0e4eec4eb9e2569554f7c70f4a3c27712f40e3284d483e88094cc0e/langdetect-1.0.9.tar.gz", hash = "sha256:cbc1fef89f8d062739774bd51eda3da3274006b3661d199c2655f6b3f6d605a0", size = 981474, upload-time = "2021-05-07T07:54:13.562Z" } @@ -3630,10 +3633,10 @@ name = "nltk" version = "3.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click" }, - { name = "joblib" }, - { name = "regex" }, - { name = "tqdm" }, + { name = "click", marker = "python_full_version < '3.13'" }, + { name = "joblib", marker = "python_full_version < '3.13'" }, + { name = "regex", marker = "python_full_version < '3.13'" }, + { name = "tqdm", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691, upload-time = "2024-08-18T19:48:37.769Z" } wheels = [ @@ -4944,9 +4947,9 @@ name = "python-oxmsg" version = "0.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click" }, - { name = "olefile" }, - { name = "typing-extensions" }, + { name = "click", marker = "python_full_version < '3.13'" }, + { name = "olefile", marker = "python_full_version < '3.13'" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a2/4e/869f34faedbc968796d2c7e9837dede079c9cb9750917356b1f1eda926e9/python_oxmsg-0.0.2.tar.gz", hash = "sha256:a6aff4deb1b5975d44d49dab1d9384089ffeec819e19c6940bc7ffbc84775fad", size = 34713, upload-time = "2025-02-03T17:13:47.415Z" } wheels = [ From 69fb6a0f072f1bba5d79e8e50bd18306aff1be4f Mon Sep 17 00:00:00 2001 From: DavdaJames Date: Sat, 27 Sep 2025 01:07:02 +0530 Subject: [PATCH 2/2] reverted the lock files --- libs/core/uv.lock | 6 ++---- uv.lock | 29 +++++++++++++---------------- 2 files changed, 15 insertions(+), 20 deletions(-) diff --git a/libs/core/uv.lock b/libs/core/uv.lock index 7107924ce63e9..5e5675550adfb 100644 --- a/libs/core/uv.lock +++ b/libs/core/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.14' and platform_python_implementation == 'PyPy'", @@ -567,7 +567,7 @@ name = "importlib-metadata" version = "8.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "zipp", marker = "python_full_version < '3.10'" }, + { name = "zipp", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/76/66/650a33bd90f786193e4de4b3ad86ea60b53c89b669a5c7be931fac31cdb0/importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000", size = 56641, upload-time = "2025-04-27T15:29:01.736Z" } wheels = [ @@ -1126,8 +1126,6 @@ test = [ test-integration = [ { name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" }, { name = "nltk", specifier = ">=3.9.1,<4.0.0" }, - { name = "scipy", marker = "python_full_version == '3.12.*'", specifier = ">=1.7.0,<2.0.0" }, - { name = "scipy", marker = "python_full_version >= '3.13'", specifier = ">=1.14.1,<2.0.0" }, { name = "sentence-transformers", specifier = ">=3.0.1,<4.0.0" }, { name = "spacy", specifier = ">=3.8.7,<4.0.0" }, { name = "thinc", specifier = ">=8.3.6,<9.0.0" }, diff --git a/uv.lock b/uv.lock index e77b3a2abe31e..82a5dbae26cf7 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.9.0, <4.0.0" resolution-markers = [ "python_full_version >= '3.13' and platform_python_implementation == 'PyPy'", @@ -1476,8 +1476,8 @@ name = "html5lib" version = "1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six", marker = "python_full_version < '3.13'" }, - { name = "webencodings", marker = "python_full_version < '3.13'" }, + { name = "six" }, + { name = "webencodings" }, ] sdist = { url = "https://files.pythonhosted.org/packages/ac/b6/b55c3f49042f1df3dcd422b7f224f939892ee94f22abcf503a9b7339eaf2/html5lib-1.1.tar.gz", hash = "sha256:b2e5b40261e20f354d198eae92afc10d750afb487ed5e50f9c4eaf07c184146f", size = 272215, upload-time = "2020-06-22T23:32:38.834Z" } wheels = [ @@ -2364,7 +2364,6 @@ test = [ { name = "onnxruntime", marker = "python_full_version >= '3.10'" }, { name = "pytest", specifier = ">=7.3.0,<8.0.0" }, { name = "pytest-asyncio", specifier = ">=0.21.1,<1.0.0" }, - { name = "pytest-benchmark" }, { name = "pytest-mock", specifier = ">=3.10.0,<4.0.0" }, { name = "pytest-socket", specifier = ">=0.7.0,<1.0.0" }, { name = "pytest-watcher", specifier = ">=0.3.4,<1.0.0" }, @@ -2485,7 +2484,7 @@ requires-dist = [ { name = "aiohttp", specifier = ">=3.9.1,<4.0.0" }, { name = "fireworks-ai", specifier = ">=0.13.0,<1.0.0" }, { name = "langchain-core", editable = "libs/core" }, - { name = "openai", specifier = ">=1.0.0,<1.108.0" }, + { name = "openai", specifier = ">=1.10.0,<2.0.0" }, { name = "requests", specifier = ">=2.0.0,<3.0.0" }, ] @@ -2738,7 +2737,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "langchain-core", editable = "libs/core" }, - { name = "openai", specifier = ">=1.104.2,<1.108.0" }, + { name = "openai", specifier = ">=1.104.2,<2.0.0" }, { name = "tiktoken", specifier = ">=0.7.0,<1.0.0" }, ] @@ -2823,8 +2822,6 @@ test = [ test-integration = [ { name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" }, { name = "nltk", specifier = ">=3.9.1,<4.0.0" }, - { name = "scipy", marker = "python_full_version == '3.12.*'", specifier = ">=1.7.0,<2.0.0" }, - { name = "scipy", marker = "python_full_version >= '3.13'", specifier = ">=1.14.1,<2.0.0" }, { name = "sentence-transformers", specifier = ">=3.0.1,<4.0.0" }, { name = "spacy", specifier = ">=3.8.7,<4.0.0" }, { name = "thinc", specifier = ">=8.3.6,<9.0.0" }, @@ -2865,7 +2862,7 @@ name = "langdetect" version = "1.0.9" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "six", marker = "python_full_version < '3.13'" }, + { name = "six" }, ] sdist = { url = "https://files.pythonhosted.org/packages/0e/72/a3add0e4eec4eb9e2569554f7c70f4a3c27712f40e3284d483e88094cc0e/langdetect-1.0.9.tar.gz", hash = "sha256:cbc1fef89f8d062739774bd51eda3da3274006b3661d199c2655f6b3f6d605a0", size = 981474, upload-time = "2021-05-07T07:54:13.562Z" } @@ -3633,10 +3630,10 @@ name = "nltk" version = "3.9.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click", marker = "python_full_version < '3.13'" }, - { name = "joblib", marker = "python_full_version < '3.13'" }, - { name = "regex", marker = "python_full_version < '3.13'" }, - { name = "tqdm", marker = "python_full_version < '3.13'" }, + { name = "click" }, + { name = "joblib" }, + { name = "regex" }, + { name = "tqdm" }, ] sdist = { url = "https://files.pythonhosted.org/packages/3c/87/db8be88ad32c2d042420b6fd9ffd4a149f9a0d7f0e86b3f543be2eeeedd2/nltk-3.9.1.tar.gz", hash = "sha256:87d127bd3de4bd89a4f81265e5fa59cb1b199b27440175370f7417d2bc7ae868", size = 2904691, upload-time = "2024-08-18T19:48:37.769Z" } wheels = [ @@ -4947,9 +4944,9 @@ name = "python-oxmsg" version = "0.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click", marker = "python_full_version < '3.13'" }, - { name = "olefile", marker = "python_full_version < '3.13'" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, + { name = "click" }, + { name = "olefile" }, + { name = "typing-extensions" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a2/4e/869f34faedbc968796d2c7e9837dede079c9cb9750917356b1f1eda926e9/python_oxmsg-0.0.2.tar.gz", hash = "sha256:a6aff4deb1b5975d44d49dab1d9384089ffeec819e19c6940bc7ffbc84775fad", size = 34713, upload-time = "2025-02-03T17:13:47.415Z" } wheels = [