Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ You can use both exact matching and glob patterns for OIDC user authorization:
| ---------------------- | -------------------- | ------- | ----------------------------------------------------------------------------------------------------- |
| `--proxy-bearer-token` | `PROXY_BEARER_TOKEN` | - | Bearer token to add to Authorization header when proxying requests |
| `--proxy-headers` | `PROXY_HEADERS` | - | Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2) |
| `--http-streaming-only` | `HTTP_STREAMING_ONLY` | `false` | Reject SSE (GET) requests and keep the backend operating in HTTP streaming-only mode |
| `--trusted-proxies` | `TRUSTED_PROXIES` | - | Comma-separated list of trusted proxies (IP addresses or CIDR ranges) |

For practical configuration examples including environment variables, Docker Compose, and Kubernetes deployments, see the [Configuration Examples](./examples.md) page.
54 changes: 50 additions & 4 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,52 @@ func splitWithEscapes(s, delimiter string) []string {
return result
}

type proxyRunnerFunc func(
listen string,
tlsListen string,
autoTLS bool,
tlsHost string,
tlsDirectoryURL string,
tlsAcceptTOS bool,
tlsCertFile string,
tlsKeyFile string,
dataPath string,
repositoryBackend string,
repositoryDSN string,
externalURL string,
googleClientID string,
googleClientSecret string,
googleAllowedUsers []string,
googleAllowedWorkspaces []string,
githubClientID string,
githubClientSecret string,
githubAllowedUsers []string,
githubAllowedOrgs []string,
oidcConfigurationURL string,
oidcClientID string,
oidcClientSecret string,
oidcScopes []string,
oidcUserIDField string,
oidcProviderName string,
oidcAllowedUsers []string,
oidcAllowedUsersGlob []string,
noProviderAutoSelect bool,
password string,
passwordHash string,
trustedProxy []string,
proxyHeaders []string,
proxyBearerToken string,
proxyTarget []string,
httpStreamingOnly bool,
) error

func main() {
if err := newRootCommand(mcpproxy.Run).Execute(); err != nil {
panic(err)
}
}

