Linux防御式编程完全指南

发布于:2025-04-17 ⋅ 阅读:(29) ⋅ 点赞:(0)

Linux防御式编程完全指南

1. 基础防御原则

1.1 输入验证

// 字符串输入验证
int validate_string(const char* input, size_t max_len) {
    if (!input) {
        errno = EINVAL;
        return -1;
    }
    
    size_t len = strlen(input);
    if (len == 0 || len >= max_len) {
        errno = EINVAL;
        return -1;
    }
    
    // 检查是否包含非法字符
    for (size_t i = 0; i < len; i++) {
        if (!isprint(input[i])) {
            errno = EINVAL;
            return -1;
        }
    }
    
    return 0;
}

1.2 边界检查

// 数组边界检查宏
#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]))
#define CHECK_ARRAY_BOUNDS(index, arr) ((index) >= 0 && (index) < ARRAY_SIZE(arr))

// 使用示例
int safe_array_access(int arr[], size_t size, int index) {
    if (index < 0 || (size_t)index >= size) {
        errno = ERANGE;
        return -1;
    }
    return arr[index];
}

2. 内存管理防御

2.1 安全的内存分配

void* safe_malloc(size_t size) {
    if (size == 0) {
        errno = EINVAL;
        return NULL;
    }
    
    void* ptr = malloc(size);
    if (!ptr) {
        // 记录错误日志
        log_error("Failed to allocate %zu bytes", size);
        return NULL;
    }
    
    // 清零内存
    memset(ptr, 0, size);
    return ptr;
}

2.2 内存泄漏防护

#define DECLARE_CLEANUP(type) \
    void cleanup_##type(type** ptr); \
    typedef type* type##_cleanup __attribute__((cleanup(cleanup_##type)))

#define DEFINE_CLEANUP(type) \
    void cleanup_##type(type** ptr) { \
        if (ptr && *ptr) { \
            free(*ptr); \
            *ptr = NULL; \
        } \
    }

// 使用示例
DECLARE_CLEANUP(char);
DEFINE_CLEANUP(char);

void process_data(void) {
    char_cleanup buffer = safe_malloc(1024);
    if (!buffer) {
        return;
    }
    // 使用buffer...
} // buffer自动释放

3. 并发安全

3.1 互斥锁包装器

struct mutex_wrapper {
    pthread_mutex_t mutex;
    bool initialized;
};

int mutex_init(struct mutex_wrapper* wrapper) {
    if (!wrapper) {
        return -1;
    }
    
    int ret = pthread_mutex_init(&wrapper->mutex, NULL);
    if (ret == 0) {
        wrapper->initialized = true;
    }
    return ret;
}

int mutex_lock(struct mutex_wrapper* wrapper) {
    if (!wrapper || !wrapper->initialized) {
        return -1;
    }
    return pthread_mutex_lock(&wrapper->mutex);
}

int mutex_unlock(struct mutex_wrapper* wrapper) {
    if (!wrapper || !wrapper->initialized) {
        return -1;
    }
    return pthread_mutex_unlock(&wrapper->mutex);
}

3.2 条件变量安全使用

struct safe_condition {
    pthread_mutex_t mutex;
    pthread_cond_t cond;
    bool predicate;
};

int safe_condition_wait(struct safe_condition* sc) {
    int ret = pthread_mutex_lock(&sc->mutex);
    if (ret != 0) return ret;
    
    while (!sc->predicate) {
        ret = pthread_cond_wait(&sc->cond, &sc->mutex);
        if (ret != 0) {
            pthread_mutex_unlock(&sc->mutex);
            return ret;
        }
    }
    
    sc->predicate = false;  // 重置条件
    pthread_mutex_unlock(&sc->mutex);
    return 0;
}

4. 信号处理

4.1 信号处理框架

#include <signal.h>

static volatile sig_atomic_t running = 1;

static void signal_handler(int signo) {
    switch (signo) {
        case SIGTERM:
        case SIGINT:
            running = 0;
            break;
        case SIGPIPE:
            // 忽略SIGPIPE
            break;
    }
}

