intel-xpu-backend-for-triton绕过pytorch直接调用Unified-Runtime

发布于:2025-04-05 ⋅ 阅读:(35) ⋅ 点赞:(0)

intel-xpu-backend-for-triton绕过pytorch直接调用Unified-Runtime

背景

  • 一般情况下triton kernel依赖pytorch
  • 是否可以直接调用Unified-Runtime的API分配设备内存,直接给triton kernel使用呢

方法

  • 将Unified-Runtime中的Mem,Queue等API封装为python接口
  • 将UR分配的64bit设备地址直接传到triton kernel中
  • 在triton kernel中将该地址转为具体类型的指针(如float32)
  • 遗留问题:因为triton kernel为异步,需要为UR增加urDeviceSynchronize接口;下面的demo采用time.sleep规避

步骤

安装intel-xpu-backend-for-triton
docker stop triton_xpu_ipcx
docker rm triton_xpu_ipcx
docker run --shm-size=32g -it --privileged --net=host \
	-v $PWD:/home -w /home \
	--name triton_xpu_ipcx intel/deep-learning-essentials:latest /bin/bash
docker start triton_xpu_ipcx
docker exec -ti triton_xpu_ipcx bash
git clone https://github.com/intel/intel-xpu-backend-for-triton.git
cd intel-xpu-backend-for-triton/
git checkout 197f5f843fd17deab1de1df1c6f17e60978ecdce
source /opt/intel/oneapi/setvars.sh  --force
export PATH=$PATH:/opt/intel/oneapi/compiler/2025.0/bin/compiler
apt install intel-ocloc -y
scripts/install-pytorch.sh --source --force-reinstall
export MAX_JOBS=16
scripts/compile-triton.sh --llvm --triton
运行测试用例
cat > xpu_triton_ur.py <<-'EOF'
# Import necessary libraries
import ctypes
from ctypes import *
import enum
import sys
import numpy as np

# Import Triton for GPU kernel execution
import triton
import triton.language as tl

# Define a Triton kernel for scaling operation
@triton.jit
def triton_scale_kernel(
    input_ptr_addr,    # Pointer to input data in device memory
    output_ptr_addr,   # Pointer to output data in device memory
    scale,             # Scaling factor to apply
    n_elements,        # Total number of elements to process
    BLOCK_SIZE: tl.constexpr,  # Number of elements processed by each thread block
):
    # Get the program ID (determines which block this thread is in)
    pid = tl.program_id(axis=0)
    # Calculate the starting index for this block
    block_start = pid * BLOCK_SIZE
    # Calculate the offsets for all elements this block will process
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # Create a mask to avoid out-of-bounds memory access
    mask = offsets < n_elements

    # Convert raw pointers to proper float32 pointers
    input_ptr = input_ptr_addr.to(tl.pointer_type(tl.float32))
    output_ptr = output_ptr_addr.to(tl.pointer_type(tl.float32))

    # Load input data from device memory
    input_data = tl.load(input_ptr + offsets, mask=mask, other=0)
    # Perform the scaling operation
    output_data = input_data * scale
    # Store the results back to device memory
    tl.store(output_ptr + offsets, output_data, mask=mask)

# Python wrapper function for the Triton kernel
def triton_scale(_input, _output, scale, n_elements):
    # Define the grid size (number of thread blocks)
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    # Launch the kernel with specified grid and block size
    triton_scale_kernel[grid](_input, _output, scale, n_elements, BLOCK_SIZE=1024)

# Try loading the UR (Unified Runtime) loader library
try:
    lib = cdll.LoadLibrary("libur_loader.so")
except OSError:
    lib = cdll.LoadLibrary("ur_loader.so")

# Custom exception for UR API errors
class URException(Exception):
    pass

# Function to check UR API call results
def _check_result(result):
    success_codes = [0]  # Assuming UR_RESULT_SUCCESS is 0
    allowed_errors = []  # Add specific error codes if needed
    if result not in success_codes + allowed_errors:
        raise URException(f"UR API call failed with error code: {result}")

# Enum definitions for UR API
class ur_device_init_flag_t(enum.IntEnum):
    UR_DEVICE_INIT_FLAG_GPU = 0x1
    UR_DEVICE_INIT_FLAG_CPU = 0x2
    UR_DEVICE_INIT_FLAG_FPGA = 0x4
    UR_DEVICE_INIT_FLAG_MCA = 0x8
    UR_DEVICE_INIT_FLAG_VPU = 0x10

