spring-ai 工作流

发布于:2025-07-01 ⋅ 阅读:(15) ⋅ 点赞:(0)

工作流概念

工作流是以相对固化的模式来人为地拆解任务,将一个大任务拆解为包含多个分支的固化流程。工作流的优势是确定性强,模型作为流程中的一个节点起到的更多是一个分类决策、内容生成的职责,因此它更适合意图识别等类别属性强的应用场景。

参考文档:https://java2ai.com/docs/1.0.0.2/get-started/workflow/?spm=4347728f.7cee0e64.0.0.39076dd1jbppqZ

工作流程图

商品评价分类流程图:
在这里插入图片描述

如用户反馈

  • This product is excellent, I love it!
    则输出:Praise, no action taken.
    说明:很好,不需要改进措施

  • The product broke after one day, very disappointed."
    则输出:product quality
    说明:有问题,产品质量问题

spring-boot 编码

使用:Spring AI Alibaba Graph

附maven的pom.xml

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
	xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
	<modelVersion>4.0.0</modelVersion>
	<parent>
		<groupId>org.springframework.boot</groupId>
		<artifactId>spring-boot-starter-parent</artifactId>
		<version>3.4.6</version>
		<relativePath/> <!-- lookup parent from repository -->
	</parent>
	<groupId>com.example</groupId>
	<artifactId>demo-spring-test</artifactId>
	<version>0.0.1-SNAPSHOT</version>
	<name>demo-spring-test</name>
	<description>Demo project for Spring Boot</description>
	<url/>
	<licenses>
		<license/>
	</licenses>
	<developers>
		<developer/>
	</developers>
	<scm>
		<connection/>
		<developerConnection/>
		<tag/>
		<url/>
	</scm>
	<properties>
		<java.version>17</java.version>
		<spring-ai.version>1.0.0</spring-ai.version>
	</properties>
	<dependencies>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-web</artifactId>
		</dependency>

		<!-- Spring AI Alibaba(通义大模型支持) -->
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-starter</artifactId>
			<version>1.0.0-M6.1</version>
		</dependency>
		<dependency>
			<groupId>org.springframework.ai</groupId>
			<artifactId>spring-ai-core</artifactId>
			<version>1.0.0-M6</version>
		</dependency>
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-autoconfigure</artifactId>
			<version>1.0.0-M6.1</version>
		</dependency>

		<!-- 引入 Graph 核心依赖 -->
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-graph-core</artifactId>
			<version>1.0.0.2</version>
		</dependency>
		<dependency>
			<groupId>com.alibaba.cloud.ai</groupId>
			<artifactId>spring-ai-alibaba-starter-document-parser-tika</artifactId>
			<version>1.0.0.2</version>
		</dependency>

		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-test</artifactId>
			<scope>test</scope>
		</dependency>
	</dependencies>

	<build>
		<plugins>
			<plugin>
				<groupId>org.springframework.boot</groupId>
				<artifactId>spring-boot-maven-plugin</artifactId>
			</plugin>
		</plugins>
	</build>

</project>

定义节点 (Node)

创建工作流中的核心节点,包括两个文本分类节点和一个记录节点

分类

// 评价正负分类节点
QuestionClassifierNode feedbackClassifier = QuestionClassifierNode.builder()
      .chatClient(chatClient)
      .inputTextKey("input")
      .categories(List.of("positive feedback", "negative feedback"))
      .classificationInstructions(
          List.of("Try to understand the user's feeling when he/she is giving the feedback."))
      .build();
// 负面评价具体问题分类节点
QuestionClassifierNode specificQuestionClassifier = QuestionClassifierNode.builder()
      .chatClient(chatClient)
      .inputTextKey("input")
      .categories(List.of("after-sale service", "transportation", "product quality", "others"))
      .classificationInstructions(List.of(
          "What kind of service or help the customer is trying to get from us? " +
          "Classify the question based on your understanding."))
      .build();

记录节点 RecordingNode:


import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.Map;

public class RecordingNode implements NodeAction {

    private static final Logger logger = LoggerFactory.getLogger(RecordingNode.class);

    @Override
    public Map<String, Object> apply(OverAllState state) throws Exception {
        String feedback = (String) state.value("classifier_output").get();

        Map<String, Object> updatedState = new HashMap<>();
        if (feedback.contains("positive")) {
            logger.info("Received positive feedback: {}", feedback);
            updatedState.put("solution", "Praise, no action taken.");
        } else {
            logger.info("Received negative feedback: {}", feedback);
            updatedState.put("solution", feedback);
        }

        return updatedState;
    }

}

定义节点图StateGraph

StateGraph graph = new StateGraph("Consumer Service Workflow Demo", stateFactory)
		// 添加节点
         .addNode("feedback_classifier", node_async(feedbackClassifier))
         .addNode("specific_question_classifier", node_async(specificQuestionClassifier))
         .addNode("recorder", node_async(recordingNode))
         // 定义边(流程顺序)
         .addEdge(START, "feedback_classifier")  // 起始节点
         .addConditionalEdges("feedback_classifier",
                 edge_async(new CustomerServiceController.FeedbackQuestionDispatcher()),
                 Map.of("positive", "recorder", "negative", "specific_question_classifier"))
         .addConditionalEdges("specific_question_classifier",
                 edge_async(new CustomerServiceController.SpecificQuestionDispatcher()),
                 Map.of("after-sale", "recorder", "transportation", "recorder",
                         "quality", "recorder", "others", "recorder"))
         .addEdge("recorder", END);  // 结束节点
 System.out.println("\n");
 return graph;