int setup_signal_handlers(void) {
    struct sigaction sa;
    memset(&sa, 0, sizeof(sa));
    sa.sa_handler = signal_handler;
    sigemptyset(&sa.sa_mask);
    
    if (sigaction(SIGTERM, &sa, NULL) == -1) return -1;
    if (sigaction(SIGINT, &sa, NULL) == -1) return -1;
    if (sigaction(SIGPIPE, &sa, NULL) == -1) return -1;
    
    return 0;
}

4.2 信号安全的数据结构

struct sig_safe_counter {
    sig_atomic_t value;
};

void increment_counter(struct sig_safe_counter* counter) {
    atomic_fetch_add(&counter->value, 1);
}

sig_atomic_t get_counter(struct sig_safe_counter* counter) {
    return atomic_load(&counter->value);
}

5. 文件操作防御

5.1 安全的文件打开

FILE* safe_fopen(const char* filename, const char* mode) {
    if (!filename || !mode) {
        errno = EINVAL;
        return NULL;
    }
    
    // 检查文件名长度
    if (strlen(filename) > PATH_MAX) {
        errno = ENAMETOOLONG;
        return NULL;
    }
    
    // 检查模式字符串
    if (strspn(mode, "rwa+b") != strlen(mode)) {
        errno = EINVAL;
        return NULL;
    }
    
    FILE* fp = fopen(filename, mode);
    if (!fp) {
        log_error("Failed to open file %s: %s", 
                  filename, strerror(errno));
        return NULL;
    }
    
    return fp;
}

5.2 安全的文件读写

ssize_t safe_read(int fd, void* buf, size_t count) {
    if (!buf || count == 0) {
        errno = EINVAL;
        return -1;
    }
    
    size_t total = 0;
    char* ptr = buf;
    
    while (total < count) {
        ssize_t n = read(fd, ptr + total, count - total);
        if (n < 0) {
            if (errno == EINTR)
                continue;
            return -1;
        }
        if (n == 0)  // EOF
            break;
        total += n;
    }
    
    return total;
}

6. 网络编程防御

6.1 套接字包装器

struct socket_wrapper {
    int fd;
    struct sockaddr_storage peer_addr;
    socklen_t peer_addr_len;
    int timeout_ms;
};

int socket_set_nonblocking(int sockfd) {
    int flags = fcntl(sockfd, F_GETFL, 0);
    if (flags == -1) return -1;
    return fcntl(sockfd, F_SETFL, flags | O_NONBLOCK);
}

int socket_connect_timeout(struct socket_wrapper* sock, 
                         const struct sockaddr* addr,
                         socklen_t addr_len,
                         int timeout_ms) {
    if (socket_set_nonblocking(sock->fd) == -1)
        return -1;
        
    int ret = connect(sock->fd, addr, addr_len);
    if (ret == 0)
        return 0;
        
    if (errno != EINPROGRESS)
        return -1;
        
    fd_set write_fds;
    struct timeval tv;
    
    FD_ZERO(&write_fds);
    FD_SET(sock->fd, &write_fds);
    tv.tv_sec = timeout_ms / 1000;
    tv.tv_usec = (timeout_ms % 1000) * 1000;
    
    ret = select(sock->fd + 1, NULL, &write_fds, NULL, &tv);
    if (ret <= 0)
        return -1;
        
    int error;
    socklen_t len = sizeof(error);
    if (getsockopt(sock->fd, SOL_SOCKET, SO_ERROR, &error, &len) == -1)
        return -1;
        
    return error;
}

6.2 安全的数据发送