ur_device_init_flags_t = c_uint32

class ur_device_type_t(enum.IntEnum):
    UR_DEVICE_TYPE_GPU = 3
    UR_DEVICE_TYPE_CPU = 4
    # ...其他类型根据需要添加

class ur_mem_flag_t(enum.IntEnum):
    UR_MEM_FLAG_READ_WRITE = 0x1
    UR_MEM_FLAG_WRITE_ONLY = 0x2
    UR_MEM_FLAG_READ_ONLY = 0x4
    # ...其他标志

ur_mem_flags_t = c_uint32

# Structure definitions for UR API
class ur_context_properties_t(Structure):
    _fields_ = [
        ("stype", c_uint32),    # Structure type identifier
        ("pNext", c_void_p),    # Pointer to extension structure
        ("flags", c_uint32),    # Context creation flags
    ]

class ur_buffer_properties_t(Structure):
    _fields_ = [
        ("stype", c_uint32),    # Structure type identifier
        ("pNext", c_void_p),    # Pointer to extension structure
        ("pHost", c_void_p),    # Optional host pointer
    ]

# Handle types for UR objects
ur_adapter_handle_t = c_void_p
ur_platform_handle_t = c_void_p
ur_device_handle_t = c_void_p
ur_context_handle_t = c_void_p
ur_mem_handle_t = c_void_p
ur_queue_handle_t = c_void_p
ur_native_handle_t = c_void_p

# Function prototypes and Python wrappers for UR API

# urLoaderInit - Initialize the UR loader
lib.urLoaderInit.argtypes = [ur_device_init_flags_t, c_void_p]
lib.urLoaderInit.restype = c_int
def urLoaderInit(device_flags=0, hLoaderConfig=None):
    result = lib.urLoaderInit(device_flags, hLoaderConfig)
    _check_result(result)

# urAdapterGet - Get available adapters
lib.urAdapterGet.argtypes = [c_uint32, POINTER(ur_adapter_handle_t), POINTER(c_uint32)]
lib.urAdapterGet.restype = c_int
def urAdapterGet(NumEntries, phAdapters, pNumAdapters):
    result = lib.urAdapterGet(NumEntries, phAdapters, pNumAdapters)
    _check_result(result)

# urPlatformGet - Get platforms for given adapters
lib.urPlatformGet.argtypes = [POINTER(ur_adapter_handle_t), c_uint32, c_uint32,
                              POINTER(ur_platform_handle_t), POINTER(c_uint32)]
lib.urPlatformGet.restype = c_int
def urPlatformGet(phAdapters, NumAdapters, NumEntries, phPlatforms, pNumPlatforms):
    result = lib.urPlatformGet(phAdapters, NumAdapters, NumEntries, phPlatforms, pNumPlatforms)
    _check_result(result)

# urDeviceGet - Get devices for a platform
lib.urDeviceGet.argtypes = [ur_platform_handle_t, c_int, c_uint32,
                            POINTER(ur_device_handle_t), POINTER(c_uint32)]
lib.urDeviceGet.restype = c_int
def urDeviceGet(hPlatform, DeviceType, NumEntries, phDevices, pNumDevices):
    result = lib.urDeviceGet(hPlatform, DeviceType, NumEntries, phDevices, pNumDevices)
    _check_result(result)

# urContextCreate - Create a context for devices
lib.urContextCreate.argtypes = [c_uint32, POINTER(ur_device_handle_t),
                                POINTER(ur_context_properties_t), POINTER(ur_context_handle_t)]
lib.urContextCreate.restype = c_int
def urContextCreate(DeviceCount, phDevices, pProperties, phContext):
    result = lib.urContextCreate(DeviceCount, phDevices, pProperties, phContext)
    _check_result(result)

# urMemBufferCreate - Create a device memory buffer
lib.urMemBufferCreate.argtypes = [ur_context_handle_t, ur_mem_flags_t, c_size_t,
                                  POINTER(ur_buffer_properties_t), POINTER(ur_mem_handle_t)]
