From e66c989f6ebd025d10c875d3ee49d64fd3aecf0d Mon Sep 17 00:00:00 2001 From: michel-laterman Date: Tue, 3 Jun 2025 16:14:15 -0600 Subject: [PATCH] Add support for ProxyConnectHeader in the dialer. Add the ability to pass ProxyConnectHeader to the dialer. This set of headers will be used when a CONNECT request is made to an http(s) proxy. --- client.go | 5 ++++- client_server_test.go | 40 ++++++++++++++++++++++++++++++++++++++++ proxy.go | 19 ++++++++++++++----- 3 files changed, 58 insertions(+), 6 deletions(-) diff --git a/client.go b/client.go index 00917ea3..52e2c88e 100644 --- a/client.go +++ b/client.go @@ -87,6 +87,9 @@ type Dialer struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*http.Request) (*url.URL, error) + // ProxyConnectHeader specifies optional headers to use during proxy connect requests. + ProxyConnectHeader http.Header + // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. // If NetDialTLSContext is set, Dial assumes the TLS handshake @@ -416,7 +419,7 @@ func (d *Dialer) netDialFn(ctx context.Context, proxyURL *url.URL, backendURL *u } // Proxy dialing is wrapped to implement CONNECT method and possibly proxy auth. if proxyURL != nil { - return proxyFromURL(proxyURL, netDial) + return proxyFromURL(proxyURL, netDial, d.ProxyConnectHeader) } return netDial, nil } diff --git a/client_server_test.go b/client_server_test.go index e4546aea..0e5c9129 100644 --- a/client_server_test.go +++ b/client_server_test.go @@ -242,6 +242,46 @@ func TestProxyAuthorizationDial(t *testing.T) { sendRecv(t, ws) } +func TestProxyDialConnectHeaders(t *testing.T) { + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.Server.URL) + + cstDialer := cstDialer // make local copy for modification on next line. + cstDialer.Proxy = http.ProxyURL(surl) + cstDialer.ProxyConnectHeader = http.Header{"User-Agent": []string{"test-proxy-agent"}} + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + t.Logf("Request headers: %v", r.Header) + userAgent := r.Header.Get("User-Agent") + if r.Method == http.MethodConnect && userAgent == "test-proxy-agent" { + connect = true + w.WriteHeader(http.StatusOK) + return + } + + if !connect { + t.Log("connect with proxy connect headers not received") + http.Error(w, "connect with proxy connect headers not received", http.StatusMethodNotAllowed) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) +} + func TestDial(t *testing.T) { s := newServer(t) defer s.Close() diff --git a/proxy.go b/proxy.go index d716a058..edb861a5 100644 --- a/proxy.go +++ b/proxy.go @@ -28,9 +28,9 @@ func (fn netDialerFunc) DialContext(ctx context.Context, network, addr string) ( return fn(ctx, network, addr) } -func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, error) { +func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc, connectHeader http.Header) (netDialerFunc, error) { if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial}).DialContext, nil + return (&httpProxyDialer{proxyURL: proxyURL, forwardDial: forwardDial, proxyConnectHeader: connectHeader}).DialContext, nil } dialer, err := proxy.FromURL(proxyURL, forwardDial) if err != nil { @@ -45,8 +45,13 @@ func proxyFromURL(proxyURL *url.URL, forwardDial netDialerFunc) (netDialerFunc, } type httpProxyDialer struct { - proxyURL *url.URL - forwardDial netDialerFunc + proxyURL *url.URL + forwardDial netDialerFunc + proxyConnectHeader http.Header +} + +func (hpd *httpProxyDialer) Dial(network, addr string) (net.Conn, error) { + return hpd.DialContext(context.Background(), network, addr) } func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, addr string) (net.Conn, error) { @@ -56,7 +61,11 @@ func (hpd *httpProxyDialer) DialContext(ctx context.Context, network string, add return nil, err } - connectHeader := make(http.Header) + connectHeader := hpd.proxyConnectHeader + if hpd.proxyConnectHeader == nil { + connectHeader = make(http.Header) + } + if user := hpd.proxyURL.User; user != nil { proxyUser := user.Username() if proxyPassword, passwordSet := user.Password(); passwordSet {