Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 38 additions & 7 deletions ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import asyncio
import importlib
import inspect
import os.path
Expand Down Expand Up @@ -26,6 +27,11 @@
from ..config.config import Config, ConfigLifecycleHandler
from .base import Agent

# Current process shared
TOTAL_PROMPT_TOKENS = 0
TOTAL_COMPLETION_TOKENS = 0
TOKEN_LOCK = asyncio.Lock()
Comment on lines +31 to +33

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The use of module-level global variables TOTAL_PROMPT_TOKENS and TOTAL_COMPLETION_TOKENS for tracking token usage introduces tight coupling between all agent instances within the same process. This can lead to incorrect accounting if multiple independent agents are running, and makes the code harder to test and maintain.

A better approach would be to encapsulate this state within a dedicated usage tracking class or within the LLMAgent instance itself. This would provide better isolation and more flexible usage tracking (e.g., per-agent or per-task).



class LLMAgent(Agent):
"""
Expand Down Expand Up @@ -467,9 +473,27 @@ async def step(
messages = await self.parallel_tool_call(messages)

await self.after_tool_call(messages)

# usage

prompt_tokens = _response_message.prompt_tokens
completion_tokens = _response_message.completion_tokens

# 使用全局累积
global TOTAL_PROMPT_TOKENS, TOTAL_COMPLETION_TOKENS, TOKEN_LOCK
async with TOKEN_LOCK:
TOTAL_PROMPT_TOKENS += prompt_tokens
TOTAL_COMPLETION_TOKENS += completion_tokens

# tokens in the current step
self.log_output(
f'[usage] prompt_tokens: {_response_message.prompt_tokens}, '
f'completion_tokens: {_response_message.completion_tokens}')
f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}'
)
# total tokens for the process so far
self.log_output(
f'[usage_total] total_prompt_tokens: {TOTAL_PROMPT_TOKENS}, '
f'total_completion_tokens: {TOTAL_COMPLETION_TOKENS}')

yield messages

def prepare_llm(self):
Expand Down Expand Up @@ -545,13 +569,13 @@ def _get_run_memory_info(self, memory_config: DictConfig):
memory_type = memory_type or None
return user_id, agent_id, run_id, memory_type

async def add_memory(self, messages: List[Message], **kwargs):
async def add_memory(self, messages: List[Message], add_type, **kwargs):
if hasattr(self.config, 'memory') and self.config.memory:
tools_num = len(
self.memory_tools
) if self.memory_tools else 0 # Check index bounds before access to avoid IndexError
for idx, memory_config in enumerate(self.config.memory):
if self.runtime.should_stop:
if add_type == 'add_after_task':
user_id, agent_id, run_id, memory_type = self._get_run_memory_info(
memory_config)
else:
Expand Down Expand Up @@ -631,7 +655,8 @@ async def run_loop(self, messages: Union[List[Message], str],
yield messages
self.runtime.round += 1
# save memory and history
await self.add_memory(messages, **kwargs)
await self.add_memory(
messages, add_type='add_after_step', **kwargs)
self.save_history(messages)

# +1 means the next round the assistant may give a conclusion
Expand All @@ -647,11 +672,17 @@ async def run_loop(self, messages: Union[List[Message], str],
yield messages

# save memory
await self.add_memory(messages, **kwargs)

await self.on_task_end(messages)
await self.cleanup_tools()
yield messages

def _add_memory():
asyncio.run(
self.add_memory(
messages, add_type='add_after_task', **kwargs))

loop = asyncio.get_running_loop()
loop.run_in_executor(None, _add_memory)
Comment on lines +679 to +685

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The current implementation for running add_memory in the background is overly complex and potentially problematic. Using loop.run_in_executor to call a function that itself uses asyncio.run() creates a new event loop within a worker thread. This is an unconventional pattern that can be inefficient and lead to subtle issues.

For running a "fire-and-forget" async task from within an async function, asyncio.create_task() is the standard and much simpler approach. It schedules the coroutine to run on the current event loop without blocking.

Suggested change
def _add_memory():
asyncio.run(
self.add_memory(
messages, add_type='add_after_task', **kwargs))
loop = asyncio.get_running_loop()
loop.run_in_executor(None, _add_memory)
# Schedule add_memory to run in the background without blocking.
asyncio.create_task(
self.add_memory(messages, add_type='add_after_task', **kwargs))

except Exception as e:
import traceback
logger.warning(traceback.format_exc())
Expand Down
4 changes: 4 additions & 0 deletions ms_agent/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def _call_llm(self,
"""
messages = self._format_input_message(messages)

