292 lines
12 KiB
Python
292 lines
12 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
SDL3_GPU Slang Compiler - Binding Manager
|
||
资源绑定点分配和管理
|
||
"""
|
||
|
||
import re
|
||
import os
|
||
from typing import Dict
|
||
|
||
from shader_types import ShaderStage, ResourceType, TargetFormat, ShaderInfo
|
||
|
||
|
||
class BindingManager:
|
||
def assign_bindings_spirv(self, shader_info: ShaderInfo):
|
||
"""为SPIR-V着色器分配绑定点"""
|
||
|
||
if shader_info.stage == ShaderStage.VERTEX:
|
||
texture_set = 0
|
||
uniform_set = 1
|
||
elif shader_info.stage == ShaderStage.FRAGMENT:
|
||
texture_set = 2
|
||
uniform_set = 3
|
||
else:
|
||
raise ValueError(f"Unsupported shader stage: {shader_info.stage}")
|
||
|
||
# 按类型分组资源
|
||
sampled_textures = []
|
||
storage_textures = []
|
||
storage_buffers = []
|
||
uniform_buffers = []
|
||
samplers = []
|
||
|
||
for resource in shader_info.resources:
|
||
if resource.type == ResourceType.SAMPLED_TEXTURE:
|
||
sampled_textures.append(resource)
|
||
elif resource.type == ResourceType.STORAGE_TEXTURE:
|
||
storage_textures.append(resource)
|
||
elif resource.type == ResourceType.STORAGE_BUFFER:
|
||
storage_buffers.append(resource)
|
||
elif resource.type == ResourceType.UNIFORM_BUFFER:
|
||
uniform_buffers.append(resource)
|
||
elif resource.type == ResourceType.SAMPLER:
|
||
samplers.append(resource)
|
||
|
||
# 分配绑定点 - Set 0/2: 采样纹理、存储纹理、存储缓冲区
|
||
binding = 0
|
||
for tex in sampled_textures:
|
||
tex.set = texture_set
|
||
tex.binding = binding
|
||
binding += 1
|
||
|
||
for tex in storage_textures:
|
||
tex.set = texture_set
|
||
tex.binding = binding
|
||
binding += 1
|
||
|
||
for buf in storage_buffers:
|
||
buf.set = texture_set
|
||
buf.binding = binding
|
||
binding += 1
|
||
|
||
# 分配绑定点 - Set 1/3: Uniform缓冲区
|
||
binding = 0
|
||
for buf in uniform_buffers:
|
||
buf.set = uniform_set
|
||
buf.binding = binding
|
||
binding += 1
|
||
|
||
# Samplers通常和sampled textures配对
|
||
for i, sampler in enumerate(samplers):
|
||
if i < len(sampled_textures):
|
||
sampler.set = sampled_textures[i].set
|
||
sampler.binding = sampled_textures[i].binding
|
||
|
||
def assign_bindings_dxil(self, shader_info: ShaderInfo):
|
||
"""为DXIL/DXBC着色器分配绑定点"""
|
||
if shader_info.stage == ShaderStage.VERTEX:
|
||
texture_space = 0
|
||
uniform_space = 1
|
||
elif shader_info.stage == ShaderStage.FRAGMENT:
|
||
texture_space = 2
|
||
uniform_space = 3
|
||
else:
|
||
raise ValueError(f"Unsupported shader stage: {shader_info.stage}")
|
||
|
||
# 分组资源
|
||
textures = [] # t registers
|
||
samplers = [] # s registers
|
||
uniforms = [] # b registers
|
||
|
||
for resource in shader_info.resources:
|
||
if resource.type in [ResourceType.SAMPLED_TEXTURE, ResourceType.STORAGE_TEXTURE, ResourceType.STORAGE_BUFFER]:
|
||
textures.append(resource)
|
||
elif resource.type == ResourceType.SAMPLER:
|
||
samplers.append(resource)
|
||
elif resource.type == ResourceType.UNIFORM_BUFFER:
|
||
uniforms.append(resource)
|
||
|
||
# 分配t寄存器
|
||
for i, tex in enumerate(textures):
|
||
tex.register = f"t{i}"
|
||
tex.space = texture_space
|
||
|
||
# 分配s寄存器
|
||
for i, samp in enumerate(samplers):
|
||
samp.register = f"s{i}"
|
||
samp.space = texture_space
|
||
|
||
# 分配b寄存器
|
||
for i, buf in enumerate(uniforms):
|
||
buf.register = f"b{i}"
|
||
buf.space = uniform_space
|
||
|
||
|
||
def assign_bindings_msl(self, shader_info: ShaderInfo):
|
||
"""为MSL着色器分配绑定点"""
|
||
texture_index = 0
|
||
sampler_index = 0
|
||
buffer_index = 0
|
||
|
||
# 按SDL3要求的顺序分配
|
||
# 纹理:采样纹理,然后存储纹理
|
||
for resource in shader_info.resources:
|
||
if resource.type == ResourceType.SAMPLED_TEXTURE:
|
||
resource.metal_index = texture_index
|
||
texture_index += 1
|
||
|
||
for resource in shader_info.resources:
|
||
if resource.type == ResourceType.STORAGE_TEXTURE:
|
||
resource.metal_index = texture_index
|
||
texture_index += 1
|
||
|
||
# 采样器
|
||
for resource in shader_info.resources:
|
||
if resource.type == ResourceType.SAMPLER:
|
||
resource.metal_index = sampler_index
|
||
sampler_index += 1
|
||
|
||
# 缓冲区:uniform缓冲区,然后存储缓冲区
|
||
for resource in shader_info.resources:
|
||
if resource.type == ResourceType.UNIFORM_BUFFER:
|
||
resource.metal_index = buffer_index
|
||
buffer_index += 1
|
||
|
||
for resource in shader_info.resources:
|
||
if resource.type == ResourceType.STORAGE_BUFFER:
|
||
resource.metal_index = buffer_index
|
||
buffer_index += 1
|
||
|
||
|
||
def generate_binding_code(self, shader_info: ShaderInfo, target: TargetFormat) -> str:
|
||
"""生成绑定代码"""
|
||
if target == TargetFormat.SPIRV:
|
||
return self._generate_spirv_bindings(shader_info)
|
||
elif target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
||
return self._generate_dx_bindings(shader_info)
|
||
elif target == TargetFormat.MSL:
|
||
return self._generate_msl_bindings(shader_info)
|
||
return ''
|
||
|
||
|
||
@staticmethod
|
||
def _generate_spirv_bindings(shader_info: ShaderInfo) -> str:
|
||
"""生成SPIR-V绑定属性"""
|
||
bindings = []
|
||
for resource in shader_info.resources:
|
||
if resource.set >= 0 and resource.binding >= 0:
|
||
bindings.append(f"[[vk::binding({resource.binding}, {resource.set})]] {resource.name}")
|
||
return "\n".join(bindings)
|
||
|
||
|
||
@staticmethod
|
||
def _generate_dx_bindings(shader_info: ShaderInfo) -> str:
|
||
"""生成DirectX绑定属性"""
|
||
bindings = []
|
||
for resource in shader_info.resources:
|
||
if resource.register and resource.space >= 0:
|
||
bindings.append(f"register({resource.register}, space{resource.space}) {resource.name}")
|
||
return "\n".join(bindings)
|
||
|
||
|
||
@staticmethod
|
||
def _generate_msl_bindings(shader_info: ShaderInfo) -> str:
|
||
"""生成Metal绑定属性"""
|
||
bindings = []
|
||
for resource in shader_info.resources:
|
||
if resource.metal_index >= 0:
|
||
if resource.type in [ResourceType.SAMPLED_TEXTURE, ResourceType.STORAGE_TEXTURE]:
|
||
bindings.append(f"[[texture({resource.metal_index})]] {resource.name}")
|
||
elif resource.type == ResourceType.SAMPLER:
|
||
bindings.append(f"[[sampler({resource.metal_index})]] {resource.name}")
|
||
elif resource.type in [ResourceType.UNIFORM_BUFFER, ResourceType.STORAGE_BUFFER]:
|
||
bindings.append(f"[[buffer({resource.metal_index})]] {resource.name}")
|
||
return "\n".join(bindings)
|
||
|
||
|
||
def inject_bindings(self, shader_info: ShaderInfo, target: TargetFormat) -> str:
|
||
"""在着色器源码中注入绑定信息"""
|
||
source_lines = shader_info.source_code.split('\n')
|
||
modified_lines = []
|
||
|
||
# 创建资源名到绑定信息的映射
|
||
resource_bindings = {}
|
||
for resource in shader_info.resources:
|
||
if target == TargetFormat.SPIRV:
|
||
if resource.set >= 0 and resource.binding >= 0:
|
||
resource_bindings[resource.name] = f"[[vk::binding({resource.binding}, {resource.set})]]"
|
||
elif target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
||
if resource.register and resource.space >= 0:
|
||
resource_bindings[resource.name] = f"register({resource.register}, space{resource.space})"
|
||
elif target == TargetFormat.MSL:
|
||
if resource.metal_index >= 0:
|
||
if resource.type in [ResourceType.SAMPLED_TEXTURE, ResourceType.STORAGE_TEXTURE]:
|
||
resource_bindings[resource.name] = f"[[texture({resource.metal_index})]]"
|
||
elif resource.type == ResourceType.SAMPLER:
|
||
resource_bindings[resource.name] = f"[[sampler({resource.metal_index})]]"
|
||
elif resource.type in [ResourceType.UNIFORM_BUFFER, ResourceType.STORAGE_BUFFER]:
|
||
resource_bindings[resource.name] = f"[[buffer({resource.metal_index})]]"
|
||
|
||
# 简单的文本替换方式注入绑定
|
||
# 实际应该使用AST解析,这里简化处理
|
||
for line in source_lines:
|
||
modified_line = line
|
||
|
||
# 查找资源声明
|
||
for resource_name, binding in resource_bindings.items():
|
||
# 匹配各种资源声明模式
|
||
patterns = [
|
||
rf'(\b\w+<.*?>\s+{resource_name}\b)', # Texture2D<float4> texName
|
||
rf'(\bSamplerState\s+{resource_name}\b)', # SamplerState sampName
|
||
rf'(\bConstantBuffer<.*?>\s+{resource_name}\b)', # ConstantBuffer<T> bufName
|
||
rf'(\bStructuredBuffer<.*?>\s+{resource_name}\b)', # StructuredBuffer<T> bufName
|
||
rf'(\bRWStructuredBuffer<.*?>\s+{resource_name}\b)', # RWStructuredBuffer<T> bufName
|
||
rf'(\bParameterBlock<.*?>\s+{resource_name}\b)', # ParameterBlock<T> bufName
|
||
]
|
||
|
||
for pattern in patterns:
|
||
if re.search(pattern, line):
|
||
# 如果是SPIRV和Metal,在声明前添加绑定属性
|
||
if target == TargetFormat.SPIRV or target == TargetFormat.MSL:
|
||
modified_line = re.sub(pattern, f'{binding} \\1', line)
|
||
# 如果是DXIL/DXBC,在声明后添加绑定属性
|
||
if target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
||
modified_line = re.sub(pattern, f'\\1 : {binding}', line)
|
||
break
|
||
|
||
modified_lines.append(modified_line)
|
||
|
||
return '\n'.join(modified_lines)
|
||
|
||
|
||
def generate_binding_info(self, shader_info: ShaderInfo, target: TargetFormat, output_path: str) -> tuple[str, dict]:
|
||
"""生成绑定信息字典,供运行时使用"""
|
||
binding_info = {
|
||
'stage': shader_info.stage.value,
|
||
'entry_point': shader_info.entry_point,
|
||
'resources': [],
|
||
'num_resources': dict[ResourceType, int]({
|
||
ResourceType.SAMPLED_TEXTURE: 0,
|
||
ResourceType.STORAGE_TEXTURE: 0,
|
||
ResourceType.STORAGE_BUFFER: 0,
|
||
ResourceType.UNIFORM_BUFFER: 0,
|
||
ResourceType.SAMPLER: 0
|
||
})
|
||
}
|
||
|
||
# 读取output_path中的二进制数据
|
||
if not os.path.exists(output_path):
|
||
raise FileNotFoundError(f"Output file {output_path} does not exist after compilation.")
|
||
with open(output_path, 'rb') as f:
|
||
binding_info['blob'] = f.read()
|
||
|
||
for resource in shader_info.resources:
|
||
res_info = {
|
||
'name': resource.name,
|
||
'type': resource.type.value
|
||
}
|
||
|
||
if target == TargetFormat.SPIRV:
|
||
res_info['set'] = resource.set
|
||
res_info['binding'] = resource.binding
|
||
elif target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
||
res_info['register'] = resource.register
|
||
res_info['space'] = resource.space
|
||
elif target == TargetFormat.MSL:
|
||
res_info['index'] = resource.metal_index
|
||
|
||
binding_info['resources'].append(res_info)
|
||
binding_info['num_resources'][resource.type] += 1
|
||
|
||
return binding_info |