Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmd/thv/app/run_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"fmt"
"log/slog"
"strings"
"time"

"github.com/spf13/cobra"

Expand Down Expand Up @@ -104,6 +105,9 @@ type RunFlags struct {
// Endpoint prefix for SSE endpoint URLs
EndpointPrefix string

// SessionTTL is the session inactivity timeout. Zero uses the transport default.
SessionTTL time.Duration

// Network mode
Network string

Expand Down Expand Up @@ -264,6 +268,8 @@ func AddRunFlags(cmd *cobra.Command, config *RunFlags) {
cmd.Flags().BoolVar(&config.Stateless, "stateless", false,
"Declare the server as stateless (POST-only, no SSE). "+
"Use for MCP servers implementing streamable-HTTP stateless mode.")
cmd.Flags().DurationVar(&config.SessionTTL, "session-ttl", 0,
"Session inactivity timeout (e.g., 30m, 2h); zero uses the default (2h)")
cmd.Flags().StringVar(&config.EndpointPrefix, "endpoint-prefix", "",
"Path prefix to prepend to SSE endpoint URLs (e.g., /playwright)")
cmd.Flags().StringVar(&config.Network, "network", "",
Expand Down Expand Up @@ -665,6 +671,7 @@ func buildRunnerConfig(
runner.WithAllowDockerGateway(runFlags.AllowDockerGateway),
runner.WithTrustProxyHeaders(runFlags.TrustProxyHeaders),
runner.WithStateless(runFlags.Stateless),
runner.WithSessionTTL(runFlags.SessionTTL),
runner.WithEndpointPrefix(runFlags.EndpointPrefix),
runner.WithNetworkMode(runFlags.Network),
runner.WithK8sPodPatch(runFlags.K8sPodPatch),
Expand Down
5 changes: 5 additions & 0 deletions cmd/thv/app/vmcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package app

import (
"fmt"
"time"

"github.com/spf13/cobra"

Expand Down Expand Up @@ -39,6 +40,7 @@ func newVMCPServeCommand() *cobra.Command {
enableEmbedding bool
embeddingModel string
embeddingImage string
sessionTTL time.Duration
)
cmd := &cobra.Command{
Use: "serve",
Expand All @@ -64,6 +66,7 @@ configuration file is needed for the common case of aggregating a local group.`,
EnableEmbedding: enableEmbedding,
EmbeddingModel: embeddingModel,
EmbeddingImage: embeddingImage,
SessionTTL: sessionTTL,
})
},
}
Expand All @@ -80,6 +83,8 @@ configuration file is needed for the common case of aggregating a local group.`,
cmd.Flags().StringVar(&host, "host", "127.0.0.1", "Host address to bind to")
cmd.Flags().IntVar(&port, "port", 4483, "Port to listen on")
cmd.Flags().BoolVar(&enableAudit, "enable-audit", false, "Enable audit logging with default configuration")
cmd.Flags().DurationVar(&sessionTTL, "session-ttl", 0,
"Session inactivity timeout (e.g., 30m, 2h); zero uses the default (30m)")
return cmd
}

Expand Down
5 changes: 5 additions & 0 deletions cmd/vmcp/app/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package app
import (
"fmt"
"log/slog"
"time"

"github.com/spf13/cobra"
"github.com/spf13/viper"
Expand Down Expand Up @@ -96,12 +97,14 @@ from all configured backend MCP servers.`,
host, _ := cmd.Flags().GetString("host")
port, _ := cmd.Flags().GetInt("port")
enableAudit, _ := cmd.Flags().GetBool("enable-audit")
sessionTTL, _ := cmd.Flags().GetDuration("session-ttl")

return vmcpcli.Serve(cmd.Context(), vmcpcli.ServeConfig{
ConfigPath: configPath,
Host: host,
Port: port,
EnableAudit: enableAudit,
SessionTTL: sessionTTL,
})
},
}
Expand All @@ -110,6 +113,8 @@ from all configured backend MCP servers.`,
cmd.Flags().String("host", "127.0.0.1", "Host address to bind to")
cmd.Flags().Int("port", 4483, "Port to listen on")
cmd.Flags().Bool("enable-audit", false, "Enable audit logging with default configuration")
cmd.Flags().Duration("session-ttl", time.Duration(0),
"Session inactivity timeout (e.g., 30m, 2h); zero uses the default (30m)")

return cmd
}
Expand Down
1 change: 1 addition & 0 deletions docs/cli/thv_run.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docs/cli/thv_vmcp_serve.md

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions docs/server/docs.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions docs/server/swagger.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions docs/server/swagger.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions pkg/runner/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"fmt"
"io"
"log/slog"
"time"

"github.com/stacklok/toolhive-core/permissions"
v1beta1 "github.com/stacklok/toolhive/cmd/thv-operator/api/v1beta1"
Expand Down Expand Up @@ -203,6 +204,10 @@ type RunConfig struct {
// Applies to both remote URLs and local container workloads.
Stateless bool `json:"stateless,omitempty" yaml:"stateless,omitempty"`

// SessionTTL is the inactivity timeout for proxy sessions.
// Zero uses the transport default (2h). Negative values are rejected by the builder.
SessionTTL time.Duration `json:"session_ttl,omitempty" yaml:"session_ttl,omitempty" swaggertype:"primitive,integer"`

// ProxyMode is the effective HTTP protocol the proxy uses.
// For stdio transports, this is the configured mode (sse or streamable-http).
// For direct transports (sse/streamable-http), this matches the transport type.
Expand Down
14 changes: 14 additions & 0 deletions pkg/runner/config_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"path/filepath"
"slices"
"strings"
"time"

"github.com/stacklok/toolhive-core/permissions"
regtypes "github.com/stacklok/toolhive-core/registry/types"
Expand Down Expand Up @@ -362,6 +363,19 @@ func WithEndpointPrefix(prefix string) RunConfigBuilderOption {
}
}

// WithSessionTTL sets the inactivity timeout for proxy sessions.
// Zero is valid and means "use the transport default" (2h).
// Negative values return an error.
func WithSessionTTL(ttl time.Duration) RunConfigBuilderOption {
return func(b *runConfigBuilder) error {
if ttl < 0 {
return fmt.Errorf("session-ttl must be non-negative, got %s", ttl)
}
b.config.SessionTTL = ttl
return nil
}
}

// WithNetworkMode sets the network mode for the container.
// The network mode will be applied to the permission profile after it is loaded.
func WithNetworkMode(networkMode string) RunConfigBuilderOption {
Expand Down
51 changes: 51 additions & 0 deletions pkg/runner/config_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1489,6 +1489,57 @@ func TestWithRegistryServerName(t *testing.T) {
}
}

func TestWithSessionTTL(t *testing.T) {
t.Parallel()

tests := []struct {
name string
ttl time.Duration
expectErr bool
expectedTTL time.Duration
}{
{
name: "zero is accepted and means use default",
ttl: 0,
expectErr: false,
expectedTTL: 0,
},
{
name: "positive duration is accepted",
ttl: 45 * time.Minute,
expectErr: false,
expectedTTL: 45 * time.Minute,
},
{
name: "large positive duration is accepted",
ttl: 24 * time.Hour,
expectErr: false,
expectedTTL: 24 * time.Hour,
},
{
name: "negative duration returns an error",
ttl: -1 * time.Second,
expectErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()

builder := &runConfigBuilder{config: NewRunConfig()}
err := WithSessionTTL(tt.ttl)(builder)

if tt.expectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.expectedTTL, builder.config.SessionTTL)
})
}
}

func TestResolveRegistryServerName(t *testing.T) {
t.Parallel()

Expand Down
11 changes: 10 additions & 1 deletion pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,14 @@ func (c *RunConfig) GetPort() int {
//
//nolint:gocyclo // This function is complex but manageable
func (r *Runner) Run(ctx context.Context) error {
// Resolve session TTL once so both the transport proxy and Redis storage use
// the same effective value, rather than each applying their own zero-fallback
// independently.
effectiveSessionTTL := r.Config.SessionTTL
if effectiveSessionTTL <= 0 {
effectiveSessionTTL = session.DefaultSessionTTL
}

// Create transport with runtime
transportConfig := types.Config{
Type: r.Config.Transport,
Expand All @@ -177,6 +185,7 @@ func (r *Runner) Run(ctx context.Context) error {
Debug: r.Config.Debug,
TrustProxyHeaders: r.Config.TrustProxyHeaders,
EndpointPrefix: r.Config.EndpointPrefix,
SessionTTL: effectiveSessionTTL,
}

// Set proxy mode for stdio transport
Expand Down Expand Up @@ -368,7 +377,7 @@ func (r *Runner) Run(ctx context.Context) error {
Password: os.Getenv(session.RedisPasswordEnvVar),
DB: int(redisCfg.DB),
KeyPrefix: keyPrefix,
}, session.DefaultSessionTTL)
}, effectiveSessionTTL)
if err != nil {
return fmt.Errorf("failed to create Redis session storage: %w", err)
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/transport/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er
if config.SessionStorage != nil {
stdio.SetSessionStorage(config.SessionStorage)
}
stdio.SetSessionTTL(config.SessionTTL)
tr = stdio
case types.TransportTypeSSE:
httpTransport := NewHTTPTransport(
Expand All @@ -73,6 +74,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er
config.Middlewares...,
)
httpTransport.sessionStorage = config.SessionStorage
httpTransport.sessionTTL = config.SessionTTL
tr = httpTransport
case types.TransportTypeStreamableHTTP:
httpTransport := NewHTTPTransport(
Expand All @@ -91,6 +93,7 @@ func (*Factory) Create(config types.Config, opts ...Option) (types.Transport, er
config.Middlewares...,
)
httpTransport.sessionStorage = config.SessionStorage
httpTransport.sessionTTL = config.SessionTTL
tr = httpTransport
case types.TransportTypeInspector:
// HTTP transport is not implemented yet
Expand Down
8 changes: 8 additions & 0 deletions pkg/transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os"
"strings"
"sync"
"time"

"golang.org/x/oauth2"

Expand Down Expand Up @@ -81,6 +82,10 @@ type HTTPTransport struct {
// Used for Redis-backed session sharing across replicas.
sessionStorage session.Storage

// sessionTTL overrides the inactivity timeout for sessions managed by the
// underlying proxy. Zero uses the proxy's default.
sessionTTL time.Duration

// Transparent proxy
proxy types.Proxy

Expand Down Expand Up @@ -432,6 +437,9 @@ func (t *HTTPTransport) buildProxyOptions(remoteBasePath, remoteRawQuery string)
if t.stateless {
opts = append(opts, transparent.WithStateless())
}
if t.sessionTTL > 0 {
opts = append(opts, transparent.WithSessionTTL(t.sessionTTL))
}
if t.sessionStorage != nil {
opts = append(opts, transparent.WithSessionStorage(t.sessionStorage))
}
Expand Down
Loading
Loading