wfp 域名拦截

发布于:2025-08-18 ⋅ 阅读:(18) ⋅ 点赞:(0)

接上篇:WFP DNS 域名解析_dns解析,压缩指针-CSDN博客

添加IP解析,并拦截对应的IP。

DNS 报文示例:

/* 
DNS响应报文示例(十六进制):
包含对www.baidu.com的A记录查询响应,其中baidu.com被压缩

整体结构(十六进制,按字节排列):
00 01   ; 事务ID(Transaction ID)
81 80   ; Flags(响应报文,成功,支持递归)
00 01   ; QDCOUNT(1个问题)
00 02   ; ANCOUNT(2个回答)
00 00   ; NSCOUNT(0个权威记录)
00 00   ; ARCOUNT(0个附加记录)

; 问题部分(Question Section)
03 77 77 77 05 62 61 69 64 75 03 63 6f 6d 00  ; QNAME: www.baidu.com
00 01   ; QTYPE: A记录(IPv4)
00 01   ; QCLASS: IN(互联网)

; 回答部分1(Answer Section 1)
03 77 77 77 C0 0C  ; NAME: www.baidu.com(压缩格式,C0 0C指向偏移0x0C)
00 01   ; TYPE: A记录
00 01   ; CLASS: IN
00 00 0E 10   ; TTL: 3600秒
00 04   ; RDLENGTH: 4字节
B4 65 32 F2   ; RDATA: 180.101.50.242(IPv4地址)

; 回答部分2(Answer Section 2)
C0 10  ; NAME: baidu.com(压缩格式,C0 10指向偏移0x10)
00 01   ; TYPE: A记录
00 01   ; CLASS: IN
00 00 0E 10   ; TTL: 3600秒
00 04   ; RDLENGTH: 4字节
B4 65 32 F3   ; RDATA: 180.101.50.243(另一个IPv4地址)
*/

/* 指针压缩解析说明:
1. 偏移量计算以报文起始位置为0x00

2. 问题部分QNAME解析:
   03 77 77 77 → 标签"www"(长度3)
   05 62 61 69 64 75 → 标签"baidu"(长度5)
   03 63 6f 6d → 标签"com"(长度3)
   00 → 终止符
   完整域名:www.baidu.com
   各标签在报文中的偏移:
   - "baidu.com"起始于0x08(03 www的结束位置)
   - "com"起始于0x0D(05 baidu的结束位置)

3. 回答部分1的NAME字段(03 77 77 77 C0 0C):
   - 03 77 77 77 → 标签"www"
   - C0 0C → 压缩指针(0x0C为偏移量)
   - 偏移0x0C指向"03 63 6f 6d 00"(即"com")
   - 完整域名:www + . + baidu.com(拼接后为www.baidu.com)

4. 回答部分2的NAME字段(C0 10):
   - C0 10 → 压缩指针(0x10为偏移量)
   - 偏移0x10指向问题部分的"05 62 61 69 64 75"(即"baidu")
   - 拼接后续的"com"后,完整域名:baidu.com
*/


/* 输出结果:
问题部分域名: www.baidu.com (长度: 15字节)
回答1域名: www.baidu.com (压缩后长度: 5字节)  // 比未压缩节省10字节
回答2域名: baidu.com (压缩后长度: 2字节)     // 比未压缩节省8字节
*/
    

//dns_parse.h
#ifndef _DNS_HEADER_H
#define  _DNS_HEADER_H


//typedef struct _DNS_HEADER {
//    WORD Xid;
//    WORD Flags;
//    BYTE RecursionDesired : 1;
//    BYTE Truncation : 1;
//    BYTE Authoritative : 1;
//    BYTE Opcode : 4;
//    BYTE IsResponse : 1;
//    BYTE ResponseCode : 4;
//    BYTE CheckingDisabled : 1;
//    BYTE AuthenticatedData : 1;
//    BYTE Reserved : 1;
//    BYTE RecursionAvailable : 1;
//    WORD QuestionCount;
//    WORD AnswerCount;
//    WORD NameServerCount;
//    WORD AdditionalCount;
//} DNS_HEADER, *PDNS_HEADER;


//struct DNS_HEADER {
//    uint16_t transID; //事务ID:DNS 报文的 ID 标识。
//                      //对于请求报文和其对应的应答报文,该字段的值是相同的,
//                      //通过它可以区分 DNS 应答报文是对哪个请求进行响应的。
//    uint16_t flags; //标志:DNS 报文中的标志字段
//    /*
//    第15位:QR(Response),查询请求/响应的标志信息。0为请求(query) 1为响应(response)。
//
//    第14-11位:Opcode, 操作码。0 表示标准查询;1 表示反向查询;2 表示服务器状态请求。
//
//    第10位:AA(Authoritative),授权应答,该字段在响应报文中有效。值为 1 时,表示名称服务器是权威服务器;值为 0 时,表示不是权威服务器。
//
//    第9位:TC(Truncated):表示是否被截断。值为 1 时,表示响应已超过 512 字节并已被截断(一个UDP报文为512字节),只返回前 512 个字节。
//
//    第8位:RD(Recursion Desired):期望递归。该字段能在一个查询中设置,并在响应中返回。
//                                   如果该位为 1,告诉名称服务器必须处理这个查询,这种方式被称为一个递归查询。
//                                   如果该位为 0,且被请求的名称服务器没有一个授权回答,
//                                                它将返回一个能解答该查询的其他名称服务器列表。这种方式被称为迭代查询。
//                                是否请求递归(这个比特位被请求设置,应答的时候使用的相同的值返回)。
//
//    第7位:RA(Recursion Available):可用递归。该字段只出现在响应报文中。当值为 1 时,表示DNS服务器支持递归查询。
//
//    第6-4位:Z:保留字段,在所有的请求和应答报文中,它的值必须为 0。
//
//    第3-0位:rcode(Reply code):返回码字段,表示响应的差错状态。
//                                   当值为 0 时,表示 没有错误;
//                                   当值为 1 时,表示 报文格式错误(Format error),服务器不能理解请求的报文;
//                                   当值为 2 时,表示 域名服务器失败(Server failure),因为服务器的原因导致没办法处理这个请求;
//                                   当值为 3 时,表示 名字错误(Name Error),只有对授权域名解析服务器有意义,指出解析的域名不存在;
//                                   当值为 4 时,表示 查询类型不支持(Not Implemented),即域名服务器不支持查询类型;
//                                   当值为 5 时,表示 拒绝(Refused),一般是服务器由于设置的策略拒绝给出应答,如服务器不希望对某些请求者给出应答。
//                                   当值为 6-15 时:保留值*/
//    
//    uint16_t questions; //问题计数:DNS查询请求的数目
//    uint16_t answerRRs; //回答资源记录数:DNS响应的数目
//    uint16_t authorityRRs; //权威名称服务器计数:权威名称服务器的数目
//    uint16_t additionalRRs; //附加资源记录数:额外的记录数目(权威名称服务器对应 IP 地址的数目)
//};
//
////DNS查询问题区域
//struct DNS_QUERIES {
//    string   qName;  //查询的域名,不定长(例子:www.baidu.com 需写作:3www5baidu3com0)
//    uint16_t qType;  //查询类型:查询的资源记录类型
//    uint16_t qClass; //查询类:指定信息的协议组
//};



