[spring6: Mvc-函数式编程]-源码解析

发布于:2025-07-27 ⋅ 阅读:(16) ⋅ 点赞:(0)

接口

ServerRequest

public interface ServerRequest {

	HttpMethod method();

	URI uri();

	UriBuilder uriBuilder();
	
	default String path() {
		return requestPath().pathWithinApplication().value();
	}

	default RequestPath requestPath() {
		return ServletRequestPathUtils.getParsedRequestPath(servletRequest());
	}

	Headers headers();

	MultiValueMap<String, Cookie> cookies();
	
	Optional<InetSocketAddress> remoteAddress();
	
	List<HttpMessageConverter<?>> messageConverters();

	<T> T body(Class<T> bodyType) throws ServletException, IOException;

	<T> T body(ParameterizedTypeReference<T> bodyType) throws ServletException, IOException;

	default <T> T bind(Class<T> bindType) throws BindException {
		return bind(bindType, dataBinder -> {});
	}

	<T> T bind(Class<T> bindType, Consumer<WebDataBinder> dataBinderCustomizer) throws BindException;

	default Optional<Object> attribute(String name) {
		Map<String, Object> attributes = attributes();
		if (attributes.containsKey(name)) {
			return Optional.of(attributes.get(name));
		}
		else {
			return Optional.empty();
		}
	}

	Map<String, Object> attributes();

	default Optional<String> param(String name) {
		List<String> paramValues = params().get(name);
		if (CollectionUtils.isEmpty(paramValues)) {
			return Optional.empty();
		}
		else {
			String value = paramValues.get(0);
			if (value == null) {
				value = "";
			}
			return Optional.of(value);
		}
	}
	
	MultiValueMap<String, String> params();

	MultiValueMap<String, Part> multipartData() throws IOException, ServletException;
	
	default String pathVariable(String name) {
		Map<String, String> pathVariables = pathVariables();
		if (pathVariables.containsKey(name)) {
			return pathVariables.get(name);
		}
		else {
			throw new IllegalArgumentException("No path variable with name \"" + name + "\" available");
		}
	}
	
	Map<String, String> pathVariables();

	HttpSession session();
	
	Optional<Principal> principal();

	HttpServletRequest servletRequest();
	
	default Optional<ServerResponse> checkNotModified(Instant lastModified) {
		Assert.notNull(lastModified, "LastModified must not be null");
		return DefaultServerRequest.checkNotModified(servletRequest(), lastModified, null);
	}

	default Optional<ServerResponse> checkNotModified(String etag) {
		Assert.notNull(etag, "Etag must not be null");
		return DefaultServerRequest.checkNotModified(servletRequest(), null, etag);
	}

	default Optional<ServerResponse> checkNotModified(Instant lastModified, String etag) {
		Assert.notNull(lastModified, "LastModified must not be null");
		Assert.notNull(etag, "Etag must not be null");
		return DefaultServerRequest.checkNotModified(servletRequest(), lastModified, etag);
	}

	static ServerRequest create(HttpServletRequest servletRequest, List<HttpMessageConverter<?>> messageReaders) {
		return new DefaultServerRequest(servletRequest, messageReaders);
	}

	static Builder from(ServerRequest other) {
		return new DefaultServerRequestBuilder(other);
	}

	// ...
}

Headers

interface Headers {

	List<MediaType> accept();

	List<Charset> acceptCharset();
	
	List<Locale.LanguageRange> acceptLanguage();
	
	OptionalLong contentLength();
	
	Optional<MediaType> contentType();
	
	@Nullable
	InetSocketAddress host();
	
	List<HttpRange> range();

	List<String> header(String headerName);

	@Nullable
	default String firstHeader(String headerName) {
		List<String> list = header(headerName);
		return list.isEmpty() ? null : list.get(0);
	}

	HttpHeaders asHttpHeaders();
}

Builder

interface Builder {

	Builder method(HttpMethod method);

	Builder uri(URI uri);

	Builder header(String headerName, String... headerValues);
	
	Builder headers(Consumer<HttpHeaders> headersConsumer);

	Builder cookie(String name, String... values);

