Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions ms_agent/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from ..config.config import Config, ConfigLifecycleHandler
from .base import Agent

# Current process shared
TOTAL_PROMPT_TOKENS = 0
TOTAL_COMPLETION_TOKENS = 0
TOKEN_LOCK = asyncio.Lock()
Comment on lines +29 to +32

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using global variables for process-wide state can lead to code that is hard to test and maintain. It's better to encapsulate this state within the LLMAgent class itself as class attributes. This clearly associates the state with the agent and avoids polluting the global namespace.

For example, you could define them inside LLMAgent like this:

class LLMAgent(Agent):
    TOTAL_PROMPT_TOKENS = 0
    TOTAL_COMPLETION_TOKENS = 0
    TOKEN_LOCK = asyncio.Lock()
    ...



class LLMAgent(Agent):
"""
Expand Down Expand Up @@ -467,9 +472,26 @@ async def step(
messages = await self.parallel_tool_call(messages)

await self.after_tool_call(messages)

# usage

prompt_tokens = _response_message.prompt_tokens
completion_tokens = _response_message.completion_tokens

global TOTAL_PROMPT_TOKENS, TOTAL_COMPLETION_TOKENS, TOKEN_LOCK
async with TOKEN_LOCK:
TOTAL_PROMPT_TOKENS += prompt_tokens
TOTAL_COMPLETION_TOKENS += completion_tokens

# tokens in the current step
self.log_output(
f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}'
)
# total tokens for the process so far
self.log_output(
f'[usage] prompt_tokens: {_response_message.prompt_tokens}, '
f'completion_tokens: {_response_message.completion_tokens}')
f'[usage_total] total_prompt_tokens: {TOTAL_PROMPT_TOKENS}, '
f'total_completion_tokens: {TOTAL_COMPLETION_TOKENS}')

Comment on lines +478 to +494

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Following my suggestion to move the tracking variables to be class attributes of LLMAgent, this block should be updated to access them via the class. This avoids the global statement and makes it clear that you're modifying the shared state of the LLMAgent class.

        prompt_tokens = _response_message.prompt_tokens
        completion_tokens = _response_message.completion_tokens

        # 使用全局累积
        async with LLMAgent.TOKEN_LOCK:
            LLMAgent.TOTAL_PROMPT_TOKENS += prompt_tokens
            LLMAgent.TOTAL_COMPLETION_TOKENS += completion_tokens

        # tokens in the current step
        self.log_output(
            f'[usage] prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}'
        )
        # total tokens for the process so far
        self.log_output(
            f'[usage_total] total_prompt_tokens: {LLMAgent.TOTAL_PROMPT_TOKENS}, '
            f'total_completion_tokens: {LLMAgent.TOTAL_COMPLETION_TOKENS}')

yield messages

def prepare_llm(self):
Expand Down
4 changes: 4 additions & 0 deletions ms_agent/llm/openai_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def _call_llm(self,
"""
messages = self._format_input_message(messages)

if kwargs.get('stream', False) and self.args.get(
'stream_options', {}).get('include_usage', True):
kwargs.setdefault('stream_options', {})['include_usage'] = True
Comment on lines +135 to +137

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This conditional logic is a bit dense and could be hard to parse. For better readability, you could break it down into a few lines with intermediate variables and a comment explaining the intent.

Suggested change
if kwargs.get('stream', False) and self.args.get(
'stream_options', {}).get('include_usage', True):
kwargs.setdefault('stream_options', {})['include_usage'] = True
is_streaming = kwargs.get('stream', False)
stream_options_config = self.args.get('stream_options', {})
# For streaming responses, we should request usage statistics by default,
# unless it's explicitly disabled in the configuration.
if is_streaming and stream_options_config.get('include_usage', True):
kwargs.setdefault('stream_options', {})['include_usage'] = True


return self.client.chat.completions.create(
model=self.model, messages=messages, tools=tools, **kwargs)

Expand Down
Loading