// DNS头部结构定义
#pragma pack(push, 1)
typedef struct _DNS_HEADER {
    UINT16 transactionID;
    UINT16 flags;
    UINT16 questionCount;
    UINT16 answerCount;
    UINT16 authorityCount;
    UINT16 additionalCount;
} DNS_HEADER, *PDNS_HEADER;
#pragma pack(pop)



// DNS 协议常量
#define DNS_HEADER_SIZE         12  // DNS 头部固定12字节
#define DNS_TYPE_A              1   // IPv4 地址记录
#define DNS_TYPE_AAAA           28  // IPv6 地址记录
#define DNS_CLASS_IN            1   // Internet 类别
#define DNS_FLAG_RESPONSE       0x8000  // 响应标志位(header.flags 中)

// 解析结果结构体:存储域名和对应的IP列表
typedef struct _DNS_PARSE_RESULT {
    UNICODE_STRING  queryDomain;      // 查询的域名(如 "www.baidu.com")
    UINT32          ipV4Count;        // IPv4 地址数量
    UINT32          ipV6Count;        // IPv6 地址数量
    UINT32          ipV4Addresses[8]; // 解析出的IPv4地址(网络字节序)
    UINT8           ipV6Addresses[8][16]; // 解析出的IPv6地址(网络字节序)
} DNS_PARSE_RESULT, *PDNS_PARSE_RESULT;

int checkCPUendian();
UINT32 UTIL_ntohl(UINT32 n);
UINT16 UTIL_ntohs(UINT16 n);

BOOLEAN ParseDNetName(
    char *szEncodedStr, 
    USHORT *pusEncodedStrLen, 
    char *szDotStr, 
    USHORT nDotStrSize, 
    char *szPacketStartPos );

BOOLEAN ParseDnsName(
    const UINT8* packetData,
    SIZE_T packetLength,
    const UINT8* dnsData,           // 当前解析指针(相对偏移)
    SIZE_T dnsDataLength,           // 剩余长度
    CHAR* domainBuffer,             // 存储域名的缓冲区
    SIZE_T bufferSize,
    SIZE_T* parsedLength,           // 返回本次解析消耗的长度(含压缩指针转换)
    int recursionDepth              // 当前递归深度
);

BOOLEAN ParseDnsQuery(
    const UINT8* packetData,
    SIZE_T packetLength,
    size_t udpHeaderLen,
    CHAR* domainBuffer,
    SIZE_T bufferSize);

BOOLEAN ParseDnsResponse(
    const UINT8* packetData,
    SIZE_T packetLength,
    size_t udpHeaderLength,
    CHAR* domainBuffer,
    SIZE_T bufferSize,
    UINT32* ipAddresses,
    SIZE_T maxIpCount,
    SIZE_T* ipCount
);
#endif


//dns_parse.c
#include "stdinc.h"


int checkCPUendian()
{
    union {
        UINT32 i;
        unsigned char s[4];
    }c;

    c.i = 0x12345678;
    return (0x12 == c.s[0]);
}

UINT32 UTIL_ntohl(UINT32 n)
{
    // 若本机为大端,与网络字节序同,直接返回
    // 若本机为小端,网络数据转换成小端再返回
    return checkCPUendian() ? n : BigLittleSwap32(n);
}


UINT16 UTIL_ntohs(UINT16 n)
{
    // 若本机为大端,与网络字节序同,直接返回
    // 若本机为小端,网络数据转换成小端再返回
    return checkCPUendian() ? n : BigLittleSwap16(n);
}


// 递归限制计数,防止压缩指针循环
#define DNS_MAX_RECURSION 10

// 递归解析DNS名称
// 解析DNS报文中的域名
BOOLEAN ParseDnsName(
    const UINT8* packetData,
    SIZE_T packetLength,
    const UINT8* dnsData,           // 当前解析指针(相对偏移)
    SIZE_T dnsDataLength,           // 剩余长度
    CHAR* domainBuffer,             // 存储域名的缓冲区
    SIZE_T bufferSize,
    SIZE_T* parsedLength,           // 返回本次解析消耗的长度(含压缩指针转换)
    int recursionDepth              // 当前递归深度
)
{
    if (recursionDepth > DNS_MAX_RECURSION)
        return FALSE;

    SIZE_T domainNameLen = 0;
    SIZE_T localParsedLen = 0; // 本次解析消耗的长度(非递归转换字节)

    //const UINT8* packetEnd = packetData + packetLength;
    const UINT8* currentPtr = dnsData;

    while (dnsDataLength > 0)
    {
        UINT8 labelLen = *currentPtr;

        // 压缩指针处理 (最高两位bit为11)
        if ((labelLen & 0xC0) == 0xC0)
        {
            if (dnsDataLength < 2)
                return FALSE;

            // 计算指针偏移
            UINT16 offset = ((labelLen & 0x3F) << 8) | *(currentPtr + 1);
            if (offset >= packetLength)
                return FALSE;

            // 递归解析压缩指针指向的域名数据
            SIZE_T unused = 0;
            BOOLEAN ret = ParseDnsName(packetData, packetLength, packetData + offset, packetLength - offset,
                domainBuffer + domainNameLen, bufferSize - domainNameLen, &unused, recursionDepth + 1);

            if (!ret)
                return FALSE;

            // 压缩指针结束,当前只消耗2字节
            localParsedLen += 2;
            *parsedLength = localParsedLen;
            return TRUE;
        }
        else if (labelLen == 0)
        {
            // 根域名结束
            localParsedLen += 1;
            *parsedLength = localParsedLen;
            if (domainNameLen < bufferSize)
                domainBuffer[domainNameLen] = '\0';
            else if (bufferSize > 0)
                domainBuffer[bufferSize - 1] = '\0';
            return TRUE;
        }
        else
        {
            // 普通标签长度解析
            if (labelLen > 63 || labelLen + 1 > dnsDataLength)
                return FALSE;

            if (domainNameLen + labelLen + 1 >= bufferSize)  // +1 for '.' or '\0'
                return FALSE;

            // 添加分隔符
            if (domainNameLen > 0)
            {
                domainBuffer[domainNameLen++] = '.';
            }

            // 拷贝标签内容
            RtlCopyMemory(domainBuffer + domainNameLen, currentPtr + 1, labelLen);
            domainNameLen += labelLen;

            // 前移指针
            currentPtr += labelLen + 1;
            dnsDataLength -= labelLen + 1;
            localParsedLen += labelLen + 1;
        }
    }

    return FALSE;
}

