cs336之注意pytorch的tensor在哪里?(assert的使用)

发布于:2025-08-03 ⋅ 阅读:(13) ⋅ 点赞:(0)

问题

记住:无论何时你在pytorch中有一个张量tensor,你应该始终问一个问题:它当前位于哪里?
注意它在CPU还是在GPU中。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
要判断它在哪里,可以使用python的assert断言语句。

assert断言

在 Python 中,assert 是一个调试辅助工具,用于在代码中设置检查点。它的核心作用是验证某个条件是否为真,如果条件为假则立即抛出 AssertionError 异常。

基本语法

assert condition, message
  • condition:要测试的条件表达式(返回布尔值)
  • message(可选):断言失败时显示的错误信息(字符串)

工作原理

  1. conditionTrue 时,程序继续执行
  2. conditionFalse 时:
    • 立即抛出 AssertionError
    • 若有 message,则将其作为异常信息输出

注意: 在运行python代码时通过 python -Opython -OO 运行程序可全局禁用断言。这意味着在优化模式(__debug__为False)下,所有的assert语句都不会被执行。

IDE 的 “Debug 运行” 按钮 ≠ Python 的 Debug 模式

  • IDE 的 Debug 按钮:启动调试器(可设置断点、单步执行)
  • Python 的 Debug 模式:由 __debug__ 标志控制,决定 assert 是否生效
    在这里插入图片描述

示例代码

# 验证输入值非负
def calculate_square_root(x):
    assert x >= 0, "输入不能为负数"
    return x ** 0.5

print(calculate_square_root(4))  # 正常执行
print(calculate_square_root(-1)) # 触发 AssertionError

输出结果

2.0
Traceback (most recent call last):
  File "demo.py", line 6, in <module>
    print(calculate_square_root(-1))
  File "demo.py", line 2, in calculate_square_root
    assert x >= 0, "输入不能为负数"
AssertionError: 输入不能为负数

关键特性

  1. 调试工具:用于捕获程序中的逻辑错误
  2. 可禁用性
    • 通过 python -Opython -OO 运行程序可全局禁用断言
    • 禁用后所有 assert 语句会被解释器忽略
  3. 非错误处理机制
    • 不应替代常规的异常处理(如 try/except
    • 不能用于验证用户输入或外部数据

典型使用场景

  1. 验证函数参数有效性

    def process_data(data):
        assert isinstance(data, list), "需要列表类型输入"
        # 处理逻辑
    
  2. 检查中间状态

    def transform(values):
        result = complex_operation(values)
        assert len(result) == len(values), "数据长度不一致"
        return result
    
  3. 测试不变性条件

    class Account:
        def withdraw(self, amount):
            new_balance = self.balance - amount
            assert new_balance >= 0, "余额不足"
            self.balance = new_balance
    

注意事项

  1. 生产环境慎用

    • 断言可能被全局禁用,不可依赖其进行安全检查
    • 重要检查应使用常规异常
    # 生产环境推荐写法
    if x < 0:
        raise ValueError("输入不能为负数")
    
  2. 错误信息优化

    • 添加有意义的错误信息便于调试
    assert len(items) > 0, f"获得空列表,当前项目: {items}"
    
  3. 性能影响

    • 断言语句会增加执行开销
    • 在性能关键代码中避免过度使用

与异常处理的区别

特性 assert 常规异常 (try/except)
设计目的 调试期间捕获程序错误 处理预期可能发生的错误情况
生产环境行为 可被全局禁用 始终生效
适用场景 检查"不可能发生"的条件 验证用户输入/外部资源等
错误类型 固定抛出 AssertionError 可抛出任意异常类型

最佳实践:将 assert 视为代码中的即时文档和调试助手,而非生产环境的错误处理机制。

在 Python 中,assert 是一个调试辅助工具,用于在代码中设置检查点。它的核心作用是验证某个条件是否为真,如果条件为假则立即抛出 AssertionError 异常。


网站公告

今日签到

点亮在社区的每一天
去签到