	Builder cookies(Consumer<MultiValueMap<String, Cookie>> cookiesConsumer);

	Builder body(byte[] body);

	Builder body(String body);

	Builder attribute(String name, Object value);
	
	Builder attributes(Consumer<Map<String, Object>> attributesConsumer);

	Builder param(String name, String... values);

	Builder params(Consumer<MultiValueMap<String, String>> paramsConsumer);
	
	Builder remoteAddress(InetSocketAddress remoteAddress);

	ServerRequest build();
	
}

RequestPredicate

// RequestPredicates 
@FunctionalInterface
public interface RequestPredicate {

	boolean test(ServerRequest request);

	default RequestPredicate and(RequestPredicate other) {
		return new RequestPredicates.AndRequestPredicate(this, other);
	}

	default RequestPredicate negate() {
		return new RequestPredicates.NegateRequestPredicate(this);
	}

	default RequestPredicate or(RequestPredicate other) {
		return new RequestPredicates.OrRequestPredicate(this, other);
	}

	default Optional<ServerRequest> nest(ServerRequest request) {
		return (test(request) ? Optional.of(request) : Optional.empty());
	}
	
	default void accept(RequestPredicates.Visitor visitor) {
		visitor.unknown(this);
	}
}

RouterFunction

// RouterFunctions
@FunctionalInterface
public interface RouterFunction<T extends ServerResponse> {

	Optional<HandlerFunction<T>> route(ServerRequest request);

	default RouterFunction<T> and(RouterFunction<T> other) {
		return new RouterFunctions.SameComposedRouterFunction<>(this, other);
	}

	default RouterFunction<?> andOther(RouterFunction<?> other) {
		return new RouterFunctions.DifferentComposedRouterFunction(this, other);
	}

	default RouterFunction<T> andRoute(RequestPredicate predicate, HandlerFunction<T> handlerFunction) {
		return and(RouterFunctions.route(predicate, handlerFunction));
	}

	default RouterFunction<T> andNest(RequestPredicate predicate, RouterFunction<T> routerFunction) {
		return and(RouterFunctions.nest(predicate, routerFunction));
	}

	default <S extends ServerResponse> RouterFunction<S> filter(HandlerFilterFunction<T, S> filterFunction) {
		return new RouterFunctions.FilteredRouterFunction<>(this, filterFunction);
	}

	default void accept(RouterFunctions.Visitor visitor) {
		visitor.unknown(this);
	}
	
	default RouterFunction<T> withAttribute(String name, Object value) {
		Assert.hasLength(name, "Name must not be empty");
		Assert.notNull(value, "Value must not be null");

		Map<String, Object> attributes = new LinkedHashMap<>();
		attributes.put(name, value);
		return new RouterFunctions.AttributesRouterFunction<>(this, attributes);
	}

	default RouterFunction<T> withAttributes(Consumer<Map<String, Object>> attributesConsumer) {
		Assert.notNull(attributesConsumer, "AttributesConsumer must not be null");

		Map<String, Object> attributes = new LinkedHashMap<>();
		attributesConsumer.accept(attributes);
		return new RouterFunctions.AttributesRouterFunction<>(this, attributes);
	}

}

HandlerFilterFunction

@FunctionalInterface
public interface HandlerFilterFunction<T extends ServerResponse, R extends ServerResponse> {

	R filter(ServerRequest request, HandlerFunction<T> next) throws Exception;

	default HandlerFilterFunction<T, R> andThen(HandlerFilterFunction<T, T> after) {
		Assert.notNull(after, "HandlerFilterFunction must not be null");
		return (request, next) -> {
			HandlerFunction<T> nextHandler = handlerRequest -> after.filter(handlerRequest, next);
			return filter(request, nextHandler);
		};
	}

	default HandlerFunction<R> apply(HandlerFunction<T> handler) {
		Assert.notNull(handler, "HandlerFunction must not be null");
		return request -> this.filter(request, handler);
	}

	static <T extends ServerResponse> HandlerFilterFunction<T, T>
	ofRequestProcessor(Function<ServerRequest, ServerRequest> requestProcessor) {

		Assert.notNull(requestProcessor, "Function must not be null");
		return (request, next) -> next.handle(requestProcessor.apply(request));
	}
	