// 解析UDP DNS查询
BOOLEAN ParseDnsQuery(
    const UINT8* packetData,
    SIZE_T packetLength,
    size_t udpHeaderLen,
    CHAR* domainBuffer,
    SIZE_T bufferSize)
{
    KdPrint((DPREFIX "开始解析UDP DNS查询,数据包长度: %lu, UDP头部长度: %lu\n",
        packetLength, udpHeaderLen));

    if (packetLength < udpHeaderLen + sizeof(DNS_HEADER))
    {
        KdPrint((DPREFIX "UDP DNS数据包长度不足\n"));
        return FALSE;
    }

    const UINT8* dnsStart = packetData + udpHeaderLen;
    SIZE_T dnsDataLength = packetLength - udpHeaderLen;

    PDNS_HEADER dnsHeader = (PDNS_HEADER)dnsStart;
    // 确保判断:响应报文的 flags 最高位为 1(0x8000),查询报文为 0
    UINT16 flags = htons(dnsHeader->flags);  // 网络字节序转换
    KdPrint((DPREFIX "DNS Flags: 0x%04X\n", flags));

    if ((flags & 0x8000) != 0)  // 响应报文,我们只处理查询报文
    {
        KdPrint((DPREFIX "忽略DNS响应报文\n"));
        return FALSE;
    }


    // 从DNS头部开始解析域名
    const UINT8* namePtr = dnsStart + sizeof(DNS_HEADER);
    SIZE_T nameDataLen = dnsDataLength - sizeof(DNS_HEADER);

    KdPrint((DPREFIX "开始解析DNS域名,域名数据偏移: %lu, 长度: %lu\n",
        namePtr - packetData, nameDataLen));

    SIZE_T parsedLen = 0;
    BOOLEAN ret = ParseDnsName(packetData, packetLength, namePtr, nameDataLen, domainBuffer, bufferSize, &parsedLen, 0);
    if (ret)
    {
        KdPrint((DPREFIX "成功解析UDP DNS域名: %s \n", domainBuffer));
    }
    else
    {
        KdPrint((DPREFIX "解析UDP DNS域名失败\n"));
    }
    return ret;
}


#define MAX_DOMAINNAME_LEN 255
#define DNS_PORT   53
#define DNS_TYPE_SIZE  2
#define DNS_CLASS_SIZE  2
#define DNS_TTL_SIZE  4
#define DNS_DATALEN_SIZE 2
#define DNS_TYPE_A   0x0001 //1 a host address
#define DNS_TYPE_CNAME  0x0005 //5 the canonical name for an alias
#define DNS_PACKET_MAX_SIZE (sizeof(DNSHeader) + MAX_DOMAINNAME_LEN + DNS_TYPE_SIZE + DNS_CLASS_SIZE)


BOOLEAN ParseDnsResponse(
    const UINT8* packetData,
    SIZE_T packetLength,
    size_t udpHeaderLength,
    CHAR* domainBuffer,
    SIZE_T bufferSize,
    UINT32* ipAddresses,
    SIZE_T maxIpCount,
    SIZE_T* ipCount
)
{
    if (packetLength < udpHeaderLength + sizeof(DNS_HEADER) ||
        domainBuffer == NULL || ipAddresses == NULL || ipCount == NULL) {
        KdPrint((DPREFIX "ParseDnsResponse: 无效的输入数据包\n"));
        return FALSE;
    }

    *ipCount = 0;
    const UINT8* dnsStart = packetData + udpHeaderLength;
    SIZE_T dnsDataLength = packetLength - udpHeaderLength;
    PDNS_HEADER dnsHeader = (PDNS_HEADER)dnsStart;

    UINT16 flags = htons(dnsHeader->flags);
    UINT16 qdCount = htons(dnsHeader->questionCount);
    UINT16 anCount = htons(dnsHeader->answerCount);

    // 检查是否为响应包(QR=1)//RFC1035 4.1.1(Header section format)
    if ((flags & 0xfb7f) != 0x8100) {
        KdPrint((DPREFIX "ParseDnsResponse: 不是响应包\n"));
        return FALSE;
    }

    if (qdCount == 0 || anCount == 0) {
        KdPrint((DPREFIX "ParseDnsResponse: 无问题或回答记录\n"));
        return FALSE;
    }


    UINT8 rcode = flags & 0x000F;

    if (rcode != 0)
    {
        KdPrint((DPREFIX "RCODE值 失败\n"));
        return FALSE;
    }

    if ((flags & 0x0200) != 0)
    {
        KdPrint((DPREFIX "DNS 报文被截断\n"));
        return FALSE;
    }

    // 定位到问题部分,跳过DNS头

    USHORT nEncodedNameLen = 0;
    char szDotName[128] = { '\0' };

    char *pDNSData = dnsStart + sizeof(DNS_HEADER);

    //解析Question字段
    for (int q = 0; q != qdCount; ++q)
    {
        if (!ParseDNetName(pDNSData, &nEncodedNameLen, szDotName, sizeof(szDotName), dnsStart))
        {
            KdPrint((DPREFIX "Question ParseDNetName: 解析错误\n"));
            return FALSE;
        }
        pDNSData += (nEncodedNameLen + DNS_TYPE_SIZE + DNS_CLASS_SIZE);
    }

    //解析Answer字段
    for (int a = 0; a != anCount; ++a)
    {
        if (!ParseDNetName(pDNSData, &nEncodedNameLen, szDotName, sizeof(szDotName), dnsStart))
        {
            KdPrint((DPREFIX "Answer ParseDNetName: 解析错误\n"));
            return FALSE;
        }
        pDNSData += nEncodedNameLen;

        USHORT usAnswerType = htons(*(USHORT*)(pDNSData));
        USHORT usAnswerClass = htons(*(USHORT*)(pDNSData + DNS_TYPE_SIZE));
        ULONG usAnswerTTL = htonl(*(ULONG*)(pDNSData + DNS_TYPE_SIZE + DNS_CLASS_SIZE));
        USHORT usAnswerDataLen = htons(*(USHORT*)(pDNSData + DNS_TYPE_SIZE + DNS_CLASS_SIZE + DNS_TTL_SIZE));
        pDNSData += (DNS_TYPE_SIZE + DNS_CLASS_SIZE + DNS_TTL_SIZE + DNS_DATALEN_SIZE);

        if (usAnswerType == DNS_TYPE_A )
        {
            UINT32 ip = *(UINT32*)(pDNSData);

            ipAddresses[*ipCount] = ip;
            (*ipCount)++;
            KdPrint((DPREFIX "ParseDnsResponse, 提取IP成功: %d.%d.%d.%d\n",                
                 ip & 0xFF, (ip >> 8) & 0xFF, (ip >> 16) & 0xFF, (ip >> 24) & 0xFF));
        }
        else if (usAnswerType == DNS_TYPE_CNAME)
        {
            if (!ParseDNetName(pDNSData, &nEncodedNameLen, szDotName, sizeof(szDotName), dnsStart))
            {
                return FALSE;
            }
            KdPrint((DPREFIX "ParseDNetName DNS_TYPE_CNAME: %s\n", szDotName));
        }

        pDNSData += (usAnswerDataLen);
    }
 
    return *ipCount > 0;
}

/*
 * convert "\x03www\x05baidu\x03com\x00" to "www.baidu.com"
 * 0x0000 03 77 77 77 05 62 61 69 64 75 03 63 6f 6d 00 ff
 * convert "\x03www\x05baidu\xc0\x13" to "www.baidu.com"
 * 0x0000 03 77 77 77 05 62 61 69 64 75 c0 13 ff ff ff ff
 * 0x0010 ff ff ff 03 63 6f 6d 00 ff ff ff ff ff ff ff ff
 */


// 增加递归深度限制参数,防止循环跳转
#define MAX_RECURSION_DEPTH 5

