diff --git a/README.md b/README.md index bee1ada..07761bd 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ AI_API_TOKEN= # MCP configs GH_TOKEN= CODEQL_DBS_BASE_PATH="/app/my_data/codeql_databases" +AI_API_ENDPOINT="https://models.github.ai/inference" ``` ## Deploying from Source diff --git a/src/seclab_taskflow_agent/agent.py b/src/seclab_taskflow_agent/agent.py index 6c26b0b..06097c1 100644 --- a/src/seclab_taskflow_agent/agent.py +++ b/src/seclab_taskflow_agent/agent.py @@ -27,7 +27,10 @@ case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: default_model = 'openai/gpt-4o' case _: - raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}") + raise ValueError( + f"Unsupported Model Endpoint: {api_endpoint}\n" + f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" + ) DEFAULT_MODEL = os.getenv('COPILOT_DEFAULT_MODEL', default=default_model) diff --git a/src/seclab_taskflow_agent/capi.py b/src/seclab_taskflow_agent/capi.py index 54744d4..478ea56 100644 --- a/src/seclab_taskflow_agent/capi.py +++ b/src/seclab_taskflow_agent/capi.py @@ -11,8 +11,20 @@ # Enumeration of currently supported API endpoints. class AI_API_ENDPOINT_ENUM(StrEnum): - AI_API_MODELS_GITHUB = 'models.github.ai' - AI_API_GITHUBCOPILOT = 'api.githubcopilot.com' + AI_API_MODELS_GITHUB = 'models.github.ai' + AI_API_GITHUBCOPILOT = 'api.githubcopilot.com' + + def to_url(self): + """ + Convert the endpoint to its full URL. + """ + match self: + case AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT: + return f"https://{self}" + case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: + return f"https://{self}/inference" + case _: + raise ValueError(f"Unsupported endpoint: {self}") COPILOT_INTEGRATION_ID = 'vscode-chat' @@ -21,7 +33,7 @@ class AI_API_ENDPOINT_ENUM(StrEnum): # since different APIs use their own id schema, use -l with your desired # endpoint to retrieve the correct id names to use for your taskflow def get_AI_endpoint(): - return os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference') + return os.getenv('AI_API_ENDPOINT', default='https://models.github.ai/inference') def get_AI_token(): """ @@ -50,7 +62,8 @@ def list_capi_models(token: str) -> dict[str, dict]: case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: models_catalog = 'catalog/models' case _: - raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}") + raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}\n" + f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}") r = httpx.get(httpx.URL(api_endpoint).join(models_catalog), headers={ 'Accept': 'application/json', @@ -64,8 +77,6 @@ def list_capi_models(token: str) -> dict[str, dict]: models_list = r.json().get('data', []) case AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB: models_list = r.json() - case _: - raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}") for model in models_list: models[model.get('id')] = dict(model) except httpx.RequestError as e: @@ -88,7 +99,10 @@ def supports_tool_calls(model: str, models: dict) -> bool: return 'tool-calling' in models.get(model, {}).\ get('capabilities', []) case _: - raise ValueError(f"Unsupported Model Endpoint: {api_endpoint}") + raise ValueError( + f"Unsupported Model Endpoint: {api_endpoint}\n" + f"Supported endpoints: {[e.to_url() for e in AI_API_ENDPOINT_ENUM]}" + ) def list_tool_call_models(token: str) -> dict[str, dict]: models = list_capi_models(token) diff --git a/tests/test_api_endpoint_config.py b/tests/test_api_endpoint_config.py index 654b44e..96e0216 100644 --- a/tests/test_api_endpoint_config.py +++ b/tests/test_api_endpoint_config.py @@ -8,7 +8,7 @@ import pytest import os from urllib.parse import urlparse -from seclab_taskflow_agent.capi import get_AI_endpoint, AI_API_ENDPOINT_ENUM +from seclab_taskflow_agent.capi import get_AI_endpoint, AI_API_ENDPOINT_ENUM, list_capi_models class TestAPIEndpoint: """Test API endpoint configuration.""" @@ -43,5 +43,26 @@ def test_api_endpoint_env_override(self): if original_env: os.environ['AI_API_ENDPOINT'] = original_env + def test_to_url_models_github(self): + """Test to_url method for models.github.ai endpoint.""" + endpoint = AI_API_ENDPOINT_ENUM.AI_API_MODELS_GITHUB + assert endpoint.to_url() == 'https://models.github.ai/inference' + + def test_to_url_githubcopilot(self): + """Test to_url method for GitHub Copilot endpoint.""" + endpoint = AI_API_ENDPOINT_ENUM.AI_API_GITHUBCOPILOT + assert endpoint.to_url() == 'https://api.githubcopilot.com' + + def test_unsupported_endpoint(self, monkeypatch): + """Test that unsupported API endpoint raises ValueError.""" + api_endpoint = 'https://unsupported.example.com' + monkeypatch.setenv('AI_API_ENDPOINT', api_endpoint) + with pytest.raises(ValueError) as excinfo: + list_capi_models("abc") + msg = str(excinfo.value) + assert 'Unsupported Model Endpoint' in msg + assert 'https://models.github.ai/inference' in msg + assert 'https://api.githubcopilot.com' in msg + if __name__ == '__main__': pytest.main([__file__, '-v'])