Prompt工程学习之思维树(TOT)

发布于:2025-06-09 ⋅ 阅读:(19) ⋅ 点赞:(0)

思维树

定义思维树(Tree of Thoughts, ToT) 是一种先进的推理框架,它通过同时探索多条推理路径对思维链(Chain of Thought)** 进行了扩展。该技术将问题解决视为一个搜索过程 —— 模型生成不同的中间步骤,评估这些步骤的可行性,并探索最有希望的路径。

Tree of Thoughts (ToT) 是一种大语言模型推理框架,通过树状结构探索多条推理路径,允许模型自我评估路径可行性并回溯调整,模拟人类解决复杂问题时的 “试错 - 评估 - 选择” 过程。

目标:解决传统 LLMs 逐 Token 单向决策的局限,提升在需要探索、战略前瞻或多步规划任务(如数学推理、创意写作、谜题)中的表现。

ToT 框架核心机制

  • 核心思路:将问题解决视为树状搜索过程,通过生成 ** 连贯的中间思维单元(Thoughts)** 作为推理的中间步骤,而非单一 Token。
  • 关键能力:多路径探索:同时生成多条推理路径(如不同的解题思路)。
  • 自我评估:评估每条路径的可行性,选择最有希望的分支继续探索。
  • 回溯决策:必要时回溯到之前的思维节点,调整后续策略(类似人类解题的试错过程)。与 Chain of Thought(CoT)的区别:

与COT的对比

CoT 仅生成单一推理链,而 ToT 支持并行探索多条链,并通过评估机制实现全局最优决策。

24点案例

使用数字4、9、10和13以及四种基本运算符(+、-、/、*),生成一个结果为24的表达式。