BOOLEAN ParseDNetNameEx(
    char* szEncodedStr,
    USHORT* pusEncodedStrLen,
    char* szDotStr,
    USHORT nDotStrSize,
    char* szPacketStartPos,
    UCHAR recursionDepth)  // 新增:递归深度计数器
{
    __try
    {
        // 入参合法性检查
        if (szEncodedStr == NULL || pusEncodedStrLen == NULL || szDotStr == NULL )
        {
            KdPrint(("ParseDNetNameEx: 无效入参\n"));
            return FALSE;
        }

        // 初始化输出长度
        *pusEncodedStrLen = 0;
        UCHAR* pDecodePos = szEncodedStr;
        USHORT usPlainStrLen = 0;
        BYTE nLabelDataLen = 0;

        // 递归深度检查(防止循环跳转导致栈溢出)
        if (recursionDepth > MAX_RECURSION_DEPTH)
        {
            KdPrint(("ParseDNetNameEx: 递归深度超限,可能存在循环指针\n"));
            return FALSE;
        }

        while (TRUE)
        {
            // 检查当前指针有效性及可访问性
            if (!MmIsAddressValid(pDecodePos))
            {
                KdPrint(("ParseDNetNameEx: 指针无效\n"));
                return FALSE;
            }

            nLabelDataLen = *pDecodePos;

            // 终止符:0x00表示域名结束
            if (nLabelDataLen == 0x00)
            {
                // 处理空域名(仅0x00的情况)
                if (usPlainStrLen == 0)
                {
                    if (nDotStrSize < 1)
                        return FALSE;
                    szDotStr[0] = '\0';  // 根域名表示为""
                }
                else
                {
                    // 将最后一个'.'替换为终止符
                    szDotStr[usPlainStrLen - 1] = '\0';
                }
                *pusEncodedStrLen += 1;  // 包含终止符的1字节
                return TRUE;
            }

            // 处理压缩格式(前2位为11)
            if ((nLabelDataLen & 0xC0) == 0xC0)
            {
                // 压缩指针需要2字节,检查边界
                if (!MmIsAddressValid(pDecodePos + 1))
                {
                    KdPrint(("ParseDNetNameEx: 压缩指针边界无效\n"));
                    return FALSE;
                }

                // 计算跳转偏移量(网络字节序,直接取14位,无需htons转换)
                USHORT usJumpPos = (*(USHORT*)pDecodePos) & 0x3FFF;
                // 检查偏移量是否超出报文范围(假设报文起始指针有效)
                if (usJumpPos >= (ULONG_PTR)(pDecodePos - szPacketStartPos + 0x10000))
                {
                    KdPrint(("ParseDNetNameEx: 无效跳转偏移量 %d\n", usJumpPos));
                    return FALSE;
                }

                // 递归解析跳转后的域名
                USHORT nSubEncodeLen = 0;
                BOOLEAN bRet = ParseDNetNameEx(
                    szPacketStartPos + usJumpPos,
                    &nSubEncodeLen,
                    szDotStr + usPlainStrLen,
                    nDotStrSize - usPlainStrLen,
                    szPacketStartPos,
                    recursionDepth + 1  // 递归深度+1
                );

                // 压缩指针本身占2字节
                *pusEncodedStrLen += 2;
                return bRet;
            }

            // 处理普通标签格式(长度+内容)
            // 检查标签长度合法性(0 < 长度 ≤ 63,DNS标准限制)
            if (nLabelDataLen == 0 || nLabelDataLen > 63)
            {
                KdPrint(("ParseDNetNameEx: 无效标签长度 %d\n", nLabelDataLen));
                return FALSE;
            }

            // 检查标签内容是否可访问(长度字节+内容字节)
            if (!MmIsAddressValid(pDecodePos + 1 + nLabelDataLen - 1))
            {
                KdPrint(("ParseDNetNameEx: 标签内容越界\n"));
                return FALSE;
            }

            // 检查输出缓冲区是否足够(当前内容+标签+'.')
            if (usPlainStrLen + nLabelDataLen + 1 > nDotStrSize)
            {
                KdPrint(("ParseDNetNameEx: 输出缓冲区不足\n"));
                return FALSE;
            }

            // 复制标签内容
            memcpy(szDotStr + usPlainStrLen, pDecodePos + 1, nLabelDataLen);
            // 添加标签分隔符'.'
            szDotStr[usPlainStrLen + nLabelDataLen] = '.';

            // 更新指针和长度
            usPlainStrLen += nLabelDataLen + 1;  // 标签内容+'.'
            *pusEncodedStrLen += nLabelDataLen + 1;  // 长度字节+标签内容
            pDecodePos += nLabelDataLen + 1;
        }
    }
    __except (EXCEPTION_EXECUTE_HANDLER)
    {
        KdPrint(("ParseDNetNameEx: 异常捕获,错误码 %x\n", GetExceptionCode()));
        return FALSE;  // 异常时返回失败,避免调用者使用无效数据
    }
}

// 对外接口(封装递归深度参数)
BOOLEAN ParseDNetName(
    char* szEncodedStr,
    USHORT* pusEncodedStrLen,
    char* szDotStr,
    USHORT nDotStrSize,
    char* szPacketStartPos)
{
    // 初始递归深度为0
    return ParseDNetNameEx(szEncodedStr, pusEncodedStrLen, szDotStr, nDotStrSize, szPacketStartPos, 0);
}

//callouts.c

#include "stdinc.h"
#include "callouts.h"

#define FLOW_ESTABLISHED_CALLOUT_DESCRIPTION L"Ethan Flow Established Callout"
#define FLOW_ESTABLISHED_CALLOUT_NAME L"Flow Established Callout"

#define STREAM_CALLOUT_DESCRIPTION L"Ethan Stream Callout"
#define STREAM_CALLOUT_NAME L"Stream Callout"

#define SUBLAYER_NAME L"Ethan Sublayer"
#define RECV_SUBLAYER_NAME L"Ethan Recv Sublayer"

#define PROVIDER_NAME L"Ethan Provider"

#define DNS_PORT 53
#define DNS_MAX_RECURSION 10

enum CALLOUT_GUIDS
{
    CG_DATAGRAM_DATA_CALLOUT_V4,
    CG_DATAGRAM_DATA_CALLOUT_V6,
    CG_ALE_AUTH_CONNECT_V4,
    CG_MAX
};

static GUID		g_sublayerGuid;
static GUID		g_recvSublayerGuid;
static GUID		g_recvProtSublayerGuid;
static GUID		g_ipSublayerGuid;
static GUID		g_calloutGuids[CG_MAX];
static UINT32	g_calloutIds[CG_MAX];
static HANDLE	g_engineHandle = NULL;
static GUID		g_providerGuid;
static BOOLEAN	g_blockRST = TRUE;
static BOOLEAN	g_blockUnexpectedRecvDisconnects = TRUE;
static BOOLEAN	g_fastRecvDisconnectOnFinWithData = TRUE;
static BOOLEAN	g_bypassConnectRedirectWithoutActionWrite = TRUE;

static BOOLEAN	g_initialized = FALSE;

static UINT32	g_BlackIps[32];

VOID callouts_udpCallout(
    IN const FWPS_INCOMING_VALUES* inFixedValues,
    IN const FWPS_INCOMING_METADATA_VALUES* inMetaValues,
    IN VOID* packet,
    IN const void* classifyContext,
    IN const FWPS_FILTER* filter,
    IN UINT64 flowContext,
    OUT FWPS_CLASSIFY_OUT* classifyOut);

NTSTATUS callouts_udpNotify(
    IN  FWPS_CALLOUT_NOTIFY_TYPE         notifyType,
    IN  const GUID*             filterKey,
    IN  const FWPS_FILTER*     filter);