ssize_t safe_send(int sockfd, const void* buf, size_t len, int flags) {
    size_t total = 0;
    const char* ptr = buf;
    
    while (total < len) {
        ssize_t sent = send(sockfd, ptr + total, 
                           len - total, flags);
        if (sent < 0) {
            if (errno == EINTR)
                continue;
            if (errno == EAGAIN || errno == EWOULDBLOCK) {
                // 处理非阻塞套接字
                fd_set write_fds;
                FD_ZERO(&write_fds);
                FD_SET(sockfd, &write_fds);
                
                struct timeval tv = {.tv_sec = 1, .tv_usec = 0};
                int ret = select(sockfd + 1, NULL, &write_fds, 
                               NULL, &tv);
                if (ret > 0)
                    continue;
            }
            return -1;
        }
        total += sent;
    }
    
    return total;
}

7. 错误处理

7.1 错误处理框架

enum error_level {
    ERROR_FATAL,
    ERROR_WARNING,
    ERROR_INFO
};

struct error_context {
    char message[256];
    int code;
    enum error_level level;
    const char* file;
    int line;
};

void handle_error(struct error_context* ctx) {
    char timestamp[32];
    time_t now = time(NULL);
    strftime(timestamp, sizeof(timestamp), 
             "%Y-%m-%d %H:%M:%S",
             localtime(&now));
             
    fprintf(stderr, "[%s] %s:%d - %s (code: %d)\n",
            timestamp, ctx->file, ctx->line,
            ctx->message, ctx->code);
            
    if (ctx->level == ERROR_FATAL) {
        abort();
    }
}

#define REPORT_ERROR(level, code, fmt, ...) do { \
    struct error_context ctx = { \
        .level = level, \
        .code = code, \
        .file = __FILE__, \
        .line = __LINE__ \
    }; \
    snprintf(ctx.message, sizeof(ctx.message), \
             fmt, ##__VA_ARGS__); \
    handle_error(&ctx); \
} while(0)

7.2 资源清理框架

struct cleanup_handler {
    void (*cleanup)(void*);
    void* data;
    struct cleanup_handler* next;
};

static __thread struct cleanup_handler* cleanup_stack = NULL;

void push_cleanup(void (*cleanup)(void*), void* data) {
    struct cleanup_handler* handler = malloc(sizeof(*handler));
    if (handler) {
        handler->cleanup = cleanup;
        handler->data = data;
        handler->next = cleanup_stack;
        cleanup_stack = handler;
    }
}

void pop_cleanup(int execute) {
    struct cleanup_handler* handler = cleanup_stack;
    if (handler) {
        cleanup_stack = handler->next;
        if (execute && handler->cleanup) {
            handler->cleanup(handler->data);
        }
        free(handler);
    }
}

#define CLEANUP_PUSH(fn, arg) push_cleanup(fn, arg)
#define CLEANUP_POP(execute) pop_cleanup(execute)

8. 配置管理

8.1 安全的配置解析

struct config_value {
    enum {
        CONFIG_STRING,
        CONFIG_INT,
        CONFIG_BOOL
    } type;
    union {
        char* str_val;
        int int_val;
        bool bool_val;
    } value;
};

struct config_entry {
    char key[64];
    struct config_value val;
    struct config_entry* next;
};

int config_set_string(struct config_entry* entry, 
                     const char* value) {
    if (!entry || !value) return -1;
    
    size_t len = strlen(value);
    entry->val.type = CONFIG_STRING;
    entry->val.value.str_val = malloc(len + 1);
    if (!entry->val.value.str_val) return -1;
    
    strncpy(entry->val.value.str_val, value, len);
    entry->val.value.str_val[len] = '\0';
    return 0;
}

8.2 配置验证

int validate_config(const struct config_entry* config) {
    if (!config) return -1;
    
    // 检查必需的配置项
    bool has_required = false;
    const struct config_entry* entry = config;
    
    while (entry) {
        if (strcmp(entry->key, "required_key") == 0) {
            has_required = true;
            break;
        }
        entry = entry->next;
    }
    
    if (!has_required) {
        log_error("Missing required configuration");
        return -1;
    }
    
    // 验证配置值
    entry = config;
    while (entry) {
        switch (entry->val.type) {
            case CONFIG_INT:
                if (entry->val.value.int_val < 0) {
                    log_error("Invalid value for %s", entry->key);
                    return -1;
                }
                break;
            case CONFIG_STRING:
                if (!entry->val.value.str_val || 
                    !*entry->val.value.str_val) {
                    log_error("Empty string for %s", entry->key);
                    return -1;
                }
                break;
            // 其他类型的验证...
        }
        entry = entry->next;
    }
    
    return 0;
}

