mirror of
https://github.com/Alexey71/opera-proxy.git
synced 2026-05-13 14:11:00 +00:00
Add new API settings
This commit is contained in:
@@ -0,0 +1,93 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Snawoot/opera-proxy/dialer"
|
||||
clog "github.com/Snawoot/opera-proxy/log"
|
||||
)
|
||||
|
||||
type recordingDialer struct {
|
||||
name string
|
||||
addresses []string
|
||||
}
|
||||
|
||||
func (d *recordingDialer) Dial(network, address string) (net.Conn, error) {
|
||||
return d.DialContext(context.Background(), network, address)
|
||||
}
|
||||
|
||||
func (d *recordingDialer) DialContext(_ context.Context, _ string, address string) (net.Conn, error) {
|
||||
d.addresses = append(d.addresses, address)
|
||||
return nil, errors.New(d.name)
|
||||
}
|
||||
|
||||
func testProxyLogger() *clog.CondLogger {
|
||||
return clog.NewCondLogger(log.New(io.Discard, "", 0), clog.CRITICAL)
|
||||
}
|
||||
|
||||
func TestProxyHandlerBypassesHTTPRequestsByTargetHost(t *testing.T) {
|
||||
direct := &recordingDialer{name: "direct"}
|
||||
proxied := &recordingDialer{name: "proxied"}
|
||||
bypassDialer, err := dialer.NewBypassDialer([]string{"*.example.com"}, direct, proxied)
|
||||
if err != nil {
|
||||
t.Fatalf("NewBypassDialer() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewProxyHandler(bypassDialer, testProxyLogger(), "")
|
||||
req := httptest.NewRequest(http.MethodGet, "http://check.example.com/path", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusInternalServerError {
|
||||
t.Fatalf("ServeHTTP() status = %d, want %d", rr.Code, http.StatusInternalServerError)
|
||||
}
|
||||
if len(direct.addresses) != 1 || direct.addresses[0] != "check.example.com:80" {
|
||||
t.Fatalf("direct dialer addresses = %#v, want []string{\"check.example.com:80\"}", direct.addresses)
|
||||
}
|
||||
if len(proxied.addresses) != 0 {
|
||||
t.Fatalf("proxied dialer should not be used, got %#v", proxied.addresses)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyHandlerBypassesConnectRequestsByTargetHost(t *testing.T) {
|
||||
direct := &recordingDialer{name: "direct"}
|
||||
proxied := &recordingDialer{name: "proxied"}
|
||||
bypassDialer, err := dialer.NewBypassDialer([]string{"*.example.com"}, direct, proxied)
|
||||
if err != nil {
|
||||
t.Fatalf("NewBypassDialer() error = %v", err)
|
||||
}
|
||||
|
||||
h := NewProxyHandler(bypassDialer, testProxyLogger(), "")
|
||||
req := &http.Request{
|
||||
Method: http.MethodConnect,
|
||||
URL: &url.URL{Host: "secure.example.com:443"},
|
||||
Host: "secure.example.com:443",
|
||||
RequestURI: "secure.example.com:443",
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
h.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusBadGateway {
|
||||
t.Fatalf("ServeHTTP() status = %d, want %d", rr.Code, http.StatusBadGateway)
|
||||
}
|
||||
if len(direct.addresses) != 1 || direct.addresses[0] != "secure.example.com:443" {
|
||||
t.Fatalf("direct dialer addresses = %#v, want []string{\"secure.example.com:443\"}", direct.addresses)
|
||||
}
|
||||
if len(proxied.addresses) != 0 {
|
||||
t.Fatalf("proxied dialer should not be used, got %#v", proxied.addresses)
|
||||
}
|
||||
}
|
||||
+21
-10
@@ -25,9 +25,10 @@ type ProxyHandler struct {
|
||||
logger *clog.CondLogger
|
||||
dialer dialer.ContextDialer
|
||||
httptransport http.RoundTripper
|
||||
fakeSNI string
|
||||
}
|
||||
|
||||
func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *ProxyHandler {
|
||||
func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger, fakeSNI string) *ProxyHandler {
|
||||
httptransport := &http.Transport{
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
@@ -39,6 +40,7 @@ func NewProxyHandler(dialer dialer.ContextDialer, logger *clog.CondLogger) *Prox
|
||||
logger: logger,
|
||||
dialer: dialer,
|
||||
httptransport: httptransport,
|
||||
fakeSNI: fakeSNI,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +55,7 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
|
||||
|
||||
if req.ProtoMajor == 0 || req.ProtoMajor == 1 {
|
||||
// Upgrade client connection
|
||||
localconn, _, err := hijack(wr)
|
||||
localconn, rw, err := hijack(wr)
|
||||
if err != nil {
|
||||
s.logger.Error("Can't hijack client connection: %v", err)
|
||||
http.Error(wr, "Can't hijack client connection", http.StatusInternalServerError)
|
||||
@@ -64,12 +66,16 @@ func (s *ProxyHandler) HandleTunnel(wr http.ResponseWriter, req *http.Request) {
|
||||
// Inform client connection is built
|
||||
fmt.Fprintf(localconn, "HTTP/%d.%d 200 OK\r\n\r\n", req.ProtoMajor, req.ProtoMinor)
|
||||
|
||||
proxy(req.Context(), localconn, conn)
|
||||
clientReader := io.Reader(localconn)
|
||||
if rw != nil && rw.Reader.Buffered() > 0 {
|
||||
clientReader = io.MultiReader(rw.Reader, localconn)
|
||||
}
|
||||
proxy(req.Context(), localconn, clientReader, conn, s.fakeSNI)
|
||||
} else if req.ProtoMajor == 2 {
|
||||
wr.Header()["Date"] = nil
|
||||
wr.WriteHeader(http.StatusOK)
|
||||
flush(wr)
|
||||
proxyh2(req.Context(), req.Body, wr, conn)
|
||||
proxyh2(req.Context(), req.Body, wr, conn, s.fakeSNI)
|
||||
} else {
|
||||
s.logger.Error("Unsupported protocol version: %s", req.Proto)
|
||||
http.Error(wr, "Unsupported protocol version.", http.StatusBadRequest)
|
||||
@@ -115,16 +121,21 @@ func (s *ProxyHandler) ServeHTTP(wr http.ResponseWriter, req *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
func proxy(ctx context.Context, left, right net.Conn) {
|
||||
func proxy(ctx context.Context, left net.Conn, leftReader io.Reader, right net.Conn, fakeSNI string) {
|
||||
wg := sync.WaitGroup{}
|
||||
cpy := func(dst, src net.Conn) {
|
||||
ltr := func(dst net.Conn, src io.Reader) {
|
||||
defer wg.Done()
|
||||
copyWithSNIRewrite(dst, src, fakeSNI)
|
||||
dst.Close()
|
||||
}
|
||||
rtl := func(dst, src net.Conn) {
|
||||
defer wg.Done()
|
||||
io.Copy(dst, src)
|
||||
dst.Close()
|
||||
}
|
||||
wg.Add(2)
|
||||
go cpy(left, right)
|
||||
go cpy(right, left)
|
||||
go ltr(right, leftReader)
|
||||
go rtl(left, right)
|
||||
groupdone := make(chan struct{})
|
||||
go func() {
|
||||
wg.Wait()
|
||||
@@ -141,11 +152,11 @@ func proxy(ctx context.Context, left, right net.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn) {
|
||||
func proxyh2(ctx context.Context, leftreader io.ReadCloser, leftwriter io.Writer, right net.Conn, fakeSNI string) {
|
||||
wg := sync.WaitGroup{}
|
||||
ltr := func(dst net.Conn, src io.Reader) {
|
||||
defer wg.Done()
|
||||
io.Copy(dst, src)
|
||||
copyWithSNIRewrite(dst, src, fakeSNI)
|
||||
dst.Close()
|
||||
}
|
||||
rtl := func(dst io.Writer, src io.Reader) {
|
||||
|
||||
+39
-2
@@ -2,14 +2,18 @@ package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/Alexey71/opera-proxy/dialer"
|
||||
"github.com/things-go/go-socks5"
|
||||
"github.com/things-go/go-socks5/statute"
|
||||
)
|
||||
|
||||
func NewSocksServer(dialer dialer.ContextDialer, logger *log.Logger) (*socks5.Server, error) {
|
||||
func NewSocksServer(dialer dialer.ContextDialer, logger *log.Logger, fakeSNI string) (*socks5.Server, error) {
|
||||
opts := []socks5.Option{
|
||||
socks5.WithLogger(socks5.NewLogger(logger)),
|
||||
socks5.WithRule(
|
||||
@@ -17,12 +21,45 @@ func NewSocksServer(dialer dialer.ContextDialer, logger *log.Logger) (*socks5.Se
|
||||
EnableConnect: true,
|
||||
},
|
||||
),
|
||||
socks5.WithDial(dialer.DialContext),
|
||||
socks5.WithResolver(DummySocksResolver{}),
|
||||
socks5.WithConnectHandle(func(ctx context.Context, writer io.Writer, request *socks5.Request) error {
|
||||
return handleSocksConnect(ctx, writer, request, dialer, fakeSNI)
|
||||
}),
|
||||
}
|
||||
return socks5.NewServer(opts...), nil
|
||||
}
|
||||
|
||||
func handleSocksConnect(ctx context.Context, writer io.Writer, request *socks5.Request, upstream dialer.ContextDialer, fakeSNI string) error {
|
||||
target, err := upstream.DialContext(ctx, "tcp", request.DestAddr.String())
|
||||
if err != nil {
|
||||
reply := statute.RepHostUnreachable
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "refused") {
|
||||
reply = statute.RepConnectionRefused
|
||||
} else if strings.Contains(msg, "network is unreachable") {
|
||||
reply = statute.RepNetworkUnreachable
|
||||
}
|
||||
if sendErr := socks5.SendReply(writer, reply, nil); sendErr != nil {
|
||||
return fmt.Errorf("failed to send reply, %v", sendErr)
|
||||
}
|
||||
return fmt.Errorf("connect to %v failed, %v", request.RawDestAddr, err)
|
||||
}
|
||||
|
||||
if err := socks5.SendReply(writer, statute.RepSuccess, target.LocalAddr()); err != nil {
|
||||
target.Close()
|
||||
return fmt.Errorf("failed to send reply, %v", err)
|
||||
}
|
||||
|
||||
clientConn, ok := writer.(net.Conn)
|
||||
if !ok {
|
||||
target.Close()
|
||||
return fmt.Errorf("writer is %T, expected net.Conn", writer)
|
||||
}
|
||||
|
||||
proxy(ctx, clientConn, request.Reader, target, fakeSNI)
|
||||
return nil
|
||||
}
|
||||
|
||||
type DummySocksResolver struct{}
|
||||
|
||||
func (_ DummySocksResolver) Resolve(ctx context.Context, name string) (context.Context, net.IP, error) {
|
||||
|
||||
@@ -0,0 +1,227 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
tlsRecordHeaderLen = 5
|
||||
tlsHandshakeHeaderLen = 4
|
||||
tlsRecordTypeHandshake = 0x16
|
||||
tlsHandshakeTypeClientHello = 0x01
|
||||
tlsExtensionServerName = 0x0000
|
||||
)
|
||||
|
||||
func copyWithSNIRewrite(dst io.Writer, src io.Reader, fakeSNI string) error {
|
||||
fakeSNI = strings.TrimSpace(fakeSNI)
|
||||
if fakeSNI == "" {
|
||||
_, err := io.Copy(dst, src)
|
||||
return err
|
||||
}
|
||||
|
||||
br, ok := src.(*bufio.Reader)
|
||||
if !ok {
|
||||
br = bufio.NewReader(src)
|
||||
}
|
||||
|
||||
header, err := br.Peek(tlsRecordHeaderLen)
|
||||
if err != nil {
|
||||
_, copyErr := io.Copy(dst, br)
|
||||
if copyErr != nil {
|
||||
return copyErr
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if !looksLikeTLSClientHelloRecord(header) {
|
||||
_, err = io.Copy(dst, br)
|
||||
return err
|
||||
}
|
||||
|
||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
record := make([]byte, tlsRecordHeaderLen+recordLen)
|
||||
n, err := io.ReadFull(br, record)
|
||||
if err != nil {
|
||||
if n > 0 {
|
||||
if _, writeErr := dst.Write(record[:n]); writeErr != nil {
|
||||
return writeErr
|
||||
}
|
||||
}
|
||||
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
_, copyErr := io.Copy(dst, br)
|
||||
return copyErr
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if rewritten, ok := rewriteTLSClientHelloRecordServerName(record, fakeSNI); ok {
|
||||
record = rewritten
|
||||
}
|
||||
|
||||
if _, err := dst.Write(record); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = io.Copy(dst, br)
|
||||
return err
|
||||
}
|
||||
|
||||
func looksLikeTLSClientHelloRecord(header []byte) bool {
|
||||
return len(header) >= tlsRecordHeaderLen &&
|
||||
header[0] == tlsRecordTypeHandshake &&
|
||||
header[1] == 0x03 &&
|
||||
header[2] <= 0x04
|
||||
}
|
||||
|
||||
func rewriteTLSClientHelloRecordServerName(record []byte, fakeSNI string) ([]byte, bool) {
|
||||
if len(record) < tlsRecordHeaderLen+tlsHandshakeHeaderLen {
|
||||
return nil, false
|
||||
}
|
||||
if !looksLikeTLSClientHelloRecord(record[:tlsRecordHeaderLen]) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
payload := record[tlsRecordHeaderLen:]
|
||||
if len(payload) < tlsHandshakeHeaderLen || payload[0] != tlsHandshakeTypeClientHello {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
handshakeLen := readUint24(payload[1:4])
|
||||
if handshakeLen > len(payload)-tlsHandshakeHeaderLen {
|
||||
// ClientHello is fragmented across multiple TLS records.
|
||||
return nil, false
|
||||
}
|
||||
|
||||
hello := payload[tlsHandshakeHeaderLen : tlsHandshakeHeaderLen+handshakeLen]
|
||||
offset := 0
|
||||
if !skipLen(hello, &offset, 2+32) {
|
||||
return nil, false
|
||||
}
|
||||
if !skipOpaque8(hello, &offset) {
|
||||
return nil, false
|
||||
}
|
||||
if !skipOpaque16(hello, &offset) {
|
||||
return nil, false
|
||||
}
|
||||
if !skipOpaque8(hello, &offset) {
|
||||
return nil, false
|
||||
}
|
||||
if offset == len(hello) {
|
||||
return nil, false
|
||||
}
|
||||
if offset+2 > len(hello) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
extensionsLenOffset := offset
|
||||
extensionsLen := int(binary.BigEndian.Uint16(hello[offset : offset+2]))
|
||||
offset += 2
|
||||
if offset+extensionsLen > len(hello) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
extensionsEnd := offset + extensionsLen
|
||||
for offset+4 <= extensionsEnd {
|
||||
extStart := offset
|
||||
extType := binary.BigEndian.Uint16(hello[offset : offset+2])
|
||||
extLen := int(binary.BigEndian.Uint16(hello[offset+2 : offset+4]))
|
||||
offset += 4
|
||||
if offset+extLen > extensionsEnd {
|
||||
return nil, false
|
||||
}
|
||||
if extType != tlsExtensionServerName {
|
||||
offset += extLen
|
||||
continue
|
||||
}
|
||||
|
||||
extDataStart := offset
|
||||
extDataEnd := offset + extLen
|
||||
extData := hello[extDataStart:extDataEnd]
|
||||
if len(extData) < 5 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
serverNameListLen := int(binary.BigEndian.Uint16(extData[:2]))
|
||||
if 2+serverNameListLen > len(extData) {
|
||||
return nil, false
|
||||
}
|
||||
if extData[2] != 0x00 {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
nameLen := int(binary.BigEndian.Uint16(extData[3:5]))
|
||||
if 5+nameLen > len(extData) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
tail := extData[5+nameLen:]
|
||||
newExtData := make([]byte, 2+1+2+len(fakeSNI)+len(tail))
|
||||
binary.BigEndian.PutUint16(newExtData[:2], uint16(1+2+len(fakeSNI)+len(tail)))
|
||||
newExtData[2] = 0x00
|
||||
binary.BigEndian.PutUint16(newExtData[3:5], uint16(len(fakeSNI)))
|
||||
copy(newExtData[5:], fakeSNI)
|
||||
copy(newExtData[5+len(fakeSNI):], tail)
|
||||
|
||||
helloStart := tlsRecordHeaderLen + tlsHandshakeHeaderLen
|
||||
extLenFieldStart := helloStart + extStart + 2
|
||||
extDataAbsStart := helloStart + extDataStart
|
||||
extDataAbsEnd := helloStart + extDataEnd
|
||||
extensionsLenFieldStart := helloStart + extensionsLenOffset
|
||||
|
||||
delta := len(newExtData) - len(extData)
|
||||
newRecord := make([]byte, 0, len(record)+delta)
|
||||
newRecord = append(newRecord, record[:extDataAbsStart]...)
|
||||
newRecord = append(newRecord, newExtData...)
|
||||
newRecord = append(newRecord, record[extDataAbsEnd:]...)
|
||||
|
||||
binary.BigEndian.PutUint16(newRecord[3:5], uint16(len(payload)+delta))
|
||||
writeUint24(newRecord[6:9], handshakeLen+delta)
|
||||
binary.BigEndian.PutUint16(newRecord[extensionsLenFieldStart:extensionsLenFieldStart+2], uint16(extensionsLen+delta))
|
||||
binary.BigEndian.PutUint16(newRecord[extLenFieldStart:extLenFieldStart+2], uint16(extLen+delta))
|
||||
|
||||
return newRecord, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func readUint24(b []byte) int {
|
||||
return int(b[0])<<16 | int(b[1])<<8 | int(b[2])
|
||||
}
|
||||
|
||||
func writeUint24(dst []byte, v int) {
|
||||
dst[0] = byte(v >> 16)
|
||||
dst[1] = byte(v >> 8)
|
||||
dst[2] = byte(v)
|
||||
}
|
||||
|
||||
func skipLen(b []byte, offset *int, n int) bool {
|
||||
if *offset+n > len(b) {
|
||||
return false
|
||||
}
|
||||
*offset += n
|
||||
return true
|
||||
}
|
||||
|
||||
func skipOpaque8(b []byte, offset *int) bool {
|
||||
if *offset >= len(b) {
|
||||
return false
|
||||
}
|
||||
l := int(b[*offset])
|
||||
*offset++
|
||||
return skipLen(b, offset, l)
|
||||
}
|
||||
|
||||
func skipOpaque16(b []byte, offset *int) bool {
|
||||
if *offset+2 > len(b) {
|
||||
return false
|
||||
}
|
||||
l := int(binary.BigEndian.Uint16(b[*offset : *offset+2]))
|
||||
*offset += 2
|
||||
return skipLen(b, offset, l)
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRewriteTLSClientHelloRecordServerName(t *testing.T) {
|
||||
record := buildClientHelloRecord("example.com")
|
||||
|
||||
rewritten, ok := rewriteTLSClientHelloRecordServerName(record, "fake.example")
|
||||
if !ok {
|
||||
t.Fatal("expected ClientHello record to be rewritten")
|
||||
}
|
||||
|
||||
if got := extractServerName(t, rewritten); got != "fake.example" {
|
||||
t.Fatalf("unexpected SNI after rewrite: got %q", got)
|
||||
}
|
||||
if got := int(binary.BigEndian.Uint16(rewritten[3:5])); got != len(rewritten)-tlsRecordHeaderLen {
|
||||
t.Fatalf("unexpected TLS record length: got %d want %d", got, len(rewritten)-tlsRecordHeaderLen)
|
||||
}
|
||||
if got := readUint24(rewritten[6:9]); got != len(rewritten)-tlsRecordHeaderLen-tlsHandshakeHeaderLen {
|
||||
t.Fatalf("unexpected handshake length: got %d want %d", got, len(rewritten)-tlsRecordHeaderLen-tlsHandshakeHeaderLen)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopyWithSNIRewritePreservesTrailingBytes(t *testing.T) {
|
||||
record := buildClientHelloRecord("example.com")
|
||||
stream := append(append([]byte{}, record...), []byte("tail")...)
|
||||
|
||||
var dst bytes.Buffer
|
||||
if err := copyWithSNIRewrite(&dst, bytes.NewReader(stream), "fake.example"); err != nil {
|
||||
t.Fatalf("copyWithSNIRewrite returned error: %v", err)
|
||||
}
|
||||
|
||||
out := dst.Bytes()
|
||||
recordLen := int(binary.BigEndian.Uint16(out[3:5])) + tlsRecordHeaderLen
|
||||
if got := extractServerName(t, out[:recordLen]); got != "fake.example" {
|
||||
t.Fatalf("unexpected SNI in output stream: got %q", got)
|
||||
}
|
||||
if tail := string(out[recordLen:]); tail != "tail" {
|
||||
t.Fatalf("unexpected trailing bytes: got %q", tail)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRewriteTLSClientHelloRecordServerNameSkipsFragmentedHello(t *testing.T) {
|
||||
record := buildClientHelloRecord("example.com")
|
||||
record[6] = 0x7f
|
||||
record[7] = 0xff
|
||||
record[8] = 0xff
|
||||
|
||||
if _, ok := rewriteTLSClientHelloRecordServerName(record, "fake.example"); ok {
|
||||
t.Fatal("expected fragmented ClientHello rewrite to be skipped")
|
||||
}
|
||||
}
|
||||
|
||||
func buildClientHelloRecord(serverName string) []byte {
|
||||
sniExtData := make([]byte, 2+1+2+len(serverName))
|
||||
binary.BigEndian.PutUint16(sniExtData[:2], uint16(1+2+len(serverName)))
|
||||
sniExtData[2] = 0x00
|
||||
binary.BigEndian.PutUint16(sniExtData[3:5], uint16(len(serverName)))
|
||||
copy(sniExtData[5:], serverName)
|
||||
|
||||
sniExt := makeExtension(tlsExtensionServerName, sniExtData)
|
||||
otherExt := makeExtension(0x002b, []byte{0x02, 0x03, 0x04})
|
||||
extensions := append(sniExt, otherExt...)
|
||||
|
||||
hello := make([]byte, 0, 128)
|
||||
hello = append(hello, 0x03, 0x03)
|
||||
hello = append(hello, bytes.Repeat([]byte{0x11}, 32)...)
|
||||
hello = append(hello, 0x00)
|
||||
hello = append(hello, 0x00, 0x02, 0x13, 0x01)
|
||||
hello = append(hello, 0x01, 0x00)
|
||||
hello = append(hello, byte(len(extensions)>>8), byte(len(extensions)))
|
||||
hello = append(hello, extensions...)
|
||||
|
||||
record := make([]byte, 0, tlsRecordHeaderLen+tlsHandshakeHeaderLen+len(hello))
|
||||
record = append(record, tlsRecordTypeHandshake, 0x03, 0x03)
|
||||
recordLen := tlsHandshakeHeaderLen + len(hello)
|
||||
record = append(record, byte(recordLen>>8), byte(recordLen))
|
||||
record = append(record, tlsHandshakeTypeClientHello)
|
||||
record = append(record, byte(len(hello)>>16), byte(len(hello)>>8), byte(len(hello)))
|
||||
record = append(record, hello...)
|
||||
return record
|
||||
}
|
||||
|
||||
func makeExtension(extType uint16, data []byte) []byte {
|
||||
ext := make([]byte, 4+len(data))
|
||||
binary.BigEndian.PutUint16(ext[:2], extType)
|
||||
binary.BigEndian.PutUint16(ext[2:4], uint16(len(data)))
|
||||
copy(ext[4:], data)
|
||||
return ext
|
||||
}
|
||||
|
||||
func extractServerName(t *testing.T, record []byte) string {
|
||||
t.Helper()
|
||||
|
||||
payload := record[tlsRecordHeaderLen:]
|
||||
handshakeLen := readUint24(payload[1:4])
|
||||
hello := payload[tlsHandshakeHeaderLen : tlsHandshakeHeaderLen+handshakeLen]
|
||||
|
||||
offset := 2 + 32
|
||||
sessionIDLen := int(hello[offset])
|
||||
offset++
|
||||
offset += sessionIDLen
|
||||
|
||||
cipherSuitesLen := int(binary.BigEndian.Uint16(hello[offset : offset+2]))
|
||||
offset += 2 + cipherSuitesLen
|
||||
|
||||
compressionMethodsLen := int(hello[offset])
|
||||
offset++
|
||||
offset += compressionMethodsLen
|
||||
|
||||
extensionsLen := int(binary.BigEndian.Uint16(hello[offset : offset+2]))
|
||||
offset += 2
|
||||
extensionsEnd := offset + extensionsLen
|
||||
for offset+4 <= extensionsEnd {
|
||||
extType := binary.BigEndian.Uint16(hello[offset : offset+2])
|
||||
extLen := int(binary.BigEndian.Uint16(hello[offset+2 : offset+4]))
|
||||
offset += 4
|
||||
if extType != tlsExtensionServerName {
|
||||
offset += extLen
|
||||
continue
|
||||
}
|
||||
extData := hello[offset : offset+extLen]
|
||||
nameLen := int(binary.BigEndian.Uint16(extData[3:5]))
|
||||
return string(extData[5 : 5+nameLen])
|
||||
}
|
||||
|
||||
t.Fatal("server_name extension not found")
|
||||
return ""
|
||||
}
|
||||
Reference in New Issue
Block a user