VOID callouts_aleConnectCallout(
    IN const FWPS_INCOMING_VALUES* inFixedValues,
    IN const FWPS_INCOMING_METADATA_VALUES* inMetaValues,
    IN VOID* layerData,
    IN const void* classifyContext,
    IN const FWPS_FILTER* filter,
    IN UINT64 flowContext,
    OUT FWPS_CLASSIFY_OUT* classifyOut);

NTSTATUS callouts_aleConnectNotify(
    IN  FWPS_CALLOUT_NOTIFY_TYPE     notifyType,
    IN  const GUID*                 filterKey,
    IN  const FWPS_FILTER*           filter);

struct NF_CALLOUT
{
    FWPS_CALLOUT_CLASSIFY_FN classifyFunction;
    FWPS_CALLOUT_NOTIFY_FN notifyFunction;
    FWPS_CALLOUT_FLOW_DELETE_NOTIFY_FN flowDeleteFunction;
    GUID const* calloutKey;
    UINT32 flags;
    UINT32* calloutId;
} g_callouts[] = {
    {
        (FWPS_CALLOUT_CLASSIFY_FN)callouts_udpCallout,
        (FWPS_CALLOUT_NOTIFY_FN)callouts_udpNotify,
        NULL,
        &g_calloutGuids[CG_DATAGRAM_DATA_CALLOUT_V4],
        0, // No flags
        &g_calloutIds[CG_DATAGRAM_DATA_CALLOUT_V4]
    },
    {
        (FWPS_CALLOUT_CLASSIFY_FN)callouts_udpCallout,
        (FWPS_CALLOUT_NOTIFY_FN)callouts_udpNotify,
        NULL,
        &g_calloutGuids[CG_DATAGRAM_DATA_CALLOUT_V6],
        0, // No flags
        &g_calloutIds[CG_DATAGRAM_DATA_CALLOUT_V6]
    },
    {
        (FWPS_CALLOUT_CLASSIFY_FN)callouts_aleConnectCallout,
        (FWPS_CALLOUT_NOTIFY_FN)callouts_aleConnectNotify, // 添加notify函数
        NULL,
        &g_calloutGuids[CG_ALE_AUTH_CONNECT_V4],
        0,
        &g_calloutIds[CG_ALE_AUTH_CONNECT_V4]
    }
};

static BOOLEAN callouts_copyBuffer(const FWPS_INCOMING_METADATA_VALUES* inMetaValues, NET_BUFFER* netBuffer, BOOLEAN isSend, PVOID pPacket, ULONG dataLength)
{
    void * buf = NULL;
    BOOLEAN result = TRUE;

    if (!isSend)
    {
        NdisRetreatNetBufferDataStart(
            netBuffer,
            inMetaValues->transportHeaderSize,
            0,
            NULL
        );
    }

    buf = NdisGetDataBuffer(
        netBuffer,
        dataLength,
        pPacket,
        1,
        0);

    if (buf != NULL)
    {
        if (buf != (pPacket))
        {
            memcpy(pPacket, buf, dataLength);
        }
    }
    else
    {
        result = FALSE;
    }

    if (!isSend)
    {
        NdisAdvanceNetBufferDataStart(
            netBuffer,
            inMetaValues->transportHeaderSize,
            FALSE,
            NULL
        );
    }

    return result;
}

VOID callouts_udpCallout(
    IN const FWPS_INCOMING_VALUES* inFixedValues,
    IN const FWPS_INCOMING_METADATA_VALUES* inMetaValues,
    IN VOID* layerData,
    IN const void* classifyContext,
    IN const FWPS_FILTER* filter,
    IN UINT64 flowContext,
    OUT FWPS_CLASSIFY_OUT* classifyOut)
{
    UNREFERENCED_PARAMETER(classifyContext);
    UNREFERENCED_PARAMETER(filter);
    UNREFERENCED_PARAMETER(flowContext);

    if (!classifyOut) return;
    classifyOut->actionType = FWP_ACTION_PERMIT;

    if (layerData == NULL) {
        //KdPrint((DPREFIX "UDP Callout: layerData is NULL  \n"));
            return;
    }

    // 仅处理UDP数据报层
    if (inFixedValues->layerId != FWPS_LAYER_DATAGRAM_DATA_V4 &&
        inFixedValues->layerId != FWPS_LAYER_DATAGRAM_DATA_V6) {
       // KdPrint((DPREFIX "UDP Callout: INVALID LAYER ID: %d \n ", inFixedValues->layerId));
            return;
    }

    // 获取端口和方向(注意:端口为网络字节序,需要转换)
    UINT16 remotePort = 0, localPort = 0;
    UINT8 direction = 0;

    if (inFixedValues->layerId == FWPS_LAYER_DATAGRAM_DATA_V4) {
        remotePort = inFixedValues->incomingValue[FWPS_FIELD_DATAGRAM_DATA_V4_IP_REMOTE_PORT].value.uint16;
        localPort = inFixedValues->incomingValue[FWPS_FIELD_DATAGRAM_DATA_V4_IP_LOCAL_PORT].value.uint16;
        direction = inFixedValues->incomingValue[FWPS_FIELD_DATAGRAM_DATA_V4_DIRECTION].value.uint8;
    }
    else {
        remotePort = inFixedValues->incomingValue[FWPS_FIELD_DATAGRAM_DATA_V6_IP_REMOTE_PORT].value.uint16;
        localPort = inFixedValues->incomingValue[FWPS_FIELD_DATAGRAM_DATA_V6_IP_LOCAL_PORT].value.uint16;
        direction = inFixedValues->incomingValue[FWPS_FIELD_DATAGRAM_DATA_V6_DIRECTION].value.uint8;
    }



    // 仅处理DNS流量(任一端口为53)
    if (remotePort != DNS_PORT && localPort != DNS_PORT) {
        return;
    }

    KdPrint((DPREFIX "UDP Callout: DNS流量 - 方向: %s, 远程端口: %u, 本地端口: %u \n ",
        direction == FWP_DIRECTION_OUTBOUND ? "出站" : "入站", remotePort, localPort));

    // 获取NET_BUFFER
    PNET_BUFFER_LIST netBufferList = (PNET_BUFFER_LIST)layerData;
    PNET_BUFFER pNetBuffer = NET_BUFFER_LIST_FIRST_NB(netBufferList);
    if (pNetBuffer == NULL) {
        KdPrint((DPREFIX "UDP Callout: 无NET_BUFFER \n "));
            return;
    }

    // 1. 获取完整数据长度(包含UDP头部+DNS数据)
    ULONG totalDataLength = NET_BUFFER_DATA_LENGTH(pNetBuffer);
    if (totalDataLength == 0) {
        KdPrint((DPREFIX "UDP Callout: 数据长度为0 \n"));
            return;
    }

    // 2. 固定UDP头部长度为8(更可靠),但如果transportHeaderSize非0且合理,可使用它
    ULONG udpHeaderLength = 8;
    if (inMetaValues && inMetaValues->transportHeaderSize != 0 && inMetaValues->transportHeaderSize <= totalDataLength) {
        udpHeaderLength = inMetaValues->transportHeaderSize;
    }

    if (udpHeaderLength > totalDataLength) {
        KdPrint((DPREFIX "UDP Callout: UDP头长度(%lu) 大于总长度(%lu) \n ", udpHeaderLength, totalDataLength));
            return;
    }

    ULONG dnsDataLength = totalDataLength - udpHeaderLength;
    if (dnsDataLength == 0) {
        KdPrint((DPREFIX "UDP Callout: 无DNS数据(总长:%lu, 头:%lu)\n", totalDataLength, udpHeaderLength));
            return;
    }

    // 3. 申请缓冲区并尝试使用 NdisGetDataBuffer 获取连续内存
    PVOID localBuf = ExAllocatePoolWithTag(NonPagedPool, totalDataLength, 'dnsT');
    if (localBuf == NULL) {
        KdPrint((DPREFIX "UDP Callout: 内存分配失败 \n"));
            return;
    }

    BOOLEAN isSend = direction == FWP_DIRECTION_OUTBOUND;

    if (!callouts_copyBuffer(inMetaValues, pNetBuffer, isSend, localBuf, totalDataLength))
    {
        KdPrint((DPREFIX "callouts_copyBuffer: 内存分配失败 \n"));
        return;
    }


    CHAR domain[256] = { 0 };

    if (direction == FWP_DIRECTION_OUTBOUND) {
        if (ParseDnsQuery((const UINT8*)localBuf, totalDataLength, udpHeaderLength, domain, sizeof(domain))) {
            KdPrint((DPREFIX "UDP Callout: 解析到查询域名: %s \n", domain));
                // 这里可以添加阻断逻辑
        }
        else {
            KdPrint((DPREFIX "UDP Callout: 解析查询失败 \n"));
        }
    }
    else {
        UINT32 ipAddresses[10] = { 0 };
        SIZE_T ipCount = 0;
        if (ParseDnsResponse((const UINT8*)localBuf, totalDataLength, udpHeaderLength, domain, sizeof(domain), ipAddresses, ARRAYSIZE(ipAddresses), &ipCount)) {
            KdPrint((DPREFIX "UDP Callout: 响应域名: %s, IPs=%llu \n", domain, (unsigned long long)ipCount));
                

            if (strstr(domain, "baidu") != NULL)
            {
                for (SIZE_T i = 0; i < ipCount; i++)
                {
                    g_BlackIps[i] = ipAddresses[i];
                }
            }
        }
        else {
            KdPrint((DPREFIX "UDP Callout: 解析响应失败 \n"));
        }
    }

    ExFreePoolWithTag(localBuf, 'dnsT');
}




