Files
mirage_slang/binding_manager.py
2025-06-10 17:21:36 +08:00

292 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
SDL3_GPU Slang Compiler - Binding Manager
资源绑定点分配和管理
"""
import re
import os
from typing import Dict
from shader_types import ShaderStage, ResourceType, TargetFormat, ShaderInfo
class BindingManager:
def assign_bindings_spirv(self, shader_info: ShaderInfo):
"""为SPIR-V着色器分配绑定点"""
if shader_info.stage == ShaderStage.VERTEX:
texture_set = 0
uniform_set = 1
elif shader_info.stage == ShaderStage.FRAGMENT:
texture_set = 2
uniform_set = 3
else:
raise ValueError(f"Unsupported shader stage: {shader_info.stage}")
# 按类型分组资源
sampled_textures = []
storage_textures = []
storage_buffers = []
uniform_buffers = []
samplers = []
for resource in shader_info.resources:
if resource.type == ResourceType.SAMPLED_TEXTURE:
sampled_textures.append(resource)
elif resource.type == ResourceType.STORAGE_TEXTURE:
storage_textures.append(resource)
elif resource.type == ResourceType.STORAGE_BUFFER:
storage_buffers.append(resource)
elif resource.type == ResourceType.UNIFORM_BUFFER:
uniform_buffers.append(resource)
elif resource.type == ResourceType.SAMPLER:
samplers.append(resource)
# 分配绑定点 - Set 0/2: 采样纹理、存储纹理、存储缓冲区
binding = 0
for tex in sampled_textures:
tex.set = texture_set
tex.binding = binding
binding += 1
for tex in storage_textures:
tex.set = texture_set
tex.binding = binding
binding += 1
for buf in storage_buffers:
buf.set = texture_set
buf.binding = binding
binding += 1
# 分配绑定点 - Set 1/3: Uniform缓冲区
binding = 0
for buf in uniform_buffers:
buf.set = uniform_set
buf.binding = binding
binding += 1
# Samplers通常和sampled textures配对
for i, sampler in enumerate(samplers):
if i < len(sampled_textures):
sampler.set = sampled_textures[i].set
sampler.binding = sampled_textures[i].binding
def assign_bindings_dxil(self, shader_info: ShaderInfo):
"""为DXIL/DXBC着色器分配绑定点"""
if shader_info.stage == ShaderStage.VERTEX:
texture_space = 0
uniform_space = 1
elif shader_info.stage == ShaderStage.FRAGMENT:
texture_space = 2
uniform_space = 3
else:
raise ValueError(f"Unsupported shader stage: {shader_info.stage}")
# 分组资源
textures = [] # t registers
samplers = [] # s registers
uniforms = [] # b registers
for resource in shader_info.resources:
if resource.type in [ResourceType.SAMPLED_TEXTURE, ResourceType.STORAGE_TEXTURE, ResourceType.STORAGE_BUFFER]:
textures.append(resource)
elif resource.type == ResourceType.SAMPLER:
samplers.append(resource)
elif resource.type == ResourceType.UNIFORM_BUFFER:
uniforms.append(resource)
# 分配t寄存器
for i, tex in enumerate(textures):
tex.register = f"t{i}"
tex.space = texture_space
# 分配s寄存器
for i, samp in enumerate(samplers):
samp.register = f"s{i}"
samp.space = texture_space
# 分配b寄存器
for i, buf in enumerate(uniforms):
buf.register = f"b{i}"
buf.space = uniform_space
def assign_bindings_msl(self, shader_info: ShaderInfo):
"""为MSL着色器分配绑定点"""
texture_index = 0
sampler_index = 0
buffer_index = 0
# 按SDL3要求的顺序分配
# 纹理:采样纹理,然后存储纹理
for resource in shader_info.resources:
if resource.type == ResourceType.SAMPLED_TEXTURE:
resource.metal_index = texture_index
texture_index += 1
for resource in shader_info.resources:
if resource.type == ResourceType.STORAGE_TEXTURE:
resource.metal_index = texture_index
texture_index += 1
# 采样器
for resource in shader_info.resources:
if resource.type == ResourceType.SAMPLER:
resource.metal_index = sampler_index
sampler_index += 1
# 缓冲区uniform缓冲区然后存储缓冲区
for resource in shader_info.resources:
if resource.type == ResourceType.UNIFORM_BUFFER:
resource.metal_index = buffer_index
buffer_index += 1
for resource in shader_info.resources:
if resource.type == ResourceType.STORAGE_BUFFER:
resource.metal_index = buffer_index
buffer_index += 1
def generate_binding_code(self, shader_info: ShaderInfo, target: TargetFormat) -> str:
"""生成绑定代码"""
if target == TargetFormat.SPIRV:
return self._generate_spirv_bindings(shader_info)
elif target in [TargetFormat.DXIL, TargetFormat.DXBC]:
return self._generate_dx_bindings(shader_info)
elif target == TargetFormat.MSL:
return self._generate_msl_bindings(shader_info)
return ''
@staticmethod
def _generate_spirv_bindings(shader_info: ShaderInfo) -> str:
"""生成SPIR-V绑定属性"""
bindings = []
for resource in shader_info.resources:
if resource.set >= 0 and resource.binding >= 0:
bindings.append(f"[[vk::binding({resource.binding}, {resource.set})]] {resource.name}")
return "\n".join(bindings)
@staticmethod
def _generate_dx_bindings(shader_info: ShaderInfo) -> str:
"""生成DirectX绑定属性"""
bindings = []
for resource in shader_info.resources:
if resource.register and resource.space >= 0:
bindings.append(f"register({resource.register}, space{resource.space}) {resource.name}")
return "\n".join(bindings)
@staticmethod
def _generate_msl_bindings(shader_info: ShaderInfo) -> str:
"""生成Metal绑定属性"""
bindings = []
for resource in shader_info.resources:
if resource.metal_index >= 0:
if resource.type in [ResourceType.SAMPLED_TEXTURE, ResourceType.STORAGE_TEXTURE]:
bindings.append(f"[[texture({resource.metal_index})]] {resource.name}")
elif resource.type == ResourceType.SAMPLER:
bindings.append(f"[[sampler({resource.metal_index})]] {resource.name}")
elif resource.type in [ResourceType.UNIFORM_BUFFER, ResourceType.STORAGE_BUFFER]:
bindings.append(f"[[buffer({resource.metal_index})]] {resource.name}")
return "\n".join(bindings)
def inject_bindings(self, shader_info: ShaderInfo, target: TargetFormat) -> str:
"""在着色器源码中注入绑定信息"""
source_lines = shader_info.source_code.split('\n')
modified_lines = []
# 创建资源名到绑定信息的映射
resource_bindings = {}
for resource in shader_info.resources:
if target == TargetFormat.SPIRV:
if resource.set >= 0 and resource.binding >= 0:
resource_bindings[resource.name] = f"[[vk::binding({resource.binding}, {resource.set})]]"
elif target in [TargetFormat.DXIL, TargetFormat.DXBC]:
if resource.register and resource.space >= 0:
resource_bindings[resource.name] = f"register({resource.register}, space{resource.space})"
elif target == TargetFormat.MSL:
if resource.metal_index >= 0:
if resource.type in [ResourceType.SAMPLED_TEXTURE, ResourceType.STORAGE_TEXTURE]:
resource_bindings[resource.name] = f"[[texture({resource.metal_index})]]"
elif resource.type == ResourceType.SAMPLER:
resource_bindings[resource.name] = f"[[sampler({resource.metal_index})]]"
elif resource.type in [ResourceType.UNIFORM_BUFFER, ResourceType.STORAGE_BUFFER]:
resource_bindings[resource.name] = f"[[buffer({resource.metal_index})]]"
# 简单的文本替换方式注入绑定
# 实际应该使用AST解析这里简化处理
for line in source_lines:
modified_line = line
# 查找资源声明
for resource_name, binding in resource_bindings.items():
# 匹配各种资源声明模式
patterns = [
rf'(\b\w+<.*?>\s+{resource_name}\b)', # Texture2D<float4> texName
rf'(\bSamplerState\s+{resource_name}\b)', # SamplerState sampName
rf'(\bConstantBuffer<.*?>\s+{resource_name}\b)', # ConstantBuffer<T> bufName
rf'(\bStructuredBuffer<.*?>\s+{resource_name}\b)', # StructuredBuffer<T> bufName
rf'(\bRWStructuredBuffer<.*?>\s+{resource_name}\b)', # RWStructuredBuffer<T> bufName
rf'(\bParameterBlock<.*?>\s+{resource_name}\b)', # ParameterBlock<T> bufName
]
for pattern in patterns:
if re.search(pattern, line):
# 如果是SPIRV和Metal在声明前添加绑定属性
if target == TargetFormat.SPIRV or target == TargetFormat.MSL:
modified_line = re.sub(pattern, f'{binding} \\1', line)
# 如果是DXIL/DXBC在声明后添加绑定属性
if target in [TargetFormat.DXIL, TargetFormat.DXBC]:
modified_line = re.sub(pattern, f'\\1 : {binding}', line)
break
modified_lines.append(modified_line)
return '\n'.join(modified_lines)
def generate_binding_info(self, shader_info: ShaderInfo, target: TargetFormat, output_path: str) -> tuple[str, dict]:
"""生成绑定信息字典,供运行时使用"""
binding_info = {
'stage': shader_info.stage.value,
'entry_point': shader_info.entry_point,
'resources': [],
'num_resources': dict[ResourceType, int]({
ResourceType.SAMPLED_TEXTURE: 0,
ResourceType.STORAGE_TEXTURE: 0,
ResourceType.STORAGE_BUFFER: 0,
ResourceType.UNIFORM_BUFFER: 0,
ResourceType.SAMPLER: 0
})
}
# 读取output_path中的二进制数据
if not os.path.exists(output_path):
raise FileNotFoundError(f"Output file {output_path} does not exist after compilation.")
with open(output_path, 'rb') as f:
binding_info['blob'] = f.read()
for resource in shader_info.resources:
res_info = {
'name': resource.name,
'type': resource.type.value
}
if target == TargetFormat.SPIRV:
res_info['set'] = resource.set
res_info['binding'] = resource.binding
elif target in [TargetFormat.DXIL, TargetFormat.DXBC]:
res_info['register'] = resource.register
res_info['space'] = resource.space
elif target == TargetFormat.MSL:
res_info['index'] = resource.metal_index
binding_info['resources'].append(res_info)
binding_info['num_resources'][resource.type] += 1
return binding_info