lib.urMemBufferCreate.restype = c_int
def urMemBufferCreate(hContext, flags, size, pProperties, phBuffer):
    result = lib.urMemBufferCreate(hContext, flags, size, pProperties, phBuffer)
    _check_result(result)

# urMemRelease - Release a memory buffer
lib.urMemRelease.argtypes = [ur_mem_handle_t]
lib.urMemRelease.restype = c_int
def urMemRelease(hMem):
    result = lib.urMemRelease(hMem)
    _check_result(result)

# urQueueCreate - Create a command queue
lib.urQueueCreate.argtypes = [ur_context_handle_t, ur_device_handle_t, c_void_p,
                              POINTER(ur_queue_handle_t)]
lib.urQueueCreate.restype = c_int
def urQueueCreate(hContext, hDevice, pProperties, phQueue):
    result = lib.urQueueCreate(hContext, hDevice, pProperties, phQueue)
    _check_result(result)

# urQueueFinish - Wait for all commands in queue to complete
lib.urQueueFinish.argtypes = [ur_queue_handle_t]
lib.urQueueFinish.restype = c_int
def urQueueFinish(hQueue):
    result = lib.urQueueFinish(hQueue)
    _check_result(result)

# urEnqueueMemBufferWrite - Write data to device memory
lib.urEnqueueMemBufferWrite.argtypes = [ur_queue_handle_t, ur_mem_handle_t, c_bool,
                                        c_size_t, c_size_t, c_void_p, c_uint32,
                                        c_void_p, c_void_p]
lib.urEnqueueMemBufferWrite.restype = c_int
def urEnqueueMemBufferWrite(hQueue, hBuffer, blocking, offset, size, pSrc,
                            numEventsInWaitList, phEventWaitList, phEvent):
    result = lib.urEnqueueMemBufferWrite(hQueue, hBuffer, blocking, offset, size, pSrc,
                                         numEventsInWaitList, phEventWaitList, phEvent)
    _check_result(result)

# urEnqueueMemBufferRead - Read data from device memory
lib.urEnqueueMemBufferRead.argtypes = [ur_queue_handle_t, ur_mem_handle_t, c_bool,
                                       c_size_t, c_size_t, c_void_p, c_uint32,
                                       c_void_p, c_void_p]
lib.urEnqueueMemBufferRead.restype = c_int
def urEnqueueMemBufferRead(hQueue, hBuffer, blocking, offset, size, pDst,
                           numEventsInWaitList, phEventWaitList, phEvent):
    result = lib.urEnqueueMemBufferRead(hQueue, hBuffer, blocking, offset, size, pDst,
                                        numEventsInWaitList, phEventWaitList, phEvent)
    _check_result(result)

# urMemGetNativeHandle - Get native handle for memory object
lib.urMemGetNativeHandle.argtypes = [ur_mem_handle_t, ur_device_handle_t, POINTER(ur_native_handle_t)]
lib.urMemGetNativeHandle.restype = c_int
def urMemGetNativeHandle(hMem, hDevice, phNativeMem):
    result = lib.urMemGetNativeHandle(hMem, hDevice, phNativeMem)
    _check_result(result)

# Resource release functions

# urContextRelease - Release a context
lib.urContextRelease.argtypes = [ur_context_handle_t]
lib.urContextRelease.restype = c_int
def urContextRelease(hContext):
    result = lib.urContextRelease(hContext)
    _check_result(result)

# urAdapterRelease - Release an adapter
lib.urAdapterRelease.argtypes = [ur_adapter_handle_t]
lib.urAdapterRelease.restype = c_int
def urAdapterRelease(hAdapter):
    result = lib.urAdapterRelease(hAdapter)
    _check_result(result)

# urLoaderTearDown - Shutdown the UR loader
lib.urLoaderTearDown.argtypes = []
lib.urLoaderTearDown.restype = c_int
def urLoaderTearDown():
    result = lib.urLoaderTearDown()
    _check_result(result)

