diff --git a/ms_agent/agent/llm_agent.py b/ms_agent/agent/llm_agent.py index b47e3a65a..84dabbbf4 100644 --- a/ms_agent/agent/llm_agent.py +++ b/ms_agent/agent/llm_agent.py @@ -1,4 +1,5 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio import importlib import inspect import os.path @@ -26,6 +27,11 @@ from ..config.config import Config, ConfigLifecycleHandler from .base import Agent +# Current process shared +TOTAL_PROMPT_TOKENS = 0 +TOTAL_COMPLETION_TOKENS = 0 +TOKEN_LOCK = asyncio.Lock() + class LLMAgent(Agent): """ @@ -467,9 +473,27 @@ async def step( messages = await self.parallel_tool_call(messages) await self.after_tool_call(messages) + + # usage + + prompt_tokens = _response_message.prompt_tokens + completion_tokens = _response_message.completion_tokens + + # 使用全局累积 + global TOTAL_PROMPT_TOKENS, TOTAL_COMPLETION_TOKENS, TOKEN_LOCK + async with TOKEN_LOCK: + TOTAL_PROMPT_TOKENS += prompt_tokens + TOTAL_COMPLETION_TOKENS += completion_tokens + + # tokens in the current step self.log_output( - f'[usage] prompt_tokens: {_response_message.prompt_tokens}, ' - f'completion_tokens: {_response_message.completion_tokens}') + f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}' + ) + # total tokens for the process so far + self.log_output( + f'[usage_total] total_prompt_tokens: {TOTAL_PROMPT_TOKENS}, ' + f'total_completion_tokens: {TOTAL_COMPLETION_TOKENS}') + yield messages def prepare_llm(self): @@ -545,13 +569,13 @@ def _get_run_memory_info(self, memory_config: DictConfig): memory_type = memory_type or None return user_id, agent_id, run_id, memory_type - async def add_memory(self, messages: List[Message], **kwargs): + async def add_memory(self, messages: List[Message], add_type, **kwargs): if hasattr(self.config, 'memory') and self.config.memory: tools_num = len( self.memory_tools ) if self.memory_tools else 0 # Check index bounds before access to avoid IndexError for idx, memory_config in enumerate(self.config.memory): - if self.runtime.should_stop: + if add_type == 'add_after_task': user_id, agent_id, run_id, memory_type = self._get_run_memory_info( memory_config) else: @@ -631,7 +655,8 @@ async def run_loop(self, messages: Union[List[Message], str], yield messages self.runtime.round += 1 # save memory and history - await self.add_memory(messages, **kwargs) + await self.add_memory( + messages, add_type='add_after_step', **kwargs) self.save_history(messages) # +1 means the next round the assistant may give a conclusion @@ -647,11 +672,17 @@ async def run_loop(self, messages: Union[List[Message], str], yield messages # save memory - await self.add_memory(messages, **kwargs) - await self.on_task_end(messages) await self.cleanup_tools() yield messages + + def _add_memory(): + asyncio.run( + self.add_memory( + messages, add_type='add_after_task', **kwargs)) + + loop = asyncio.get_running_loop() + loop.run_in_executor(None, _add_memory) except Exception as e: import traceback logger.warning(traceback.format_exc()) diff --git a/ms_agent/llm/openai_llm.py b/ms_agent/llm/openai_llm.py index 3d012619d..ff5bb6744 100644 --- a/ms_agent/llm/openai_llm.py +++ b/ms_agent/llm/openai_llm.py @@ -132,6 +132,10 @@ def _call_llm(self, """ messages = self._format_input_message(messages) + if kwargs.get('stream', False) and self.args.get( + 'stream_options', {}).get('include_usage', True): + kwargs.setdefault('stream_options', {})['include_usage'] = True + return self.client.chat.completions.create( model=self.model, messages=messages, tools=tools, **kwargs) diff --git a/ms_agent/memory/default_memory.py b/ms_agent/memory/default_memory.py index 47276ff6c..8fc58f007 100644 --- a/ms_agent/memory/default_memory.py +++ b/ms_agent/memory/default_memory.py @@ -314,7 +314,6 @@ def search(self, agent_id = meta_info.get('agent_id', None) run_id = meta_info.get('run_id', None) limit = meta_info.get('limit', self.search_limit) - relevant_memories = self.memory.search( query, user_id=user_id or self.user_id, @@ -391,21 +390,36 @@ def _analyze_messages( """ new_blocks = self._split_into_blocks(messages) self.cache_messages = dict(sorted(self.cache_messages.items())) - cache_messages = [(key, value) for key, value in self.cache_messages.items()] + first_unmatched_idx = -1 + for idx in range(len(new_blocks)): block_hash = self._hash_block(new_blocks[idx]) - if idx < len(cache_messages) - 1 and str(block_hash) == str( + + # Must allow comparison up to the last cache entry + if idx < len(cache_messages) and str(block_hash) == str( cache_messages[idx][1][1]): continue + + # mismatch first_unmatched_idx = idx break + + # If all new_blocks match but the cache has extra entries → delete the extra cache entries + if first_unmatched_idx == -1: + should_add_messages = [] + should_delete = [ + item[0] for item in cache_messages[len(new_blocks):] + ] + return should_add_messages, should_delete + + # On mismatch: add all new blocks and delete all cache entries starting from the mismatch index + should_add_messages = new_blocks[first_unmatched_idx:] should_delete = [ item[0] for item in cache_messages[first_unmatched_idx:] - ] if first_unmatched_idx != -1 else [] - should_add_messages = new_blocks[first_unmatched_idx:] + ] return should_add_messages, should_delete @@ -440,6 +454,7 @@ async def add( logger.info(item[1]) if should_add_messages: for messages in should_add_messages: + messages = self.parse_messages(messages) await self.add_single( messages, user_id=user_id, @@ -448,6 +463,23 @@ async def add( memory_type=memory_type) self.save_cache() + def parse_messages(self, messages: List[Message]) -> List[Message]: + new_messages = [] + for msg in messages: + role = getattr(msg, 'role', None) + content = getattr(msg, 'content', None) + + if 'system' not in self.ignore_roles and role == 'system': + new_messages.append(msg) + if role == 'user': + new_messages.append(msg) + if 'assistant' not in self.ignore_roles and role == 'assistant' and content is not None: + new_messages.append(msg) + if 'tool' not in self.ignore_roles and role == 'tool': + new_messages.append(msg) + + return new_messages + def delete(self, user_id: Optional[str] = None, agent_id: Optional[str] = None, @@ -484,7 +516,6 @@ def get_all(self, user_id=user_id or self.user_id, agent_id=agent_id, run_id=run_id) - print(res['results']) return res['results'] except Exception: return [] @@ -538,6 +569,8 @@ async def run( query = self._get_latest_user_message(messages) if not query: return messages + if meta_infos is None: + meta_infos = [{'user_id': self.user_id}] async with self._lock: try: memories = self.search(query, meta_infos) @@ -558,34 +591,14 @@ def _init_memory_obj(self): ) raise - parse_messages_origin = mem0.memory.main.parse_messages capture_event_origin = mem0.memory.main.capture_event - @wraps(parse_messages_origin) - def patched_parse_messages(messages, ignore_roles): - response = '' - for msg in messages: - if 'system' not in ignore_roles and msg['role'] == 'system': - response += f"system: {msg['content']}\n" - if msg['role'] == 'user': - response += f"user: {msg['content']}\n" - if msg['role'] == 'assistant' and msg['content'] is not None: - response += f"assistant: {msg['content']}\n" - if 'tool' not in ignore_roles and msg['role'] == 'tool': - response += f"tool: {msg['content']}\n" - return response - @wraps(capture_event_origin) def patched_capture_event(event_name, memory_instance, additional_data=None): pass - mem0.memory.main.parse_messages = partial( - patched_parse_messages, - ignore_roles=self.ignore_roles, - ) - mem0.memory.main.capture_event = partial(patched_capture_event, ) # emb config @@ -705,9 +718,13 @@ def sanitize_database_name(ori_name: str, mem0_config['llm'] = llm logger.info(f'Memory config: {mem0_config}') # Prompt content is too long, default logging reduces readability - mem0_config['custom_fact_extraction_prompt'] = getattr( - self.config, 'fact_retrieval_prompt', get_fact_retrieval_prompt() - ) + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.' + custom_fact_extraction_prompt = getattr( + self.config, 'fact_retrieval_prompt', + getattr(self.config, 'custom_fact_extraction_prompt', None)) + if custom_fact_extraction_prompt is not None: + mem0_config['custom_fact_extraction_prompt'] = ( + custom_fact_extraction_prompt + + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.') try: memory = mem0.Memory.from_config(mem0_config) memory._telemetry_vector_store = None diff --git a/ms_agent/tools/image_generator/ds_image_gen.py b/ms_agent/tools/image_generator/ds_image_gen.py index 71aeab432..b7286b218 100644 --- a/ms_agent/tools/image_generator/ds_image_gen.py +++ b/ms_agent/tools/image_generator/ds_image_gen.py @@ -2,7 +2,6 @@ import uuid from io import BytesIO -import aiohttp from PIL import Image @@ -19,6 +18,7 @@ async def generate_image(self, size=None, ratio=None, **kwargs): + import aiohttp image_generator = self.config.tools.image_generator base_url = ( getattr(image_generator, 'base_url', None) diff --git a/ms_agent/tools/image_generator/ms_image_gen.py b/ms_agent/tools/image_generator/ms_image_gen.py index 648a10b3b..b121458e6 100644 --- a/ms_agent/tools/image_generator/ms_image_gen.py +++ b/ms_agent/tools/image_generator/ms_image_gen.py @@ -3,7 +3,6 @@ import uuid from io import BytesIO -import aiohttp import json from PIL import Image @@ -20,6 +19,7 @@ async def generate_image(self, negative_prompt=None, size=None, **kwargs): + import aiohttp image_generator = self.config.tools.image_generator base_url = (getattr(image_generator, 'base_url', None) or 'https://api-inference.modelscope.cn').strip('/') diff --git a/ms_agent/tools/video_generator/ds_video_gen.py b/ms_agent/tools/video_generator/ds_video_gen.py index fbdc38870..d757038d3 100644 --- a/ms_agent/tools/video_generator/ds_video_gen.py +++ b/ms_agent/tools/video_generator/ds_video_gen.py @@ -2,7 +2,6 @@ import os import uuid -import aiohttp from ms_agent.utils import get_logger logger = get_logger() @@ -34,6 +33,7 @@ async def generate_video(self, @staticmethod async def download_video(video_url, output_file): + import aiohttp max_retries = 3 retry_count = 0 @@ -57,6 +57,7 @@ async def download_video(video_url, output_file): @staticmethod async def _generate_video(base_url, api_key, model, prompt, size, seconds): + import aiohttp base_url = base_url.strip('/') create_endpoint = '/api/v1/services/aigc/model-evaluation/async-inference/'