9. 系统调用防御

9.1 系统调用包装器

ssize_t safe_write(int fd, const void* buf, size_t count) {
    ssize_t total = 0;
    const char* ptr = buf;
    
    while (total < count) {
        ssize_t written = write(fd, ptr + total, 
                              count - total);
        if (written < 0) {
            if (errno == EINTR)
                continue;
            return -1;
        }
        total += written;
    }
    
    return total;
}

pid_t safe_fork(void) {
    pid_t pid;
    while ((pid = fork()) == -1 && errno == EAGAIN) {
        // 资源暂时不可用,稍后重试
        usleep(100000);  // 100ms
    }
    return pid;
}

9.2 资源限制控制

int set_resource_limits(void) {
    struct rlimit rlim;
    
    // 设置最大文件描述符数
    rlim.rlim_cur = 1024;
    rlim.rlim_max = 2048;
    if (setrlimit(RLIMIT_NOFILE, &rlim) != 0) {
        return -1;
    }
    
    // 设置最大内存使用量
    rlim.rlim_cur = 1024 * 1024 * 1024;  // 1GB
    rlim.rlim_max = 2 * 1024 * 1024 * 1024;  // 2GB
    if (setrlimit(RLIMIT_AS, &rlim) != 0) {
        return -1;
    }
    
    return 0;
}

10. 进程间通信防御

10.1 共享内存安全访问

struct shm_context {
    int shmid;
    void* addr;
    size_t size;
    sem_t* sem;
};

int shm_init(struct shm_context* ctx, key_t key, 
             size_t size) {
    ctx->shmid = shmget(key, size, 
                        IPC_CREAT | IPC_EXCL | 0600);
    if (ctx->shmid == -1) {
        if (errno == EEXIST) {
            ctx->shmid = shmget(key, size, 0600);
        }
        if (ctx->shmid == -1) {
            return -1;
        }
    }
    
    ctx->addr = shmat(ctx->shmid, NULL, 0);
    if (ctx->addr == (void*)-1) {
        return -1;
    }
    
    ctx->size = size;
    ctx->sem = sem_open("/mysem", O_CREAT, 0600, 1);
    if (ctx->sem == SEM_FAILED) {
        shmdt(ctx->addr);
        return -1;
    }
    
    return 0;
}

10.2 消息队列安全使用

struct msg_queue {
    int mqid;
    long type;
};

int mq_send_safe(struct msg_queue* mq, const void* data, 
                 size_t size) {
    struct {
        long type;
        char data[8192];  // 最大消息大小
    } msg;
    
    if (size > sizeof(msg.data)) {
        errno = EMSGSIZE;
        return -1;
    }
    
    msg.type = mq->type;
    memcpy(msg.data, data, size);
    
    while (1) {
        if (msgsnd(mq->mqid, &msg, size, 0) == 0) {
            break;
        }
        if (errno != EINTR) {
            return -1;
        }
    }
    
    return 0;
}

总结

防御式编程是一种全面的编程思想,需要在代码的各个层面都保持警惕。主要包括:

  1. 输入验证和边界检查
  2. 内存管理和资源控制
  3. 并发安全和同步机制
  4. 信号处理和异常处理
  5. 文件和网络操作安全
  6. 错误处理和日志记录
  7. 配置管理和验证
  8. 系统调用防御
  9. 进程间通信安全
  10. 资源限制和清理

通过实施这些防御措施,我们可以:

  • 提高程序的健壮性和可靠性
  • 增强安全性
  • 改善可维护性
  • 减少运行时错误
  • 提供更好的用户体验

记住:永远不要对输入做任何假设,永远要考虑最坏的情况,这就是防御式编程的核心理念。