在阅读这篇文章前,推荐先阅读:[netty5: MessageToMessageCodec & MessageToMessageEncoder & MessageToMessageDecoder]-源码分析
WebSocketProtocolHandler
WebSocketProtocolHandler
是 WebSocket 处理的基础抽象类,负责管理 WebSocket 帧的解码、关闭流程及通用协议逻辑。
abstract class WebSocketProtocolHandler extends MessageToMessageDecoder<WebSocketFrame> {
private final boolean dropPongFrames;
private final WebSocketCloseStatus closeStatus;
private final long forceCloseTimeoutMillis;
private Promise<Void> closeSent;
WebSocketProtocolHandler() {
this(true);
}
WebSocketProtocolHandler(boolean dropPongFrames) {
this(dropPongFrames, null, 0L);
}
WebSocketProtocolHandler(boolean dropPongFrames, WebSocketCloseStatus closeStatus, long forceCloseTimeoutMillis) {
this.dropPongFrames = dropPongFrames;
this.closeStatus = closeStatus;
this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
}
@Override
protected void decode(ChannelHandlerContext ctx, WebSocketFrame msg) throws Exception {
throw new UnsupportedOperationException("WebSocketProtocolHandler use decodeAndClose().");
}
@Override
protected void decodeAndClose(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
if (frame instanceof PingWebSocketFrame) {
try (frame) {
ctx.writeAndFlush(new PongWebSocketFrame(frame.binaryData().send()));
}
readIfNeeded(ctx);
return;
}
if (frame instanceof PongWebSocketFrame && dropPongFrames) {
frame.close();
readIfNeeded(ctx);
return;
}
ctx.fireChannelRead(frame);
}
private static void readIfNeeded(ChannelHandlerContext ctx) {
if (!ctx.channel().getOption(ChannelOption.AUTO_READ)) {
ctx.read();
}
}
@Override
public Future<Void> close(final ChannelHandlerContext ctx) {
if (closeStatus == null || !ctx.channel().isActive()) {
return ctx.close();
}
final Future<Void> future = closeSent == null ?
write(ctx, new CloseWebSocketFrame(ctx.bufferAllocator(), closeStatus)) : closeSent.asFuture();
flush(ctx);
applyCloseSentTimeout(ctx);
Promise<Void> promise = ctx.newPromise();
future.addListener(f -> ctx.close().cascadeTo(promise));
return promise.asFuture();
}
@Override
public Future<Void> write(final ChannelHandlerContext ctx, Object msg) {
if (closeSent != null) {
Resource.dispose(msg);
return ctx.newFailedFuture(new ClosedChannelException());
}
if (msg instanceof CloseWebSocketFrame) {
Promise<Void> promise = ctx.newPromise();
closeSent(promise);
ctx.write(msg).cascadeTo(closeSent);
return promise.asFuture();
}
return ctx.write(msg);
}
void closeSent(Promise<Void> promise) {
closeSent = promise;
}
private void applyCloseSentTimeout(ChannelHandlerContext ctx) {
if (closeSent.isDone() || forceCloseTimeoutMillis < 0) {
return;
}
Future<?> timeoutTask = ctx.executor().schedule(() -> {
if (!closeSent.isDone()) {
closeSent.tryFailure(buildHandshakeException("send close frame timed out"));
}
}, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS);
closeSent.asFuture().addListener(future -> timeoutTask.cancel());
}
protected WebSocketHandshakeException buildHandshakeException(String message) {
return new WebSocketHandshakeException(message);
}
@Override
public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.fireChannelExceptionCaught(cause);
ctx.close();
}
}
WebSocketServerProtocolHandler
WebSocketServerProtocolHandler
负责在服务器端管理 WebSocket 握手、帧的解码与关闭处理,并支持协议校验与异常处理。
public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler {
private static final AttributeKey<WebSocketServerHandshaker> HANDSHAKER_ATTR_KEY = AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER");
private final WebSocketServerProtocolConfig serverConfig;
public WebSocketServerProtocolHandler(WebSocketServerProtocolConfig serverConfig) {
super(Objects.requireNonNull(serverConfig, "serverConfig").dropPongFrames(),
serverConfig.sendCloseFrame(),
serverConfig.forceCloseTimeoutMillis()
);
this.serverConfig = serverConfig;
}
// `handlerAdded` 方法负责在 ChannelPipeline 中动态添加握手处理器和 UTF-8 校验器,确保 WebSocket 握手和数据帧合法性校验功能生效。
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
ChannelPipeline cp = ctx.pipeline();
if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) {
// Add the WebSocketHandshakeHandler before this one.
cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(),
new WebSocketServerProtocolHandshakeHandler(serverConfig));
}
if (serverConfig.decoderConfig().withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator(serverConfig.decoderConfig().closeOnProtocolViolation()));
}
}
@Override
protected void decodeAndClose(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
// 当收到关闭帧时,优先通过已绑定的 WebSocketServerHandshaker 进行优雅关闭,否则直接关闭连接;非关闭帧则继续正常处理。
if (serverConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) {
WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel());
if (handshaker != null) {
Promise<Void> promise = ctx.newPromise();
closeSent(promise);
handshaker.close(ctx, (CloseWebSocketFrame) frame).cascadeTo(promise);
} else {
frame.close();
ctx.writeAndFlush(ctx.bufferAllocator().allocate(0)).addListener(ctx, ChannelFutureListeners.CLOSE);
}
return;
}
super.decodeAndClose(ctx, frame);
}
@Override
protected WebSocketServerHandshakeException buildHandshakeException(String message) {
return new WebSocketServerHandshakeException(message);
}
@Override
public void channelExceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
if (cause instanceof WebSocketHandshakeException) {
final byte[] bytes = cause.getMessage().getBytes();
FullHttpResponse response = new DefaultFullHttpResponse(
HTTP_1_1, HttpResponseStatus.BAD_REQUEST,
ctx.bufferAllocator().allocate(bytes.length).writeBytes(bytes));
ctx.channel().writeAndFlush(response).addListener(ctx, ChannelFutureListeners.CLOSE);
} else {
ctx.fireChannelExceptionCaught(cause);
ctx.close();
}
}
static WebSocketServerHandshaker getHandshaker(Channel channel) {
return channel.attr(HANDSHAKER_ATTR_KEY).get();
}
static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) {
channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker);
}
}
WebSocketClientProtocolHandler
WebSocketClientProtocolHandler
是 Netty 中用于处理 WebSocket 客户端协议升级、帧处理与自动注入握手与 UTF-8 校验器的核心 ChannelHandler。
public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler {
private final WebSocketClientHandshaker handshaker;
private final WebSocketClientProtocolConfig clientConfig;
public WebSocketClientHandshaker handshaker() {
return handshaker;
}
public WebSocketClientProtocolHandler(WebSocketClientProtocolConfig clientConfig) {
super(Objects.requireNonNull(clientConfig, "clientConfig").dropPongFrames(),
clientConfig.sendCloseFrame(), clientConfig.forceCloseTimeoutMillis());
this.handshaker = WebSocketClientHandshakerFactory.newHandshaker(
clientConfig.webSocketUri(),
clientConfig.version(),
clientConfig.subprotocol(),
clientConfig.allowExtensions(),
clientConfig.customHeaders(),
clientConfig.maxFramePayloadLength(),
clientConfig.performMasking(),
clientConfig.allowMaskMismatch(),
clientConfig.forceCloseTimeoutMillis(),
clientConfig.absoluteUpgradeUrl(),
clientConfig.generateOriginHeader()
);
this.clientConfig = clientConfig;
}
@Override
protected void decodeAndClose(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
if (clientConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) {
Resource.dispose(frame);
ctx.close();
return;
}
super.decodeAndClose(ctx, frame);
}
@Override
protected WebSocketClientHandshakeException buildHandshakeException(String message) {
return new WebSocketClientHandshakeException(message);
}
// `handlerAdded` 方法会在当前 Handler 加入 pipeline 时,
// 自动向其前方插入握手处理器和(可选的)UTF-8 校验器,以确保 WebSocket 客户端协议的正确初始化与安全性。
@Override
public void handlerAdded(ChannelHandlerContext ctx) {
ChannelPipeline cp = ctx.pipeline();
if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) {
// Add the WebSocketClientProtocolHandshakeHandler before this one.
ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(),
new WebSocketClientProtocolHandshakeHandler(handshaker, clientConfig.handshakeTimeoutMillis()));
}
if (clientConfig.withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) {
// Add the UFT8 checking before this one.
ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(),
new Utf8FrameValidator());
}
}
}