	static <T extends ServerResponse, R extends ServerResponse> HandlerFilterFunction<T, R>
	ofResponseProcessor(BiFunction<ServerRequest, T, R> responseProcessor) {

		Assert.notNull(responseProcessor, "Function must not be null");
		return (request, next) -> responseProcessor.apply(request, next.handle(request));
	}

	static <T extends ServerResponse> HandlerFilterFunction<T, T>
	ofErrorHandler(Predicate<Throwable> predicate, BiFunction<Throwable, ServerRequest, T> errorHandler) {

		Assert.notNull(predicate, "Predicate must not be null");
		Assert.notNull(errorHandler, "ErrorHandler must not be null");

		return (request, next) -> {
			try {
				T t = next.handle(request);
				if (t instanceof ErrorHandlingServerResponse response) {
					response.addErrorHandler(predicate, errorHandler);
				}
				return t;
			}
			catch (Throwable throwable) {
				if (predicate.test(throwable)) {
					return errorHandler.apply(throwable, request);
				}
				else {
					throw throwable;
				}
			}
		};
	}

}

HandlerFunction

@FunctionalInterface
public interface HandlerFunction<T extends ServerResponse> {
	T handle(ServerRequest request) throws Exception;
}

ServerResponse

// AsyncServerResponse, EntityResponse<T>, RenderingResponse
public interface ServerResponse {

	HttpStatusCode statusCode();

	HttpHeaders headers();
	
	MultiValueMap<String, Cookie> cookies();

	@Nullable
	ModelAndView writeTo(HttpServletRequest request, HttpServletResponse response, Context context) throws ServletException, IOException;


	static BodyBuilder from(ServerResponse other) {
		return new DefaultServerResponseBuilder(other);
	}

	static ServerResponse from(ErrorResponse response) {
		return status(response.getStatusCode())
				.headers(headers -> headers.putAll(response.getHeaders()))
				.body(response.getBody());
	}

	static BodyBuilder status(HttpStatusCode status) {
		return new DefaultServerResponseBuilder(status);
	}

	static BodyBuilder status(int status) {
		return new DefaultServerResponseBuilder(HttpStatusCode.valueOf(status));
	}

	static BodyBuilder ok() {
		return status(HttpStatus.OK);
	}

	static BodyBuilder created(URI location) {
		BodyBuilder builder = status(HttpStatus.CREATED);
		return builder.location(location);
	}

	static BodyBuilder accepted() {
		return status(HttpStatus.ACCEPTED);
	}

	static HeadersBuilder<?> noContent() {
		return status(HttpStatus.NO_CONTENT);
	}

	static BodyBuilder seeOther(URI location) {
		BodyBuilder builder = status(HttpStatus.SEE_OTHER);
		return builder.location(location);
	}

	static BodyBuilder temporaryRedirect(URI location) {
		BodyBuilder builder = status(HttpStatus.TEMPORARY_REDIRECT);
		return builder.location(location);
	}
	
	static BodyBuilder permanentRedirect(URI location) {
		BodyBuilder builder = status(HttpStatus.PERMANENT_REDIRECT);
		return builder.location(location);
	}

	static BodyBuilder badRequest() {
		return status(HttpStatus.BAD_REQUEST);
	}

	static HeadersBuilder<?> notFound() {
		return status(HttpStatus.NOT_FOUND);
	}

	static BodyBuilder unprocessableEntity() {
		return status(HttpStatus.UNPROCESSABLE_ENTITY);
	}

	static ServerResponse async(Object asyncResponse) {
		return AsyncServerResponse.create(asyncResponse);
	}
	
	static ServerResponse async(Object asyncResponse, Duration timeout) {
		return AsyncServerResponse.create(asyncResponse, timeout);
	}

	static ServerResponse sse(Consumer<SseBuilder> consumer) {
		return SseServerResponse.create(consumer, null);
	}

	static ServerResponse sse(Consumer<SseBuilder> consumer, Duration timeout) {
		return SseServerResponse.create(consumer, timeout);
	}

	// ...	
}

HeadersBuilder

