[netty5: WebSocketClientHandshaker & WebSocketClientHandshakerFactory]-源码分析

发布于:2025-07-11 ⋅ 阅读:(15) ⋅ 点赞:(0)

在阅读这篇文章前,推荐先阅读以下内容:

  1. [netty5: WebSocketFrame]-源码分析
  2. [netty5: WebSocketFrameEncoder & WebSocketFrameDecoder]-源码解析

WebSocketClientHandshakerFactory

WebSocketClientHandshakerFactory 是用于根据 URI 和协议版本创建对应 WebSocket 握手器(Handshaker)的工厂类,简化客户端握手流程。

public final class WebSocketClientHandshakerFactory {

    private WebSocketClientHandshakerFactory() {}

	// ...

	// new WebSocketClientProtocolHandler(config)
    public static WebSocketClientHandshaker newHandshaker(
            URI webSocketURL, WebSocketVersion version, String subprotocol,
            boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
            boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis,
            boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
        return new WebSocketClientHandshaker13(
                webSocketURL, subprotocol, allowExtensions, customHeaders,
                maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis,
                absoluteUpgradeUrl, generateOriginHeader);
    }
}

WebSocketClientHandshaker13

WebSocketClientHandshaker13 是实现 WebSocket 协议 RFC 6455(版本13)的客户端握手器,负责构造握手请求、验证响应并完成协议升级。

public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker {

    private final boolean allowExtensions;
    private final boolean performMasking;
    private final boolean allowMaskMismatch;

    private volatile String sentNonce;

	// WebSocketClientHandshakerFactory.newHandshaker
    WebSocketClientHandshaker13(URI webSocketURL, String subprotocol,
        boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength,
        boolean performMasking, boolean allowMaskMismatch,
        long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl,
        boolean generateOriginHeader
    ) {
        super(webSocketURL, WebSocketVersion.V13, 
        	  subprotocol, customHeaders, maxFramePayloadLength,
              forceCloseTimeoutMillis, absoluteUpgradeUrl, generateOriginHeader
         );
        this.allowExtensions = allowExtensions;
        this.performMasking = performMasking;
        this.allowMaskMismatch = allowMaskMismatch;
    }

    /**
     * /**
     * <p>
     * Sends the opening request to the server:
     * </p>
     *
     * <pre>
     * GET /chat HTTP/1.1
     * Host: server.example.com
     * Upgrade: websocket
     * Connection: Upgrade
     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
     * Sec-WebSocket-Protocol: chat, superchat
     * Sec-WebSocket-Version: 13
     * </pre>
     *
     */
    @Override
    protected FullHttpRequest newHandshakeRequest(BufferAllocator allocator) {
        URI wsURL = uri();
        
        FullHttpRequest request = new DefaultFullHttpRequest(
        		HttpVersion.HTTP_1_1, 
        		HttpMethod.GET, upgradeUrl(wsURL),
                allocator.allocate(0)
        );
        
        HttpHeaders headers = request.headers();
        
        if (customHeaders != null) {
            headers.add(customHeaders);
            if (!headers.contains(HttpHeaderNames.HOST)) {
                headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
            }
        } else {
            headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL));
        }

        String nonce = createNonce();
        headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET)
               .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)
               .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, nonce);

        if (generateOriginHeader && !headers.contains(HttpHeaderNames.ORIGIN)) {
            headers.set(HttpHeaderNames.ORIGIN, websocketHostValue(wsURL));
        }

        sentNonce = nonce;
        String expectedSubprotocol = expectedSubprotocol();
        if (!StringUtil.isNullOrEmpty(expectedSubprotocol)) {
            headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol);
        }

        headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, version().toAsciiString());
        return request;
    }

    /**
     * <p>
     * Process server response:
     * </p>
     *
     * <pre>
     * HTTP/1.1 101 Switching Protocols
     * Upgrade: websocket
     * Connection: Upgrade
     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
     * Sec-WebSocket-Protocol: chat
     * </pre>
     *
     * @param response
     *            HTTP response returned from the server for the request sent by beginOpeningHandshake00().
     * @throws WebSocketHandshakeException if handshake response is invalid.
     */
    @Override
    protected void verify(FullHttpResponse response) {
        HttpResponseStatus status = response.status();
        if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) {
            throw new WebSocketClientHandshakeException("Invalid handshake response status: " + status, response);
        }

        HttpHeaders headers = response.headers();
        CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE);
        if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) {
            throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response);
        }

        if (!headers.containsIgnoreCase(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE)) {
            throw new WebSocketClientHandshakeException("Invalid handshake response connection: " + headers.get(HttpHeaderNames.CONNECTION), response);
        }

        CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT);
        if (accept == null) {
            throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: null", response);
        }

        String expectedAccept = WebSocketUtil.calculateV13Accept(sentNonce);
        if (!AsciiString.contentEquals(expectedAccept, AsciiString.trim(accept))) {
            throw new WebSocketClientHandshakeException("Invalid handshake response sec-websocket-accept: " + accept + ", expected: " + expectedAccept, response);
        }
    }

    @Override
    protected WebSocketFrameDecoder newWebsocketDecoder() {
        return new WebSocket13FrameDecoder(false, allowExtensions, maxFramePayloadLength(), allowMaskMismatch);
    }

    @Override
    protected WebSocketFrameEncoder newWebSocketEncoder() {
        return new WebSocket13FrameEncoder(performMasking);
    }

    @Override
    public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) {
        super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis);
        return this;
    }

	// 生成一个符合 WebSocket 协议要求的 16 字节 Base64 编码的随机值,用作 Sec-WebSocket-Key
    private static String createNonce() {
        var nonce = WebSocketUtil.randomBytes(16);
        return WebSocketUtil.base64(nonce);
    }
}