NTSTATUS callouts_udpNotify(
    IN  FWPS_CALLOUT_NOTIFY_TYPE         notifyType,
    IN  const GUID*             filterKey,
    IN  const FWPS_FILTER*     filter)
{
    UNREFERENCED_PARAMETER(notifyType);
    UNREFERENCED_PARAMETER(filterKey);
    UNREFERENCED_PARAMETER(filter);

    switch (notifyType)
    {
    case FWPS_CALLOUT_NOTIFY_ADD_FILTER:
        KdPrint((DPREFIX"Filter Added to UDP layer.\n"));
        break;
    case FWPS_CALLOUT_NOTIFY_DELETE_FILTER:
        KdPrint((DPREFIX"Filter Deleted from UDP layer.\n"));
        break;
    }
    return STATUS_SUCCESS;
}

NTSTATUS callouts_addAleConnectFilter(const GUID * calloutKey, const GUID * layer, FWPM_SUBLAYER * subLayer)
{
    FWPM_CALLOUT callout;
    FWPM_DISPLAY_DATA displayData;
    FWPM_FILTER filter;
    NTSTATUS status;

    for (;;)
    {
        RtlZeroMemory(&callout, sizeof(FWPM_CALLOUT));
        displayData.description = L"Ethan ALE Connect Callout";
        displayData.name = L"ALE Connect Callout";

        callout.calloutKey = *calloutKey;
        callout.displayData = displayData;
        callout.applicableLayer = *layer;
        callout.flags = 0;

        status = FwpmCalloutAdd(g_engineHandle, &callout, NULL, NULL);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"FwpmCalloutAdd for ALE failed, status=%x\n", status));
            break;
        }

        RtlZeroMemory(&filter, sizeof(FWPM_FILTER));

        filter.layerKey = *layer;
        filter.displayData.name = L"ALE Connect Filter";
        filter.displayData.description = L"ALE Connect Filter";
        filter.action.type = FWP_ACTION_CALLOUT_TERMINATING;
        filter.action.calloutKey = *calloutKey;
        filter.subLayerKey = subLayer->subLayerKey;
        filter.weight.type = FWP_EMPTY; // auto-weight.

        // 不添加过滤条件,拦截所有连接
        filter.filterCondition = NULL;
        filter.numFilterConditions = 0;

        status = FwpmFilterAdd(g_engineHandle,
            &filter,
            NULL,
            NULL);

        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"FwpmFilterAdd for ALE failed, status=%x\n", status));
            break;
        }

        break;
    }

    return status;
}

// 添加ALE层的callout函数
VOID callouts_aleConnectCallout(
    IN const FWPS_INCOMING_VALUES* inFixedValues,
    IN const FWPS_INCOMING_METADATA_VALUES* inMetaValues,
    IN VOID* layerData,
    IN const void* classifyContext,
    IN const FWPS_FILTER* filter,
    IN UINT64 flowContext,
    OUT FWPS_CLASSIFY_OUT* classifyOut)
{
    UNREFERENCED_PARAMETER(layerData);
    UNREFERENCED_PARAMETER(classifyContext);
    UNREFERENCED_PARAMETER(filter);
    UNREFERENCED_PARAMETER(flowContext);

    // 初始化动作为允许
    classifyOut->actionType = FWP_ACTION_PERMIT;

    // 检查是否为出站连接
    WORD wDirection = inFixedValues->incomingValue[FWPS_FIELD_ALE_FLOW_ESTABLISHED_V4_DIRECTION].value.int8;
    

    if (wDirection != FWP_DIRECTION_OUTBOUND) {
        return; // 只处理出站连接
    }

    // 获取远程IP地址
    UINT32 remoteIp = inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_REMOTE_ADDRESS].value.uint32;
    UINT16 remotePort = inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_REMOTE_PORT].value.uint16;

    // 获取协议类型
    UINT32 protocol = inFixedValues->incomingValue[FWPS_FIELD_ALE_AUTH_CONNECT_V4_IP_PROTOCOL].value.uint32;

    KdPrint((DPREFIX"ALE Connect: Remote IP: %d.%d.%d.%d, Port: %d, Protocol: %d\n",
        (remoteIp >> 24) & 0xFF, (remoteIp >> 16) & 0xFF, (remoteIp >> 8) & 0xFF, remoteIp & 0xFF,
        remotePort, protocol));

    // 检查是否为百度IP地址范围(这里需要根据实际情况调整)
    // 示例:检查是否为常见的百度IP段

    for (UINT32 i = 0; i < 32;i++)
    {
        if (g_BlackIps[i] = remoteIp)
        {
            classifyOut->actionType = FWP_ACTION_BLOCK;
            classifyOut->rights &= ~FWPS_RIGHT_ACTION_WRITE;
            classifyOut->flags |= FWPS_CLASSIFY_OUT_FLAG_ABSORB;

            UINT8 ip1 = (remoteIp >> 24) & 0xFF;
            UINT8 ip2 = (remoteIp >> 16) & 0xFF;
            UINT8 ip3 = (remoteIp >> 8) & 0xFF;
            UINT8 ip4 = remoteIp & 0xFF;
            KdPrint((DPREFIX"ALE Connect: Blocking Baidu IP: %d.%d.%d.%d\n", ip1, ip2, ip3, ip4));
            return;
        }
    }

    // 也可以通过进程信息进行判断
    if (inMetaValues->processPath != NULL) {
        WCHAR processName[256] = { 0 };
        UINT32 copyLength = UTIL_MIN(inMetaValues->processPath->size, sizeof(processName) - sizeof(WCHAR));
        if (copyLength > 0) {
            RtlCopyMemory(processName, inMetaValues->processPath->data, copyLength);

            // 检查是否为浏览器进程并尝试访问百度
            if (wcsstr(processName, L"chrome.exe") ||
                wcsstr(processName, L"firefox.exe") ||
                wcsstr(processName, L"iexplore.exe")) {
                // 可以添加更复杂的逻辑来判断目标网站
               // KdPrint((DPREFIX"ALE Connect: Browser process detected: %wZ\n", inMetaValues->processPath));
            }
        }
    }
}

