diff --git a/README.md b/README.md index 54effb9..498abbf 100644 --- a/README.md +++ b/README.md @@ -132,7 +132,7 @@ If SurfEasy discover returns API error `801`, the app also automatically tries ` | -discover-csv | String | read proxy endpoints from CSV instead of SurfEasy discover API | | -dp-export | - | export configuration for dumbproxy | | -fetch-freeproxy-out | - | download proxy list from `https://advanced.name/freeproxy` and save it as a text file with one ip:port per line. Examples: `-fetch-freeproxy-out proxies.txt` or `-fetch-freeproxy-out D:\myproxy.txt` | -| -fake-SNI | String | domain name to use as SNI in outbound TLS and in tunneled TLS ClientHello when possible | +| -fake-SNI | String | domain name to use as SNI in communications with servers | | -init-retries | Number | number of attempts for initialization steps, zero for unlimited retry | | -init-retry-interval | Duration | delay between initialization retries (default 5s) | | -list-countries | - | list available countries and exit | diff --git a/handler/handler.go b/handler/handler.go index 3c0456b..c152452 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -41,10 +41,9 @@ type ProxyHandler struct { logger *clog.CondLogger dialer dialer.ContextDialer httptransport http.RoundTripper - fakeSNI string } -func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger, fakeSNI string) *ProxyHandler { +func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *ProxyHandler { httptransport := &http.Transport{ MaxIdleConns: TRANSPORT_MAX_IDLE_CONNS, MaxIdleConnsPerHost: TRANSPORT_MAX_IDLE_CONNS_PER_HOST, @@ -57,7 +56,6 @@ func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger, fakeS logger: logger, dialer: dialer, httptransport: httptransport, - fakeSNI: fakeSNI, } } @@ -72,7 +70,7 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { if req.ProtoMajor == 0 || req.ProtoMajor == 1 { // Upgrade client connection - localconn, rw, err := hijack(wr) + localconn, _, err := hijack(wr) if err != nil { s.logger.Error("Can't hijack client connection: %v", err) http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError) @@ -83,16 +81,12 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) { // Inform client connection is built fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor) - clientReader := io.Reader(localconn) - if rw != nil && rw.Reader.Buffered() > 0 { - clientReader = io.MultiReader(rw.Reader, localconn) - } - proxy(req.Context(), localconn, clientReader, conn, s.fakeSNI) + proxy(req.Context(), localconn, conn) } else if req.ProtoMajor == 2 { wr.Header()["Date"] = nil wr.WriteHeader(http.StatusOK) flush(wr) - proxyh2(req.Context(), req.Body, wr, conn, s.fakeSNI) + proxyh2(req.Context(), req.Body, wr, conn) } else { s.logger.Error("Unsupported protocol version: %s", req.Proto) http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest) @@ -138,14 +132,9 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) { } } -func proxy(ctx context.Context, left net.Conn, leftReader io.Reader, right net.Conn, fakeSNI string) { +func proxy(ctx context.Context, left, right net.Conn) { wg := sync.WaitGroup{} - ltr := func(dst net.Conn, src io.Reader) { - defer wg.Done() - copyWithSNIRewrite(dst, src, fakeSNI) - dst.Close() - } - rtl := func(dst, src net.Conn) { + cpy := func(dst, src net.Conn) { defer wg.Done() // Grab a pooled buffer for this copy direction. bufp := copyBufPool.Get().(*[]byte) @@ -154,8 +143,8 @@ func proxy(ctx context.Context, left net.Conn, leftReader io.Reader, right net.C dst.Close() } wg.Add(2) - go ltr(right, leftReader) - go rtl(left, right) + go cpy(left, right) + go cpy(right, left) groupdone := make(chan struct{}) go func() { wg.Wait() @@ -172,14 +161,13 @@ func proxy(ctx context.Context, left net.Conn, leftReader io.Reader, right net.C return } -func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn, fakeSNI string) { +func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) { wg := sync.WaitGroup{} ltr := func(dst net.Conn, src io.Reader) { defer wg.Done() bufp := copyBufPool.Get().(*[]byte) defer copyBufPool.Put(bufp) io.CopyBuffer(dst, src, *bufp) - copyWithSNIRewrite(dst, src, fakeSNI) dst.Close() } rtl := func(dst io.Writer, src io.Reader) { diff --git a/handler/socks.go b/handler/socks.go index e4c3a2b..9781cc7 100644 --- a/handler/socks.go +++ b/handler/socks.go @@ -13,7 +13,7 @@ import ( "github.com/things-go/go-socks5/statute" ) -func NewSocksServer(dialer dialer.ContextDialer, logger *log.Logger, fakeSNI string) (*socks5.Server, error) { +func NewSocksServer(dialer dialer.ContextDialer, logger *log.Logger) (*socks5.Server, error) { opts := []socks5.Option{ socks5.WithLogger(socks5.NewLogger(logger)), socks5.WithRule( @@ -23,13 +23,13 @@ func NewSocksServer(dialer dialer.ContextDialer, logger *log.Logger, fakeSNI str ), socks5.WithResolver(DummySocksResolver{}), socks5.WithConnectHandle(func(ctx context.Context, writer io.Writer, request *socks5.Request) error { - return handleSocksConnect(ctx, writer, request, dialer, fakeSNI) + return handleSocksConnect(ctx, writer, request, dialer) }), } return socks5.NewServer(opts...), nil } -func handleSocksConnect(ctx context.Context, writer io.Writer, request *socks5.Request, upstream dialer.ContextDialer, fakeSNI string) error { +func handleSocksConnect(ctx context.Context, writer io.Writer, request *socks5.Request, upstream dialer.ContextDialer) error { target, err := upstream.DialContext(ctx, "tcp", request.DestAddr.String()) if err != nil { reply := statute.RepHostUnreachable @@ -56,7 +56,7 @@ func handleSocksConnect(ctx context.Context, writer io.Writer, request *socks5.R return fmt.Errorf("writer is %T, expected net.Conn", writer) } - proxy(ctx, clientConn, request.Reader, target, fakeSNI) + proxy(ctx, clientConn, target) return nil } diff --git a/handler/tls_sni.go b/handler/tls_sni.go deleted file mode 100644 index f2dc175..0000000 --- a/handler/tls_sni.go +++ /dev/null @@ -1,227 +0,0 @@ -package handler - -import ( - "bufio" - "encoding/binary" - "errors" - "io" - "strings" -) - -const ( - tlsRecordHeaderLen = 5 - tlsHandshakeHeaderLen = 4 - tlsRecordTypeHandshake = 0x16 - tlsHandshakeTypeClientHello = 0x01 - tlsExtensionServerName = 0x0000 -) - -func copyWithSNIRewrite(dst io.Writer, src io.Reader, fakeSNI string) error { - fakeSNI = strings.TrimSpace(fakeSNI) - if fakeSNI == "" { - _, err := io.Copy(dst, src) - return err - } - - br, ok := src.(*bufio.Reader) - if !ok { - br = bufio.NewReader(src) - } - - header, err := br.Peek(tlsRecordHeaderLen) - if err != nil { - _, copyErr := io.Copy(dst, br) - if copyErr != nil { - return copyErr - } - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - return nil - } - return err - } - if !looksLikeTLSClientHelloRecord(header) { - _, err = io.Copy(dst, br) - return err - } - - recordLen := int(binary.BigEndian.Uint16(header[3:5])) - record := make([]byte, tlsRecordHeaderLen+recordLen) - n, err := io.ReadFull(br, record) - if err != nil { - if n > 0 { - if _, writeErr := dst.Write(record[:n]); writeErr != nil { - return writeErr - } - } - if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { - _, copyErr := io.Copy(dst, br) - return copyErr - } - return err - } - - if rewritten, ok := rewriteTLSClientHelloRecordServerName(record, fakeSNI); ok { - record = rewritten - } - - if _, err := dst.Write(record); err != nil { - return err - } - _, err = io.Copy(dst, br) - return err -} - -func looksLikeTLSClientHelloRecord(header []byte) bool { - return len(header) >= tlsRecordHeaderLen && - header[0] == tlsRecordTypeHandshake && - header[1] == 0x03 && - header[2] <= 0x04 -} - -func rewriteTLSClientHelloRecordServerName(record []byte, fakeSNI string) ([]byte, bool) { - if len(record) < tlsRecordHeaderLen+tlsHandshakeHeaderLen { - return nil, false - } - if !looksLikeTLSClientHelloRecord(record[:tlsRecordHeaderLen]) { - return nil, false - } - - payload := record[tlsRecordHeaderLen:] - if len(payload) < tlsHandshakeHeaderLen || payload[0] != tlsHandshakeTypeClientHello { - return nil, false - } - - handshakeLen := readUint24(payload[1:4]) - if handshakeLen > len(payload)-tlsHandshakeHeaderLen { - // ClientHello is fragmented across multiple TLS records. - return nil, false - } - - hello := payload[tlsHandshakeHeaderLen : tlsHandshakeHeaderLen+handshakeLen] - offset := 0 - if !skipLen(hello, &offset, 2+32) { - return nil, false - } - if !skipOpaque8(hello, &offset) { - return nil, false - } - if !skipOpaque16(hello, &offset) { - return nil, false - } - if !skipOpaque8(hello, &offset) { - return nil, false - } - if offset == len(hello) { - return nil, false - } - if offset+2 > len(hello) { - return nil, false - } - - extensionsLenOffset := offset - extensionsLen := int(binary.BigEndian.Uint16(hello[offset : offset+2])) - offset += 2 - if offset+extensionsLen > len(hello) { - return nil, false - } - - extensionsEnd := offset + extensionsLen - for offset+4 <= extensionsEnd { - extStart := offset - extType := binary.BigEndian.Uint16(hello[offset : offset+2]) - extLen := int(binary.BigEndian.Uint16(hello[offset+2 : offset+4])) - offset += 4 - if offset+extLen > extensionsEnd { - return nil, false - } - if extType != tlsExtensionServerName { - offset += extLen - continue - } - - extDataStart := offset - extDataEnd := offset + extLen - extData := hello[extDataStart:extDataEnd] - if len(extData) < 5 { - return nil, false - } - - serverNameListLen := int(binary.BigEndian.Uint16(extData[:2])) - if 2+serverNameListLen > len(extData) { - return nil, false - } - if extData[2] != 0x00 { - return nil, false - } - - nameLen := int(binary.BigEndian.Uint16(extData[3:5])) - if 5+nameLen > len(extData) { - return nil, false - } - - tail := extData[5+nameLen:] - newExtData := make([]byte, 2+1+2+len(fakeSNI)+len(tail)) - binary.BigEndian.PutUint16(newExtData[:2], uint16(1+2+len(fakeSNI)+len(tail))) - newExtData[2] = 0x00 - binary.BigEndian.PutUint16(newExtData[3:5], uint16(len(fakeSNI))) - copy(newExtData[5:], fakeSNI) - copy(newExtData[5+len(fakeSNI):], tail) - - helloStart := tlsRecordHeaderLen + tlsHandshakeHeaderLen - extLenFieldStart := helloStart + extStart + 2 - extDataAbsStart := helloStart + extDataStart - extDataAbsEnd := helloStart + extDataEnd - extensionsLenFieldStart := helloStart + extensionsLenOffset - - delta := len(newExtData) - len(extData) - newRecord := make([]byte, 0, len(record)+delta) - newRecord = append(newRecord, record[:extDataAbsStart]...) - newRecord = append(newRecord, newExtData...) - newRecord = append(newRecord, record[extDataAbsEnd:]...) - - binary.BigEndian.PutUint16(newRecord[3:5], uint16(len(payload)+delta)) - writeUint24(newRecord[6:9], handshakeLen+delta) - binary.BigEndian.PutUint16(newRecord[extensionsLenFieldStart:extensionsLenFieldStart+2], uint16(extensionsLen+delta)) - binary.BigEndian.PutUint16(newRecord[extLenFieldStart:extLenFieldStart+2], uint16(extLen+delta)) - - return newRecord, true - } - - return nil, false -} - -func readUint24(b []byte) int { - return int(b[0])<<16 | int(b[1])<<8 | int(b[2]) -} - -func writeUint24(dst []byte, v int) { - dst[0] = byte(v >> 16) - dst[1] = byte(v >> 8) - dst[2] = byte(v) -} - -func skipLen(b []byte, offset *int, n int) bool { - if *offset+n > len(b) { - return false - } - *offset += n - return true -} - -func skipOpaque8(b []byte, offset *int) bool { - if *offset >= len(b) { - return false - } - l := int(b[*offset]) - *offset++ - return skipLen(b, offset, l) -} - -func skipOpaque16(b []byte, offset *int) bool { - if *offset+2 > len(b) { - return false - } - l := int(binary.BigEndian.Uint16(b[*offset : *offset+2])) - *offset += 2 - return skipLen(b, offset, l) -} diff --git a/handler/tls_sni_test.go b/handler/tls_sni_test.go deleted file mode 100644 index 5d23414..0000000 --- a/handler/tls_sni_test.go +++ /dev/null @@ -1,133 +0,0 @@ -package handler - -import ( - "bytes" - "encoding/binary" - "testing" -) - -func TestRewriteTLSClientHelloRecordServerName(t *testing.T) { - record := buildClientHelloRecord("example.com") - - rewritten, ok := rewriteTLSClientHelloRecordServerName(record, "fake.example") - if !ok { - t.Fatal("expected ClientHello record to be rewritten") - } - - if got := extractServerName(t, rewritten); got != "fake.example" { - t.Fatalf("unexpected SNI after rewrite: got %q", got) - } - if got := int(binary.BigEndian.Uint16(rewritten[3:5])); got != len(rewritten)-tlsRecordHeaderLen { - t.Fatalf("unexpected TLS record length: got %d want %d", got, len(rewritten)-tlsRecordHeaderLen) - } - if got := readUint24(rewritten[6:9]); got != len(rewritten)-tlsRecordHeaderLen-tlsHandshakeHeaderLen { - t.Fatalf("unexpected handshake length: got %d want %d", got, len(rewritten)-tlsRecordHeaderLen-tlsHandshakeHeaderLen) - } -} - -func TestCopyWithSNIRewritePreservesTrailingBytes(t *testing.T) { - record := buildClientHelloRecord("example.com") - stream := append(append([]byte{}, record...), []byte("tail")...) - - var dst bytes.Buffer - if err := copyWithSNIRewrite(&dst, bytes.NewReader(stream), "fake.example"); err != nil { - t.Fatalf("copyWithSNIRewrite returned error: %v", err) - } - - out := dst.Bytes() - recordLen := int(binary.BigEndian.Uint16(out[3:5])) + tlsRecordHeaderLen - if got := extractServerName(t, out[:recordLen]); got != "fake.example" { - t.Fatalf("unexpected SNI in output stream: got %q", got) - } - if tail := string(out[recordLen:]); tail != "tail" { - t.Fatalf("unexpected trailing bytes: got %q", tail) - } -} - -func TestRewriteTLSClientHelloRecordServerNameSkipsFragmentedHello(t *testing.T) { - record := buildClientHelloRecord("example.com") - record[6] = 0x7f - record[7] = 0xff - record[8] = 0xff - - if _, ok := rewriteTLSClientHelloRecordServerName(record, "fake.example"); ok { - t.Fatal("expected fragmented ClientHello rewrite to be skipped") - } -} - -func buildClientHelloRecord(serverName string) []byte { - sniExtData := make([]byte, 2+1+2+len(serverName)) - binary.BigEndian.PutUint16(sniExtData[:2], uint16(1+2+len(serverName))) - sniExtData[2] = 0x00 - binary.BigEndian.PutUint16(sniExtData[3:5], uint16(len(serverName))) - copy(sniExtData[5:], serverName) - - sniExt := makeExtension(tlsExtensionServerName, sniExtData) - otherExt := makeExtension(0x002b, []byte{0x02, 0x03, 0x04}) - extensions := append(sniExt, otherExt...) - - hello := make([]byte, 0, 128) - hello = append(hello, 0x03, 0x03) - hello = append(hello, bytes.Repeat([]byte{0x11}, 32)...) - hello = append(hello, 0x00) - hello = append(hello, 0x00, 0x02, 0x13, 0x01) - hello = append(hello, 0x01, 0x00) - hello = append(hello, byte(len(extensions)>>8), byte(len(extensions))) - hello = append(hello, extensions...) - - record := make([]byte, 0, tlsRecordHeaderLen+tlsHandshakeHeaderLen+len(hello)) - record = append(record, tlsRecordTypeHandshake, 0x03, 0x03) - recordLen := tlsHandshakeHeaderLen + len(hello) - record = append(record, byte(recordLen>>8), byte(recordLen)) - record = append(record, tlsHandshakeTypeClientHello) - record = append(record, byte(len(hello)>>16), byte(len(hello)>>8), byte(len(hello))) - record = append(record, hello...) - return record -} - -func makeExtension(extType uint16, data []byte) []byte { - ext := make([]byte, 4+len(data)) - binary.BigEndian.PutUint16(ext[:2], extType) - binary.BigEndian.PutUint16(ext[2:4], uint16(len(data))) - copy(ext[4:], data) - return ext -} - -func extractServerName(t *testing.T, record []byte) string { - t.Helper() - - payload := record[tlsRecordHeaderLen:] - handshakeLen := readUint24(payload[1:4]) - hello := payload[tlsHandshakeHeaderLen : tlsHandshakeHeaderLen+handshakeLen] - - offset := 2 + 32 - sessionIDLen := int(hello[offset]) - offset++ - offset += sessionIDLen - - cipherSuitesLen := int(binary.BigEndian.Uint16(hello[offset : offset+2])) - offset += 2 + cipherSuitesLen - - compressionMethodsLen := int(hello[offset]) - offset++ - offset += compressionMethodsLen - - extensionsLen := int(binary.BigEndian.Uint16(hello[offset : offset+2])) - offset += 2 - extensionsEnd := offset + extensionsLen - for offset+4 <= extensionsEnd { - extType := binary.BigEndian.Uint16(hello[offset : offset+2]) - extLen := int(binary.BigEndian.Uint16(hello[offset+2 : offset+4])) - offset += 4 - if extType != tlsExtensionServerName { - offset += extLen - continue - } - extData := hello[offset : offset+extLen] - nameLen := int(binary.BigEndian.Uint16(extData[3:5])) - return string(extData[5 : 5+nameLen]) - } - - t.Fatal("server_name extension not found") - return "" -} diff --git a/main.go b/main.go index 6345804..2195cf2 100644 --- a/main.go +++ b/main.go @@ -1070,7 +1070,7 @@ func run() int { mainLogger.Info("Starting proxy server...") if args.socksMode { - socks, initError := handler.NewSocksServer(handlerDialer, socksLogger, args.fakeSNI) + socks, initError := handler.NewSocksServer(handlerDialer, socksLogger) if initError != nil { mainLogger.Critical("Failed to start: %v", initError) return 16 @@ -1078,7 +1078,7 @@ func run() int { mainLogger.Info("Init complete.") err = socks.ListenAndServe("tcp", args.bindAddress) } else { - h := handler.NewProxyHandler(handlerDialer, proxyLogger, args.fakeSNI) + h := handler.NewProxyHandler(handlerDialer, proxyLogger) mainLogger.Info("Init complete.") err = http.ListenAndServe(args.bindAddress, h) }