generated from langchain-ai/integration-repo-template
-
Notifications
You must be signed in to change notification settings - Fork 253
Add: Google Gemini Callback #314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
tykimseoul
wants to merge
6
commits into
langchain-ai:main
Choose a base branch
from
tykimseoul:gemini-callback
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
a0ad0cd
add: gemini callback
tykimseoul 03e69f8
Update libs/community/langchain_community/callbacks/gemini_info.py
tykimseoul 1895b85
Update libs/community/langchain_community/callbacks/gemini_info.py
tykimseoul 7ca0f57
Update libs/community/langchain_community/callbacks/gemini_info.py
tykimseoul 6358809
Update libs/community/langchain_community/callbacks/gemini_info.py
tykimseoul f916483
Refactor constructor in gemini_info.py
tykimseoul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
222 changes: 222 additions & 0 deletions
222
libs/community/langchain_community/callbacks/gemini_info.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import threading | ||
from enum import Enum, auto | ||
from typing import Any, Dict, List | ||
|
||
from langchain_core.callbacks import BaseCallbackHandler | ||
from langchain_core.messages import AIMessage | ||
from langchain_core.outputs import ChatGeneration, LLMResult | ||
|
||
MODEL_COST_PER_1M_TOKENS = { | ||
"gemini-2.5-pro": 1.25, | ||
"gemini-2.5-pro-completion": 10, | ||
"gemini-2.5-flash": 0.3, | ||
"gemini-2.5-flash-completion": 2.5, | ||
"gemini-2.5-flash-lite": 0.1, | ||
"gemini-2.5-flash-lite-completion": 0.4, | ||
"gemini-2.0-flash": 0.1, | ||
"gemini-2.0-flash-completion": 0.4, | ||
"gemini-2.0-flash-lite": 0.075, | ||
"gemini-2.0-flash-lite-completion": 0.3, | ||
"gemini-1.5-pro": 1.25, | ||
"gemini-1.5-pro-completion": 5, | ||
"gemini-1.5-flash": 0.075, | ||
"gemini-1.5-flash-completion": 0.3, | ||
} | ||
MODEL_COST_PER_1K_TOKENS = {k: v / 1000 for k, v in MODEL_COST_PER_1M_TOKENS.items()} | ||
|
||
|
||
class TokenType(Enum): | ||
"""Token type enum.""" | ||
|
||
PROMPT = auto() | ||
PROMPT_CACHED = auto() | ||
COMPLETION = auto() | ||
|
||
|
||
|
||
def standardize_model_name( | ||
model_name: str, | ||
token_type: TokenType = TokenType.PROMPT, | ||
) -> str: | ||
"""Standardize the model name to a format that can be used in the Gemini API. | ||
|
||
Args: | ||
model_name: The name of the model to standardize. | ||
token_type: The type of token, defaults to PROMPT. | ||
""" | ||
model_name = model_name.lower() | ||
if token_type == TokenType.COMPLETION: | ||
return model_name + "-completion" | ||
else: | ||
return model_name | ||
|
||
|
||
|
||
|
||
def get_gemini_token_cost_for_model( | ||
model_name: str, | ||
num_tokens: int, | ||
is_completion: bool = False, | ||
*, | ||
token_type: TokenType = TokenType.PROMPT, | ||
) -> float: | ||
"""Get the cost in USD for a given model and number of tokens. | ||
|
||
Args: | ||
model_name: The name of the Gemini model to calculate cost for. | ||
num_tokens: The number of tokens to calculate cost for. | ||
is_completion: Whether the tokens are completion tokens. | ||
If True, token_type will be set to TokenType.COMPLETION. | ||
token_type: The type of token (prompt or completion). | ||
Defaults to TokenType.PROMPT. | ||
|
||
Returns: | ||
The cost in USD for the specified number of tokens. | ||
|
||
Raises: | ||
ValueError: If the model name is not recognized as a valid Gemini model. | ||
""" | ||
if is_completion: | ||
token_type = TokenType.COMPLETION | ||
model_name = standardize_model_name(model_name, token_type=token_type) | ||
if model_name not in MODEL_COST_PER_1K_TOKENS: | ||
raise ValueError( | ||
f"Unknown model: {model_name}. Please provide a valid Gemini model name. Known models are: " | ||
+ ", ".join(MODEL_COST_PER_1K_TOKENS.keys()) | ||
) | ||
return MODEL_COST_PER_1K_TOKENS[model_name] * (num_tokens / 1000) | ||
|
||
|
||
|
||
class GeminiCallbackHandler(BaseCallbackHandler): | ||
"""Callback Handler that tracks Gemini info.""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self._lock = threading.Lock() | ||
self.total_tokens = 0 | ||
self.prompt_tokens = 0 | ||
self.prompt_tokens_cached = 0 | ||
self.completion_tokens = 0 | ||
self.reasoning_tokens = 0 | ||
self.successful_requests = 0 | ||
self.total_cost = 0.0 | ||
|
||
def __repr__(self) -> str: | ||
return f"""Tokens Used: {self.total_tokens} | ||
\tPrompt Tokens: {self.prompt_tokens} | ||
\tPrompt Cached Tokens: {self.prompt_tokens_cached} | ||
\tCompletion Tokens: {self.completion_tokens} | ||
\tReasoning Tokens: {self.reasoning_tokens} | ||
Successful Requests: {self.successful_requests} | ||
Total Cost (USD): ${self.total_cost}""" | ||
|
||
@property | ||
def always_verbose(self) -> bool: | ||
"""Whether to call verbose callbacks even if verbose is False.""" | ||
return True | ||
|
||
def on_llm_start( | ||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||
) -> None: | ||
"""Print out the prompts.""" | ||
pass | ||
|
||
def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | ||
"""Print out the token.""" | ||
pass | ||
|
||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: | ||
"""Collect token usage.""" | ||
# Check for usage_metadata (langchain-core >= 0.2.2) | ||
try: | ||
generation = response.generations[0][0] | ||
except IndexError: | ||
generation = None | ||
if isinstance(generation, ChatGeneration): | ||
try: | ||
message = generation.message | ||
if isinstance(message, AIMessage): | ||
usage_metadata = message.usage_metadata | ||
response_metadata = message.response_metadata | ||
else: | ||
usage_metadata = None | ||
response_metadata = None | ||
except AttributeError: | ||
usage_metadata = None | ||
response_metadata = None | ||
else: | ||
usage_metadata = None | ||
response_metadata = None | ||
|
||
prompt_tokens_cached = 0 | ||
reasoning_tokens = 0 | ||
|
||
if usage_metadata: | ||
token_usage = {"total_tokens": usage_metadata["total_tokens"]} | ||
completion_tokens = usage_metadata["output_tokens"] | ||
prompt_tokens = usage_metadata["input_tokens"] | ||
if response_model_name := (response_metadata or {}).get("model_name"): | ||
model_name = standardize_model_name(response_model_name) | ||
elif response.llm_output is None: | ||
model_name = "" | ||
else: | ||
model_name = standardize_model_name( | ||
response.llm_output.get("model_name", "") | ||
) | ||
if "cache_read" in usage_metadata.get("input_token_details", {}): | ||
prompt_tokens_cached = usage_metadata["input_token_details"][ | ||
"cache_read" | ||
] | ||
if "reasoning" in usage_metadata.get("output_token_details", {}): | ||
reasoning_tokens = usage_metadata["output_token_details"]["reasoning"] | ||
else: | ||
if response.llm_output is None: | ||
return None | ||
|
||
if "token_usage" not in response.llm_output: | ||
with self._lock: | ||
self.successful_requests += 1 | ||
return None | ||
|
||
# compute tokens and cost for this request | ||
token_usage = response.llm_output["token_usage"] | ||
completion_tokens = token_usage.get("completion_tokens", 0) | ||
prompt_tokens = token_usage.get("prompt_tokens", 0) | ||
model_name = standardize_model_name( | ||
response.llm_output.get("model_name", "") | ||
) | ||
|
||
if model_name in MODEL_COST_PER_1K_TOKENS: | ||
uncached_prompt_tokens = prompt_tokens - prompt_tokens_cached | ||
uncached_prompt_cost = get_gemini_token_cost_for_model( | ||
model_name, uncached_prompt_tokens, token_type=TokenType.PROMPT | ||
) | ||
cached_prompt_cost = get_gemini_token_cost_for_model( | ||
model_name, prompt_tokens_cached, token_type=TokenType.PROMPT_CACHED | ||
) | ||
tykimseoul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
prompt_cost = uncached_prompt_cost + cached_prompt_cost | ||
completion_cost = get_gemini_token_cost_for_model( | ||
model_name, completion_tokens, token_type=TokenType.COMPLETION | ||
) | ||
else: | ||
completion_cost = 0 | ||
prompt_cost = 0 | ||
|
||
# update shared state behind lock | ||
with self._lock: | ||
self.total_cost += prompt_cost + completion_cost | ||
self.total_tokens += token_usage.get("total_tokens", 0) | ||
self.prompt_tokens += prompt_tokens | ||
self.prompt_tokens_cached += prompt_tokens_cached | ||
self.completion_tokens += completion_tokens | ||
self.reasoning_tokens += reasoning_tokens | ||
self.successful_requests += 1 | ||
tykimseoul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __copy__(self) -> "GeminiCallbackHandler": | ||
"""Return a copy of the callback handler.""" | ||
return self | ||
|
||
def __deepcopy__(self, memo: Any) -> "GeminiCallbackHandler": | ||
"""Return a deep copy of the callback handler.""" | ||
return self |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.