def main():
    try:
        # Initialize the UR loader
        urLoaderInit()

        # Get available adapters
        adapter_count = ctypes.c_uint32()
        # First call gets the count
        urAdapterGet(0, None, ctypes.byref(adapter_count))
        # Allocate array for adapters
        adapters = (ur_adapter_handle_t * adapter_count.value)()
        # Second call gets the adapters
        urAdapterGet(adapter_count.value, adapters, None)

        # Get platforms for the adapters
        platform_count = ctypes.c_uint32()
        # First call gets the count
        urPlatformGet(adapters, adapter_count.value, 1, None, ctypes.byref(platform_count))
        # Allocate array for platforms
        platforms = (ur_platform_handle_t * platform_count.value)()
        # Second call gets the platforms
        urPlatformGet(adapters, adapter_count.value, platform_count.value, platforms, None)

        # Process each platform
        for platform in platforms:
            # Get GPU devices for this platform
            device_count = ctypes.c_uint32()
            # First call gets the count
            urDeviceGet(platform, ur_device_type_t.UR_DEVICE_TYPE_GPU.value, 0, None, ctypes.byref(device_count))
            # Allocate array for devices
            devices = (ur_device_handle_t * device_count.value)()
            # Second call gets the devices
            urDeviceGet(platform, ur_device_type_t.UR_DEVICE_TYPE_GPU.value, device_count.value, devices, None)

            # Process each device (just the first one in this example)
            for i in range(device_count.value):
                device = devices[i]
                # Create array with single device for context creation
                device_array = (ur_device_handle_t * 1)(device)

                # Create context for this device
                hContext = ur_context_handle_t()
                urContextCreate(1, device_array, None, ctypes.byref(hContext))

                # Create input and output buffers in device memory
                n_elements = 32  # Number of elements in our test array

                # Create input buffer
                dA = ur_mem_handle_t()
                urMemBufferCreate(hContext, ur_mem_flag_t.UR_MEM_FLAG_READ_WRITE.value,
                                  n_elements * ctypes.sizeof(ctypes.c_float), None, ctypes.byref(dA))

                # Get native handle for Triton to use
                dA_ptr = ur_native_handle_t()
                urMemGetNativeHandle(dA, device, byref(dA_ptr))

                # Create output buffer
                dB = ur_mem_handle_t()
                urMemBufferCreate(hContext, ur_mem_flag_t.UR_MEM_FLAG_READ_WRITE.value,
                                  n_elements * ctypes.sizeof(ctypes.c_float), None, ctypes.byref(dB))
                dB_ptr = ur_native_handle_t()
                urMemGetNativeHandle(dB, device, byref(dB_ptr))

                # Create command queue for this device
                queue = ur_queue_handle_t()
                urQueueCreate(hContext, device, None, ctypes.byref(queue))

                # Prepare host memory
                host_A = np.ones(n_elements, dtype=np.float32)*1.2  # Input array
                host_B = np.empty(n_elements, dtype=np.float32)     # Output array

                # Copy input data to device
                src_ptr = host_A.ctypes.data_as(ctypes.c_void_p)
                urEnqueueMemBufferWrite(queue, dA, True, 0, n_elements * ctypes.sizeof(ctypes.c_float),
                                         src_ptr, 0, None, None)

                # Execute the Triton kernel to scale the data
                scale = 10.0  # Scaling factor
                triton_scale(dA_ptr.value, dB_ptr.value, scale, n_elements)

                # TODO: Replace with proper synchronization
                # Currently using sleep as a temporary solution
                import time
                time.sleep(2)

                # Read results back from device
                dst_ptr = host_B.ctypes.data_as(ctypes.c_void_p)
                urEnqueueMemBufferRead(queue, dB, True, 0, n_elements * ctypes.sizeof(ctypes.c_float),
                                        dst_ptr, 0, None, None)

                # Wait for all commands to complete
                urQueueFinish(queue)

                # Compute ground truth for verification
                gt = host_A * scale
                # Calculate mean squared error between actual and expected results
                mse = np.mean((gt - host_B) ** 2)
                print(f"MSE:{mse}")
                print(host_B)

                # Clean up resources
                urMemRelease(dA)
                urMemRelease(dB)
                urContextRelease(hContext)
                break  # Just process first device
            break  # Just process first platform

        # Release adapters
        for adapter in adapters:
            urAdapterRelease(adapter)

        # Shutdown UR loader
        urLoaderTearDown()

    except URException as e:
        print(f"Error: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()
EOF
python xpu_triton_ur.py

网站公告

今日签到

点亮在社区的每一天
去签到