【P2P】【Go】采用go语言实现udp hole punching 打洞 传输速度测试 ping测试

发布于:2024-12-19 ⋅ 阅读:(8) ⋅ 点赞:(0)

服务器端 udpserver/main.go

package main

import (
	"fmt"
	"net"
	"sync"
	"sync/atomic"
)

var (
	clientCounter uint64 = 0 // 客户端连接计数器
	mu            sync.Mutex
)

func main() {
	addr, err := net.ResolveUDPAddr("udp", ":3478")
	if err != nil {
		fmt.Println("Error resolving UDP address:", err)
		return
	}

	conn, err := net.ListenUDP("udp", addr)
	if err != nil {
		fmt.Println("Error listening on UDP port:", err)
		return
	}
	defer conn.Close()

	fmt.Println("UDP server is running on :3478")

	clientMap := make(map[string]int) // 存储客户端地址和ID的映射

	for {
		buffer := make([]byte, 1024)
		n, remoteAddr, err := conn.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading from UDP:", err)
			continue
		}

		clientKey := remoteAddr.String()
		fmt.Printf("Received from client %s: %s\n", clientKey, string(buffer[:n]))

		mu.Lock()
		clientID, exists := clientMap[clientKey]
		if !exists {
			clientID = int(atomic.AddUint64(&clientCounter, 1))
			clientMap[clientKey] = clientID
			fmt.Printf("New client joined with ID %d\n", clientID)

			// 如果已经有两个客户端,则向新加入的客户端发送另一个客户端的信息
			if len(clientMap) == 2 {
				for key, otherID := range clientMap {
					if key != clientKey {
						otherAddr, _ := net.ResolveUDPAddr("udp", key)
						peerAddrStr := fmt.Sprintf("1@%s:%d", remoteAddr.IP.To4().String(), remoteAddr.Port)
						_, err := conn.WriteToUDP([]byte(peerAddrStr), otherAddr)
						if err != nil {
							fmt.Println("Error sending peer address to existing client:", err)
						} else {
							fmt.Printf("Sent peer address %s to existing client %d\n", peerAddrStr, otherID)
						}

						otherPeerAddrStr := fmt.Sprintf("2@%s:%d", otherAddr.IP.To4().String(), otherAddr.Port)
						_, err = conn.WriteToUDP([]byte(otherPeerAddrStr), remoteAddr)
						if err != nil {
							fmt.Println("Error sending peer address to new client:", err)
						} else {
							fmt.Printf("Sent peer address %s to new client %d\n", otherPeerAddrStr, clientID)
						}
						break
					}
				}
				// 清理clientMap,从新存入开始两个客户端
				for k := range clientMap {
					delete(clientMap, k)
				}
			}
		}
		mu.Unlock()

		// 直接处理已知客户端的消息
		if exists {
			fmt.Printf("Message from known client %d (%s): %s\n", clientID, clientKey, string(buffer[:n]))
		}
	}
}

客户端 udpclient/main.go

package main

import (
	"flag"
	"fmt"
	"math/rand"
	"net"
	"strings"
	"time"
)

const (
	totalDataSize  = 10 * 1024 * 1024 // 10 MB
	chunkSize      = 50 * 1024        //	50 KB
	pingPacketSize = 64               // 小块UDP数据包大小,例如64字节
)

var (
	dataChunk  = make([]byte, chunkSize)
	pingPacket = make([]byte, pingPacketSize)
)

func init() {
	rand.Seed(time.Now().UnixNano())
	for i := range dataChunk {
		dataChunk[i] = byte(rand.Intn(256))
	}
	for i := range pingPacket {
		pingPacket[i] = byte(rand.Intn(256))
	}
}

func main() {
	serverAddrPtr := flag.String("server", "127.0.0.1:3478", "STUN server address (IP:port)")
	flag.Parse()

	fmt.Printf("Using STUN server address: %s\n", *serverAddrPtr)

	udpAddr, err := net.ResolveUDPAddr("udp", *serverAddrPtr)
	if err != nil {
		fmt.Println("Error resolving UDP address:", err)
		return
	}

	localAddr, err := net.ResolveUDPAddr("udp", ":0")
	if err != nil {
		fmt.Println("Error resolving local UDP address:", err)
		return
	}

	listener, err := net.ListenUDP("udp", localAddr)
	if err != nil {
		fmt.Println("Error listening on UDP:", err)
		return
	}
	defer listener.Close()
	fmt.Printf("Local listener created at %s\n", listener.LocalAddr())

	// 发送消息给 STUN 服务器获取公共地址
	publicAddr, err := getPublicAddr(listener, udpAddr)
	if err != nil {
		fmt.Println("Error getting public address:", err)
		return
	}
	fmt.Printf("Public address is %s\n", publicAddr)

	// publicAddr 是一个字符串,格式未id@ip:port,需要解析为 net.UDPAddr并解析出id号
	id := publicAddr[:strings.Index(publicAddr, "@")]
	publicAddr = publicAddr[strings.Index(publicAddr, "@")+1:]

	// 使用公共地址与对端建立直接连接
	peerAddr, err := establishConnection(listener, publicAddr)
	if err != nil {
		fmt.Println("Error establishing connection:", err)
		return
	}

	// 根据连接到服务器的顺序确定客户端和服务端角色, id ==1 则先做服务器
	isServer := id == "1"

	if isServer {
		runServer(listener, peerAddr)
	} else {
		runClient(listener, peerAddr)
	}

	// 角色互换再次运行
	if isServer {
		runClient(listener, peerAddr)
	} else {
		runServer(listener, peerAddr)
	}

	// 测试完成,开始ping功能
	runPingTest(listener, peerAddr)
}