if kwargs.get('stream', False) and self.args.get(
'stream_options', {}).get('include_usage', True):
kwargs.setdefault('stream_options', {})['include_usage'] = True

return self.client.chat.completions.create(
model=self.model, messages=messages, tools=tools, **kwargs)

Expand Down
75 changes: 46 additions & 29 deletions ms_agent/memory/default_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def search(self,
agent_id = meta_info.get('agent_id', None)
run_id = meta_info.get('run_id', None)
limit = meta_info.get('limit', self.search_limit)

relevant_memories = self.memory.search(
query,
user_id=user_id or self.user_id,
Expand Down Expand Up @@ -391,21 +390,36 @@ def _analyze_messages(
"""
new_blocks = self._split_into_blocks(messages)
self.cache_messages = dict(sorted(self.cache_messages.items()))

cache_messages = [(key, value)
for key, value in self.cache_messages.items()]

first_unmatched_idx = -1

for idx in range(len(new_blocks)):
block_hash = self._hash_block(new_blocks[idx])
if idx < len(cache_messages) - 1 and str(block_hash) == str(

# Must allow comparison up to the last cache entry
if idx < len(cache_messages) and str(block_hash) == str(
cache_messages[idx][1][1]):
continue

# mismatch
first_unmatched_idx = idx
break

# If all new_blocks match but the cache has extra entries → delete the extra cache entries
if first_unmatched_idx == -1:
should_add_messages = []
should_delete = [
item[0] for item in cache_messages[len(new_blocks):]
]
return should_add_messages, should_delete

# On mismatch: add all new blocks and delete all cache entries starting from the mismatch index
should_add_messages = new_blocks[first_unmatched_idx:]
should_delete = [
item[0] for item in cache_messages[first_unmatched_idx:]
] if first_unmatched_idx != -1 else []
should_add_messages = new_blocks[first_unmatched_idx:]
]

return should_add_messages, should_delete

Expand Down Expand Up @@ -440,6 +454,7 @@ async def add(
logger.info(item[1])
if should_add_messages:
for messages in should_add_messages:
messages = self.parse_messages(messages)
await self.add_single(
messages,
user_id=user_id,
Expand All @@ -448,6 +463,23 @@ async def add(
memory_type=memory_type)
self.save_cache()

def parse_messages(self, messages: List[Message]) -> List[Message]:
new_messages = []
for msg in messages:
role = getattr(msg, 'role', None)
content = getattr(msg, 'content', None)

if 'system' not in self.ignore_roles and role == 'system':
new_messages.append(msg)
if role == 'user':
new_messages.append(msg)
if 'assistant' not in self.ignore_roles and role == 'assistant' and content is not None:
new_messages.append(msg)
if 'tool' not in self.ignore_roles and role == 'tool':
new_messages.append(msg)

return new_messages
Comment on lines +466 to +481

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parse_messages method contains a series of if statements that can be simplified to improve readability and reduce repetition. Consolidating the logic for checking roles against self.ignore_roles would make the function's intent clearer and easier to maintain.

    def parse_messages(self, messages: List[Message]) -> List[Message]:
        new_messages = []
        for msg in messages:
            role = getattr(msg, 'role', None)

            if role == 'user':
                new_messages.append(msg)
                continue

            if role in self.ignore_roles:
                continue

            if role == 'assistant' and getattr(msg, 'content', None) is None:
                continue

            new_messages.append(msg)

        return new_messages


def delete(self,
user_id: Optional[str] = None,
agent_id: Optional[str] = None,
Expand Down Expand Up @@ -484,7 +516,6 @@ def get_all(self,
user_id=user_id or self.user_id,
agent_id=agent_id,
run_id=run_id)
print(res['results'])
return res['results']
except Exception:
return []
Expand Down Expand Up @@ -538,6 +569,8 @@ async def run(
query = self._get_latest_user_message(messages)
if not query:
return messages
if meta_infos is None:
meta_infos = [{'user_id': self.user_id}]
async with self._lock:
try:
memories = self.search(query, meta_infos)
Expand All @@ -558,34 +591,14 @@ def _init_memory_obj(self):
)
raise

parse_messages_origin = mem0.memory.main.parse_messages
capture_event_origin = mem0.memory.main.capture_event

@wraps(parse_messages_origin)
def patched_parse_messages(messages, ignore_roles):
response = ''
for msg in messages:
if 'system' not in ignore_roles and msg['role'] == 'system':
response += f"system: {msg['content']}\n"
if msg['role'] == 'user':
response += f"user: {msg['content']}\n"
if msg['role'] == 'assistant' and msg['content'] is not None:
response += f"assistant: {msg['content']}\n"
if 'tool' not in ignore_roles and msg['role'] == 'tool':
response += f"tool: {msg['content']}\n"
return response

@wraps(capture_event_origin)
def patched_capture_event(event_name,
memory_instance,
additional_data=None):
pass

mem0.memory.main.parse_messages = partial(
patched_parse_messages,
ignore_roles=self.ignore_roles,
)

mem0.memory.main.capture_event = partial(patched_capture_event, )

# emb config
Expand Down Expand Up @@ -705,9 +718,13 @@ def sanitize_database_name(ori_name: str,
mem0_config['llm'] = llm
logger.info(f'Memory config: {mem0_config}')
# Prompt content is too long, default logging reduces readability
mem0_config['custom_fact_extraction_prompt'] = getattr(
self.config, 'fact_retrieval_prompt', get_fact_retrieval_prompt()
) + f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.'
custom_fact_extraction_prompt = getattr(
self.config, 'fact_retrieval_prompt',
getattr(self.config, 'custom_fact_extraction_prompt', None))
if custom_fact_extraction_prompt is not None:
mem0_config['custom_fact_extraction_prompt'] = (
custom_fact_extraction_prompt
+ f'Today\'s date is {datetime.now().strftime("%Y-%m-%d")}.')
try:
memory = mem0.Memory.from_config(mem0_config)
memory._telemetry_vector_store = None
Expand Down
2 changes: 1 addition & 1 deletion ms_agent/tools/image_generator/ds_image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import uuid
from io import BytesIO

import aiohttp
from PIL import Image


Expand All @@ -19,6 +18,7 @@ async def generate_image(self,
size=None,
ratio=None,
**kwargs):
import aiohttp
image_generator = self.config.tools.image_generator
base_url = (
getattr(image_generator, 'base_url', None)
Expand Down
2 changes: 1 addition & 1 deletion ms_agent/tools/image_generator/ms_image_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import uuid
from io import BytesIO

import aiohttp
import json
from PIL import Image

Expand All @@ -20,6 +19,7 @@ async def generate_image(self,
negative_prompt=None,
size=None,
**kwargs):
import aiohttp
image_generator = self.config.tools.image_generator
base_url = (getattr(image_generator, 'base_url', None)
or 'https://api-inference.modelscope.cn').strip('/')
Expand Down
3 changes: 2 additions & 1 deletion ms_agent/tools/video_generator/ds_video_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import uuid

import aiohttp
from ms_agent.utils import get_logger

logger = get_logger()
Expand Down Expand Up @@ -34,6 +33,7 @@ async def generate_video(self,

@staticmethod
async def download_video(video_url, output_file):
import aiohttp
max_retries = 3
retry_count = 0

Expand All @@ -57,6 +57,7 @@ async def download_video(video_url, output_file):

@staticmethod
async def _generate_video(base_url, api_key, model, prompt, size, seconds):
import aiohttp
base_url = base_url.strip('/')
create_endpoint = '/api/v1/services/aigc/model-evaluation/async-inference/'

Expand Down
Loading