-
Notifications
You must be signed in to change notification settings - Fork 428
Feat/agent usage #819
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat/agent usage #819
Changes from all commits
b8f842d
257c557
8eb5973
33281b8
82bc063
a0af6d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,4 +1,5 @@ | ||||||||||||||||||||||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||||||||||||||||||||||||
| import asyncio | ||||||||||||||||||||||||
| import importlib | ||||||||||||||||||||||||
| import inspect | ||||||||||||||||||||||||
| import os.path | ||||||||||||||||||||||||
|
|
@@ -26,6 +27,11 @@ | |||||||||||||||||||||||
| from ..config.config import Config, ConfigLifecycleHandler | ||||||||||||||||||||||||
| from .base import Agent | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # Current process shared | ||||||||||||||||||||||||
| TOTAL_PROMPT_TOKENS = 0 | ||||||||||||||||||||||||
| TOTAL_COMPLETION_TOKENS = 0 | ||||||||||||||||||||||||
| TOKEN_LOCK = asyncio.Lock() | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| class LLMAgent(Agent): | ||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||
|
|
@@ -467,9 +473,27 @@ async def step( | |||||||||||||||||||||||
| messages = await self.parallel_tool_call(messages) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| await self.after_tool_call(messages) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # usage | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| prompt_tokens = _response_message.prompt_tokens | ||||||||||||||||||||||||
| completion_tokens = _response_message.completion_tokens | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # 使用全局累积 | ||||||||||||||||||||||||
| global TOTAL_PROMPT_TOKENS, TOTAL_COMPLETION_TOKENS, TOKEN_LOCK | ||||||||||||||||||||||||
| async with TOKEN_LOCK: | ||||||||||||||||||||||||
| TOTAL_PROMPT_TOKENS += prompt_tokens | ||||||||||||||||||||||||
| TOTAL_COMPLETION_TOKENS += completion_tokens | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # tokens in the current step | ||||||||||||||||||||||||
| self.log_output( | ||||||||||||||||||||||||
| f'[usage] prompt_tokens: {_response_message.prompt_tokens}, ' | ||||||||||||||||||||||||
| f'completion_tokens: {_response_message.completion_tokens}') | ||||||||||||||||||||||||
| f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| # total tokens for the process so far | ||||||||||||||||||||||||
| self.log_output( | ||||||||||||||||||||||||
| f'[usage_total] total_prompt_tokens: {TOTAL_PROMPT_TOKENS}, ' | ||||||||||||||||||||||||
| f'total_completion_tokens: {TOTAL_COMPLETION_TOKENS}') | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| yield messages | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def prepare_llm(self): | ||||||||||||||||||||||||
|
|
@@ -545,13 +569,13 @@ def _get_run_memory_info(self, memory_config: DictConfig): | |||||||||||||||||||||||
| memory_type = memory_type or None | ||||||||||||||||||||||||
| return user_id, agent_id, run_id, memory_type | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| async def add_memory(self, messages: List[Message], **kwargs): | ||||||||||||||||||||||||
| async def add_memory(self, messages: List[Message], add_type, **kwargs): | ||||||||||||||||||||||||
| if hasattr(self.config, 'memory') and self.config.memory: | ||||||||||||||||||||||||
| tools_num = len( | ||||||||||||||||||||||||
| self.memory_tools | ||||||||||||||||||||||||
| ) if self.memory_tools else 0 # Check index bounds before access to avoid IndexError | ||||||||||||||||||||||||
| for idx, memory_config in enumerate(self.config.memory): | ||||||||||||||||||||||||
| if self.runtime.should_stop: | ||||||||||||||||||||||||
| if add_type == 'add_after_task': | ||||||||||||||||||||||||
| user_id, agent_id, run_id, memory_type = self._get_run_memory_info( | ||||||||||||||||||||||||
| memory_config) | ||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||
|
|
@@ -631,7 +655,8 @@ async def run_loop(self, messages: Union[List[Message], str], | |||||||||||||||||||||||
| yield messages | ||||||||||||||||||||||||
| self.runtime.round += 1 | ||||||||||||||||||||||||
| # save memory and history | ||||||||||||||||||||||||
| await self.add_memory(messages, **kwargs) | ||||||||||||||||||||||||
| await self.add_memory( | ||||||||||||||||||||||||
| messages, add_type='add_after_step', **kwargs) | ||||||||||||||||||||||||
| self.save_history(messages) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # +1 means the next round the assistant may give a conclusion | ||||||||||||||||||||||||
|
|
@@ -647,11 +672,17 @@ async def run_loop(self, messages: Union[List[Message], str], | |||||||||||||||||||||||
| yield messages | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # save memory | ||||||||||||||||||||||||
| await self.add_memory(messages, **kwargs) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| await self.on_task_end(messages) | ||||||||||||||||||||||||
| await self.cleanup_tools() | ||||||||||||||||||||||||
| yield messages | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def _add_memory(): | ||||||||||||||||||||||||
| asyncio.run( | ||||||||||||||||||||||||
| self.add_memory( | ||||||||||||||||||||||||
| messages, add_type='add_after_task', **kwargs)) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| loop = asyncio.get_running_loop() | ||||||||||||||||||||||||
| loop.run_in_executor(None, _add_memory) | ||||||||||||||||||||||||
|
Comment on lines
+679
to
+685
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation for running For running a "fire-and-forget" async task from within an async function,
Suggested change
|
||||||||||||||||||||||||
| except Exception as e: | ||||||||||||||||||||||||
| import traceback | ||||||||||||||||||||||||
| logger.warning(traceback.format_exc()) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -314,7 +314,6 @@ def search(self, | |
| agent_id = meta_info.get('agent_id', None) | ||
| run_id = meta_info.get('run_id', None) | ||
| limit = meta_info.get('limit', self.search_limit) | ||
|
|
||
| relevant_memories = self.memory.search( | ||
| query, | ||
| user_id=user_id or self.user_id, | ||
|
|
@@ -391,21 +390,36 @@ def _analyze_messages( | |
| """ | ||
| new_blocks = self._split_into_blocks(messages) | ||
| self.cache_messages = dict(sorted(self.cache_messages.items())) | ||
|
|
||
| cache_messages = [(key, value) | ||
| for key, value in self.cache_messages.items()] | ||
|
|
||
| first_unmatched_idx = -1 | ||
|
|
||
| for idx in range(len(new_blocks)): | ||
| block_hash = self._hash_block(new_blocks[idx]) | ||
| if idx < len(cache_messages) - 1 and str(block_hash) == str( | ||
|
|
||
| # Must allow comparison up to the last cache entry | ||
| if idx < len(cache_messages) and str(block_hash) == str( | ||
| cache_messages[idx][1][1]): | ||
| continue | ||
|
|
||
| # mismatch | ||
| first_unmatched_idx = idx | ||
| break | ||
|
|
||
| # If all new_blocks match but the cache has extra entries → delete the extra cache entries | ||
| if first_unmatched_idx == -1: | ||
| should_add_messages = [] | ||
| should_delete = [ | ||
| item[0] for item in cache_messages[len(new_blocks):] | ||
| ] | ||
| return should_add_messages, should_delete | ||
|
|
||
| # On mismatch: add all new blocks and delete all cache entries starting from the mismatch index | ||
| should_add_messages = new_blocks[first_unmatched_idx:] | ||
| should_delete = [ | ||
| item[0] for item in cache_messages[first_unmatched_idx:] | ||
| ] if first_unmatched_idx != -1 else [] | ||
| should_add_messages = new_blocks[first_unmatched_idx:] | ||
| ] | ||
|
|
||
| return should_add_messages, should_delete | ||
|
|
||
|
|
@@ -440,6 +454,7 @@ async def add( | |
| logger.info(item[1]) | ||
| if should_add_messages: | ||
| for messages in should_add_messages: | ||
| messages = self.parse_messages(messages) | ||
| await self.add_single( | ||
| messages, | ||
| user_id=user_id, | ||
|
|
@@ -448,6 +463,23 @@ async def add( | |
| memory_type=memory_type) | ||
| self.save_cache() | ||
|
|
||
| def parse_messages(self, messages: List[Message]) -> List[Message]: | ||
| new_messages = [] | ||
| for msg in messages: | ||
| role = getattr(msg, 'role', None) | ||
| content = getattr(msg, 'content', None) | ||
|
|
||
| if 'system' not in self.ignore_roles and role == 'system': | ||
| new_messages.append(msg) | ||
| if role == 'user': | ||
| new_messages.append(msg) | ||
| if 'assistant' not in self.ignore_roles and role == 'assistant' and content is not None: | ||
| new_messages.append(msg) | ||
| if 'tool' not in self.ignore_roles and role == 'tool': | ||
| new_messages.append(msg) | ||
|
|
||
| return new_messages | ||
|
Comment on lines
+466
to
+481
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def parse_messages(self, messages: List[Message]) -> List[Message]:
new_messages = []
for msg in messages:
role = getattr(msg, 'role', None)
if role == 'user':
new_messages.append(msg)
continue
if role in self.ignore_roles:
continue
if role == 'assistant' and getattr(msg, 'content', None) is None:
continue
new_messages.append(msg)
return new_messages |
||
|
|
||
| def delete(self, | ||
| user_id: Optional[str] = None, | ||
| agent_id: Optional[str] = None, | ||
|
|
@@ -484,7 +516,6 @@ def get_all(self, | |
| user_id=user_id or self.user_id, | ||
| agent_id=agent_id, | ||
| run_id=run_id) | ||
| print(res['results']) | ||
| return res['results'] | ||
| except Exception: | ||
| return [] | ||
|
|
@@ -538,6 +569,8 @@ async def run( | |
| query = self._get_latest_user_message(messages) | ||
| if not query: | ||
| return messages | ||
| if meta_infos is None: | ||
| meta_infos = [{'user_id': self.user_id}] | ||
| async with self._lock: | ||
| try: | ||
| memories = self.search(query, meta_infos) | ||
|
|
@@ -558,34 +591,14 @@ def _init_memory_obj(self): | |
| ) | ||
| raise | ||
|
|
||
| parse_messages_origin = mem0.memory.main.parse_messages | ||
| capture_event_origin = mem0.memory.main.capture_event | ||
|
|
||
| @wraps(parse_messages_origin) | ||
| def patched_parse_messages(messages, ignore_roles): | ||
| response = '' | ||
| for msg in messages: | ||
| if 'system' not in ignore_roles and msg['role'] == 'system': | ||
| response += f"system: {msg['content']}\n" | ||
| if msg['role'] == 'user': | ||
| response += f"user: {msg['content']}\n" | ||
| if msg['role'] == 'assistant' and msg['content'] is not None: | ||
| response += f"assistant: {msg['content']}\n" | ||
| if 'tool' not in ignore_roles and msg['role'] == 'tool': | ||
| response += f"tool: {msg['content']}\n" | ||
| return response | ||
|
|
||
| @wraps(capture_event_origin) | ||
| def patched_capture_event(event_name, | ||
| memory_instance, | ||
| additional_data=None): | ||
| pass | ||
|
|
||
| mem0.memory.main.parse_messages = partial( | ||
| patched_parse_messages, | ||
| ignore_roles=self.ignore_roles, | ||
| ) | ||
|
|
||
| mem0.memory.main.capture_event = partial(patched_capture_event, ) | ||
|
|
||
| # emb config | ||
|
|
@@ -705,9 +718,13 @@ def sanitize_database_name(ori_name: str, | |
| mem0_config['llm'] = llm | ||
| logger.info(f'Memory config: {mem0_config}') | ||
| # Prompt content is too long, default logging reduces readability | ||
| mem0_config['custom_fact_extraction_prompt'] = getattr( | ||
| self.config, 'fact_retrieval_prompt', get_fact_retrieval_prompt() | ||
| ) + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.' | ||
| custom_fact_extraction_prompt = getattr( | ||
| self.config, 'fact_retrieval_prompt', | ||
| getattr(self.config, 'custom_fact_extraction_prompt', None)) | ||
| if custom_fact_extraction_prompt is not None: | ||
| mem0_config['custom_fact_extraction_prompt'] = ( | ||
| custom_fact_extraction_prompt | ||
| + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.') | ||
| try: | ||
| memory = mem0.Memory.from_config(mem0_config) | ||
| memory._telemetry_vector_store = None | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of module-level global variables
TOTAL_PROMPT_TOKENSandTOTAL_COMPLETION_TOKENSfor tracking token usage introduces tight coupling between all agent instances within the same process. This can lead to incorrect accounting if multiple independent agents are running, and makes the code harder to test and maintain.A better approach would be to encapsulate this state within a dedicated usage tracking class or within the
LLMAgentinstance itself. This would provide better isolation and more flexible usage tracking (e.g., per-agent or per-task).