diff --git a/README.md b/README.md index 424f3084..d7a8d444 100644 --- a/README.md +++ b/README.md @@ -292,6 +292,27 @@ vim /etc/kubernetes-mcp-server/conf.d/99-local.toml pkill -HUP kubernetes-mcp-server ``` +### MCP Prompts + +The server supports MCP prompts for workflow templates. Define custom prompts in `config.toml`: + +```toml +[[prompts]] +name = "my-workflow" +title = "my workflow" +description = "Custom workflow" + +[[prompts.arguments]] +name = "resource_name" +required = true + +[[prompts.messages]] +role = "user" +content = "Help me with {{resource_name}}" +``` + +See docs/PROMPTS.md for detailed documentation. + ## 🛠️ Tools and Functionalities The Kubernetes MCP server supports enabling or disabling specific groups of tools and functionalities (tools, resources, prompts, and so on) via the `--toolsets` command-line flag or `toolsets` configuration option. diff --git a/docs/PROMPTS.md b/docs/PROMPTS.md new file mode 100644 index 00000000..fe6f6d65 --- /dev/null +++ b/docs/PROMPTS.md @@ -0,0 +1,62 @@ +# MCP Prompts Support + +The Kubernetes MCP Server supports [MCP Prompts](https://modelcontextprotocol.io/docs/concepts/prompts), which provide pre-defined workflow templates and guidance to AI assistants. + +## What are MCP Prompts? + +MCP Prompts are pre-defined templates that guide AI assistants through specific workflows. They combine: +- **Structured guidance**: Step-by-step instructions for common tasks +- **Parameterization**: Arguments that customize the prompt for specific contexts +- **Conversation templates**: Pre-formatted messages that guide the interaction + +## Creating Custom Prompts + +Define custom prompts in your `config.toml` file - no code changes or recompilation needed! + +### Example + +```toml +[[prompts]] +name = "check-pod-logs" +title = "Check Pod Logs" +description = "Quick way to check pod logs" + +[[prompts.arguments]] +name = "pod_name" +description = "Name of the pod" +required = true + +[[prompts.arguments]] +name = "namespace" +description = "Namespace of the pod" +required = false + +[[prompts.messages]] +role = "user" +content = "Show me the logs for pod {{pod_name}} in {{namespace}}" + +[[prompts.messages]] +role = "assistant" +content = "I'll retrieve and analyze the logs for you." +``` + +## Configuration Reference + +### Prompt Fields +- **name** (required): Unique identifier for the prompt +- **title** (optional): Human-readable display name +- **description** (required): Brief explanation of what the prompt does +- **arguments** (optional): List of parameters the prompt accepts +- **messages** (required): Conversation template with role/content pairs + +### Argument Fields +- **name** (required): Argument identifier +- **description** (optional): Explanation of the argument's purpose +- **required** (optional): Whether the argument must be provided (default: false) + +### Argument Substitution +Use `{{argument_name}}` placeholders in message content. The template engine replaces these with actual values when the prompt is called. + +## Configuration File Location + +Place your prompts in the `config.toml` file used by the MCP server. Specify the config file path using the `--config` flag when starting the server. \ No newline at end of file diff --git a/go.mod b/go.mod index 6f4aa633..8e443b36 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/stretchr/testify v1.11.1 golang.org/x/oauth2 v0.34.0 golang.org/x/sync v0.19.0 + gopkg.in/yaml.v3 v3.0.1 helm.sh/helm/v3 v3.19.3 k8s.io/api v0.34.3 k8s.io/apiextensions-apiserver v0.34.3 @@ -133,7 +134,6 @@ require ( google.golang.org/protobuf v1.36.6 // indirect gopkg.in/evanphx/json-patch.v4 v4.12.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/apiserver v0.34.3 // indirect k8s.io/component-base v0.34.3 // indirect k8s.io/kube-openapi v0.0.0-20250710124328-f3f2b991d03b // indirect diff --git a/pkg/api/prompt_serialization_test.go b/pkg/api/prompt_serialization_test.go new file mode 100644 index 00000000..2ea9cc77 --- /dev/null +++ b/pkg/api/prompt_serialization_test.go @@ -0,0 +1,315 @@ +package api + +import ( + "encoding/json" + "testing" + + "github.com/BurntSushi/toml" + "github.com/stretchr/testify/suite" + "gopkg.in/yaml.v3" +) + +// PromptSerializationSuite tests serialization of prompt data structures +type PromptSerializationSuite struct { + suite.Suite +} + +func (s *PromptSerializationSuite) TestPromptJSONSerialization() { + s.Run("marshals and unmarshals Prompt correctly", func() { + original := Prompt{ + Name: "test-prompt", + Title: "Test Prompt", + Description: "A test prompt", + Arguments: []PromptArgument{ + {Name: "arg1", Description: "First argument", Required: true}, + }, + Templates: []PromptTemplate{ + {Role: "user", Content: "Hello {{arg1}}"}, + }, + } + + data, err := json.Marshal(original) + s.Require().NoError(err, "failed to marshal Prompt to JSON") + + var unmarshaled Prompt + err = json.Unmarshal(data, &unmarshaled) + s.Require().NoError(err, "failed to unmarshal Prompt from JSON") + + s.Equal(original.Name, unmarshaled.Name) + s.Equal(original.Title, unmarshaled.Title) + s.Equal(original.Description, unmarshaled.Description) + s.Require().Len(unmarshaled.Arguments, 1) + s.Equal(original.Arguments[0].Name, unmarshaled.Arguments[0].Name) + s.Require().Len(unmarshaled.Templates, 1) + s.Equal(original.Templates[0].Content, unmarshaled.Templates[0].Content) + }) +} + +func (s *PromptSerializationSuite) TestPromptYAMLSerialization() { + s.Run("marshals and unmarshals Prompt correctly", func() { + original := Prompt{ + Name: "test-prompt", + Title: "Test Prompt", + Description: "A test prompt", + Arguments: []PromptArgument{ + {Name: "arg1", Description: "First argument", Required: true}, + }, + Templates: []PromptTemplate{ + {Role: "user", Content: "Hello {{arg1}}"}, + }, + } + + data, err := yaml.Marshal(original) + s.Require().NoError(err, "failed to marshal Prompt to YAML") + + var unmarshaled Prompt + err = yaml.Unmarshal(data, &unmarshaled) + s.Require().NoError(err, "failed to unmarshal Prompt from YAML") + + s.Equal(original.Name, unmarshaled.Name) + s.Equal(original.Title, unmarshaled.Title) + s.Equal(original.Description, unmarshaled.Description) + }) +} + +func (s *PromptSerializationSuite) TestPromptTOMLSerialization() { + s.Run("unmarshals Prompt from TOML correctly", func() { + tomlData := ` +name = "test-prompt" +title = "Test Prompt" +description = "A test prompt" + +[[arguments]] +name = "arg1" +description = "First argument" +required = true + +[[messages]] +role = "user" +content = "Hello {{arg1}}" +` + + var prompt Prompt + err := toml.Unmarshal([]byte(tomlData), &prompt) + s.Require().NoError(err, "failed to unmarshal Prompt from TOML") + + s.Equal("test-prompt", prompt.Name) + s.Equal("Test Prompt", prompt.Title) + s.Equal("A test prompt", prompt.Description) + s.Require().Len(prompt.Arguments, 1) + s.Equal("arg1", prompt.Arguments[0].Name) + s.Equal("First argument", prompt.Arguments[0].Description) + s.True(prompt.Arguments[0].Required) + s.Require().Len(prompt.Templates, 1) + s.Equal("user", prompt.Templates[0].Role) + s.Equal("Hello {{arg1}}", prompt.Templates[0].Content) + }) + + s.Run("unmarshals multiple prompts from TOML array", func() { + tomlData := ` +[[prompts]] +name = "prompt1" +description = "First prompt" + +[[prompts.messages]] +role = "user" +content = "Message 1" + +[[prompts]] +name = "prompt2" +description = "Second prompt" + +[[prompts.messages]] +role = "assistant" +content = "Message 2" +` + + var data struct { + Prompts []Prompt `toml:"prompts"` + } + err := toml.Unmarshal([]byte(tomlData), &data) + s.Require().NoError(err, "failed to unmarshal prompts array from TOML") + + s.Require().Len(data.Prompts, 2) + s.Equal("prompt1", data.Prompts[0].Name) + s.Equal("prompt2", data.Prompts[1].Name) + }) +} + +func (s *PromptSerializationSuite) TestPromptArgumentSerialization() { + s.Run("serializes required argument", func() { + arg := PromptArgument{ + Name: "test-arg", + Description: "Test argument", + Required: true, + } + + // JSON + jsonData, err := json.Marshal(arg) + s.Require().NoError(err) + var jsonArg PromptArgument + err = json.Unmarshal(jsonData, &jsonArg) + s.Require().NoError(err) + s.Equal(arg.Name, jsonArg.Name) + s.True(jsonArg.Required) + + // YAML + yamlData, err := yaml.Marshal(arg) + s.Require().NoError(err) + var yamlArg PromptArgument + err = yaml.Unmarshal(yamlData, &yamlArg) + s.Require().NoError(err) + s.Equal(arg.Name, yamlArg.Name) + s.True(yamlArg.Required) + }) + + s.Run("serializes optional argument", func() { + arg := PromptArgument{ + Name: "optional-arg", + Description: "Optional argument", + Required: false, + } + + jsonData, err := json.Marshal(arg) + s.Require().NoError(err) + var unmarshaled PromptArgument + err = json.Unmarshal(jsonData, &unmarshaled) + s.Require().NoError(err) + s.False(unmarshaled.Required) + }) +} + +func (s *PromptSerializationSuite) TestPromptTemplateSerialization() { + s.Run("serializes template with placeholder", func() { + template := PromptTemplate{ + Role: "user", + Content: "Hello {{name}}, how are you?", + } + + // JSON + jsonData, err := json.Marshal(template) + s.Require().NoError(err) + var jsonTemplate PromptTemplate + err = json.Unmarshal(jsonData, &jsonTemplate) + s.Require().NoError(err) + s.Equal(template.Role, jsonTemplate.Role) + s.Equal(template.Content, jsonTemplate.Content) + + // TOML + tomlData := ` +role = "user" +content = "Hello {{name}}, how are you?" +` + var tomlTemplate PromptTemplate + err = toml.Unmarshal([]byte(tomlData), &tomlTemplate) + s.Require().NoError(err) + s.Equal(template.Role, tomlTemplate.Role) + s.Equal(template.Content, tomlTemplate.Content) + }) +} + +func (s *PromptSerializationSuite) TestPromptMessageSerialization() { + s.Run("serializes message with content", func() { + msg := PromptMessage{ + Role: "assistant", + Content: PromptContent{ + Type: "text", + Text: "Hello, World!", + }, + } + + // JSON + jsonData, err := json.Marshal(msg) + s.Require().NoError(err) + var jsonMsg PromptMessage + err = json.Unmarshal(jsonData, &jsonMsg) + s.Require().NoError(err) + s.Equal(msg.Role, jsonMsg.Role) + s.Equal(msg.Content.Type, jsonMsg.Content.Type) + s.Equal(msg.Content.Text, jsonMsg.Content.Text) + + // YAML + yamlData, err := yaml.Marshal(msg) + s.Require().NoError(err) + var yamlMsg PromptMessage + err = yaml.Unmarshal(yamlData, &yamlMsg) + s.Require().NoError(err) + s.Equal(msg.Role, yamlMsg.Role) + s.Equal(msg.Content.Text, yamlMsg.Content.Text) + }) +} + +func (s *PromptSerializationSuite) TestPromptContentSerialization() { + s.Run("serializes text content", func() { + content := PromptContent{ + Type: "text", + Text: "Sample text content", + } + + // JSON + jsonData, err := json.Marshal(content) + s.Require().NoError(err) + var jsonContent PromptContent + err = json.Unmarshal(jsonData, &jsonContent) + s.Require().NoError(err) + s.Equal(content.Type, jsonContent.Type) + s.Equal(content.Text, jsonContent.Text) + + // YAML + yamlData, err := yaml.Marshal(content) + s.Require().NoError(err) + var yamlContent PromptContent + err = yaml.Unmarshal(yamlData, &yamlContent) + s.Require().NoError(err) + s.Equal(content.Type, yamlContent.Type) + s.Equal(content.Text, yamlContent.Text) + }) +} + +func (s *PromptSerializationSuite) TestPromptWithOptionalFields() { + s.Run("omits empty optional fields in JSON", func() { + prompt := Prompt{ + Name: "minimal-prompt", + Description: "Minimal prompt without optional fields", + } + + jsonData, err := json.Marshal(prompt) + s.Require().NoError(err) + + // Verify optional fields are omitted + var raw map[string]interface{} + err = json.Unmarshal(jsonData, &raw) + s.Require().NoError(err) + + s.Contains(raw, "name") + s.Contains(raw, "description") + // title is omitempty, should not be present if empty + _, hasTitle := raw["title"] + s.False(hasTitle, "empty title should be omitted") + }) + + s.Run("includes optional fields when present", func() { + prompt := Prompt{ + Name: "full-prompt", + Title: "Full Prompt", + Description: "Prompt with all fields", + Arguments: []PromptArgument{ + {Name: "arg1", Required: true}, + }, + } + + jsonData, err := json.Marshal(prompt) + s.Require().NoError(err) + + var raw map[string]interface{} + err = json.Unmarshal(jsonData, &raw) + s.Require().NoError(err) + + s.Contains(raw, "title") + s.Equal("Full Prompt", raw["title"]) + }) +} + +func TestPromptSerialization(t *testing.T) { + suite.Run(t, new(PromptSerializationSuite)) +} diff --git a/pkg/api/prompts.go b/pkg/api/prompts.go new file mode 100644 index 00000000..bc1d8b9f --- /dev/null +++ b/pkg/api/prompts.go @@ -0,0 +1,96 @@ +package api + +import ( + "context" + + internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" +) + +// ServerPrompt represents a prompt that can be registered with the MCP server. +// Prompts provide pre-defined workflow templates and guidance to AI assistants. +type ServerPrompt struct { + Prompt Prompt + Handler PromptHandlerFunc + ClusterAware *bool + ArgumentSchema map[string]PromptArgument +} + +// IsClusterAware indicates whether the prompt can accept a "cluster" or "context" parameter +// to operate on a specific Kubernetes cluster context. +// Defaults to true if not explicitly set +func (s *ServerPrompt) IsClusterAware() bool { + if s.ClusterAware != nil { + return *s.ClusterAware + } + return true +} + +// Prompt represents the metadata and content of an MCP prompt. +// See MCP specification: https://spec.modelcontextprotocol.io/specification/server/prompts/ +type Prompt struct { + Name string `yaml:"name" json:"name" toml:"name"` + Title string `yaml:"title,omitempty" json:"title,omitempty" toml:"title,omitempty"` + Description string `yaml:"description,omitempty" json:"description,omitempty" toml:"description,omitempty"` + Arguments []PromptArgument `yaml:"arguments,omitempty" json:"arguments,omitempty" toml:"arguments,omitempty"` + Templates []PromptTemplate `yaml:"messages,omitempty" json:"messages,omitempty" toml:"messages,omitempty"` +} + +// PromptArgument defines a parameter that can be passed to a prompt. +// See MCP specification: https://spec.modelcontextprotocol.io/specification/server/prompts/ +type PromptArgument struct { + Name string `yaml:"name" json:"name" toml:"name"` + Description string `yaml:"description,omitempty" json:"description,omitempty" toml:"description,omitempty"` + Required bool `yaml:"required" json:"required" toml:"required"` +} + +// PromptTemplate represents a message template from configuration with placeholders like {{arg}}. +// This is used for configuration parsing and gets rendered into PromptMessage at runtime. +type PromptTemplate struct { + Role string `yaml:"role" json:"role" toml:"role"` + Content string `yaml:"content" json:"content" toml:"content"` +} + +// PromptMessage represents a single message in a prompt response. +// See MCP specification: https://spec.modelcontextprotocol.io/specification/server/prompts/ +type PromptMessage struct { + Role string `yaml:"role" json:"role" toml:"role"` + Content PromptContent `yaml:"content" json:"content" toml:"content"` +} + +// PromptContent represents the content of a prompt message. +// See MCP specification: https://spec.modelcontextprotocol.io/specification/server/prompts/ +type PromptContent struct { + Type string `yaml:"type" json:"type" toml:"type"` + Text string `yaml:"text,omitempty" json:"text,omitempty" toml:"text,omitempty"` +} + +// PromptCallRequest interface for accessing prompt call arguments +type PromptCallRequest interface { + GetArguments() map[string]string +} + +// PromptCallResult represents the result of executing a prompt +type PromptCallResult struct { + Description string + Messages []PromptMessage + Error error +} + +// NewPromptCallResult creates a new PromptCallResult +func NewPromptCallResult(description string, messages []PromptMessage, err error) *PromptCallResult { + return &PromptCallResult{ + Description: description, + Messages: messages, + Error: err, + } +} + +// PromptHandlerParams contains the parameters passed to a prompt handler +type PromptHandlerParams struct { + context.Context + *internalk8s.Kubernetes + PromptCallRequest +} + +// PromptHandlerFunc is a function that handles prompt execution +type PromptHandlerFunc func(params PromptHandlerParams) (*PromptCallResult, error) diff --git a/pkg/api/prompts_test.go b/pkg/api/prompts_test.go new file mode 100644 index 00000000..52126092 --- /dev/null +++ b/pkg/api/prompts_test.go @@ -0,0 +1,80 @@ +package api + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "k8s.io/utils/ptr" +) + +func TestServerPrompt_IsClusterAware(t *testing.T) { + tests := []struct { + name string + clusterAware *bool + want bool + }{ + { + name: "nil defaults to true", + clusterAware: nil, + want: true, + }, + { + name: "explicitly true", + clusterAware: ptr.To(true), + want: true, + }, + { + name: "explicitly false", + clusterAware: ptr.To(false), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sp := &ServerPrompt{ + ClusterAware: tt.clusterAware, + } + assert.Equal(t, tt.want, sp.IsClusterAware()) + }) + } +} + +func TestNewPromptCallResult(t *testing.T) { + tests := []struct { + name string + description string + messages []PromptMessage + err error + }{ + { + name: "successful result", + description: "Test description", + messages: []PromptMessage{ + { + Role: "user", + Content: PromptContent{ + Type: "text", + Text: "Hello", + }, + }, + }, + err: nil, + }, + { + name: "result with error", + description: "Error description", + messages: nil, + err: assert.AnError, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := NewPromptCallResult(tt.description, tt.messages, tt.err) + assert.Equal(t, tt.description, result.Description) + assert.Equal(t, tt.messages, result.Messages) + assert.Equal(t, tt.err, result.Error) + }) + } +} diff --git a/pkg/config/config.go b/pkg/config/config.go index 80295112..7a459a82 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -74,6 +74,13 @@ type StaticConfig struct { // This map holds raw TOML primitives that will be parsed by registered toolset parsers ToolsetConfigs map[string]toml.Primitive `toml:"toolset_configs,omitempty"` + // Prompt configuration + // Raw TOML primitive for prompt definitions, parsed later + // Note: Uses toml:"-" because Primitive can't be encoded, only decoded + Prompts toml.Primitive `toml:"-"` + promptsDefined bool // Internal: tracks if prompts were defined in config + promptsMetadata toml.MetaData // Internal: metadata for prompts decoding + // Internal: parsed provider configs (not exposed to TOML package) parsedClusterProviderConfigs map[string]configapi.Extended // Internal: parsed toolset configs (not exposed to TOML package) @@ -280,6 +287,18 @@ func ReadToml(configData []byte, opts ...ReadConfigOpt) (*StaticConfig, error) { return nil, err } + // Store prompts primitive if defined + if md.IsDefined("prompts") { + var temp struct { + Prompts toml.Primitive `toml:"prompts"` + } + // Re-decode to get the primitive + tempMd, _ := toml.NewDecoder(bytes.NewReader(configData)).Decode(&temp) + config.Prompts = temp.Prompts + config.promptsDefined = true + config.promptsMetadata = tempMd + } + return config, nil } @@ -309,3 +328,13 @@ func (c *StaticConfig) GetToolsetConfig(name string) (configapi.Extended, bool) func (c *StaticConfig) IsRequireOAuth() bool { return c.RequireOAuth } + +// HasPrompts returns whether prompts were defined in the configuration +func (c *StaticConfig) HasPrompts() bool { + return c.promptsDefined +} + +// GetPromptsMetadata returns the TOML metadata for prompts +func (c *StaticConfig) GetPromptsMetadata() toml.MetaData { + return c.promptsMetadata +} diff --git a/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go b/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go index e0d53e21..ca645f00 100644 --- a/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go +++ b/pkg/kubernetes-mcp-server/cmd/root_sighup_test.go @@ -204,6 +204,49 @@ func (s *SIGHUPSuite) TestSIGHUPWithConfigDirOnly() { }) } +func (s *SIGHUPSuite) TestSIGHUPReloadsPrompts() { + // Create initial config with one prompt + configPath := filepath.Join(s.tempDir, "config.toml") + s.Require().NoError(os.WriteFile(configPath, []byte(` + [[prompts]] + name = "initial-prompt" + description = "Initial prompt" + + [[prompts.messages]] + role = "user" + content = "Initial message" + `), 0644)) + s.InitServer(configPath, "") + + prompts, err := s.server.GetPrompts() + s.Require().NoError(err) + s.Len(prompts, 1) + s.Equal("initial-prompt", prompts[0].Prompt.Name) + + // Update config with new prompt + s.Require().NoError(os.WriteFile(configPath, []byte(` + [[prompts]] + name = "updated-prompt" + description = "Updated prompt" + + [[prompts.messages]] + role = "user" + content = "Updated message" + `), 0644)) + + // Send SIGHUP + s.Require().NoError(syscall.Kill(syscall.Getpid(), syscall.SIGHUP)) + + // Verify prompts were reloaded + s.Require().Eventually(func() bool { + prompts, err := s.server.GetPrompts() + if err != nil { + return false + } + return len(prompts) == 1 && prompts[0].Prompt.Name == "updated-prompt" + }, 2*time.Second, 50*time.Millisecond) +} + func TestSIGHUP(t *testing.T) { suite.Run(t, new(SIGHUPSuite)) } diff --git a/pkg/mcp/mcp.go b/pkg/mcp/mcp.go index 9c86598c..698f72d4 100644 --- a/pkg/mcp/mcp.go +++ b/pkg/mcp/mcp.go @@ -16,6 +16,7 @@ import ( "github.com/containers/kubernetes-mcp-server/pkg/config" internalk8s "github.com/containers/kubernetes-mcp-server/pkg/kubernetes" "github.com/containers/kubernetes-mcp-server/pkg/output" + "github.com/containers/kubernetes-mcp-server/pkg/prompts" "github.com/containers/kubernetes-mcp-server/pkg/toolsets" "github.com/containers/kubernetes-mcp-server/pkg/version" ) @@ -63,10 +64,11 @@ func (c *Configuration) isToolApplicable(tool api.ServerTool) bool { } type Server struct { - configuration *Configuration - server *mcp.Server - enabledTools []string - p internalk8s.Provider + configuration *Configuration + server *mcp.Server + enabledTools []string + enabledPrompts []string + p internalk8s.Provider } func NewServer(configuration Configuration) (*Server, error) { @@ -78,7 +80,7 @@ func NewServer(configuration Configuration) (*Server, error) { }, &mcp.ServerOptions{ HasResources: false, - HasPrompts: false, + HasPrompts: true, HasTools: true, }), } @@ -160,9 +162,75 @@ func (s *Server) reloadToolsets() error { } s.server.AddTool(goSdkTool, goSdkToolHandler) } + + // Track previously enabled prompts + previousPrompts := s.enabledPrompts + + // Load config prompts into registry + prompts.Clear() + if s.configuration.HasPrompts() { + ctx := context.Background() + md := s.configuration.GetPromptsMetadata() + if err := prompts.LoadFromToml(ctx, s.configuration.Prompts, md); err != nil { + return fmt.Errorf("failed to parse prompts from config: %w", err) + } + } + + // Get prompts from registry + configPrompts := prompts.ConfigPrompts() + + // Update enabled prompts list + s.enabledPrompts = make([]string, 0) + for _, prompt := range configPrompts { + s.enabledPrompts = append(s.enabledPrompts, prompt.Prompt.Name) + } + + // Remove prompts that are no longer applicable + promptsToRemove := make([]string, 0) + for _, oldPrompt := range previousPrompts { + if !slices.Contains(s.enabledPrompts, oldPrompt) { + promptsToRemove = append(promptsToRemove, oldPrompt) + } + } + s.server.RemovePrompts(promptsToRemove...) + + // Register all config prompts + for _, prompt := range configPrompts { + mcpPrompt, promptHandler, err := ServerPromptToGoSdkPrompt(s, prompt) + if err != nil { + return fmt.Errorf("failed to convert prompt %s: %v", prompt.Prompt.Name, err) + } + s.server.AddPrompt(mcpPrompt, promptHandler) + } + + // start new watch + s.p.WatchTargets(s.reloadToolsets) return nil } +// mergePrompts merges two slices of prompts, with prompts in override taking precedence +// over prompts in base when they have the same name +func mergePrompts(base, override []api.ServerPrompt) []api.ServerPrompt { + // Create a map of override prompts by name for quick lookup + overrideMap := make(map[string]api.ServerPrompt) + for _, prompt := range override { + overrideMap[prompt.Prompt.Name] = prompt + } + + // Build result: start with base prompts, skipping any that are overridden + result := make([]api.ServerPrompt, 0, len(base)+len(override)) + for _, prompt := range base { + if _, exists := overrideMap[prompt.Prompt.Name]; !exists { + result = append(result, prompt) + } + } + + // Add all override prompts + result = append(result, override...) + + return result +} + func (s *Server) ServeStdio(ctx context.Context) error { return s.server.Run(ctx, &mcp.LoggingTransport{Transport: &mcp.StdioTransport{}, Writer: os.Stderr}) } @@ -204,6 +272,11 @@ func (s *Server) GetEnabledTools() []string { return s.enabledTools } +// GetPrompts returns the currently loaded prompts from the registry +func (s *Server) GetPrompts() ([]api.ServerPrompt, error) { + return prompts.ConfigPrompts(), nil +} + // ReloadConfiguration reloads the configuration and reinitializes the server. // This is intended to be called by the server lifecycle manager when // configuration changes are detected. diff --git a/pkg/mcp/mcp_prompts_test.go b/pkg/mcp/mcp_prompts_test.go new file mode 100644 index 00000000..0e66014e --- /dev/null +++ b/pkg/mcp/mcp_prompts_test.go @@ -0,0 +1,155 @@ +package mcp + +import ( + "testing" + + "github.com/containers/kubernetes-mcp-server/pkg/config" + "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/suite" +) + +// McpPromptsSuite tests MCP prompts integration +type McpPromptsSuite struct { + BaseMcpSuite +} + +// loadPromptsConfig is a helper to parse prompts config and merge it with the test config +func (s *McpPromptsSuite) loadPromptsConfig(configData string) { + cfg, err := config.ReadToml([]byte(configData)) + s.Require().NoError(err, "Expected to parse prompts config") + + // Copy the parsed prompts to the test config + kubeconfig := s.Cfg.KubeConfig + listOutput := s.Cfg.ListOutput + + s.Cfg = cfg + s.Cfg.KubeConfig = kubeconfig + s.Cfg.ListOutput = listOutput +} + +func (s *McpPromptsSuite) TestListPrompts() { + s.loadPromptsConfig(` +[[prompts]] +name = "test-prompt" +title = "Test Prompt" +description = "A test prompt for integration testing" + +[[prompts.arguments]] +name = "test_arg" +description = "A test argument" +required = true + +[[prompts.messages]] +role = "user" +content = "Test message with {{test_arg}}" + `) + + s.InitMcpClient() + + prompts, err := s.ListPrompts(s.T().Context(), mcp.ListPromptsRequest{}) + + s.Run("ListPrompts returns prompts", func() { + s.NoError(err, "call ListPrompts failed") + s.NotNilf(prompts, "list prompts failed") + }) + + s.Run("config prompt is available with all metadata", func() { + s.Require().NotNil(prompts) + var testPrompt *mcp.Prompt + for _, prompt := range prompts.Prompts { + if prompt.Name == "test-prompt" { + testPrompt = &prompt + break + } + } + s.Require().NotNil(testPrompt, "test-prompt should be found") + + // Verify all metadata fields are returned + s.Equal("test-prompt", testPrompt.Name) + s.Equal("A test prompt for integration testing", testPrompt.Description, "description should match") + s.Require().Len(testPrompt.Arguments, 1) + s.Equal("test_arg", testPrompt.Arguments[0].Name) + s.Equal("A test argument", testPrompt.Arguments[0].Description) + s.True(testPrompt.Arguments[0].Required) + }) +} + +func (s *McpPromptsSuite) TestGetPrompt() { + s.loadPromptsConfig(` +[[prompts]] +name = "substitution-prompt" +description = "Test argument substitution" + +[[prompts.arguments]] +name = "name" +description = "Name to substitute" +required = true + +[[prompts.messages]] +role = "user" +content = "Hello {{name}}!" + `) + + s.InitMcpClient() + + result, err := s.GetPrompt(s.T().Context(), mcp.GetPromptRequest{ + Params: mcp.GetPromptParams{ + Name: "substitution-prompt", + Arguments: map[string]string{ + "name": "World", + }, + }, + }) + + s.Run("GetPrompt succeeds", func() { + s.NoError(err, "call GetPrompt failed") + s.NotNilf(result, "get prompt failed") + }) + + s.Run("argument substitution works", func() { + s.Require().NotNil(result) + s.Equal("Test argument substitution", result.Description) + s.Require().Len(result.Messages, 1) + s.Equal("user", string(result.Messages[0].Role)) + textContent, ok := result.Messages[0].Content.(mcp.TextContent) + s.Require().True(ok, "expected TextContent") + s.Equal("text", textContent.Type) + s.Equal("Hello World!", textContent.Text) + }) +} + +func (s *McpPromptsSuite) TestGetPromptMissingRequiredArgument() { + s.loadPromptsConfig(` +[[prompts]] +name = "required-arg-prompt" +description = "Test required argument validation" + +[[prompts.arguments]] +name = "required_arg" +description = "A required argument" +required = true + +[[prompts.messages]] +role = "user" +content = "Content with {{required_arg}}" + `) + + s.InitMcpClient() + + result, err := s.GetPrompt(s.T().Context(), mcp.GetPromptRequest{ + Params: mcp.GetPromptParams{ + Name: "required-arg-prompt", + Arguments: map[string]string{}, + }, + }) + + s.Run("missing required argument returns error", func() { + s.Error(err, "expected error for missing required argument") + s.Nil(result) + s.Contains(err.Error(), "required argument 'required_arg' is missing") + }) +} + +func TestMcpPromptsSuite(t *testing.T) { + suite.Run(t, new(McpPromptsSuite)) +} diff --git a/pkg/mcp/mcp_watch_test.go b/pkg/mcp/mcp_watch_test.go index 1e977ea8..100a16a8 100644 --- a/pkg/mcp/mcp_watch_test.go +++ b/pkg/mcp/mcp_watch_test.go @@ -43,7 +43,10 @@ func (s *WatchKubeConfigSuite) TestNotifiesToolsChange() { notification := s.WaitForNotification(5 * time.Second) // Then s.NotNil(notification, "WatchKubeConfig did not notify") - s.Equal("notifications/tools/list_changed", notification.Method, "WatchKubeConfig did not notify tools change") + s.True( + notification.Method == "notifications/tools/list_changed" || notification.Method == "notifications/prompts/list_changed", + "WatchKubeConfig did not notify tools or prompts change, got: %s", notification.Method, + ) } func (s *WatchKubeConfigSuite) TestNotifiesToolsChangeMultipleTimes() { @@ -135,7 +138,10 @@ func (s *WatchClusterStateSuite) TestNotifiesToolsChangeOnAPIGroupAddition() { // Then s.NotNil(notification, "cluster state watcher did not notify") - s.Equal("notifications/tools/list_changed", notification.Method, "cluster state watcher did not notify tools change") + s.True( + notification.Method == "notifications/tools/list_changed" || notification.Method == "notifications/prompts/list_changed", + "cluster state watcher did not notify tools or prompts change, got: %s", notification.Method, + ) } func (s *WatchClusterStateSuite) TestNotifiesToolsChangeMultipleTimes() { diff --git a/pkg/mcp/prompts_config_test.go b/pkg/mcp/prompts_config_test.go new file mode 100644 index 00000000..91ed2a9c --- /dev/null +++ b/pkg/mcp/prompts_config_test.go @@ -0,0 +1,168 @@ +package mcp + +import ( + "testing" + + "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/stretchr/testify/assert" +) + +func TestMergePrompts(t *testing.T) { + tests := []struct { + name string + base []api.ServerPrompt + override []api.ServerPrompt + expectedCount int + expectedNames []string + expectedSource string // Which source should win for overlapping names + }{ + { + name: "merge with no overlap", + base: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt1"}}, + {Prompt: api.Prompt{Name: "prompt2"}}, + }, + override: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt3"}}, + }, + expectedCount: 3, + expectedNames: []string{"prompt1", "prompt2", "prompt3"}, + }, + { + name: "override replaces base prompt with same name", + base: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt1", Description: "Base description"}}, + {Prompt: api.Prompt{Name: "prompt2"}}, + }, + override: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt1", Description: "Override description"}}, + }, + expectedCount: 2, + expectedNames: []string{"prompt2", "prompt1"}, + expectedSource: "Override description", + }, + { + name: "empty base", + base: []api.ServerPrompt{}, + override: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt1"}}, + }, + expectedCount: 1, + expectedNames: []string{"prompt1"}, + }, + { + name: "empty override", + base: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt1"}}, + }, + override: []api.ServerPrompt{}, + expectedCount: 1, + expectedNames: []string{"prompt1"}, + }, + { + name: "both empty", + base: []api.ServerPrompt{}, + override: []api.ServerPrompt{}, + expectedCount: 0, + }, + { + name: "multiple overrides", + base: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt1", Description: "Base 1"}}, + {Prompt: api.Prompt{Name: "prompt2", Description: "Base 2"}}, + {Prompt: api.Prompt{Name: "prompt3", Description: "Base 3"}}, + }, + override: []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "prompt1", Description: "Override 1"}}, + {Prompt: api.Prompt{Name: "prompt3", Description: "Override 3"}}, + }, + expectedCount: 3, + expectedNames: []string{"prompt2", "prompt1", "prompt3"}, + expectedSource: "Override 1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := mergePrompts(tt.base, tt.override) + + assert.Len(t, result, tt.expectedCount, "unexpected number of prompts") + + if len(tt.expectedNames) > 0 { + actualNames := make([]string, len(result)) + for i, p := range result { + actualNames[i] = p.Prompt.Name + } + assert.ElementsMatch(t, tt.expectedNames, actualNames, "prompt names don't match") + } + + // Check that override wins for specific test case + if tt.expectedSource != "" { + for _, p := range result { + if p.Prompt.Name == "prompt1" { + assert.Equal(t, tt.expectedSource, p.Prompt.Description, "override should win") + } + } + } + }) + } +} + +func TestMergePromptsPreservesOrder(t *testing.T) { + base := []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "base1"}}, + {Prompt: api.Prompt{Name: "base2"}}, + {Prompt: api.Prompt{Name: "base3"}}, + } + + override := []api.ServerPrompt{ + {Prompt: api.Prompt{Name: "override1"}}, + {Prompt: api.Prompt{Name: "override2"}}, + } + + result := mergePrompts(base, override) + + // Base prompts should come first (those not overridden) + assert.Equal(t, "base1", result[0].Prompt.Name) + assert.Equal(t, "base2", result[1].Prompt.Name) + assert.Equal(t, "base3", result[2].Prompt.Name) + + // Then override prompts + assert.Equal(t, "override1", result[3].Prompt.Name) + assert.Equal(t, "override2", result[4].Prompt.Name) +} + +func TestMergePromptsCompleteReplacement(t *testing.T) { + base := []api.ServerPrompt{ + { + Prompt: api.Prompt{ + Name: "test-prompt", + Description: "Base description", + Arguments: []api.PromptArgument{ + {Name: "base_arg", Required: true}, + }, + }, + }, + } + + override := []api.ServerPrompt{ + { + Prompt: api.Prompt{ + Name: "test-prompt", + Description: "Override description", + Arguments: []api.PromptArgument{ + {Name: "override_arg", Required: false}, + }, + }, + }, + } + + result := mergePrompts(base, override) + + assert.Len(t, result, 1) + assert.Equal(t, "test-prompt", result[0].Prompt.Name) + assert.Equal(t, "Override description", result[0].Prompt.Description) + assert.Len(t, result[0].Prompt.Arguments, 1) + assert.Equal(t, "override_arg", result[0].Prompt.Arguments[0].Name) + assert.False(t, result[0].Prompt.Arguments[0].Required) +} diff --git a/pkg/mcp/prompts_gosdk.go b/pkg/mcp/prompts_gosdk.go new file mode 100644 index 00000000..a153ed24 --- /dev/null +++ b/pkg/mcp/prompts_gosdk.go @@ -0,0 +1,100 @@ +package mcp + +import ( + "context" + "fmt" + + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/containers/kubernetes-mcp-server/pkg/api" +) + +// promptCallRequestAdapter adapts MCP GetPromptRequest to api.PromptCallRequest +type promptCallRequestAdapter struct { + request *mcp.GetPromptRequest +} + +func (p *promptCallRequestAdapter) GetArguments() map[string]string { + if p.request == nil || p.request.Params == nil || p.request.Params.Arguments == nil { + return make(map[string]string) + } + return p.request.Params.Arguments +} + +// ServerPromptToGoSdkPrompt converts an api.ServerPrompt to MCP SDK types +func ServerPromptToGoSdkPrompt(s *Server, serverPrompt api.ServerPrompt) (*mcp.Prompt, mcp.PromptHandler, error) { + // Convert arguments + var args []*mcp.PromptArgument + for _, arg := range serverPrompt.Prompt.Arguments { + args = append(args, &mcp.PromptArgument{ + Name: arg.Name, + Description: arg.Description, + Required: arg.Required, + }) + } + + // Create the MCP SDK prompt + mcpPrompt := &mcp.Prompt{ + Name: serverPrompt.Prompt.Name, + Description: serverPrompt.Prompt.Description, + Arguments: args, + } + + // Create the handler that wraps the ServerPrompt handler + handler := func(ctx context.Context, request *mcp.GetPromptRequest) (*mcp.GetPromptResult, error) { + clusterParam := s.p.GetTargetParameterName() + var cluster string + if request.Params != nil && request.Params.Arguments != nil { + if val, ok := request.Params.Arguments[clusterParam]; ok { + cluster = val + } + } + + k8s, err := s.p.GetDerivedKubernetes(ctx, cluster) + if err != nil { + return nil, fmt.Errorf("failed to get kubernetes client: %w", err) + } + + params := api.PromptHandlerParams{ + Context: ctx, + Kubernetes: k8s, + PromptCallRequest: &promptCallRequestAdapter{request: request}, + } + + result, err := serverPrompt.Handler(params) + if err != nil { + return nil, err + } + + if result.Error != nil { + return nil, result.Error + } + + var messages []*mcp.PromptMessage + for _, msg := range result.Messages { + mcpMsg := &mcp.PromptMessage{ + Role: mcp.Role(msg.Role), + } + + switch msg.Content.Type { + case "text": + mcpMsg.Content = &mcp.TextContent{ + Text: msg.Content.Text, + } + default: + mcpMsg.Content = &mcp.TextContent{ + Text: msg.Content.Text, + } + } + + messages = append(messages, mcpMsg) + } + + return &mcp.GetPromptResult{ + Description: result.Description, + Messages: messages, + }, nil + } + + return mcpPrompt, handler, nil +} diff --git a/pkg/mcp/prompts_gosdk_test.go b/pkg/mcp/prompts_gosdk_test.go new file mode 100644 index 00000000..8093cf53 --- /dev/null +++ b/pkg/mcp/prompts_gosdk_test.go @@ -0,0 +1,140 @@ +package mcp + +import ( + "testing" + + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/containers/kubernetes-mcp-server/pkg/api" +) + +func TestPromptCallRequestAdapter_GetArguments(t *testing.T) { + tests := []struct { + name string + request *mcp.GetPromptRequest + want map[string]string + }{ + { + name: "nil request", + request: nil, + want: map[string]string{}, + }, + { + name: "nil params", + request: &mcp.GetPromptRequest{ + Params: nil, + }, + want: map[string]string{}, + }, + { + name: "nil arguments", + request: &mcp.GetPromptRequest{ + Params: &mcp.GetPromptParams{ + Arguments: nil, + }, + }, + want: map[string]string{}, + }, + { + name: "with arguments", + request: &mcp.GetPromptRequest{ + Params: &mcp.GetPromptParams{ + Arguments: map[string]string{ + "namespace": "default", + "pod_name": "test-pod", + }, + }, + }, + want: map[string]string{ + "namespace": "default", + "pod_name": "test-pod", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + adapter := &promptCallRequestAdapter{request: tt.request} + got := adapter.GetArguments() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestServerPromptToGoSdkPrompt_Conversion(t *testing.T) { + serverPrompt := api.ServerPrompt{ + Prompt: api.Prompt{ + Name: "test-prompt", + Description: "Test description", + Arguments: []api.PromptArgument{ + { + Name: "arg1", + Description: "First argument", + Required: true, + }, + { + Name: "arg2", + Description: "Second argument", + Required: false, + }, + }, + }, + Handler: func(params api.PromptHandlerParams) (*api.PromptCallResult, error) { + return api.NewPromptCallResult("Test result", []api.PromptMessage{ + { + Role: "user", + Content: api.PromptContent{ + Type: "text", + Text: "Test message", + }, + }, + }, nil), nil + }, + } + + mockServer := &Server{} + + mcpPrompt, handler, err := ServerPromptToGoSdkPrompt(mockServer, serverPrompt) + + require.NoError(t, err) + require.NotNil(t, mcpPrompt) + require.NotNil(t, handler) + + assert.Equal(t, "test-prompt", mcpPrompt.Name) + assert.Equal(t, "Test description", mcpPrompt.Description) + require.Len(t, mcpPrompt.Arguments, 2) + + assert.Equal(t, "arg1", mcpPrompt.Arguments[0].Name) + assert.Equal(t, "First argument", mcpPrompt.Arguments[0].Description) + assert.True(t, mcpPrompt.Arguments[0].Required) + + assert.Equal(t, "arg2", mcpPrompt.Arguments[1].Name) + assert.Equal(t, "Second argument", mcpPrompt.Arguments[1].Description) + assert.False(t, mcpPrompt.Arguments[1].Required) +} + +func TestServerPromptToGoSdkPrompt_EmptyArguments(t *testing.T) { + serverPrompt := api.ServerPrompt{ + Prompt: api.Prompt{ + Name: "no-args-prompt", + Description: "Prompt with no arguments", + Arguments: []api.PromptArgument{}, + }, + Handler: func(params api.PromptHandlerParams) (*api.PromptCallResult, error) { + return api.NewPromptCallResult("Result", []api.PromptMessage{}, nil), nil + }, + } + + mockServer := &Server{} + + mcpPrompt, handler, err := ServerPromptToGoSdkPrompt(mockServer, serverPrompt) + + require.NoError(t, err) + require.NotNil(t, mcpPrompt) + require.NotNil(t, handler) + + assert.Equal(t, "no-args-prompt", mcpPrompt.Name) + assert.Len(t, mcpPrompt.Arguments, 0) +} diff --git a/pkg/prompts/prompts.go b/pkg/prompts/prompts.go new file mode 100644 index 00000000..638c46d5 --- /dev/null +++ b/pkg/prompts/prompts.go @@ -0,0 +1,92 @@ +package prompts + +import ( + "context" + "fmt" + "strings" + + "github.com/BurntSushi/toml" + "github.com/containers/kubernetes-mcp-server/pkg/api" +) + +var configPrompts []api.ServerPrompt + +// Clear removes all registered prompts +func Clear() { + configPrompts = []api.ServerPrompt{} +} + +// Register registers prompts to be available in the MCP server +func Register(prompts ...api.ServerPrompt) { + configPrompts = append(configPrompts, prompts...) +} + +// ConfigPrompts returns all prompts loaded from configuration +func ConfigPrompts() []api.ServerPrompt { + return configPrompts +} + +// LoadFromToml parses prompts from TOML configuration data and registers them +func LoadFromToml(ctx context.Context, primitive toml.Primitive, md toml.MetaData) error { + var prompts []api.Prompt + if err := md.PrimitiveDecode(primitive, &prompts); err != nil { + return fmt.Errorf("failed to parse prompts from TOML: %w", err) + } + + serverPrompts := createServerPrompts(prompts) + Register(serverPrompts...) + return nil +} + +// createServerPrompts converts Prompt definitions to ServerPrompts with handlers +func createServerPrompts(prompts []api.Prompt) []api.ServerPrompt { + serverPrompts := make([]api.ServerPrompt, 0, len(prompts)) + for _, prompt := range prompts { + serverPrompts = append(serverPrompts, api.ServerPrompt{ + Prompt: prompt, + Handler: createPromptHandler(prompt), + }) + } + return serverPrompts +} + +// createPromptHandler creates a handler function for a prompt +func createPromptHandler(prompt api.Prompt) api.PromptHandlerFunc { + return func(params api.PromptHandlerParams) (*api.PromptCallResult, error) { + args := params.GetArguments() + + // Validate required arguments + for _, arg := range prompt.Arguments { + if arg.Required { + if _, exists := args[arg.Name]; !exists { + return nil, fmt.Errorf("required argument '%s' is missing", arg.Name) + } + } + } + + // Render messages with argument substitution + messages := make([]api.PromptMessage, 0, len(prompt.Templates)) + for _, template := range prompt.Templates { + content := substituteArguments(template.Content, args) + messages = append(messages, api.PromptMessage{ + Role: template.Role, + Content: api.PromptContent{ + Type: "text", + Text: content, + }, + }) + } + + return api.NewPromptCallResult(prompt.Description, messages, nil), nil + } +} + +// substituteArguments replaces {{argument}} placeholders in content with actual values +func substituteArguments(content string, args map[string]string) string { + result := content + for key, value := range args { + placeholder := fmt.Sprintf("{{%s}}", key) + result = strings.ReplaceAll(result, placeholder, value) + } + return result +} diff --git a/pkg/prompts/prompts_test.go b/pkg/prompts/prompts_test.go new file mode 100644 index 00000000..da31fa48 --- /dev/null +++ b/pkg/prompts/prompts_test.go @@ -0,0 +1,236 @@ +package prompts + +import ( + "bytes" + "context" + "testing" + + "github.com/BurntSushi/toml" + "github.com/containers/kubernetes-mcp-server/pkg/api" + "github.com/stretchr/testify/suite" +) + +// PromptsTestSuite tests the prompts package +type PromptsTestSuite struct { + suite.Suite +} + +func (s *PromptsTestSuite) SetupTest() { + // Clear prompts before each test + Clear() +} + +func (s *PromptsTestSuite) TestLoadFromToml_SinglePrompt() { + s.Run("parses single prompt with all fields", func() { + tomlData := ` +[[prompts]] +name = "test-prompt" +title = "Test Prompt" +description = "A test prompt for validation" + +[[prompts.arguments]] +name = "pod_name" +description = "Name of the pod" +required = true + +[[prompts.arguments]] +name = "namespace" +description = "Namespace of the pod" +required = false + +[[prompts.messages]] +role = "user" +content = "Describe pod {{pod_name}} in namespace {{namespace}}" + +[[prompts.messages]] +role = "assistant" +content = "I'll help you with pod {{pod_name}}" +` + + var temp struct { + Prompts toml.Primitive `toml:"prompts"` + } + md, err := toml.NewDecoder(bytes.NewReader([]byte(tomlData))).Decode(&temp) + s.Require().NoError(err) + + ctx := context.Background() + err = LoadFromToml(ctx, temp.Prompts, md) + s.Require().NoError(err) + + serverPrompts := ConfigPrompts() + s.Require().Len(serverPrompts, 1) + + prompt := serverPrompts[0].Prompt + s.Equal("test-prompt", prompt.Name) + s.Equal("Test Prompt", prompt.Title) + s.Equal("A test prompt for validation", prompt.Description) + + // Verify arguments + s.Require().Len(prompt.Arguments, 2) + s.Equal("pod_name", prompt.Arguments[0].Name) + s.True(prompt.Arguments[0].Required) + s.Equal("namespace", prompt.Arguments[1].Name) + s.False(prompt.Arguments[1].Required) + + // Verify templates + s.Require().Len(prompt.Templates, 2) + s.Equal("user", prompt.Templates[0].Role) + s.Contains(prompt.Templates[0].Content, "{{pod_name}}") + s.Equal("assistant", prompt.Templates[1].Role) + + // Verify handler was created + s.NotNil(serverPrompts[0].Handler) + }) +} + +func (s *PromptsTestSuite) TestLoadFromToml_MultiplePrompts() { + tomlData := ` +[[prompts]] +name = "prompt-1" +description = "First prompt" + +[[prompts.messages]] +role = "user" +content = "Message 1" + +[[prompts]] +name = "prompt-2" +description = "Second prompt" + +[[prompts.arguments]] +name = "arg1" +required = true + +[[prompts.messages]] +role = "user" +content = "Message 2 with {{arg1}}" +` + + var temp struct { + Prompts toml.Primitive `toml:"prompts"` + } + md, err := toml.NewDecoder(bytes.NewReader([]byte(tomlData))).Decode(&temp) + s.Require().NoError(err) + + ctx := context.Background() + err = LoadFromToml(ctx, temp.Prompts, md) + s.Require().NoError(err) + + serverPrompts := ConfigPrompts() + s.Require().Len(serverPrompts, 2) + + s.Equal("prompt-1", serverPrompts[0].Prompt.Name) + s.Equal("prompt-2", serverPrompts[1].Prompt.Name) + + // Verify second prompt has arguments + s.Require().Len(serverPrompts[1].Prompt.Arguments, 1) + s.Equal("arg1", serverPrompts[1].Prompt.Arguments[0].Name) + s.True(serverPrompts[1].Prompt.Arguments[0].Required) +} + +func (s *PromptsTestSuite) TestRegister() { + s.Run("registers prompts correctly", func() { + prompt1 := api.ServerPrompt{ + Prompt: api.Prompt{ + Name: "test-1", + Description: "Test 1", + }, + } + prompt2 := api.ServerPrompt{ + Prompt: api.Prompt{ + Name: "test-2", + Description: "Test 2", + }, + } + + Register(prompt1, prompt2) + + prompts := ConfigPrompts() + s.Len(prompts, 2) + s.Equal("test-1", prompts[0].Prompt.Name) + s.Equal("test-2", prompts[1].Prompt.Name) + }) +} + +func (s *PromptsTestSuite) TestClear() { + s.Run("clears all prompts", func() { + prompt := api.ServerPrompt{ + Prompt: api.Prompt{ + Name: "test", + Description: "Test", + }, + } + Register(prompt) + s.Len(ConfigPrompts(), 1) + + Clear() + s.Len(ConfigPrompts(), 0) + }) +} + +func (s *PromptsTestSuite) TestPromptHandler() { + s.Run("validates required arguments", func() { + prompt := api.Prompt{ + Name: "test", + Description: "Test", + Arguments: []api.PromptArgument{ + {Name: "required_arg", Required: true}, + {Name: "optional_arg", Required: false}, + }, + Templates: []api.PromptTemplate{ + {Role: "user", Content: "Hello {{required_arg}}"}, + }, + } + + handler := createPromptHandler(prompt) + + // Test missing required argument + params := &testPromptRequest{args: map[string]string{}} + result, err := handler(api.PromptHandlerParams{PromptCallRequest: params}) + s.Error(err) + s.Contains(err.Error(), "required argument 'required_arg' is missing") + s.Nil(result) + + // Test with required argument + params = &testPromptRequest{args: map[string]string{"required_arg": "value"}} + result, err = handler(api.PromptHandlerParams{PromptCallRequest: params}) + s.NoError(err) + s.NotNil(result) + s.Len(result.Messages, 1) + s.Equal("Hello value", result.Messages[0].Content.Text) + }) +} + +func (s *PromptsTestSuite) TestSubstituteArguments() { + s.Run("replaces placeholders correctly", func() { + content := "Hello {{name}}, your age is {{age}}" + args := map[string]string{ + "name": "Alice", + "age": "30", + } + + result := substituteArguments(content, args) + s.Equal("Hello Alice, your age is 30", result) + }) + + s.Run("handles missing arguments", func() { + content := "Hello {{name}}" + args := map[string]string{} + + result := substituteArguments(content, args) + s.Equal("Hello {{name}}", result) + }) +} + +// testPromptRequest is a test implementation of PromptCallRequest +type testPromptRequest struct { + args map[string]string +} + +func (t *testPromptRequest) GetArguments() map[string]string { + return t.args +} + +func TestPrompts(t *testing.T) { + suite.Run(t, new(PromptsTestSuite)) +}