Files
mirage_slang/shader_parser.py

359 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
SDL3_GPU Slang Compiler - Shader Parser
着色器源码解析功能
"""
import json
import os
import re
import subprocess
import tempfile
from typing import List, Dict, Optional
from compiler_cmd import make_cmd
from global_vars import global_vars
from shader_layout_generator import ShaderLayoutGenerator
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo
class ShaderParser:
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
"""解析Slang着色器源码提取资源信息"""
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
source = f.read()
# 首先分析源码找到所有入口点
entry_points = self._find_entry_points(source)
print(f"Found potential entry points: {entry_points}")
shaders = {}
# 为每个入口点单独进行完整编译和反射
for entry_name, stage in entry_points.items():
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
shader_info = self._compile_and_reflect_entry_point(
entry_name, stage, source
)
if shader_info:
shaders[entry_name] = shader_info
else:
print(f"Failed to process entry point: {entry_name}")
return shaders
def _find_entry_points(self, source: str) -> Dict[str, ShaderStage]:
"""在源码中查找入口点函数"""
entry_points = {}
# 1. 查找带有Slang属性的函数
attribute_patterns = [
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.VERTEX),
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.COMPUTE),
]
for pattern, stage in attribute_patterns:
matches = re.finditer(pattern, source, re.DOTALL | re.IGNORECASE)
for match in matches:
entry_points[match.group(1)] = stage
print(f"Found attributed entry point: {match.group(1)} -> {stage.value}")
# 2. 查找常见的命名约定
common_entry_points = {
# Vertex shaders
'vertex_main': ShaderStage.VERTEX,
'vertexMain': ShaderStage.VERTEX,
'vert_main': ShaderStage.VERTEX,
'vs_main': ShaderStage.VERTEX,
'vertex_shader': ShaderStage.VERTEX,
'VS': ShaderStage.VERTEX,
# Fragment shaders
'fragment_main': ShaderStage.FRAGMENT,
'fragmentMain': ShaderStage.FRAGMENT,
'frag_main': ShaderStage.FRAGMENT,
'pixel_main': ShaderStage.FRAGMENT,
'ps_main': ShaderStage.FRAGMENT,
'fragment_shader': ShaderStage.FRAGMENT,
'PS': ShaderStage.FRAGMENT,
# Compute shaders
'compute_main': ShaderStage.COMPUTE,
'computeMain': ShaderStage.COMPUTE,
'comp_main': ShaderStage.COMPUTE,
'cs_main': ShaderStage.COMPUTE,
'compute_shader': ShaderStage.COMPUTE,
'CS': ShaderStage.COMPUTE,
}
# 查找这些函数名是否在源码中存在
for func_name, stage in common_entry_points.items():
if func_name not in entry_points:
# 使用词边界确保完整匹配函数名
pattern = r'\b' + re.escape(func_name) + r'\s*\('
if re.search(pattern, source, re.IGNORECASE):
entry_points[func_name] = stage
print(f"Found common entry point: {func_name} -> {stage.value}")
# 3. 查找所有函数声明,根据返回类型和参数推断
function_pattern = r'(?:float4|void|float3|float2|float)\s+(\w+)\s*$[^)]*$(?:\s*:\s*\w+)?'
matches = re.finditer(function_pattern, source)
for match in matches:
func_name = match.group(1)
if func_name not in entry_points:
# 根据函数名模式推断
func_lower = func_name.lower()
# 顶点着色器关键词
if any(keyword in func_lower for keyword in ['vert', 'vs', 'vertex']):
entry_points[func_name] = ShaderStage.VERTEX
print(f"Inferred vertex shader: {func_name}")
# 片元着色器关键词
elif any(keyword in func_lower for keyword in ['frag', 'pixel', 'ps', 'fs', 'fragment']):
entry_points[func_name] = ShaderStage.FRAGMENT
print(f"Inferred fragment shader: {func_name}")
# 计算着色器关键词
elif any(keyword in func_lower for keyword in ['comp', 'cs', 'compute']):
entry_points[func_name] = ShaderStage.COMPUTE
print(f"Inferred compute shader: {func_name}")
# 4. 查找带有HLSL语义的函数
semantic_patterns = [
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', ShaderStage.VERTEX),
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', ShaderStage.FRAGMENT),
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', ShaderStage.VERTEX),
]
for pattern, stage in semantic_patterns:
matches = re.finditer(pattern, source, re.IGNORECASE)
for match in matches:
func_name = match.group(1)
if func_name not in entry_points:
entry_points[func_name] = stage
print(f"Found semantic-based entry point: {func_name} -> {stage.value}")
# 5. 如果没找到任何入口点查找main函数
if not entry_points:
if re.search(r'\bmain\s*\(', source, re.IGNORECASE):
entry_points['main'] = ShaderStage.VERTEX
print("Found main function, assuming vertex shader")
print(f"Total entry points found: {len(entry_points)}")
return entry_points
def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]:
"""为单个入口点进行完整编译和反射"""
# 创建临时输出文件
with tempfile.NamedTemporaryFile(suffix='.shader.tmp', delete=False) as tmp_output:
temp_output_path = tmp_output.name
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
reflection_path = tmp_reflection.name
cmd = make_cmd(global_vars.source_file, global_vars.target, stage, entry_name, temp_output_path)
cmd.extend([
'-reflection-json', reflection_path, # 反射输出
])
print(f"Command: {' '.join(cmd)}")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60
)
print(f"Return code: {result.returncode}")
if result.stdout:
print(f"Stdout: {result.stdout}")
if result.stderr:
print(f"Stderr: {result.stderr}")
if result.returncode == 0:
generator = ShaderLayoutGenerator()
header_content = generator.generate_header(reflection_path)
if stage == ShaderStage.VERTEX:
# 生成顶点布局代码
layout_path_filename = os.path.join(global_vars.output_dir, f"{global_vars.source_file_name}_layout.h")
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
out_file.write(header_content)
# 读取反射数据
if os.path.exists(reflection_path):
with open(reflection_path, 'r', encoding='utf-8') as f:
reflection_json = f.read()
print(f"Reflection JSON length: {len(reflection_json)}")
if reflection_json.strip():
try:
reflection = json.loads(reflection_json)
return self._create_shader_info_from_reflection(
reflection, entry_name, stage, source
)
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw JSON: {reflection_json[:500]}...")
else:
print(f"Reflection file not created: {reflection_path}")
else:
print(f"Compilation failed for entry point {entry_name}")
# 如果反射失败,尝试手动解析
print(f"Falling back to manual parsing for {entry_name}")
# 清理临时文件
for temp_file in [temp_output_path, reflection_path]:
if os.path.exists(temp_file):
os.unlink(temp_file)
return self._create_shader_info_manual(entry_name, stage, source)
def _create_shader_info_from_reflection(self, reflection: dict, entry_name: str,
stage: ShaderStage, source: str) -> ShaderInfo:
"""从反射数据创建ShaderInfo"""
print(f"Processing reflection data for {entry_name}")
print(f"Reflection keys: {list(reflection.keys())}")
shader_info = ShaderInfo(
stage=stage,
entry_point=entry_name,
resources=[],
source_code=source,
)
# Slang反射数据的可能结构
# 尝试不同的数据结构
entry_point_data = None
# 方法1: 直接在根级别查找
if 'parameters' in reflection:
entry_point_data = reflection
# 方法2: 在entryPoints数组中查找
elif 'entryPoints' in reflection:
for ep in reflection['entryPoints']:
if ep.get('name') == entry_name:
entry_point_data = ep
break
# 方法3: 在modules中查找
elif 'modules' in reflection:
for module in reflection['modules']:
if 'entryPoints' in module:
for ep in module['entryPoints']:
if ep.get('name') == entry_name:
entry_point_data = ep
break
if entry_point_data:
print(f"Found entry point data: {list(entry_point_data.keys())}")
# 解析资源参数
parameters = entry_point_data.get('parameters', [])
print(f"Found {len(parameters)} parameters")
for param in parameters:
print(f"Processing parameter: {param}")
resource = self._parse_resource(param)
if resource:
shader_info.resources.append(resource)
print(f"Added resource: {resource.name} ({resource.type.value})")
else:
print("No entry point data found in reflection, using manual parsing")
# 使用手动解析作为fallback
manual_resources = self._extract_resources_from_source(source)
shader_info.resources.extend(manual_resources)
print(f"Shader {entry_name} has {len(shader_info.resources)} resources")
return shader_info
def _create_shader_info_manual(self, entry_name: str, stage: ShaderStage, source: str) -> ShaderInfo:
"""手动创建ShaderInfofallback方法"""
print(f"Creating shader info manually for {entry_name}")
shader_info = ShaderInfo(
stage=stage,
entry_point=entry_name,
resources=[],
source_code=source
)
# 手动解析资源
resources = self._extract_resources_from_source(source)
shader_info.resources.extend(resources)
print(f"Manual parsing found {len(resources)} resources")
return shader_info
def _extract_resources_from_source(self, source: str) -> List[Resource]:
"""从源码中提取资源声明"""
resources = []
# 资源声明的正则表达式模式
patterns = {
# Texture resources
ResourceType.SAMPLED_TEXTURE: [
r'Texture2D\s*(?:<[^>]*>)?\s+(\w+)',
r'Texture3D\s*(?:<[^>]*>)?\s+(\w+)',
r'TextureCube\s*(?:<[^>]*>)?\s+(\w+)',
],
ResourceType.STORAGE_TEXTURE: [
r'RWTexture2D\s*(?:<[^>]*>)?\s+(\w+)',
r'RWTexture3D\s*(?:<[^>]*>)?\s+(\w+)',
],
# Buffer resources
ResourceType.STORAGE_BUFFER: [
r'RWStructuredBuffer\s*<[^>]*>\s+(\w+)',
r'RWByteAddressBuffer\s+(\w+)',
r'StructuredBuffer\s*<[^>]*>\s+(\w+)',
r'ByteAddressBuffer\s+(\w+)',
],
ResourceType.UNIFORM_BUFFER: [
r'ConstantBuffer\s*<[^>]*>\s+(\w+)',
r'cbuffer\s+(\w+)',
],
ResourceType.SAMPLER: [
r'SamplerState\s+(\w+)',
r'SamplerComparisonState\s+(\w+)',
]
}
for resource_type, type_patterns in patterns.items():
for pattern in type_patterns:
matches = re.findall(pattern, source, re.IGNORECASE)
for match in matches:
resources.append(Resource(match, resource_type))
print(f"Found resource: {match} ({resource_type.value})")
return resources
def _parse_resource(self, param: dict) -> Optional[Resource]:
"""解析资源参数"""
type_info = param.get('type', {})
type_name = type_info.get('name', '')
# 判断资源类型
if 'Texture2D' in type_name or 'Texture3D' in type_name or 'TextureCube' in type_name:
if 'RW' in type_name:
return Resource(param['name'], ResourceType.STORAGE_TEXTURE)
else:
return Resource(param['name'], ResourceType.SAMPLED_TEXTURE)
elif 'StructuredBuffer' in type_name or 'ByteAddressBuffer' in type_name:
if 'RW' in type_name:
return Resource(param['name'], ResourceType.STORAGE_BUFFER)
else:
# SDL3 GPU似乎不支持只读storage buffer这里可能需要特殊处理
return Resource(param['name'], ResourceType.STORAGE_BUFFER)
elif 'constantBuffer' in type_name or type_info.get('kind') == 'constantBuffer' or type_info.get('kind') == 'parameterBlock':
return Resource(param['name'], ResourceType.UNIFORM_BUFFER)
elif 'SamplerState' in type_name or 'Sampler' in type_name:
return Resource(param['name'], ResourceType.SAMPLER)
return None