在阅读这篇文章前,推荐先阅读以下内容:
WebSocketClientHandshakerFactory
WebSocketClientHandshakerFactory
是用于根据 URI 和协议版本创建对应 WebSocket 握手器(Handshaker)的工厂类,简化客户端握手流程。
public final class WebSocketClientHandshakerFactory {
private WebSocketClientHandshakerFactory() {}
// ...
// new WebSocketClientProtocolHandler(config)
public static WebSocketClientHandshaker newHandshaker(
URI webSocketURL, WebSocketVersion version, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis,
boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
return new WebSocketClientHandshaker13(
webSocketURL, subprotocol, allowExtensions, customHeaders,
maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis,
absoluteUpgradeUrl, generateOriginHeader);
}
}
WebSocketClientHandshaker13
WebSocketClientHandshaker13
是实现 WebSocket 协议 RFC 6455(版本13)的客户端握手器,负责构造握手请求、验证响应并完成协议升级。
public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {
private final boolean allowExtensions;
private final boolean performMasking;
private final boolean allowMaskMismatch;
private volatile String sentNonce;
// WebSocketClientHandshakerFactory.newHandshaker
WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
boolean performMasking, boolean allowMaskMismatch,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl,
boolean generateOriginHeader
) {
super(webSocketURL, WebSocketVersion.V13,
subprotocol, customHeaders, maxFramePayloadLength,
forceCloseTimeoutMillis, absoluteUpgradeUrl, generateOriginHeader
);
this.allowExtensions = allowExtensions;
this.performMasking = performMasking;
this.allowMaskMismatch = allowMaskMismatch;
}
/**
* /**
* <p>
* Sends the opening request to the server:
* </p>
*
* <pre>
* GET /chat HTTP/1.1
* Host: server.example.com
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
* Sec-WebSocket-Protocol: chat, superchat
* Sec-WebSocket-Version: 13
* </pre>
*
*/
@Override
protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
URI wsURL = uri();
FullHttpRequest request = new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
HttpMethod.GET, upgradeUrl(wsURL),
allocator.allocate(0)
);
HttpHeaders headers = request.headers();
if (customHeaders != null) {
headers.add(customHeaders);
if (!headers.contains(HttpHeaderNames.HOST)) {
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}
} else {
headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
}
String nonce = createNonce();
headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce);
if (generateOriginHeader && !headers.contains(HttpHeaderNames.ORIGIN)) {
headers.set(HttpHeaderNames.ORIGIN, websocketHostValue(wsURL));
}
sentNonce = nonce;
String expectedSubprotocol = expectedSubprotocol();
if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) {
headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
}
headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, version().toAsciiString());
return request;
}
/**
* <p>
* Process server response:
* </p>
*
* <pre>
* HTTP/1.1 101 Switching Protocols
* Upgrade: websocket
* Connection: Upgrade
* Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
* Sec-WebSocket-Protocol: chat
* </pre>
*
* @param response
* HTTP response returned from the server for the request sent by beginOpeningHandshake00().
* @throws WebSocketHandshakeException if handshake response is invalid.
*/
@Override
protected void verify(FullHttpResponse response) {
HttpResponseStatus status = response.status();
if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
throw new WebSocketClientHandshakeException("Invalid handshake response status: " + status, response);
}
HttpHeaders headers = response.headers();
CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE);
if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) {
throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response);
}
if (!headers.containsIgnoreCase(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)) {
throw new WebSocketClientHandshakeException("Invalid handshake response connection: " + headers.get(HttpHeaderNames.CONNECTION), response);
}
CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
if (accept == null) {
throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: null", response);
}
String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce);
if (!AsciiString.contentEquals(expectedAccept, AsciiString.trim(accept))) {
throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: " + accept + ", expected: " + expectedAccept, response);
}
}
@Override
protected WebSocketFrameDecoder newWebsocketDecoder() {
return new WebSocket13FrameDecoder(false, allowExtensions, maxFramePayloadLength(), allowMaskMismatch);
}
@Override
protected WebSocketFrameEncoder newWebSocketEncoder() {
return new WebSocket13FrameEncoder(performMasking);
}
@Override
public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis);
return this;
}
// 生成一个符合 WebSocket 协议要求的 16 字节 Base64 编码的随机值,用作 Sec-WebSocket-Key
private static String createNonce() {
var nonce = WebSocketUtil.randomBytes(16);
return WebSocketUtil.base64(nonce);
}
}
WebSocketClientHandshaker
public abstract class WebSocketClientHandshaker {
protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
// 代表握手时的目标地址, 例如 ws://example.com/chat
private final URI uri;
// 控制握手请求和数据帧的格式, 比如 RFC 6455 标准版本
private final WebSocketVersion version;
// 标记握手是否完成,volatile 保证多线程访问时的可见性
private volatile boolean handshakeComplete;
// 握手完成后,如果关闭 WebSocket 连接时等待超时,会触发强制关闭。
private volatile long forceCloseTimeoutMillis;
// 用于标记强制关闭流程是否初始化, 通过 AtomicIntegerFieldUpdater 原子更新
private volatile int forceCloseInit;
private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER = AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");
// 标记强制关闭流程是否完成。
private volatile boolean forceCloseComplete;
// 握手时客户端希望协商的子协议(Subprotocol), 例如视频、聊天子协议名称等
private final String expectedSubprotocol;
// 握手后服务器协商确认的子协议,握手成功后才有值。
private volatile String actualSubprotocol;
// 握手请求时使用,方便传递用户自定义信息。
protected final HttpHeaders customHeaders;
// 最大单个 WebSocket 帧负载长度限制, 防止收到超大数据导致内存溢出。
private final int maxFramePayloadLength;
// 是否在握手请求中使用绝对 URI 作为 Upgrade URL, 一般用于特殊代理或协议场景
private final boolean absoluteUpgradeUrl;
// 是否自动生成 Origin 请求头
protected final boolean generateOriginHeader;
protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
HttpHeaders customHeaders, int maxFramePayloadLength,
long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
this.uri = uri;
this.version = version;
expectedSubprotocol = subprotocol;
this.customHeaders = customHeaders;
this.maxFramePayloadLength = maxFramePayloadLength;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
this.absoluteUpgradeUrl = absoluteUpgradeUrl;
this.generateOriginHeader = generateOriginHeader;
}
// WebSocketClientProtocolHandshakeHandler.channelActive
public Future<Void> handshake(Channel channel) {
requireNonNull(channel, "channel");
ChannelPipeline pipeline = channel.pipeline();
// 检查管道中解码器
HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
if (decoder == null) {
HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
if (codec == null) {
return channel.newFailedFuture(new IllegalStateException("ChannelPipeline does not contain " + "an HttpResponseDecoder or HttpClientCodec"));
}
}
// 检查 URI 和 Header 相关的 Host 与 Origin
if (uri.getHost() == null) {
if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) {
return channel.newFailedFuture(new IllegalArgumentException("Cannot generate the 'host' header value," + " webSocketURI should contain host or passed through customHeaders"));
}
if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) {
final String originName = HttpHeaderNames.ORIGIN.toString();
return channel.newFailedFuture(
new IllegalArgumentException("Cannot generate the '" + originName + "' header" + " value, webSocketURI should contain host or disable generateOriginHeader or pass value" + " through customHeaders"));
}
}
// 创建握手请求
FullHttpRequest request = newHandshakeRequest(channel.bufferAllocator());
// 创建 Promise,异步写出请求
Promise<Void> promise = channel.newPromise();
channel.writeAndFlush(request).addListener(channel, (ch, future) -> {
// 如果写操作成功
if (future.isSuccess()) {
ChannelPipeline p = ch.pipeline();
//找出管道中 HTTP 请求编码器 HttpRequestEncoder 或者 HttpClientCodec,
ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
if (ctx == null) {
ctx = p.context(HttpClientCodec.class);
}
if (ctx == null) {
promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + "an HttpRequestEncoder or HttpClientCodec"));
return;
}
// 然后在其后面动态添加 WebSocket 专用的编码器 ws-encoder(由 newWebSocketEncoder() 创建)
p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());
promise.setSuccess(null);
} else {
promise.setFailure(future.cause());
}
});
return promise.asFuture();
}
// WebSocketClientProtocolHandshakeHandler.channelRead
public final void finishHandshake(Channel channel, FullHttpResponse response) {
verify(response);
// 服务器返回的子协议
CharSequence receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
receivedProtocol = receivedProtocol != null ? AsciiString.trim(receivedProtocol) : null;
// 客户端期望的子协议
String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
boolean protocolValid = false;
// 如果客户端没指定预期协议,且服务器也没返回协议,视为通过。
if (expectedProtocol.isEmpty() && receivedProtocol == null) {
protocolValid = true;
setActualSubprotocol(expectedSubprotocol);
} else if (!expectedProtocol.isEmpty() && receivedProtocol != null && receivedProtocol.length() > 0) {
// 如果客户端有期望协议且服务器返回了协议,则判断服务器返回的协议是否在客户端允许的列表中
for (String protocol : expectedProtocol.split(",")) {
if (AsciiString.contentEquals(protocol.trim(), receivedProtocol)) {
protocolValid = true;
setActualSubprotocol(receivedProtocol.toString());
break;
}
}
}
// 如果子协议校验失败,抛出握手异常。
if (!protocolValid) {
throw new WebSocketClientHandshakeException(String.format(
"Invalid subprotocol. Actual: %s. Expected one of: %s",
receivedProtocol, expectedSubprotocol), response);
}
// 标记握手完成。
setHandshakeComplete();
final ChannelPipeline p = channel.pipeline();
// 移除 HTTP 消息解压处理器(如 gzip 解压),以及 HTTP 聚合器,WebSocket 不需要这些
HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
if (decompressor != null) {
p.remove(decompressor);
}
HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
if (aggregator != null) {
p.remove(aggregator);
}
// 查找 HTTP 解码器上下文:
// 1. 若是 HttpClientCodec,先调用 removeOutboundHandler(),然后添加 WebSocket 解码器,最后异步移除 HTTP Codec。
// 2. 若是单独的 HttpResponseDecoder,先移除对应的请求编码器,再添加 WebSocket 解码器,异步移除响应解码器。
// 新加入的 ws-decoder 是 WebSocket 的解码器,处理 WebSocket 帧。
ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
if (ctx == null) {
ctx = p.context(HttpClientCodec.class);
if (ctx == null) {
throw new IllegalStateException("ChannelPipeline does not contain " +
"an HttpRequestEncoder or HttpClientCodec");
}
final HttpClientCodec codec = (HttpClientCodec) ctx.handler();
codec.removeOutboundHandler();
p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
channel.executor().execute(() -> p.remove(codec));
} else {
if (p.get(HttpRequestEncoder.class) != null) {
p.remove(HttpRequestEncoder.class);
}
final ChannelHandlerContext context = ctx;
p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
channel.executor().execute(() -> p.remove(context.handler()));
}
}
// ...
protected abstract FullHttpRequest newHandshakeRequest(BufferAllocator allocator);
protected abstract void verify(FullHttpResponse response);
protected abstract WebSocketFrameDecoder newWebsocketDecoder();
protected abstract WebSocketFrameEncoder newWebSocketEncoder();
}