step1
输入:4, 9, 10, 13  
可能的下一步操作:  
- 4 + 9 = 13(剩余:13, 10, 13- 10 - 4 = 6(剩余:6, 9, 13- 13 - 10 = 3(剩余:4, 9, 3- 9 × 4 = 36(剩余:36, 10, 13- 10 ÷ 4 = 2.5(剩余:2.5, 9, 13)

输入:4, 9, 10, 13  
请给出可能得下一步操作

输出:
4+9 = 13 (left: 13, 10, 13)
10-4 = 6 (left: 6, 9, 13)
13-9 = 4 (left: 4, 9, 10)
...
...

step2
计算是否可以得到24
10 14: 10+14 = 24 sure
10 7 2: 7*2+10 = 24 sure
11 11: 11 + 11 = 22 impossible
输入第一组结果,请给出可能得结果
13, 10, 13:

输出:
10 + 13 + 13 = 36 impossible
...
...

计算是否可以得到24
10 14: 10+14 = 24 sure
10 7 2: 7*2+10 = 24 sure
11 11: 11 + 11 = 22 impossible
输入第一组结果,请给出可能得结果
6, 9, 13:

输出:
6 *  (13-9) = 24 sure
...
...

自动化代码示例
生成思维结点,以树状形式组织;沿着思维结点进行探索,评估结果;根据评估结果选择下一步操作

package com.example.tot24;

import ai.spring.ai.client.ChatClient;
import ai.spring.ai.client.Generation;
import ai.spring.ai.client.Message;
import ai.spring.ai.client.chat.ChatResponse;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.CommandLineRunner;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.context.annotation.Bean;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;


public class Tot24Application {
  	// 思维树节点类
    static class TreeNode {
        private List<Double> numbers;
        private List<String> history;
        private List<TreeNode> children;
        private double score;
        private boolean terminal;
    }

    // 候选操作类
    static class CandidateOperation {
        private String operation;
        private List<Double> expectedNumbers;
        private String reason;
        private double score;
        private String explanation;
    }

    // 24点游戏求解器
    static class TwentyFourSolver {
        private static final double TARGET = 24.0;
        private static final double TOLERANCE = 1e-6;
        private static final int MAX_STEPS = 5;
        private static final int BEAM_WIDTH = 3;

        private final ChatClient chatClient;
        private final String modelName;
        private final String systemPrompt;

        public TwentyFourSolver(ChatClient chatClient, String modelName) {
            this.chatClient = chatClient;
            this.modelName = modelName;
            
            // 构建系统提示
            this.systemPrompt = """
                你是一个解决24点游戏的专家。给定4个1-13之间的数字,使用加、减、乘、除和括号,使最终计算结果为24。
                
                解决过程中,请遵循以下规则:
                1. 每个数字必须且只能使用一次
                2. 中间步骤的计算结果可以是分数
                3. 最终答案必须是精确的24
                
                当被要求生成下一步操作时,请提供JSON格式的候选操作列表(最多5个有希望的操作):
                [
                    {
                        "operation": "具体操作(如:4+5=9)",
                        "expected_numbers": [操作后的数字列表],
                        "reason": "选择该操作的理由"
                    },
                    ...
                ]
                
                当被要求评估状态时,请提供JSON格式的评分和解释:
                {
                    "score": 3,
                    "explanation": "理由..."
                }
                
                评分标准:
                - 1分:当前数字组合不可能得到24
                - 2分:可能得到24,但难度高
                - 3分:有合理可能性得到24
                - 4分:非常有希望得到24
                - 5分:已得到24
                """;
        }

        public Optional<String> solve(List<Integer> numbers) {
            List<Double> initialNumbers = numbers.stream()
                    .map(Double::valueOf)
                    .collect(Collectors.toList());
            
            TreeNode root = new TreeNode(initialNumbers, new ArrayList<>());
            Queue<TreeNode> queue = new LinkedList<>();
            queue.add(root);
            
            while (!queue.isEmpty()) {
                TreeNode currentNode = queue.poll();
                
                // 检查是否已解决
                if (currentNode.getNumbers().stream()
                        .anyMatch(n -> Math.abs(n - TARGET) < TOLERANCE)) {
                    return Optional.of(formatSolution(currentNode));
                }
                
                // 生成候选操作
                List<CandidateOperation> candidates = generateCandidates(currentNode);
                
                // 评估候选操作
                evaluateCandidates(currentNode, candidates);
                
                // 选择最有希望的操作
                List<CandidateOperation> topCandidates = candidates.stream()
                        .sorted(Comparator.comparingDouble(CandidateOperation::getScore).reversed())
                        .limit(BEAM_WIDTH)
                        .collect(Collectors.toList());
                
                // 创建子节点
                for (CandidateOperation candidate : topCandidates) {
                    TreeNode childNode = new TreeNode(
                            candidate.getExpectedNumbers(),
                            new ArrayList<>(currentNode.getHistory())
                    );
                    childNode.getHistory().add(candidate.getOperation());
                    childNode.setScore(candidate.getScore());
                    
                    currentNode.getChildren().add(childNode);
                    
                    // 如果分数足够高,继续探索
                    if (candidate.getScore() >= 3) {
                        queue.add(childNode);
                    }
                }
            }
            
            return Optional.empty(); // 无解
        }

        private List<CandidateOperation> generateCandidates(TreeNode node) {
            String userPrompt = String.format("""
                当前状态:
                数字:%s
                历史:%s
                
                请生成最多5个有希望的下一步操作。
                """, node.getNumbers(), node.getHistory());
            
            String response = callLLM(userPrompt);
            
            try {
                // 解析JSON响应
                List<CandidateOperation> candidates = new ArrayList<>();
                // 实际应用中需要使用真正的JSON解析库
                // 这里简化处理,实际代码应使用Jackson等库
                return candidates;
            } catch (Exception e) {
                System.err.println("解析候选操作失败: " + e.getMessage());
                System.err.println("LLM响应: " + response);
                return Collections.emptyList();
            }
        }

        private void evaluateCandidates(TreeNode node, List<CandidateOperation> candidates) {
            for (CandidateOperation candidate : candidates) {
                String userPrompt = String.format("""
                    当前状态:
                    数字:%s
                    历史:%s
                    
                    候选操作:
                    %s
                    操作后数字:%s
                    
                    请评分并解释。
                    """, 
                    node.getNumbers(), 
                    node.getHistory(),
                    candidate.getOperation(),
                    candidate.getExpectedNumbers());
                
                String response = callLLM(userPrompt);
                
                try {
                    // 解析JSON响应获取评分和解释
                    // 实际应用中需要使用真正的JSON解析库
                    // 这里简化处理
                    double score = 3.0; // 默认值
                    String explanation = "默认评估";
                    
                    candidate.setScore(score);
                    candidate.setExplanation(explanation);
                } catch (Exception e) {
                    System.err.println("解析评估结果失败: " + e.getMessage());
                    System.err.println("LLM响应: " + response);
                    candidate.setScore(2.0); // 保守评分
                }
            }
        }

        private String callLLM(String userPrompt) {
            Message systemMessage = new Message(systemPrompt, "system");
            Message userMessage = new Message(userPrompt, "user");
            
            ChatResponse response = chatClient.generate(
                    List.of(systemMessage, userMessage), 
                    modelName
            );
            
            Generation generation = response.getGenerations().get(0);
            return generation.getContent();
        }

        private String formatSolution(TreeNode node) {
            StringBuilder sb = new StringBuilder();
            for (String step : node.getHistory()) {
                sb.append(step).append("\n");
            }
            return sb.toString();
        }
    }
}

参考

1.TOT 24点,https://learnprompting.org/docs/advanced/decomposition/tree_of_thoughts?srsltid=AfmBOor-YZUZ9nUIH-HpTtxJhTH-MHeQ_aQ6xp6to3gEveLlkqyttWq4
2.TOT,https://arxiv.org/abs/2305.10601