NTSTATUS callouts_aleConnectNotify(
    IN  FWPS_CALLOUT_NOTIFY_TYPE     notifyType,
    IN  const GUID*                 filterKey,
    IN  const FWPS_FILTER*           filter)
{
    UNREFERENCED_PARAMETER(filterKey);
    UNREFERENCED_PARAMETER(filter);

    switch (notifyType)
    {
    case FWPS_CALLOUT_NOTIFY_ADD_FILTER:
        KdPrint((DPREFIX"ALE Connect Filter Added.\n"));
        break;
    case FWPS_CALLOUT_NOTIFY_DELETE_FILTER:
        KdPrint((DPREFIX"ALE Connect Filter Deleted.\n"));
        break;
    }
    return STATUS_SUCCESS;
}

void callouts_unregisterCallouts()
{
    NTSTATUS status;
    int i;

    for (i = 0; i < CG_MAX; i++)
    {
        status = FwpsCalloutUnregisterByKey(&g_calloutGuids[i]);
        if (!NT_SUCCESS(status) && status != STATUS_FWP_CALLOUT_NOT_FOUND)
        {
            KdPrint((DPREFIX"Failed to unregister callout %d, status=%x\n", i, status));
        }
    }
}

NTSTATUS callouts_registerCallout(
    void* deviceObject,
    FWPS_CALLOUT_CLASSIFY_FN     classifyFunction,
    FWPS_CALLOUT_NOTIFY_FN       notifyFunction,
    FWPS_CALLOUT_FLOW_DELETE_NOTIFY_FN   flowDeleteFunction,
    GUID const* calloutKey,
    UINT32      flags,
    UINT32* calloutId)
{
    FWPS_CALLOUT sCallout;
    NTSTATUS status = STATUS_SUCCESS;
    memset(&sCallout, 0, sizeof(sCallout));

    sCallout.calloutKey = *calloutKey;
    sCallout.flags = flags;
    sCallout.classifyFn = classifyFunction;
    sCallout.notifyFn = notifyFunction;
    sCallout.flowDeleteFn = flowDeleteFunction;

    status = FwpsCalloutRegister(deviceObject, (FWPS_CALLOUT*)&sCallout, calloutId);

    // 添加对已存在错误的处理
    if (status == STATUS_FWP_ALREADY_EXISTS) {
        KdPrint((DPREFIX"Callout already exists, treating as success\n"));
        status = STATUS_SUCCESS;
    }
    else if (!NT_SUCCESS(status)) {
        KdPrint((DPREFIX"Failed to register callout, status=%x\n", status));
    }

    return status;
}

NTSTATUS callouts_registerCallouts(void* deviceObject)
{
    NTSTATUS status = STATUS_SUCCESS;
    int i;

    status = FwpmTransactionBegin(g_engineHandle, 0);
    if (!NT_SUCCESS(status))
    {
        KdPrint((DPREFIX"FwpmTransactionBegin failed, status=%x\n", status));
        FwpmEngineClose(g_engineHandle);
        g_engineHandle = NULL;
        return status;
    }

    for (;;)
    {
        for (i = 0; i < sizeof(g_callouts) / sizeof(g_callouts[0]); i++)
        {
            status = callouts_registerCallout(deviceObject,
                g_callouts[i].classifyFunction,
                g_callouts[i].notifyFunction,
                g_callouts[i].flowDeleteFunction,
                g_callouts[i].calloutKey,
                g_callouts[i].flags,
                g_callouts[i].calloutId);

            if (!NT_SUCCESS(status))
            {
                KdPrint((DPREFIX"Failed to register callout %d, status=%x\n", i, status));
                break;
            }
        }

        if (!NT_SUCCESS(status))
        {
            break;
        }

        status = FwpmTransactionCommit(g_engineHandle);

        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"FwpmTransactionCommit failed, status=%x\n", status));
            break;
        }
        break;
    }

    if (!NT_SUCCESS(status))
    {
        KdPrint((DPREFIX"FwpmTransactionAbort due to failure\n"));
        FwpmTransactionAbort(g_engineHandle);
        FwpmEngineClose(g_engineHandle);
        g_engineHandle = NULL;
    }

    return status;
}

NTSTATUS
callouts_addUdpFlowEstablishedFilter(const GUID * calloutKey, const GUID * layer, FWPM_SUBLAYER * subLayer)
{
    FWPM_CALLOUT callout;
    FWPM_DISPLAY_DATA displayData;
    FWPM_FILTER filter;
    FWPM_FILTER_CONDITION filterConditions[1];
    NTSTATUS status;

    for (;;)
    {
        RtlZeroMemory(&callout, sizeof(FWPM_CALLOUT));
        displayData.description = FLOW_ESTABLISHED_CALLOUT_DESCRIPTION;
        displayData.name = FLOW_ESTABLISHED_CALLOUT_NAME;

        callout.calloutKey = *calloutKey;
        callout.displayData = displayData;
        callout.applicableLayer = *layer;
        callout.flags = 0;

        status = FwpmCalloutAdd(g_engineHandle, &callout, NULL, NULL);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"FwpmCalloutAdd failed, status=%x\n", status));
            break;
        }

        RtlZeroMemory(&filter, sizeof(FWPM_FILTER));

        filter.layerKey = *layer;
        filter.displayData.name = FLOW_ESTABLISHED_CALLOUT_NAME;
        filter.displayData.description = FLOW_ESTABLISHED_CALLOUT_NAME;
        filter.action.type = FWP_ACTION_CALLOUT_TERMINATING;
        filter.action.calloutKey = *calloutKey;
        filter.filterCondition = filterConditions;
        filter.subLayerKey = subLayer->subLayerKey;
        filter.weight.type = FWP_EMPTY; // auto-weight.

        filter.numFilterConditions = 1;

        RtlZeroMemory(filterConditions, sizeof(filterConditions));

        filterConditions[0].fieldKey = FWPM_CONDITION_IP_PROTOCOL;
        filterConditions[0].matchType = FWP_MATCH_EQUAL;
        filterConditions[0].conditionValue.type = FWP_UINT8;
        filterConditions[0].conditionValue.uint8 = IPPROTO_UDP;

        status = FwpmFilterAdd(g_engineHandle,
            &filter,
            NULL,
            NULL);

        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"FwpmFilterAdd failed, status=%x\n", status));
            break;
        }

        break;
    }

    return status;
}


