mirror of
https://github.com/Alexey71/opera-proxy.git
synced 2026-05-13 14:11:00 +00:00
94 lines
2.8 KiB
Go
94 lines
2.8 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"testing"
|
|
|
|
"github.com/Snawoot/opera-proxy/dialer"
|
|
clog "github.com/Snawoot/opera-proxy/log"
|
|
)
|
|
|
|
type recordingDialer struct {
|
|
name string
|
|
addresses []string
|
|
}
|
|
|
|
func (d *recordingDialer) Dial(network, address string) (net.Conn, error) {
|
|
return d.DialContext(context.Background(), network, address)
|
|
}
|
|
|
|
func (d *recordingDialer) DialContext(_ context.Context, _ string, address string) (net.Conn, error) {
|
|
d.addresses = append(d.addresses, address)
|
|
return nil, errors.New(d.name)
|
|
}
|
|
|
|
func testProxyLogger() *clog.CondLogger {
|
|
return clog.NewCondLogger(log.New(io.Discard, "", 0), clog.CRITICAL)
|
|
}
|
|
|
|
func TestProxyHandlerBypassesHTTPRequestsByTargetHost(t *testing.T) {
|
|
direct := &recordingDialer{name: "direct"}
|
|
proxied := &recordingDialer{name: "proxied"}
|
|
bypassDialer, err := dialer.NewBypassDialer([]string{"*.example.com"}, direct, proxied)
|
|
if err != nil {
|
|
t.Fatalf("NewBypassDialer() error = %v", err)
|
|
}
|
|
|
|
h := NewProxyHandler(bypassDialer, testProxyLogger(), "")
|
|
req := httptest.NewRequest(http.MethodGet, "http://check.example.com/path", nil)
|
|
rr := httptest.NewRecorder()
|
|
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusInternalServerError {
|
|
t.Fatalf("ServeHTTP() status = %d, want %d", rr.Code, http.StatusInternalServerError)
|
|
}
|
|
if len(direct.addresses) != 1 || direct.addresses[0] != "check.example.com:80" {
|
|
t.Fatalf("direct dialer addresses = %#v, want []string{\"check.example.com:80\"}", direct.addresses)
|
|
}
|
|
if len(proxied.addresses) != 0 {
|
|
t.Fatalf("proxied dialer should not be used, got %#v", proxied.addresses)
|
|
}
|
|
}
|
|
|
|
func TestProxyHandlerBypassesConnectRequestsByTargetHost(t *testing.T) {
|
|
direct := &recordingDialer{name: "direct"}
|
|
proxied := &recordingDialer{name: "proxied"}
|
|
bypassDialer, err := dialer.NewBypassDialer([]string{"*.example.com"}, direct, proxied)
|
|
if err != nil {
|
|
t.Fatalf("NewBypassDialer() error = %v", err)
|
|
}
|
|
|
|
h := NewProxyHandler(bypassDialer, testProxyLogger(), "")
|
|
req := &http.Request{
|
|
Method: http.MethodConnect,
|
|
URL: &url.URL{Host: "secure.example.com:443"},
|
|
Host: "secure.example.com:443",
|
|
RequestURI: "secure.example.com:443",
|
|
Proto: "HTTP/1.1",
|
|
ProtoMajor: 1,
|
|
ProtoMinor: 1,
|
|
Header: make(http.Header),
|
|
}
|
|
rr := httptest.NewRecorder()
|
|
|
|
h.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusBadGateway {
|
|
t.Fatalf("ServeHTTP() status = %d, want %d", rr.Code, http.StatusBadGateway)
|
|
}
|
|
if len(direct.addresses) != 1 || direct.addresses[0] != "secure.example.com:443" {
|
|
t.Fatalf("direct dialer addresses = %#v, want []string{\"secure.example.com:443\"}", direct.addresses)
|
|
}
|
|
if len(proxied.addresses) != 0 {
|
|
t.Fatalf("proxied dialer should not be used, got %#v", proxied.addresses)
|
|
}
|
|
}
|