func getPublicAddr(listener *net.UDPConn, stunAddr *net.UDPAddr) (string, error) {
	buffer := make([]byte, 1024)
	_, err := listener.WriteToUDP([]byte("Binding Request"), stunAddr)
	if err != nil {
		return "", err
	}

	n, addr, err := listener.ReadFromUDP(buffer)
	if err != nil || addr.String() != stunAddr.String() {
		return "", fmt.Errorf("failed to receive response from STUN server")
	}

	// 假设 STUN 服务器返回的是公共地址字符串(实际情况需要解析 STUN 消息)
	publicAddrStr := string(buffer[:n])
	return publicAddrStr, nil
}

func establishConnection(listener *net.UDPConn, publicAddr string) (*net.UDPAddr, error) {
	connEstablished := make(chan struct{})
	var peerAddr *net.UDPAddr
	go func() {
		buffer := make([]byte, 1024)
		for {
			n, addr, err := listener.ReadFromUDP(buffer)
			if err != nil {
				fmt.Println("Error reading from peer:", err)
				continue
			}
			message := string(buffer[:n])
			fmt.Printf("Received from %s: %s\n", addr, message)
			if message == "Connection established" {
				peerAddr = addr
				close(connEstablished)
				break
			}
		}
	}()

	// 模拟发送“Connection established”以确认连接建立
	time.Sleep(2 * time.Second) // 等待一段时间让对方先尝试打洞
	addr, err := net.ResolveUDPAddr("udp", publicAddr)
	if err != nil {
		fmt.Println("Error resolving UDP address:", err)
		return addr, err
	}
	listener.WriteToUDP([]byte("Connection established"), addr)

	<-connEstablished // 等待连接建立成
	return peerAddr, nil
}

func runServer(listener *net.UDPConn, peerAddr *net.UDPAddr) {
	fmt.Println("Running as server...")
	buffer := make([]byte, chunkSize)
	receivedChunks := 0
	startTime := time.Now()

	for receivedChunks < totalDataSize/chunkSize {
		n, addr, err := listener.ReadFromUDP(buffer)
		if err != nil {
			fmt.Println("Error reading from peer:", err)
			continue
		}
		if addr.String() == peerAddr.String() && n == chunkSize {
			receivedChunks++
			fmt.Printf("Received chunk %d/%d\n", receivedChunks, totalDataSize/chunkSize)
			// 模拟确认收到数据
			listener.WriteToUDP([]byte("ACK"), peerAddr)
		}
	}

	duration := time.Since(startTime)
	speed := float64(totalDataSize) / duration.Seconds() / (1024 * 1024) // MB/s
	fmt.Printf("Server received all data in %.2f seconds, speed: %.2f MB/s\n", duration.Seconds(), speed)
}

func runClient(listener *net.UDPConn, peerAddr *net.UDPAddr) {
	fmt.Println("Running as client...")
	sentChunks := 0
	startTime := time.Now()

	for sentChunks < totalDataSize/chunkSize {
		_, err := listener.WriteToUDP(dataChunk, peerAddr)
		if err != nil {
			fmt.Println("Error sending data chunk:", err)
			continue
		}
		sentChunks++

		// 等待接收端确认
		buffer := make([]byte, 3)
		n, addr, err := listener.ReadFromUDP(buffer)
		if err != nil || addr.String() != peerAddr.String() || string(buffer[:n]) != "ACK" {
			fmt.Println("Failed to receive ACK, resending chunk...")
			sentChunks--            // 减少计数以便重发
			time.Sleep(time.Second) // 等待一段时间再重发
			continue
		}
		fmt.Printf("Sent chunk %d/%d and received ACK\n", sentChunks, totalDataSize/chunkSize)
	}

	duration := time.Since(startTime)
	speed := float64(totalDataSize) / duration.Seconds() / (1024 * 1024) // MB/s
	fmt.Printf("Client sent all data in %.2f seconds, speed: %.2f MB/s\n", duration.Seconds(), speed)
}

func runPingTest(listener *net.UDPConn, peerAddr *net.UDPAddr) {
	fmt.Println("Starting ping test...")

	const numPings = 10
	pingTimes := make([]time.Duration, numPings)

	for i := 0; i < numPings; i++ {
		startTime := time.Now()
		listener.WriteToUDP(pingPacket, peerAddr)
		fmt.Printf("Sent ping packet %d/%d\n", i+1, numPings)

		buffer := make([]byte, pingPacketSize)
		n, addr, err := listener.ReadFromUDP(buffer)
		if err != nil || addr.String() != peerAddr.String() || n != pingPacketSize {
			fmt.Println("Failed to receive pong, skipping this ping...")
			continue
		}

		elapsed := time.Since(startTime)
		pingTimes[i] = elapsed
		fmt.Printf("Received pong after %.2f ms\n", float64(elapsed)/float64(time.Millisecond))
	}

	// 计算平均延时
	var totalDelay time.Duration
	for _, t := range pingTimes {
		totalDelay += t
	}
	averageDelay := totalDelay / time.Duration(numPings)
	fmt.Printf("Average ping delay over %d pings: %.2f ms\n", numPings, float64(averageDelay)/float64(time.Millisecond))
}

本机测试

由于数据块设置的比较小(MTU限制),且每次发送都等待了ACK,导致速度不高。后续可以考虑基于UDT协议来优化。


网站公告

今日签到

点亮在社区的每一天
去签到