完整代码:

import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.OverAllStateFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.node.QuestionClassifierNode;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

import java.util.List;
import java.util.Map;

import static com.alibaba.cloud.ai.graph.StateGraph.END;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;

@Configuration
public class WorkflowAutoconfiguration {

    @Bean
    public StateGraph workflowGraph(ChatModel chatModel) throws GraphStateException {
        ChatClient chatClient = ChatClient.builder(chatModel)
                .defaultAdvisors(new SimpleLoggerAdvisor()).build();

        RecordingNode recordingNode = new RecordingNode();

        // 评价正负分类节点
        QuestionClassifierNode feedbackClassifier = QuestionClassifierNode.builder()
                .chatClient(chatClient)
                .inputTextKey("input")
                .categories(List.of("positive feedback", "negative feedback"))
                .classificationInstructions(
                        List.of("Try to understand the user's feeling when he/she is giving the feedback."))
                .build();

        // 负面评价具体问题分类节点
        QuestionClassifierNode specificQuestionClassifier = QuestionClassifierNode.builder()
                .chatClient(chatClient)
                .inputTextKey("input")
                .categories(List.of("after-sale service", "transportation", "product quality", "others"))
                .classificationInstructions(List.of(
                        "What kind of service or help the customer is trying to get from us? " +
                                "Classify the question based on your understanding."))
                .build();

        // 定义一个 OverAllStateFactory,用于在每次执行工作流时创建初始的全局状态对象
        OverAllStateFactory stateFactory = () -> {
            OverAllState state = new OverAllState();
            state.registerKeyAndStrategy("input", new ReplaceStrategy());
            state.registerKeyAndStrategy("classifier_output", new ReplaceStrategy());
            state.registerKeyAndStrategy("solution", new ReplaceStrategy());
            return state;
        };

        StateGraph graph = new StateGraph("Consumer Service Workflow Demo", stateFactory)
                .addNode("feedback_classifier", node_async(feedbackClassifier))
                .addNode("specific_question_classifier", node_async(specificQuestionClassifier))
                .addNode("recorder", node_async(recordingNode))
                // 定义边(流程顺序)
                .addEdge(START, "feedback_classifier")  // 起始节点
                .addConditionalEdges("feedback_classifier",
                        edge_async(new CustomerServiceController.FeedbackQuestionDispatcher()),
                        Map.of("positive", "recorder", "negative", "specific_question_classifier"))
                .addConditionalEdges("specific_question_classifier",
                        edge_async(new CustomerServiceController.SpecificQuestionDispatcher()),
                        Map.of("after-sale", "recorder", "transportation", "recorder",
                                "quality", "recorder", "others", "recorder"))
                .addEdge("recorder", END);  // 结束节点
        System.out.println("\n");
        return graph;
    }

}

controller测试

  • CustomerServiceController 完整代码
import java.util.HashMap;
import java.util.Map;

import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/customer")
public class CustomerServiceController {

    private static final Logger logger = LoggerFactory.getLogger(CustomerServiceController.class);

    private CompiledGraph compiledGraph;

    public CustomerServiceController(@Qualifier("workflowGraph") StateGraph stateGraph) throws GraphStateException {
        this.compiledGraph = stateGraph.compile();
    }

    /**
     * localhost:8080/customer/chat?query=The product broke after one day, very disappointed.
     */
    @GetMapping("/chat")
    public String simpleChat(String query) throws Exception {
        logger.info("simpleChat: {}", query);
        return compiledGraph.invoke(Map.of("input", query))
                .get().value("solution")
                .get().toString();
    }

    public static class FeedbackQuestionDispatcher implements EdgeAction {

        @Override
        public String apply(OverAllState state) throws Exception {
            /**
             * 反馈的是商品的负面内容
             * 分类为:negative
             */
            String classifierOutput = (String) state.value("classifier_output").orElse("");
            logger.info("classifierOutput: {}", classifierOutput);

            if (classifierOutput.contains("positive")) {
                return "positive";
            }
            return "negative";
        }

    }

    public static class SpecificQuestionDispatcher implements EdgeAction {

        @Override
        public String apply(OverAllState state) throws Exception {
            /**
             * 反馈的是产品的质量
             * 分类为:quality
             */
            String classifierOutput = (String) state.value("classifier_output").orElse("");
            logger.info("classifierOutput: {}", classifierOutput);

            Map<String, String> classifierMap = new HashMap<>();
            classifierMap.put("after-sale", "after-sale");
            classifierMap.put("quality", "quality");
            classifierMap.put("transportation", "transportation");

            for (Map.Entry<String, String> entry : classifierMap.entrySet()) {
                if (classifierOutput.contains(entry.getKey())) {
                    return entry.getValue();
                }
            }

            return "others";
        }

    }

}
浏览器测试用户输入

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述


网站公告

今日签到

点亮在社区的每一天
去签到