358 lines
15 KiB
Python
358 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
SDL3_GPU Slang Compiler - Shader Parser
|
||
着色器源码解析功能
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
import subprocess
|
||
import tempfile
|
||
from typing import List, Dict, Optional
|
||
|
||
from compiler_cmd import make_cmd
|
||
from global_vars import global_vars
|
||
from shader_layout_generator import ShaderLayoutGenerator
|
||
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo
|
||
|
||
|
||
class ShaderParser:
|
||
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||
"""解析Slang着色器源码,提取资源信息"""
|
||
|
||
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
|
||
source = f.read()
|
||
|
||
# 首先分析源码找到所有入口点
|
||
entry_points = self._find_entry_points(source)
|
||
print(f"Found potential entry points: {entry_points}")
|
||
|
||
shaders = {}
|
||
|
||
# 为每个入口点单独进行完整编译和反射
|
||
for entry_name, stage in entry_points.items():
|
||
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
|
||
|
||
shader_info = self._compile_and_reflect_entry_point(
|
||
entry_name, stage, source
|
||
)
|
||
|
||
if shader_info:
|
||
shaders[entry_name] = shader_info
|
||
else:
|
||
print(f"Failed to process entry point: {entry_name}")
|
||
|
||
return shaders
|
||
|
||
def _find_entry_points(self, source: str) -> Dict[str, ShaderStage]:
|
||
"""在源码中查找入口点函数"""
|
||
entry_points = {}
|
||
|
||
# 1. 查找带有Slang属性的函数
|
||
attribute_patterns = [
|
||
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.VERTEX),
|
||
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
|
||
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
|
||
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.COMPUTE),
|
||
]
|
||
|
||
for pattern, stage in attribute_patterns:
|
||
matches = re.finditer(pattern, source, re.DOTALL | re.IGNORECASE)
|
||
for match in matches:
|
||
entry_points[match.group(1)] = stage
|
||
print(f"Found attributed entry point: {match.group(1)} -> {stage.value}")
|
||
|
||
# 2. 查找常见的命名约定
|
||
common_entry_points = {
|
||
# Vertex shaders
|
||
'vertex_main': ShaderStage.VERTEX,
|
||
'vertexMain': ShaderStage.VERTEX,
|
||
'vert_main': ShaderStage.VERTEX,
|
||
'vs_main': ShaderStage.VERTEX,
|
||
'vertex_shader': ShaderStage.VERTEX,
|
||
'VS': ShaderStage.VERTEX,
|
||
|
||
# Fragment shaders
|
||
'fragment_main': ShaderStage.FRAGMENT,
|
||
'fragmentMain': ShaderStage.FRAGMENT,
|
||
'frag_main': ShaderStage.FRAGMENT,
|
||
'pixel_main': ShaderStage.FRAGMENT,
|
||
'ps_main': ShaderStage.FRAGMENT,
|
||
'fragment_shader': ShaderStage.FRAGMENT,
|
||
'PS': ShaderStage.FRAGMENT,
|
||
|
||
# Compute shaders
|
||
'compute_main': ShaderStage.COMPUTE,
|
||
'computeMain': ShaderStage.COMPUTE,
|
||
'comp_main': ShaderStage.COMPUTE,
|
||
'cs_main': ShaderStage.COMPUTE,
|
||
'compute_shader': ShaderStage.COMPUTE,
|
||
'CS': ShaderStage.COMPUTE,
|
||
}
|
||
|
||
# 查找这些函数名是否在源码中存在
|
||
for func_name, stage in common_entry_points.items():
|
||
if func_name not in entry_points:
|
||
# 使用词边界确保完整匹配函数名
|
||
pattern = r'\b' + re.escape(func_name) + r'\s*\('
|
||
if re.search(pattern, source, re.IGNORECASE):
|
||
entry_points[func_name] = stage
|
||
print(f"Found common entry point: {func_name} -> {stage.value}")
|
||
|
||
# 3. 查找所有函数声明,根据返回类型和参数推断
|
||
function_pattern = r'(?:float4|void|float3|float2|float)\s+(\w+)\s*$[^)]*$(?:\s*:\s*\w+)?'
|
||
matches = re.finditer(function_pattern, source)
|
||
|
||
for match in matches:
|
||
func_name = match.group(1)
|
||
if func_name not in entry_points:
|
||
# 根据函数名模式推断
|
||
func_lower = func_name.lower()
|
||
|
||
# 顶点着色器关键词
|
||
if any(keyword in func_lower for keyword in ['vert', 'vs', 'vertex']):
|
||
entry_points[func_name] = ShaderStage.VERTEX
|
||
print(f"Inferred vertex shader: {func_name}")
|
||
# 片元着色器关键词
|
||
elif any(keyword in func_lower for keyword in ['frag', 'pixel', 'ps', 'fs', 'fragment']):
|
||
entry_points[func_name] = ShaderStage.FRAGMENT
|
||
print(f"Inferred fragment shader: {func_name}")
|
||
# 计算着色器关键词
|
||
elif any(keyword in func_lower for keyword in ['comp', 'cs', 'compute']):
|
||
entry_points[func_name] = ShaderStage.COMPUTE
|
||
print(f"Inferred compute shader: {func_name}")
|
||
|
||
# 4. 查找带有HLSL语义的函数
|
||
semantic_patterns = [
|
||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', ShaderStage.VERTEX),
|
||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', ShaderStage.FRAGMENT),
|
||
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', ShaderStage.VERTEX),
|
||
]
|
||
|
||
for pattern, stage in semantic_patterns:
|
||
matches = re.finditer(pattern, source, re.IGNORECASE)
|
||
for match in matches:
|
||
func_name = match.group(1)
|
||
if func_name not in entry_points:
|
||
entry_points[func_name] = stage
|
||
print(f"Found semantic-based entry point: {func_name} -> {stage.value}")
|
||
|
||
# 5. 如果没找到任何入口点,查找main函数
|
||
if not entry_points:
|
||
if re.search(r'\bmain\s*\(', source, re.IGNORECASE):
|
||
entry_points['main'] = ShaderStage.VERTEX
|
||
print("Found main function, assuming vertex shader")
|
||
|
||
print(f"Total entry points found: {len(entry_points)}")
|
||
return entry_points
|
||
|
||
def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]:
|
||
"""为单个入口点进行完整编译和反射"""
|
||
|
||
# 创建临时输出文件
|
||
with tempfile.NamedTemporaryFile(suffix='.shader.tmp', delete=False) as tmp_output:
|
||
temp_output_path = tmp_output.name
|
||
|
||
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
|
||
reflection_path = tmp_reflection.name
|
||
|
||
cmd = make_cmd(global_vars.source_file, global_vars.target, stage, entry_name, temp_output_path)
|
||
cmd.extend([
|
||
'-reflection-json', reflection_path, # 反射输出
|
||
])
|
||
|
||
print(f"Command: {' '.join(cmd)}")
|
||
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=60
|
||
)
|
||
|
||
print(f"Return code: {result.returncode}")
|
||
if result.stdout:
|
||
print(f"Stdout: {result.stdout}")
|
||
if result.stderr:
|
||
print(f"Stderr: {result.stderr}")
|
||
|
||
if result.returncode == 0:
|
||
if stage == ShaderStage.VERTEX:
|
||
# 生成顶点布局代码
|
||
generator = ShaderLayoutGenerator()
|
||
header_content = generator.generate_header(reflection_path)
|
||
layout_path_filename = os.path.join(global_vars.output_dir, f"{global_vars.source_file_name}_layout.h")
|
||
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
||
out_file.write(header_content)
|
||
|
||
# 读取反射数据
|
||
if os.path.exists(reflection_path):
|
||
with open(reflection_path, 'r', encoding='utf-8') as f:
|
||
reflection_json = f.read()
|
||
print(f"Reflection JSON length: {len(reflection_json)}")
|
||
|
||
if reflection_json.strip():
|
||
try:
|
||
reflection = json.loads(reflection_json)
|
||
return self._create_shader_info_from_reflection(
|
||
reflection, entry_name, stage, source
|
||
)
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON parsing error: {e}")
|
||
print(f"Raw JSON: {reflection_json[:500]}...")
|
||
else:
|
||
print(f"Reflection file not created: {reflection_path}")
|
||
else:
|
||
print(f"Compilation failed for entry point {entry_name}")
|
||
|
||
# 如果反射失败,尝试手动解析
|
||
print(f"Falling back to manual parsing for {entry_name}")
|
||
# 清理临时文件
|
||
for temp_file in [temp_output_path, reflection_path]:
|
||
if os.path.exists(temp_file):
|
||
os.unlink(temp_file)
|
||
|
||
return self._create_shader_info_manual(entry_name, stage, source)
|
||
|
||
def _create_shader_info_from_reflection(self, reflection: dict, entry_name: str,
|
||
stage: ShaderStage, source: str) -> ShaderInfo:
|
||
"""从反射数据创建ShaderInfo"""
|
||
print(f"Processing reflection data for {entry_name}")
|
||
print(f"Reflection keys: {list(reflection.keys())}")
|
||
|
||
shader_info = ShaderInfo(
|
||
stage=stage,
|
||
entry_point=entry_name,
|
||
resources=[],
|
||
source_code=source,
|
||
)
|
||
|
||
# Slang反射数据的可能结构
|
||
# 尝试不同的数据结构
|
||
entry_point_data = None
|
||
|
||
# 方法1: 直接在根级别查找
|
||
if 'parameters' in reflection:
|
||
entry_point_data = reflection
|
||
|
||
# 方法2: 在entryPoints数组中查找
|
||
elif 'entryPoints' in reflection:
|
||
for ep in reflection['entryPoints']:
|
||
if ep.get('name') == entry_name:
|
||
entry_point_data = ep
|
||
break
|
||
|
||
# 方法3: 在modules中查找
|
||
elif 'modules' in reflection:
|
||
for module in reflection['modules']:
|
||
if 'entryPoints' in module:
|
||
for ep in module['entryPoints']:
|
||
if ep.get('name') == entry_name:
|
||
entry_point_data = ep
|
||
break
|
||
|
||
if entry_point_data:
|
||
print(f"Found entry point data: {list(entry_point_data.keys())}")
|
||
|
||
# 解析资源参数
|
||
parameters = entry_point_data.get('parameters', [])
|
||
print(f"Found {len(parameters)} parameters")
|
||
|
||
for param in parameters:
|
||
print(f"Processing parameter: {param}")
|
||
resource = self._parse_resource(param)
|
||
if resource:
|
||
shader_info.resources.append(resource)
|
||
print(f"Added resource: {resource.name} ({resource.type.value})")
|
||
else:
|
||
print("No entry point data found in reflection, using manual parsing")
|
||
# 使用手动解析作为fallback
|
||
manual_resources = self._extract_resources_from_source(source)
|
||
shader_info.resources.extend(manual_resources)
|
||
|
||
print(f"Shader {entry_name} has {len(shader_info.resources)} resources")
|
||
return shader_info
|
||
|
||
def _create_shader_info_manual(self, entry_name: str, stage: ShaderStage, source: str) -> ShaderInfo:
|
||
"""手动创建ShaderInfo(fallback方法)"""
|
||
print(f"Creating shader info manually for {entry_name}")
|
||
shader_info = ShaderInfo(
|
||
stage=stage,
|
||
entry_point=entry_name,
|
||
resources=[],
|
||
source_code=source
|
||
)
|
||
|
||
# 手动解析资源
|
||
resources = self._extract_resources_from_source(source)
|
||
shader_info.resources.extend(resources)
|
||
|
||
print(f"Manual parsing found {len(resources)} resources")
|
||
return shader_info
|
||
|
||
def _extract_resources_from_source(self, source: str) -> List[Resource]:
|
||
"""从源码中提取资源声明"""
|
||
resources = []
|
||
|
||
# 资源声明的正则表达式模式
|
||
patterns = {
|
||
# Texture resources
|
||
ResourceType.SAMPLED_TEXTURE: [
|
||
r'Texture2D\s*(?:<[^>]*>)?\s+(\w+)',
|
||
r'Texture3D\s*(?:<[^>]*>)?\s+(\w+)',
|
||
r'TextureCube\s*(?:<[^>]*>)?\s+(\w+)',
|
||
],
|
||
ResourceType.STORAGE_TEXTURE: [
|
||
r'RWTexture2D\s*(?:<[^>]*>)?\s+(\w+)',
|
||
r'RWTexture3D\s*(?:<[^>]*>)?\s+(\w+)',
|
||
],
|
||
# Buffer resources
|
||
ResourceType.STORAGE_BUFFER: [
|
||
r'RWStructuredBuffer\s*<[^>]*>\s+(\w+)',
|
||
r'RWByteAddressBuffer\s+(\w+)',
|
||
r'StructuredBuffer\s*<[^>]*>\s+(\w+)',
|
||
r'ByteAddressBuffer\s+(\w+)',
|
||
],
|
||
ResourceType.UNIFORM_BUFFER: [
|
||
r'ConstantBuffer\s*<[^>]*>\s+(\w+)',
|
||
r'cbuffer\s+(\w+)',
|
||
],
|
||
ResourceType.SAMPLER: [
|
||
r'SamplerState\s+(\w+)',
|
||
r'SamplerComparisonState\s+(\w+)',
|
||
]
|
||
}
|
||
|
||
for resource_type, type_patterns in patterns.items():
|
||
for pattern in type_patterns:
|
||
matches = re.findall(pattern, source, re.IGNORECASE)
|
||
for match in matches:
|
||
resources.append(Resource(match, resource_type))
|
||
print(f"Found resource: {match} ({resource_type.value})")
|
||
|
||
return resources
|
||
|
||
def _parse_resource(self, param: dict) -> Optional[Resource]:
|
||
"""解析资源参数"""
|
||
type_info = param.get('type', {})
|
||
type_name = type_info.get('name', '')
|
||
|
||
# 判断资源类型
|
||
if 'Texture2D' in type_name or 'Texture3D' in type_name or 'TextureCube' in type_name:
|
||
if 'RW' in type_name:
|
||
return Resource(param['name'], ResourceType.STORAGE_TEXTURE)
|
||
else:
|
||
return Resource(param['name'], ResourceType.SAMPLED_TEXTURE)
|
||
elif 'StructuredBuffer' in type_name or 'ByteAddressBuffer' in type_name:
|
||
if 'RW' in type_name:
|
||
return Resource(param['name'], ResourceType.STORAGE_BUFFER)
|
||
else:
|
||
# SDL3 GPU似乎不支持只读storage buffer,这里可能需要特殊处理
|
||
return Resource(param['name'], ResourceType.STORAGE_BUFFER)
|
||
elif 'constantBuffer' in type_name or type_info.get('kind') == 'constantBuffer' or type_info.get('kind') == 'parameterBlock':
|
||
return Resource(param['name'], ResourceType.UNIFORM_BUFFER)
|
||
elif 'SamplerState' in type_name or 'Sampler' in type_name:
|
||
return Resource(param['name'], ResourceType.SAMPLER)
|
||
|
||
return None |