HeadersBuilder 是一个用于构建 HTTP 响应头的链式构建器接口,支持添加和修改响应头、Cookie、缓存控制、允许的方法、ETag、资源位置等常用 HTTP 头部信息,并最终生成响应对象。

interface HeadersBuilder<B extends HeadersBuilder<B>> {

	B header(String headerName, String... headerValues);
	
	B headers(Consumer<HttpHeaders> headersConsumer);
	
	B cookie(Cookie cookie);
	
	B cookies(Consumer<MultiValueMap<String, Cookie>> cookiesConsumer);

	B allow(HttpMethod... allowedMethods);

	B allow(Set<HttpMethod> allowedMethods);

	B eTag(String eTag);

	B lastModified(ZonedDateTime lastModified);
	
	B lastModified(Instant lastModified);

	B location(URI location);
	
	B cacheControl(CacheControl cacheControl);

	B varyBy(String... requestHeaders);
	
	ServerResponse build();

	ServerResponse build(WriteFunction writeFunction);

	@FunctionalInterface
	interface WriteFunction {
	
		@Nullable
		ModelAndView write(HttpServletRequest servletRequest, HttpServletResponse servletResponse) throws Exception;

	}

}

BodyBuilder

BodyBuilder 是在 HeadersBuilder 基础上扩展的接口,用于构建包含响应体(包括对象、模板渲染、流式数据等)的 HTTP 响应内容。

interface BodyBuilder extends HeadersBuilder<BodyBuilder> {

	BodyBuilder contentLength(long contentLength);
	
	BodyBuilder contentType(MediaType contentType);

	ServerResponse body(Object body);

	<T> ServerResponse body(T body, ParameterizedTypeReference<T> bodyType);

	ServerResponse render(String name, Object... modelAttributes);

	ServerResponse render(String name, Map<String, ?> model);
	
	ServerResponse stream(Consumer<StreamBuilder> streamConsumer);

}

SseBuilder

SseBuilder 是用于构建并发送 Server-Sent Events(SSE) 的构建器接口,支持事件 ID、名称、重试时间、注释、数据发送,以及超时、异常和完成回调等控制逻辑。

interface SseBuilder {

	void error(Throwable t);
	
	void complete();

	SseBuilder onTimeout(Runnable onTimeout);
	
	SseBuilder onError(Consumer<Throwable> onError);
	
	SseBuilder onComplete(Runnable onCompletion);
	
	void send(Object object) throws IOException;
	
	void send() throws IOException;
	
	SseBuilder id(String id);
	
	SseBuilder event(String eventName);
	
	SseBuilder retry(Duration duration);

	SseBuilder comment(String comment);

	void data(Object object) throws IOException;
	
}

StreamBuilder

StreamBuilder 是用于构建低层级响应流的构建器接口,支持向客户端逐步写入对象数据(可指定媒体类型)、刷新缓冲、注册流控制回调,以及在出现异常或超时等情况时进行处理。

interface StreamBuilder {

	void error(Throwable t);
	
	void complete();
	
	StreamBuilder onTimeout(Runnable onTimeout);
	
	StreamBuilder onError(Consumer<Throwable> onError);
	
	StreamBuilder onComplete(Runnable onCompletion);
	
	StreamBuilder write(Object object) throws IOException;
	
	StreamBuilder write(Object object, @Nullable MediaType mediaType) throws IOException;
	
	void flush() throws IOException;

}

Context

Context 接口定义了在 writeTo(HttpServletRequest, HttpServletResponse, Context) 方法中使用的上下文,提供用于响应体转换的 HttpMessageConverter 列表。

interface Context {
	List<HttpMessageConverter<?>> messageConverters();
}

执行

推荐阅读: [spring6: DispatcherServlet.doDispatch]-源码分析

RouterFunctionMapping

public class RouterFunctionMapping extends AbstractHandlerMapping implements InitializingBean, MatchableHandlerMapping {

	@Nullable
	private RouterFunction<?> routerFunction;

	private List<HttpMessageConverter<?>> messageConverters = Collections.emptyList();

	private boolean detectHandlerFunctionsInAncestorContexts = false;
	
	public RouterFunctionMapping() {}

