Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/seclab_taskflow_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
default_model = 'gpt-4o'
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
default_model = 'openai/gpt-4o'
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
default_model = 'gpt-4o'
case _:
raise ValueError(
f"Unsupported Model Endpoint: {api_endpoint}\n"
Expand Down
14 changes: 14 additions & 0 deletions src/seclab_taskflow_agent/capi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
class AI_API_ENDPOINT_ENUM(StrEnum):
AI_API_MODELS_GITHUB = 'models.github.ai'
AI_API_GITHUBCOPILOT = 'api.githubcopilot.com'
AI_API_OPENAI = 'api.openai.com'

def to_url(self):
"""
Expand All @@ -23,6 +24,8 @@ def to_url(self):
return f"https://{self}"
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
return f"https://{self}/inference"
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
return f"https://{self}/v1"
case _:
raise ValueError(f"Unsupported endpoint: {self}")

Expand Down Expand Up @@ -61,6 +64,8 @@ def list_capi_models(token: str) -> dict[str, dict]:
models_catalog = 'models'
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_catalog = 'catalog/models'
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
models_catalog = 'models'
case _:
raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}\n"
f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}")
Expand All @@ -77,6 +82,8 @@ def list_capi_models(token: str) -> dict[str, dict]:
models_list = r.json().get('data', [])
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
models_list = r.json()
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
models_list = r.json().get('data', [])
for model in models_list:
models[model.get('id')] = dict(model)
except httpx.RequestError as e:
Expand All @@ -98,6 +105,13 @@ def supports_tool_calls(model: str, models: dict) -> bool:
case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB:
return 'tool-calling' in models.get(model, {}).\
get('capabilities', [])
case AI_API_ENDPOINT_ENUM.AI_API_OPENAI:
# OpenAI doesn't expose capabilities in the models list
# Check if model name indicates function calling support
model_lower = model.lower()
return any([
'gpt-' in model_lower,
])
case _:
raise ValueError(
f"Unsupported Model Endpoint: {api_endpoint}\n"
Expand Down
5 changes: 5 additions & 0 deletions tests/test_api_endpoint_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def test_to_url_githubcopilot(self):
endpoint = AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT
assert endpoint.to_url() == 'https://api.githubcopilot.com'

def test_to_url_openai(self):
"""Test to_url method for OpenAI endpoint."""
endpoint = AI_API_ENDPOINT_ENUM.AI_API_OPENAI
assert endpoint.to_url() == 'https://api.openai.com/v1'

def test_unsupported_endpoint(self, monkeypatch):
"""Test that unsupported API endpoint raises ValueError."""
api_endpoint = 'https://unsupported.example.com'
Expand Down