服务器端 udpserver/main.go
package main
import (
"fmt"
"net"
"sync"
"sync/atomic"
)
var (
clientCounter uint64 = 0 // 客户端连接计数器
mu sync.Mutex
)
func main() {
addr, err := net.ResolveUDPAddr("udp", ":3478")
if err != nil {
fmt.Println("Error resolving UDP address:", err)
return
}
conn, err := net.ListenUDP("udp", addr)
if err != nil {
fmt.Println("Error listening on UDP port:", err)
return
}
defer conn.Close()
fmt.Println("UDP server is running on :3478")
clientMap := make(map[string]int) // 存储客户端地址和ID的映射
for {
buffer := make([]byte, 1024)
n, remoteAddr, err := conn.ReadFromUDP(buffer)
if err != nil {
fmt.Println("Error reading from UDP:", err)
continue
}
clientKey := remoteAddr.String()
fmt.Printf("Received from client %s: %s\n", clientKey, string(buffer[:n]))
mu.Lock()
clientID, exists := clientMap[clientKey]
if !exists {
clientID = int(atomic.AddUint64(&clientCounter, 1))
clientMap[clientKey] = clientID
fmt.Printf("New client joined with ID %d\n", clientID)
// 如果已经有两个客户端,则向新加入的客户端发送另一个客户端的信息
if len(clientMap) == 2 {
for key, otherID := range clientMap {
if key != clientKey {
otherAddr, _ := net.ResolveUDPAddr("udp", key)
peerAddrStr := fmt.Sprintf("1@%s:%d", remoteAddr.IP.To4().String(), remoteAddr.Port)
_, err := conn.WriteToUDP([]byte(peerAddrStr), otherAddr)
if err != nil {
fmt.Println("Error sending peer address to existing client:", err)
} else {
fmt.Printf("Sent peer address %s to existing client %d\n", peerAddrStr, otherID)
}
otherPeerAddrStr := fmt.Sprintf("2@%s:%d", otherAddr.IP.To4().String(), otherAddr.Port)
_, err = conn.WriteToUDP([]byte(otherPeerAddrStr), remoteAddr)
if err != nil {
fmt.Println("Error sending peer address to new client:", err)
} else {
fmt.Printf("Sent peer address %s to new client %d\n", otherPeerAddrStr, clientID)
}
break
}
}
// 清理clientMap,从新存入开始两个客户端
for k := range clientMap {
delete(clientMap, k)
}
}
}
mu.Unlock()
// 直接处理已知客户端的消息
if exists {
fmt.Printf("Message from known client %d (%s): %s\n", clientID, clientKey, string(buffer[:n]))
}
}
}
客户端 udpclient/main.go
package main
import (
"flag"
"fmt"
"math/rand"
"net"
"strings"
"time"
)
const (
totalDataSize = 10 * 1024 * 1024 // 10 MB
chunkSize = 50 * 1024 // 50 KB
pingPacketSize = 64 // 小块UDP数据包大小,例如64字节
)
var (
dataChunk = make([]byte, chunkSize)
pingPacket = make([]byte, pingPacketSize)
)
func init() {
rand.Seed(time.Now().UnixNano())
for i := range dataChunk {
dataChunk[i] = byte(rand.Intn(256))
}
for i := range pingPacket {
pingPacket[i] = byte(rand.Intn(256))
}
}
func main() {
serverAddrPtr := flag.String("server", "127.0.0.1:3478", "STUN server address (IP:port)")
flag.Parse()
fmt.Printf("Using STUN server address: %s\n", *serverAddrPtr)
udpAddr, err := net.ResolveUDPAddr("udp", *serverAddrPtr)
if err != nil {
fmt.Println("Error resolving UDP address:", err)
return
}
localAddr, err := net.ResolveUDPAddr("udp", ":0")
if err != nil {
fmt.Println("Error resolving local UDP address:", err)
return
}
listener, err := net.ListenUDP("udp", localAddr)
if err != nil {
fmt.Println("Error listening on UDP:", err)
return
}
defer listener.Close()
fmt.Printf("Local listener created at %s\n", listener.LocalAddr())
// 发送消息给 STUN 服务器获取公共地址
publicAddr, err := getPublicAddr(listener, udpAddr)
if err != nil {
fmt.Println("Error getting public address:", err)
return
}
fmt.Printf("Public address is %s\n", publicAddr)
// publicAddr 是一个字符串,格式未id@ip:port,需要解析为 net.UDPAddr并解析出id号
id := publicAddr[:strings.Index(publicAddr, "@")]
publicAddr = publicAddr[strings.Index(publicAddr, "@")+1:]
// 使用公共地址与对端建立直接连接
peerAddr, err := establishConnection(listener, publicAddr)
if err != nil {
fmt.Println("Error establishing connection:", err)
return
}
// 根据连接到服务器的顺序确定客户端和服务端角色, id ==1 则先做服务器
isServer := id == "1"
if isServer {
runServer(listener, peerAddr)
} else {
runClient(listener, peerAddr)
}
// 角色互换再次运行
if isServer {
runClient(listener, peerAddr)
} else {
runServer(listener, peerAddr)
}
// 测试完成,开始ping功能
runPingTest(listener, peerAddr)
}
func getPublicAddr(listener *net.UDPConn, stunAddr *net.UDPAddr) (string, error) {
buffer := make([]byte, 1024)
_, err := listener.WriteToUDP([]byte("Binding Request"), stunAddr)
if err != nil {
return "", err
}
n, addr, err := listener.ReadFromUDP(buffer)
if err != nil || addr.String() != stunAddr.String() {
return "", fmt.Errorf("failed to receive response from STUN server")
}
// 假设 STUN 服务器返回的是公共地址字符串(实际情况需要解析 STUN 消息)
publicAddrStr := string(buffer[:n])
return publicAddrStr, nil
}
func establishConnection(listener *net.UDPConn, publicAddr string) (*net.UDPAddr, error) {
connEstablished := make(chan struct{})
var peerAddr *net.UDPAddr
go func() {
buffer := make([]byte, 1024)
for {
n, addr, err := listener.ReadFromUDP(buffer)
if err != nil {
fmt.Println("Error reading from peer:", err)
continue
}
message := string(buffer[:n])
fmt.Printf("Received from %s: %s\n", addr, message)
if message == "Connection established" {
peerAddr = addr
close(connEstablished)
break
}
}
}()
// 模拟发送“Connection established”以确认连接建立
time.Sleep(2 * time.Second) // 等待一段时间让对方先尝试打洞
addr, err := net.ResolveUDPAddr("udp", publicAddr)
if err != nil {
fmt.Println("Error resolving UDP address:", err)
return addr, err
}
listener.WriteToUDP([]byte("Connection established"), addr)
<-connEstablished // 等待连接建立成
return peerAddr, nil
}
func runServer(listener *net.UDPConn, peerAddr *net.UDPAddr) {
fmt.Println("Running as server...")
buffer := make([]byte, chunkSize)
receivedChunks := 0
startTime := time.Now()
for receivedChunks < totalDataSize/chunkSize {
n, addr, err := listener.ReadFromUDP(buffer)
if err != nil {
fmt.Println("Error reading from peer:", err)
continue
}
if addr.String() == peerAddr.String() && n == chunkSize {
receivedChunks++
fmt.Printf("Received chunk %d/%d\n", receivedChunks, totalDataSize/chunkSize)
// 模拟确认收到数据
listener.WriteToUDP([]byte("ACK"), peerAddr)
}
}
duration := time.Since(startTime)
speed := float64(totalDataSize) / duration.Seconds() / (1024 * 1024) // MB/s
fmt.Printf("Server received all data in %.2f seconds, speed: %.2f MB/s\n", duration.Seconds(), speed)
}
func runClient(listener *net.UDPConn, peerAddr *net.UDPAddr) {
fmt.Println("Running as client...")
sentChunks := 0
startTime := time.Now()
for sentChunks < totalDataSize/chunkSize {
_, err := listener.WriteToUDP(dataChunk, peerAddr)
if err != nil {
fmt.Println("Error sending data chunk:", err)
continue
}
sentChunks++
// 等待接收端确认
buffer := make([]byte, 3)
n, addr, err := listener.ReadFromUDP(buffer)
if err != nil || addr.String() != peerAddr.String() || string(buffer[:n]) != "ACK" {
fmt.Println("Failed to receive ACK, resending chunk...")
sentChunks-- // 减少计数以便重发
time.Sleep(time.Second) // 等待一段时间再重发
continue
}
fmt.Printf("Sent chunk %d/%d and received ACK\n", sentChunks, totalDataSize/chunkSize)
}
duration := time.Since(startTime)
speed := float64(totalDataSize) / duration.Seconds() / (1024 * 1024) // MB/s
fmt.Printf("Client sent all data in %.2f seconds, speed: %.2f MB/s\n", duration.Seconds(), speed)
}
func runPingTest(listener *net.UDPConn, peerAddr *net.UDPAddr) {
fmt.Println("Starting ping test...")
const numPings = 10
pingTimes := make([]time.Duration, numPings)
for i := 0; i < numPings; i++ {
startTime := time.Now()
listener.WriteToUDP(pingPacket, peerAddr)
fmt.Printf("Sent ping packet %d/%d\n", i+1, numPings)
buffer := make([]byte, pingPacketSize)
n, addr, err := listener.ReadFromUDP(buffer)
if err != nil || addr.String() != peerAddr.String() || n != pingPacketSize {
fmt.Println("Failed to receive pong, skipping this ping...")
continue
}
elapsed := time.Since(startTime)
pingTimes[i] = elapsed
fmt.Printf("Received pong after %.2f ms\n", float64(elapsed)/float64(time.Millisecond))
}
// 计算平均延时
var totalDelay time.Duration
for _, t := range pingTimes {
totalDelay += t
}
averageDelay := totalDelay / time.Duration(numPings)
fmt.Printf("Average ping delay over %d pings: %.2f ms\n", numPings, float64(averageDelay)/float64(time.Millisecond))
}
本机测试
由于数据块设置的比较小(MTU限制),且每次发送都等待了ACK,导致速度不高。后续可以考虑基于UDT协议来优化。