WebSocketClientHandshaker

public abstract class WebSocketClientHandshaker {
    protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000;
	
	// 代表握手时的目标地址, 例如 ws://example.com/chat
    private final URI uri;
	
	// 控制握手请求和数据帧的格式, 比如 RFC 6455 标准版本
    private final WebSocketVersion version;

	// 标记握手是否完成,volatile 保证多线程访问时的可见性
    private volatile boolean handshakeComplete;

	// 握手完成后,如果关闭 WebSocket 连接时等待超时,会触发强制关闭。
    private volatile long forceCloseTimeoutMillis;

	// 用于标记强制关闭流程是否初始化, 通过 AtomicIntegerFieldUpdater 原子更新
    private volatile int forceCloseInit;
    private static final AtomicIntegerFieldUpdater<WebSocketClientHandshaker> FORCE_CLOSE_INIT_UPDATER = AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit");

	// 标记强制关闭流程是否完成。
    private volatile boolean forceCloseComplete;
	// 握手时客户端希望协商的子协议(Subprotocol), 例如视频、聊天子协议名称等
    private final String expectedSubprotocol;

	// 握手后服务器协商确认的子协议,握手成功后才有值。
    private volatile String actualSubprotocol;

	// 握手请求时使用,方便传递用户自定义信息。
    protected final HttpHeaders customHeaders;
	
	// 最大单个 WebSocket 帧负载长度限制, 防止收到超大数据导致内存溢出。
    private final int maxFramePayloadLength;
	// 是否在握手请求中使用绝对 URI 作为 Upgrade URL, 一般用于特殊代理或协议场景
    private final boolean absoluteUpgradeUrl;

	// 是否自动生成 Origin 请求头
    protected final boolean generateOriginHeader;
    
    protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol,
            HttpHeaders customHeaders, int maxFramePayloadLength,
            long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) {
        this.uri = uri;
        this.version = version;
        expectedSubprotocol = subprotocol;
        this.customHeaders = customHeaders;
        this.maxFramePayloadLength = maxFramePayloadLength;
        this.forceCloseTimeoutMillis = forceCloseTimeoutMillis;
        this.absoluteUpgradeUrl = absoluteUpgradeUrl;
        this.generateOriginHeader = generateOriginHeader;
    }

	// WebSocketClientProtocolHandshakeHandler.channelActive
    public Future<Void> handshake(Channel channel) {
        requireNonNull(channel, "channel");
        ChannelPipeline pipeline = channel.pipeline();
        // 检查管道中解码器
        HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class);
        if (decoder == null) {
            HttpClientCodec codec = pipeline.get(HttpClientCodec.class);
            if (codec == null) {
               return channel.newFailedFuture(new IllegalStateException("ChannelPipeline does not contain " + "an HttpResponseDecoder or HttpClientCodec"));
            }
        }

		// 检查 URI 和 Header 相关的 Host 与 Origin
        if (uri.getHost() == null) {
            if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) {
                return channel.newFailedFuture(new IllegalArgumentException("Cannot generate the 'host' header value," + " webSocketURI should contain host or passed through customHeaders"));
            }

            if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) {
                final String originName = HttpHeaderNames.ORIGIN.toString();
                return channel.newFailedFuture(
                        new IllegalArgumentException("Cannot generate the '" + originName + "' header" + " value, webSocketURI should contain host or disable generateOriginHeader or pass value" + " through customHeaders"));
            }
        }

		// 创建握手请求
        FullHttpRequest request = newHandshakeRequest(channel.bufferAllocator());

		// 创建 Promise,异步写出请求
        Promise<Void> promise = channel.newPromise();
        channel.writeAndFlush(request).addListener(channel, (ch, future) -> {
        	// 如果写操作成功
            if (future.isSuccess()) {
                ChannelPipeline p = ch.pipeline();
                //找出管道中 HTTP 请求编码器 HttpRequestEncoder 或者 HttpClientCodec,
                ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class);
                if (ctx == null) {
                    ctx = p.context(HttpClientCodec.class);
                }
                if (ctx == null) {
                    promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + "an HttpRequestEncoder or HttpClientCodec"));
                    return;
                }

				// 然后在其后面动态添加 WebSocket 专用的编码器 ws-encoder(由 newWebSocketEncoder() 创建)
                p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder());

                promise.setSuccess(null);
            } else {
                promise.setFailure(future.cause());
            }
        });
        return promise.asFuture();
    }

	// WebSocketClientProtocolHandshakeHandler.channelRead
    public final void finishHandshake(Channel channel, FullHttpResponse response) {
        verify(response);
        
        // 服务器返回的子协议
        CharSequence receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL);
        receivedProtocol = receivedProtocol != null ? AsciiString.trim(receivedProtocol) : null;
        // 客户端期望的子协议
        String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : "";
        boolean protocolValid = false;

		// 如果客户端没指定预期协议,且服务器也没返回协议,视为通过。
        if (expectedProtocol.isEmpty() && receivedProtocol == null) {
            protocolValid = true;
            setActualSubprotocol(expectedSubprotocol);
        } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && receivedProtocol.length() > 0) {
        	// 如果客户端有期望协议且服务器返回了协议,则判断服务器返回的协议是否在客户端允许的列表中
            for (String protocol : expectedProtocol.split(",")) {
                if (AsciiString.contentEquals(protocol.trim(), receivedProtocol)) {
                    protocolValid = true;
                    setActualSubprotocol(receivedProtocol.toString());
                    break;
                }
            }
        }

		// 如果子协议校验失败,抛出握手异常。
        if (!protocolValid) {
            throw new WebSocketClientHandshakeException(String.format(
                    "Invalid subprotocol. Actual: %s. Expected one of: %s",
                    receivedProtocol, expectedSubprotocol), response);
        }

		// 标记握手完成。
        setHandshakeComplete();

        final ChannelPipeline p = channel.pipeline();
        // 移除 HTTP 消息解压处理器(如 gzip 解压),以及 HTTP 聚合器,WebSocket 不需要这些
        HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class);
        if (decompressor != null) {
            p.remove(decompressor);
        }

        HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class);
        if (aggregator != null) {
            p.remove(aggregator);
        }

		// 查找 HTTP 解码器上下文:
	  	// 		1. 若是 HttpClientCodec,先调用 removeOutboundHandler(),然后添加 WebSocket 解码器,最后异步移除 HTTP Codec。
		// 		2. 若是单独的 HttpResponseDecoder,先移除对应的请求编码器,再添加 WebSocket 解码器,异步移除响应解码器。
		// 新加入的 ws-decoder 是 WebSocket 的解码器,处理 WebSocket 帧。
        ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class);
        if (ctx == null) {
            ctx = p.context(HttpClientCodec.class);
            if (ctx == null) {
                throw new IllegalStateException("ChannelPipeline does not contain " +
                        "an HttpRequestEncoder or HttpClientCodec");
            }
            final HttpClientCodec codec =  (HttpClientCodec) ctx.handler();
            codec.removeOutboundHandler();

            p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder());
            channel.executor().execute(() -> p.remove(codec));
        } else {
            if (p.get(HttpRequestEncoder.class) != null) {
                p.remove(HttpRequestEncoder.class);
            }
            final ChannelHandlerContext context = ctx;
            p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder());
            channel.executor().execute(() -> p.remove(context.handler()));
        }
    }
    
	// ...
	
    protected abstract FullHttpRequest newHandshakeRequest(BufferAllocator allocator);
    protected abstract void verify(FullHttpResponse response);
    protected abstract WebSocketFrameDecoder newWebsocketDecoder();
    protected abstract WebSocketFrameEncoder newWebSocketEncoder();
}

网站公告

今日签到

点亮在社区的每一天
去签到