NTSTATUS
callouts_addStreamFilter(const GUID * calloutKey, const GUID * layer, FWPM_SUBLAYER * subLayer)
{
    FWPM_CALLOUT callout;
    FWPM_DISPLAY_DATA displayData;
    FWPM_FILTER filter;
    NTSTATUS status;

    for (;;)
    {
        RtlZeroMemory(&callout, sizeof(FWPM_CALLOUT));
        displayData.description = STREAM_CALLOUT_DESCRIPTION;
        displayData.name = STREAM_CALLOUT_NAME;

        callout.calloutKey = *calloutKey;
        callout.displayData = displayData;
        callout.applicableLayer = *layer;
        callout.flags = 0;

        status = FwpmCalloutAdd(g_engineHandle, &callout, NULL, NULL);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"Stream FwpmCalloutAdd failed, status=%x\n", status));
            break;
        }

        RtlZeroMemory(&filter, sizeof(FWPM_FILTER));

        filter.layerKey = *layer;
        filter.displayData.name = STREAM_CALLOUT_NAME;
        filter.displayData.description = STREAM_CALLOUT_NAME;
        filter.action.type = FWP_ACTION_CALLOUT_TERMINATING;
        filter.action.calloutKey = *calloutKey;
        filter.subLayerKey = subLayer->subLayerKey;
        filter.weight.type = FWP_EMPTY; // auto-weight.

        // TCP流层不需要过滤条件,因为它已经专门用于TCP流数据
        filter.filterCondition = NULL;
        filter.numFilterConditions = 0;

        status = FwpmFilterAdd(g_engineHandle,
            &filter,
            NULL,
            NULL);

        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"Stream FwpmFilterAdd failed, status=%x\n", status));
            break;
        }

        break;
    }

    return status;
}

NTSTATUS
callouts_addFilters()
{
    FWPM_SUBLAYER subLayer;
    NTSTATUS status;
    FWPM_SUBLAYER * pUdpSubLayer = NULL;

    status = FwpmTransactionBegin(g_engineHandle, 0);
    if (!NT_SUCCESS(status))
    {
        KdPrint((DPREFIX"FwpmTransactionBegin for filters failed, status=%x\n", status));
        return status;
    }

    for (;;)
    {
        RtlZeroMemory(&subLayer, sizeof(FWPM_SUBLAYER));

        subLayer.subLayerKey = g_sublayerGuid;
        subLayer.displayData.name = SUBLAYER_NAME;
        subLayer.displayData.description = SUBLAYER_NAME;
        subLayer.flags = 0;
        subLayer.weight = 0x100;

        status = FwpmSubLayerAdd(g_engineHandle, &subLayer, NULL);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"FwpmSubLayerAdd failed, status=%x\n", status));
            break;
        }

        pUdpSubLayer = &subLayer;

        status = callouts_addUdpFlowEstablishedFilter(
            &g_calloutGuids[CG_DATAGRAM_DATA_CALLOUT_V4],
            &FWPM_LAYER_DATAGRAM_DATA_V4,
            pUdpSubLayer);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"Failed to add UDP v4 filter, status=%x\n", status));
            break;
        }

        status = callouts_addUdpFlowEstablishedFilter(
            &g_calloutGuids[CG_DATAGRAM_DATA_CALLOUT_V6],
            &FWPM_LAYER_DATAGRAM_DATA_V6,
            pUdpSubLayer);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"Failed to add UDP v6 filter, status=%x\n", status));
            break;
        }

          // 添加ALE连接过滤器
        status = callouts_addAleConnectFilter(
            &g_calloutGuids[CG_ALE_AUTH_CONNECT_V4],
            &FWPM_LAYER_ALE_AUTH_CONNECT_V4,
            pUdpSubLayer);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"Failed to add ALE connect filter, status=%x\n", status));
            break;
        }

        status = FwpmTransactionCommit(g_engineHandle);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"FwpmTransactionCommit for filters failed, status=%x\n", status));
            break;
        }

        break;
    }

    if (!NT_SUCCESS(status))
    {
        KdPrint((DPREFIX"FwpmTransactionAbort for filters due to failure\n"));
        FwpmTransactionAbort(g_engineHandle);
    }

    return status;
}

NTSTATUS  callouts_init(PDEVICE_OBJECT deviceObject)
{
    NTSTATUS status = STATUS_SUCCESS;
    DWORD dwStatus;
    FWPM_SESSION session = { 0 };
    int i;

    if (g_initialized)
    {
        KdPrint((DPREFIX"Callouts already initialized\n"));
        return STATUS_SUCCESS;
    }

    KdPrint((DPREFIX"Initializing callouts\n"));

    ExUuidCreate(&g_providerGuid);
    ExUuidCreate(&g_sublayerGuid);
    ExUuidCreate(&g_recvSublayerGuid);
    ExUuidCreate(&g_recvProtSublayerGuid);

    for (i = 0; i < CG_MAX; i++)
    {
        ExUuidCreate(&g_calloutGuids[i]);
    }

    session.flags = FWPM_SESSION_FLAG_DYNAMIC;

    status = FwpmEngineOpen(
        NULL,
        RPC_C_AUTHN_WINNT,
        NULL,
        &session,
        &g_engineHandle
    );
    if (!NT_SUCCESS(status))
    {
        KdPrint((DPREFIX"FwpmEngineOpen failed, status=%x\n", status));
        return status;
    }

    for (;;)
    {
        FWPM_PROVIDER provider;

        RtlZeroMemory(&provider, sizeof(provider));
        provider.displayData.description = PROVIDER_NAME;
        provider.displayData.name = PROVIDER_NAME;
        provider.providerKey = g_providerGuid;

        dwStatus = FwpmProviderAdd(g_engineHandle, &provider, NULL);
        if (dwStatus != 0)
        {
            KdPrint((DPREFIX"FwpmProviderAdd failed, status=%x\n", dwStatus));
            status = STATUS_UNSUCCESSFUL;
            break;
        }

        status = callouts_registerCallouts(deviceObject);
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"callouts_registerCallouts failed, status=%x\n", status));
            break;
        }

        status = callouts_addFilters();
        if (!NT_SUCCESS(status))
        {
            KdPrint((DPREFIX"callouts_addFilters failed, status=%x\n", status));
            break;
        }

        break;
    }

    g_initialized = TRUE;

    if (!NT_SUCCESS(status))
    {
        KdPrint((DPREFIX"Initialization failed, cleaning up\n"));
        callouts_free();
    } else {
        KdPrint((DPREFIX"Callouts initialization completed successfully\n"));
    }

    return status;
}

void callouts_free()
{
    KdPrint((DPREFIX"callouts_free\n"));

    if (!g_initialized)
    {
        KdPrint((DPREFIX"Callouts not initialized, nothing to free\n"));
        return;
    }

    g_initialized = FALSE;

    callouts_unregisterCallouts();

    if (g_engineHandle)
    {
        FwpmSubLayerDeleteByKey(g_engineHandle, &g_sublayerGuid);
        FwpmSubLayerDeleteByKey(g_engineHandle, &g_recvSublayerGuid);
        FwpmSubLayerDeleteByKey(g_engineHandle, &g_recvProtSublayerGuid);
        FwpmSubLayerDeleteByKey(g_engineHandle, &g_ipSublayerGuid);

        FwpmProviderContextDeleteByKey(g_engineHandle, &g_providerGuid);

        FwpmEngineClose(g_engineHandle);
        g_engineHandle = NULL;
    }
}


网站公告

今日签到

点亮在社区的每一天
去签到