diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 06097c1..c8fdeaa 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -26,6 +26,8 @@ default_model = 'gpt-4o' case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: default_model = 'openai/gpt-4o' + case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: + default_model = 'gpt-4o' case _: raise ValueError( f"Unsupported Model Endpoint: {api_endpoint}\n" diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 478ea56..a4a308f 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -13,6 +13,7 @@ class AI_API_ENDPOINT_ENUM(StrEnum): AI_API_MODELS_GITHUB = 'models.github.ai' AI_API_GITHUBCOPILOT = 'api.githubcopilot.com' + AI_API_OPENAI = 'api.openai.com' def to_url(self): """ @@ -23,6 +24,8 @@ def to_url(self): return f"https://{self}" case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: return f"https://{self}/inference" + case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: + return f"https://{self}/v1" case _: raise ValueError(f"Unsupported endpoint: {self}") @@ -61,6 +64,8 @@ def list_capi_models(token: str) -> dict[str, dict]: models_catalog = 'models' case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: models_catalog = 'catalog/models' + case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: + models_catalog = 'models' case _: raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}\n" f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}") @@ -77,6 +82,8 @@ def list_capi_models(token: str) -> dict[str, dict]: models_list = r.json().get('data', []) case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: models_list = r.json() + case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: + models_list = r.json().get('data', []) for model in models_list: models[model.get('id')] = dict(model) except httpx.RequestError as e: @@ -98,6 +105,13 @@ def supports_tool_calls(model: str, models: dict) -> bool: case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: return 'tool-calling' in models.get(model, {}).\ get('capabilities', []) + case AI_API_ENDPOINT_ENUM.AI_API_OPENAI: + # OpenAI doesn't expose capabilities in the models list + # Check if model name indicates function calling support + model_lower = model.lower() + return any([ + 'gpt-' in model_lower, + ]) case _: raise ValueError( f"Unsupported Model Endpoint: {api_endpoint}\n" diff --git a/tests/test_api_endpoint_config.py b/tests/test_api_endpoint_config.py index 96e0216..d415dd3 100644 --- a/tests/test_api_endpoint_config.py +++ b/tests/test_api_endpoint_config.py @@ -53,6 +53,11 @@ def test_to_url_githubcopilot(self): endpoint = AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT assert endpoint.to_url() == 'https://api.githubcopilot.com' + def test_to_url_openai(self): + """Test to_url method for OpenAI endpoint.""" + endpoint = AI_API_ENDPOINT_ENUM.AI_API_OPENAI + assert endpoint.to_url() == 'https://api.openai.com/v1' + def test_unsupported_endpoint(self, monkeypatch): """Test that unsupported API endpoint raises ValueError.""" api_endpoint = 'https://unsupported.example.com'