func newRootCommand(run proxyRunnerFunc) *cobra.Command {
var listen string
var tlsListen string
var noAutoTLS bool
Expand Down Expand Up @@ -98,6 +143,7 @@ func main() {
var passwordHash string
var proxyBearerToken string
var proxyHeaders string
var httpStreamingOnly bool
var trustedProxies string

rootCmd := &cobra.Command{
Expand Down Expand Up @@ -175,7 +221,7 @@ func main() {
}
}

if err := mcpproxy.Run(
if err := run(
listen,
tlsListen,
(!noAutoTLS) || tlsCertFile != "" || tlsKeyFile != "",
Expand Down Expand Up @@ -211,6 +257,7 @@ func main() {
proxyHeadersList,
proxyBearerToken,
args,
httpStreamingOnly,
); err != nil {
panic(err)
}
Expand Down Expand Up @@ -261,8 +308,7 @@ func main() {
rootCmd.Flags().StringVar(&proxyBearerToken, "proxy-bearer-token", getEnvWithDefault("PROXY_BEARER_TOKEN", ""), "Bearer token to add to Authorization header when proxying requests")
rootCmd.Flags().StringVar(&trustedProxies, "trusted-proxies", getEnvWithDefault("TRUSTED_PROXIES", ""), "Comma-separated list of trusted proxies (IP addresses or CIDR ranges)")
rootCmd.Flags().StringVar(&proxyHeaders, "proxy-headers", getEnvWithDefault("PROXY_HEADERS", ""), "Comma-separated list of headers to add when proxying requests (format: Header1:Value1,Header2:Value2)")
rootCmd.Flags().BoolVar(&httpStreamingOnly, "http-streaming-only", getEnvBoolWithDefault("HTTP_STREAMING_ONLY", false), "Reject SSE (GET) requests and keep the backend in HTTP streaming-only mode")

if err := rootCmd.Execute(); err != nil {
panic(err)
}
return rootCmd
}
119 changes: 119 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,122 @@ func TestGetEnvBoolWithDefault(t *testing.T) {
})
}
}

func TestNewRootCommand_HTTPStreamingOnlyFlag(t *testing.T) {
t.Setenv("HTTP_STREAMING_ONLY", "")

var streamingOnly bool
var receivedTargets []string
runner := proxyRunnerFunc(func(listen string,
tlsListen string,
autoTLS bool,
tlsHost string,
tlsDirectoryURL string,
tlsAcceptTOS bool,
tlsCertFile string,
tlsKeyFile string,
dataPath string,
repositoryBackend string,
repositoryDSN string,
externalURL string,
googleClientID string,
googleClientSecret string,
googleAllowedUsers []string,
googleAllowedWorkspaces []string,
githubClientID string,
githubClientSecret string,
githubAllowedUsers []string,
githubAllowedOrgs []string,
oidcConfigurationURL string,
oidcClientID string,
oidcClientSecret string,
oidcScopes []string,
oidcUserIDField string,
oidcProviderName string,
oidcAllowedUsers []string,
oidcAllowedUsersGlob []string,
noProviderAutoSelect bool,
password string,
passwordHash string,
trustedProxy []string,
proxyHeaders []string,
proxyBearerToken string,
proxyTarget []string,
httpStreamingOnly bool,
) error {
streamingOnly = httpStreamingOnly
receivedTargets = proxyTarget
return nil
})

cmd := newRootCommand(runner)
cmd.SetArgs([]string{"--http-streaming-only", "http://backend"})

if err := cmd.Execute(); err != nil {
t.Fatalf("expected command to succeed, got error: %v", err)
}

if !streamingOnly {
t.Fatalf("expected httpStreamingOnly to be true when flag is set")
}
if len(receivedTargets) != 1 || receivedTargets[0] != "http://backend" {
t.Fatalf("expected proxyTarget to receive CLI args, got %v", receivedTargets)
}
}

func TestNewRootCommand_HTTPStreamingOnlyFromEnv(t *testing.T) {
t.Setenv("HTTP_STREAMING_ONLY", "true")

var streamingOnly bool
runner := proxyRunnerFunc(func(listen string,
tlsListen string,
autoTLS bool,
tlsHost string,
tlsDirectoryURL string,
tlsAcceptTOS bool,
tlsCertFile string,
tlsKeyFile string,
dataPath string,
repositoryBackend string,
repositoryDSN string,
externalURL string,
googleClientID string,
googleClientSecret string,
googleAllowedUsers []string,
googleAllowedWorkspaces []string,
githubClientID string,
githubClientSecret string,
githubAllowedUsers []string,
githubAllowedOrgs []string,
oidcConfigurationURL string,
oidcClientID string,
oidcClientSecret string,
oidcScopes []string,
oidcUserIDField string,
oidcProviderName string,
oidcAllowedUsers []string,
oidcAllowedUsersGlob []string,
noProviderAutoSelect bool,
password string,
passwordHash string,
trustedProxy []string,
proxyHeaders []string,
proxyBearerToken string,
proxyTarget []string,
httpStreamingOnly bool,
) error {
streamingOnly = httpStreamingOnly
return nil
})

cmd := newRootCommand(runner)
cmd.SetArgs([]string{"http://backend"})

if err := cmd.Execute(); err != nil {
t.Fatalf("expected command to succeed, got error: %v", err)
}

if !streamingOnly {
t.Fatalf("expected httpStreamingOnly to default to true from env var")
}
}
5 changes: 4 additions & 1 deletion pkg/mcp-proxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (

var ServerShutdownTimeout = 5 * time.Second

var newProxyRouter = proxy.NewProxyRouter

func Run(
listen string,
tlsListen string,
Expand Down Expand Up @@ -70,6 +72,7 @@ func Run(
proxyHeaders []string,
proxyBearerToken string,
proxyTarget []string,
httpStreamingOnly bool,
) error {
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
defer stop()
Expand Down Expand Up @@ -265,7 +268,7 @@ func Run(
if err != nil {
return fmt.Errorf("failed to create IDP router: %w", err)
}
proxyRouter, err := proxy.NewProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap)
proxyRouter, err := newProxyRouter(externalURL, beHandler, &privKey.PublicKey, proxyHeadersMap, httpStreamingOnly)
if err != nil {
return fmt.Errorf("failed to create proxy router: %w", err)
}
Expand Down
67 changes: 67 additions & 0 deletions pkg/mcp-proxy/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package mcpproxy

import (
"crypto/rsa"
"errors"
"net/http"
"testing"

"github.com/sigbit/mcp-auth-proxy/pkg/proxy"
"github.com/stretchr/testify/require"
)

func TestRun_PassesHTTPStreamingOnlyToProxyRouter(t *testing.T) {
originalNewProxyRouter := newProxyRouter
t.Cleanup(func() {
newProxyRouter = originalNewProxyRouter
})

var streamingOnlyReceived bool
newProxyRouter = func(externalURL string, proxyHandler http.Handler, publicKey *rsa.PublicKey, proxyHeaders http.Header, httpStreamingOnly bool) (*proxy.ProxyRouter, error) {
streamingOnlyReceived = httpStreamingOnly
return nil, errors.New("proxy router init failed")
}

err := Run(
":0",
":0",
false,
"",
"",
false,
"",
"",
t.TempDir(),
"local",
"",
"http://localhost",
"",
"",
nil,
nil,
"",
"",
nil,
nil,
"",
"",
"",
nil,
"",
"",
nil,
nil,
false,
"",
"",
nil,
nil,
"",
[]string{"http://example.com"},
true,
)

require.Error(t, err)
require.Contains(t, err.Error(), "failed to create proxy router")
require.True(t, streamingOnlyReceived, "httpStreamingOnly should be forwarded to proxy router")
}
44 changes: 36 additions & 8 deletions pkg/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,26 @@ import (
)

type ProxyRouter struct {
externalURL string
proxy http.Handler
publicKey *rsa.PublicKey
proxyHeaders http.Header
externalURL string
proxy http.Handler
publicKey *rsa.PublicKey
proxyHeaders http.Header
httpStreamingOnly bool
}

func NewProxyRouter(
externalURL string,
proxy http.Handler,
publicKey *rsa.PublicKey,
proxyHeaders http.Header,
httpStreamingOnly bool,
) (*ProxyRouter, error) {
return &ProxyRouter{
externalURL: externalURL,
proxy: proxy,
publicKey: publicKey,
proxyHeaders: proxyHeaders,
externalURL: externalURL,
proxy: proxy,
publicKey: publicKey,
proxyHeaders: proxyHeaders,
httpStreamingOnly: httpStreamingOnly,
}, nil
}

Expand Down Expand Up @@ -72,6 +75,11 @@ func (p *ProxyRouter) handleProxy(c *gin.Context) {
return
}

if p.httpStreamingOnly && isSSEGetRequest(c.Request) {
c.AbortWithStatusJSON(http.StatusMethodNotAllowed, gin.H{"error": "SSE (GET) streaming is not supported by this backend; use POST-based HTTP streaming instead"})
return
}

c.Request.Header.Del("Authorization")
for key, values := range p.proxyHeaders {
for _, value := range values {
Expand All @@ -81,3 +89,23 @@ func (p *ProxyRouter) handleProxy(c *gin.Context) {

p.proxy.ServeHTTP(c.Writer, c.Request)
}

func isSSEGetRequest(r *http.Request) bool {
if r.Method != http.MethodGet {
return false
}
accept := r.Header.Get("Accept")
if accept == "" {
return false
}
for _, value := range strings.Split(accept, ",") {
mediaType := strings.TrimSpace(strings.ToLower(value))
if idx := strings.Index(mediaType, ";"); idx != -1 {
mediaType = strings.TrimSpace(mediaType[:idx])
}
if mediaType == "text/event-stream" {
return true
}
}
return false
}
Loading