	public RouterFunctionMapping(RouterFunction<?> routerFunction) {
		this.routerFunction = routerFunction;
	}

	@Override
	public void afterPropertiesSet() {
		if (this.routerFunction == null) {
			initRouterFunctions();
		}
		if (CollectionUtils.isEmpty(this.messageConverters)) {
			initMessageConverters();
		}
		if (this.routerFunction != null) {
			PathPatternParser patternParser = getPatternParser();
			if (patternParser == null) {
				patternParser = new PathPatternParser();
				setPatternParser(patternParser);
			}
			RouterFunctions.changeParser(this.routerFunction, patternParser);
		}
	}

	private void initRouterFunctions() {
		List<RouterFunction<?>> routerFunctions = obtainApplicationContext()
				.getBeanProvider(RouterFunction.class)
				.orderedStream()
				.map(router -> (RouterFunction<?>) router)
				.collect(Collectors.toList());

		ApplicationContext parentContext = obtainApplicationContext().getParent();
		if (parentContext != null && !this.detectHandlerFunctionsInAncestorContexts) {
			parentContext.getBeanProvider(RouterFunction.class).stream().forEach(routerFunctions::remove);
		}

		this.routerFunction = routerFunctions.stream().reduce(RouterFunction::andOther).orElse(null);
		logRouterFunctions(routerFunctions);
	}

	private void initMessageConverters() {
		List<HttpMessageConverter<?>> messageConverters = new ArrayList<>(4);
		messageConverters.add(new ByteArrayHttpMessageConverter());
		messageConverters.add(new StringHttpMessageConverter());
		messageConverters.add(new AllEncompassingFormHttpMessageConverter());

		this.messageConverters = messageConverters;
	}

	@Override
	@Nullable
	protected Object getHandlerInternal(HttpServletRequest servletRequest) throws Exception {
		if (this.routerFunction != null) {
			// DefaultServerRequest
			ServerRequest request = ServerRequest.create(servletRequest, this.messageConverters);
			HandlerFunction<?> handlerFunction = this.routerFunction.route(request).orElse(null);
			setAttributes(servletRequest, request, handlerFunction);
			return handlerFunction;
		}
		else {
			return null;
		}
	}

	private void setAttributes(HttpServletRequest servletRequest, ServerRequest request,
			@Nullable HandlerFunction<?> handlerFunction) {

		PathPattern matchingPattern =
				(PathPattern) servletRequest.getAttribute(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE);
		if (matchingPattern != null) {
			servletRequest.removeAttribute(RouterFunctions.MATCHING_PATTERN_ATTRIBUTE);
			servletRequest.setAttribute(BEST_MATCHING_PATTERN_ATTRIBUTE, matchingPattern.getPatternString());
			ServerHttpObservationFilter.findObservationContext(request.servletRequest())
					.ifPresent(context -> context.setPathPattern(matchingPattern.getPatternString()));
		}
		servletRequest.setAttribute(BEST_MATCHING_HANDLER_ATTRIBUTE, handlerFunction);
		servletRequest.setAttribute(RouterFunctions.REQUEST_ATTRIBUTE, request);
	}

	@Nullable
	@Override
	public RequestMatchResult match(HttpServletRequest request, String pattern) {
		throw new UnsupportedOperationException("This HandlerMapping uses PathPatterns");
	}
}

HandlerFunctionAdapter

public class HandlerFunctionAdapter implements HandlerAdapter, Ordered {

	private static final Log logger = LogFactory.getLog(HandlerFunctionAdapter.class);

	private int order = Ordered.LOWEST_PRECEDENCE;

	@Nullable
	private Long asyncRequestTimeout;

	public void setOrder(int order) {
		this.order = order;
	}

	@Override
	public int getOrder() {
		return this.order;
	}

	public void setAsyncRequestTimeout(long timeout) {
		this.asyncRequestTimeout = timeout;
	}

	@Override
	public boolean supports(Object handler) {
		return handler instanceof HandlerFunction;
	}

