工作流概念
工作流是以相对固化的模式来人为地拆解任务,将一个大任务拆解为包含多个分支的固化流程。工作流的优势是确定性强,模型作为流程中的一个节点起到的更多是一个分类决策、内容生成的职责,因此它更适合意图识别等类别属性强的应用场景。
参考文档:https://java2ai.com/docs/1.0.0.2/get-started/workflow/?spm=4347728f.7cee0e64.0.0.39076dd1jbppqZ
工作流程图
商品评价分类流程图:
如用户反馈
This product is excellent, I love it!
则输出:Praise, no action taken.
说明:很好,不需要改进措施The product broke after one day, very disappointed."
则输出:product quality
说明:有问题,产品质量问题
spring-boot 编码
附maven的pom.xml
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.4.6</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.example</groupId>
<artifactId>demo-spring-test</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>demo-spring-test</name>
<description>Demo project for Spring Boot</description>
<url/>
<licenses>
<license/>
</licenses>
<developers>
<developer/>
</developers>
<scm>
<connection/>
<developerConnection/>
<tag/>
<url/>
</scm>
<properties>
<java.version>17</java.version>
<spring-ai.version>1.0.0</spring-ai.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<!-- Spring AI Alibaba(通义大模型支持) -->
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-starter</artifactId>
<version>1.0.0-M6.1</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>1.0.0-M6</version>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-autoconfigure</artifactId>
<version>1.0.0-M6.1</version>
</dependency>
<!-- 引入 Graph 核心依赖 -->
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-graph-core</artifactId>
<version>1.0.0.2</version>
</dependency>
<dependency>
<groupId>com.alibaba.cloud.ai</groupId>
<artifactId>spring-ai-alibaba-starter-document-parser-tika</artifactId>
<version>1.0.0.2</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
定义节点 (Node)
创建工作流中的核心节点,包括两个文本分类节点和一个记录节点
分类
// 评价正负分类节点
QuestionClassifierNode feedbackClassifier = QuestionClassifierNode.builder()
.chatClient(chatClient)
.inputTextKey("input")
.categories(List.of("positive feedback", "negative feedback"))
.classificationInstructions(
List.of("Try to understand the user's feeling when he/she is giving the feedback."))
.build();
// 负面评价具体问题分类节点
QuestionClassifierNode specificQuestionClassifier = QuestionClassifierNode.builder()
.chatClient(chatClient)
.inputTextKey("input")
.categories(List.of("after-sale service", "transportation", "product quality", "others"))
.classificationInstructions(List.of(
"What kind of service or help the customer is trying to get from us? " +
"Classify the question based on your understanding."))
.build();
记录节点 RecordingNode:
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.action.NodeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.Map;
public class RecordingNode implements NodeAction {
private static final Logger logger = LoggerFactory.getLogger(RecordingNode.class);
@Override
public Map<String, Object> apply(OverAllState state) throws Exception {
String feedback = (String) state.value("classifier_output").get();
Map<String, Object> updatedState = new HashMap<>();
if (feedback.contains("positive")) {
logger.info("Received positive feedback: {}", feedback);
updatedState.put("solution", "Praise, no action taken.");
} else {
logger.info("Received negative feedback: {}", feedback);
updatedState.put("solution", feedback);
}
return updatedState;
}
}
定义节点图StateGraph
StateGraph graph = new StateGraph("Consumer Service Workflow Demo", stateFactory)
// 添加节点
.addNode("feedback_classifier", node_async(feedbackClassifier))
.addNode("specific_question_classifier", node_async(specificQuestionClassifier))
.addNode("recorder", node_async(recordingNode))
// 定义边(流程顺序)
.addEdge(START, "feedback_classifier") // 起始节点
.addConditionalEdges("feedback_classifier",
edge_async(new CustomerServiceController.FeedbackQuestionDispatcher()),
Map.of("positive", "recorder", "negative", "specific_question_classifier"))
.addConditionalEdges("specific_question_classifier",
edge_async(new CustomerServiceController.SpecificQuestionDispatcher()),
Map.of("after-sale", "recorder", "transportation", "recorder",
"quality", "recorder", "others", "recorder"))
.addEdge("recorder", END); // 结束节点
System.out.println("\n");
return graph;
完整代码:
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.OverAllStateFactory;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.node.QuestionClassifierNode;
import com.alibaba.cloud.ai.graph.state.strategy.ReplaceStrategy;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.List;
import java.util.Map;
import static com.alibaba.cloud.ai.graph.StateGraph.END;
import static com.alibaba.cloud.ai.graph.StateGraph.START;
import static com.alibaba.cloud.ai.graph.action.AsyncEdgeAction.edge_async;
import static com.alibaba.cloud.ai.graph.action.AsyncNodeAction.node_async;
@Configuration
public class WorkflowAutoconfiguration {
@Bean
public StateGraph workflowGraph(ChatModel chatModel) throws GraphStateException {
ChatClient chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(new SimpleLoggerAdvisor()).build();
RecordingNode recordingNode = new RecordingNode();
// 评价正负分类节点
QuestionClassifierNode feedbackClassifier = QuestionClassifierNode.builder()
.chatClient(chatClient)
.inputTextKey("input")
.categories(List.of("positive feedback", "negative feedback"))
.classificationInstructions(
List.of("Try to understand the user's feeling when he/she is giving the feedback."))
.build();
// 负面评价具体问题分类节点
QuestionClassifierNode specificQuestionClassifier = QuestionClassifierNode.builder()
.chatClient(chatClient)
.inputTextKey("input")
.categories(List.of("after-sale service", "transportation", "product quality", "others"))
.classificationInstructions(List.of(
"What kind of service or help the customer is trying to get from us? " +
"Classify the question based on your understanding."))
.build();
// 定义一个 OverAllStateFactory,用于在每次执行工作流时创建初始的全局状态对象
OverAllStateFactory stateFactory = () -> {
OverAllState state = new OverAllState();
state.registerKeyAndStrategy("input", new ReplaceStrategy());
state.registerKeyAndStrategy("classifier_output", new ReplaceStrategy());
state.registerKeyAndStrategy("solution", new ReplaceStrategy());
return state;
};
StateGraph graph = new StateGraph("Consumer Service Workflow Demo", stateFactory)
.addNode("feedback_classifier", node_async(feedbackClassifier))
.addNode("specific_question_classifier", node_async(specificQuestionClassifier))
.addNode("recorder", node_async(recordingNode))
// 定义边(流程顺序)
.addEdge(START, "feedback_classifier") // 起始节点
.addConditionalEdges("feedback_classifier",
edge_async(new CustomerServiceController.FeedbackQuestionDispatcher()),
Map.of("positive", "recorder", "negative", "specific_question_classifier"))
.addConditionalEdges("specific_question_classifier",
edge_async(new CustomerServiceController.SpecificQuestionDispatcher()),
Map.of("after-sale", "recorder", "transportation", "recorder",
"quality", "recorder", "others", "recorder"))
.addEdge("recorder", END); // 结束节点
System.out.println("\n");
return graph;
}
}
controller测试
- CustomerServiceController 完整代码
import java.util.HashMap;
import java.util.Map;
import com.alibaba.cloud.ai.graph.CompiledGraph;
import com.alibaba.cloud.ai.graph.exception.GraphStateException;
import com.alibaba.cloud.ai.graph.OverAllState;
import com.alibaba.cloud.ai.graph.StateGraph;
import com.alibaba.cloud.ai.graph.action.EdgeAction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
@RequestMapping("/customer")
public class CustomerServiceController {
private static final Logger logger = LoggerFactory.getLogger(CustomerServiceController.class);
private CompiledGraph compiledGraph;
public CustomerServiceController(@Qualifier("workflowGraph") StateGraph stateGraph) throws GraphStateException {
this.compiledGraph = stateGraph.compile();
}
/**
* localhost:8080/customer/chat?query=The product broke after one day, very disappointed.
*/
@GetMapping("/chat")
public String simpleChat(String query) throws Exception {
logger.info("simpleChat: {}", query);
return compiledGraph.invoke(Map.of("input", query))
.get().value("solution")
.get().toString();
}
public static class FeedbackQuestionDispatcher implements EdgeAction {
@Override
public String apply(OverAllState state) throws Exception {
/**
* 反馈的是商品的负面内容
* 分类为:negative
*/
String classifierOutput = (String) state.value("classifier_output").orElse("");
logger.info("classifierOutput: {}", classifierOutput);
if (classifierOutput.contains("positive")) {
return "positive";
}
return "negative";
}
}
public static class SpecificQuestionDispatcher implements EdgeAction {
@Override
public String apply(OverAllState state) throws Exception {
/**
* 反馈的是产品的质量
* 分类为:quality
*/
String classifierOutput = (String) state.value("classifier_output").orElse("");
logger.info("classifierOutput: {}", classifierOutput);
Map<String, String> classifierMap = new HashMap<>();
classifierMap.put("after-sale", "after-sale");
classifierMap.put("quality", "quality");
classifierMap.put("transportation", "transportation");
for (Map.Entry<String, String> entry : classifierMap.entrySet()) {
if (classifierOutput.contains(entry.getKey())) {
return entry.getValue();
}
}
return "others";
}
}
}