	@Nullable
	@Override
	public ModelAndView handle(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Object handler) throws Exception {
		WebAsyncManager asyncManager = getWebAsyncManager(servletRequest, servletResponse);
		servletResponse = getWrappedResponse(asyncManager);

		ServerRequest serverRequest = getServerRequest(servletRequest);
		ServerResponse serverResponse;

		if (asyncManager.hasConcurrentResult()) {
			serverResponse = handleAsync(asyncManager);
		}
		else {
			HandlerFunction<?> handlerFunction = (HandlerFunction<?>) handler;
			serverResponse = handlerFunction.handle(serverRequest);
		}

		if (serverResponse != null) {
			return serverResponse.writeTo(servletRequest, servletResponse, new ServerRequestContext(serverRequest));
		}
		else {
			return null;
		}
	}

	private WebAsyncManager getWebAsyncManager(HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
		AsyncWebRequest asyncWebRequest = WebAsyncUtils.createAsyncWebRequest(servletRequest, servletResponse);
		asyncWebRequest.setTimeout(this.asyncRequestTimeout);

		WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(servletRequest);
		asyncManager.setAsyncWebRequest(asyncWebRequest);
		return asyncManager;
	}

	private static HttpServletResponse getWrappedResponse(WebAsyncManager asyncManager) {
		AsyncWebRequest asyncRequest = asyncManager.getAsyncWebRequest();
		Assert.notNull(asyncRequest, "No AsyncWebRequest");

		HttpServletResponse servletResponse = asyncRequest.getNativeResponse(HttpServletResponse.class);
		Assert.notNull(servletResponse, "No HttpServletResponse");

		return servletResponse;
	}

	private ServerRequest getServerRequest(HttpServletRequest servletRequest) {
		ServerRequest serverRequest =
				(ServerRequest) servletRequest.getAttribute(RouterFunctions.REQUEST_ATTRIBUTE);
		Assert.state(serverRequest != null, () -> "Required attribute '" +
				RouterFunctions.REQUEST_ATTRIBUTE + "' is missing");
		return serverRequest;
	}

	@Nullable
	private ServerResponse handleAsync(WebAsyncManager asyncManager) throws Exception {
		Object result = asyncManager.getConcurrentResult();
		asyncManager.clearConcurrentResult();
		LogFormatUtils.traceDebug(logger, traceOn -> {
			String formatted = LogFormatUtils.formatValue(result, !traceOn);
			return "Resume with async result [" + formatted + "]";
		});
		if (result instanceof ServerResponse response) {
			return response;
		}
		else if (result instanceof Exception exception) {
			throw exception;
		}
		else if (result instanceof Throwable throwable) {
			throw new ServletException("Async processing failed", throwable);
		}
		else if (result == null) {
			return null;
		}
		else {
			throw new IllegalArgumentException("Unknown result from WebAsyncManager: [" + result + "]");
		}
	}

	@Override
	@SuppressWarnings("deprecation")
	public long getLastModified(HttpServletRequest request, Object handler) {
		return -1L;
	}


	private static class ServerRequestContext implements ServerResponse.Context {

		private final ServerRequest serverRequest;


		public ServerRequestContext(ServerRequest serverRequest) {
			this.serverRequest = serverRequest;
		}

		@Override
		public List<HttpMessageConverter<?>> messageConverters() {
			return this.serverRequest.messageConverters();
		}
	}
}

实战

@SpringBootApplication
public class Application {

    @Bean
    public RequestPredicate requestPredicate() {
        return RequestPredicates.GET("/");
    }

    @Bean
    public HandlerFunction<ServerResponse> handlerFunction() {
        return request -> ServerResponse.ok().body("hello, world");
    }

    @Bean
    public HandlerFilterFunction<ServerResponse, ServerResponse> handlerFilterFunction() {
        return (request, next) -> {
            System.out.println("Before handler");
            ServerResponse response = next.handle(request);
            System.out.println("After handler");
            return response;
        };
    }

    @Bean
    public RouterFunction<ServerResponse> routerFunction() {
        return RouterFunctions.route(requestPredicate(), handlerFunction()).filter(handlerFilterFunction());
    }

    public static void main(String[] args) {
        SpringApplication.run(Application.class, args);
    }

}

网站公告

今日签到

点亮在社区的每一天
去签到