Files
mirage/tools/shader_parser.py
2025-06-19 11:11:56 +08:00

216 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
Slang Compiler - Shader Parser
着色器源码解析功能
"""
import json
import os
import re
import subprocess
import tempfile
from typing import List, Dict, Optional
from compiler_cmd import make_cmd
from global_vars import global_vars
from shader_reflection_type import *
class ShaderParser:
def parse_slang_shader(self) -> List[ShaderReflection]:
"""解析Slang着色器源码提取资源信息"""
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
source = f.read()
# 首先分析源码找到所有入口点
entry_points = self._find_entry_points(source)
if not entry_points:
print("No entry points found in the shader source.")
return []
print(f"Found potential entry points: {entry_points}")
shaders = []
# 为每个入口点单独进行完整编译和反射
for entry_name, stage in entry_points.items():
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
shader_info = self._compile_and_reflect_entry_point(
entry_name, stage
)
if shader_info:
shaders.append(shader_info)
else:
print(f"Failed to process entry point: {entry_name}")
return shaders
def _find_entry_points(self, source: str) -> Dict[str, Stage]:
"""在源码中查找入口点函数"""
entry_points = {}
# 1. 查找带有Slang属性的函数
attribute_patterns = [
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.VERTEX),
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.COMPUTE),
]
for pattern, stage in attribute_patterns:
matches = re.finditer(pattern, source, re.DOTALL | re.IGNORECASE)
for match in matches:
entry_points[match.group(1)] = stage
print(f"Found attributed entry point: {match.group(1)} -> {stage.value}")
# 2. 查找常见的命名约定
common_entry_points = {
# Vertex shaders
'vertex_main': Stage.VERTEX,
'vertexMain': Stage.VERTEX,
'vert_main': Stage.VERTEX,
'vs_main': Stage.VERTEX,
'vertex_shader': Stage.VERTEX,
'VS': Stage.VERTEX,
# Fragment shaders
'fragment_main': Stage.FRAGMENT,
'fragmentMain': Stage.FRAGMENT,
'frag_main': Stage.FRAGMENT,
'pixel_main': Stage.FRAGMENT,
'ps_main': Stage.FRAGMENT,
'fragment_shader': Stage.FRAGMENT,
'PS': Stage.FRAGMENT,
# Compute shaders
'compute_main': Stage.COMPUTE,
'computeMain': Stage.COMPUTE,
'comp_main': Stage.COMPUTE,
'cs_main': Stage.COMPUTE,
'compute_shader': Stage.COMPUTE,
'CS': Stage.COMPUTE,
}
# 查找这些函数名是否在源码中存在
for func_name, stage in common_entry_points.items():
if func_name not in entry_points:
# 使用词边界确保完整匹配函数名
pattern = r'\b' + re.escape(func_name) + r'\s*\('
if re.search(pattern, source, re.IGNORECASE):
entry_points[func_name] = stage
print(f"Found common entry point: {func_name} -> {stage.value}")
# 3. 查找所有函数声明,根据返回类型和参数推断
function_pattern = r'(?:float4|void|float3|float2|float)\s+(\w+)\s*$[^)]*$(?:\s*:\s*\w+)?'
matches = re.finditer(function_pattern, source)
for match in matches:
func_name = match.group(1)
if func_name not in entry_points:
# 根据函数名模式推断
func_lower = func_name.lower()
# 顶点着色器关键词
if any(keyword in func_lower for keyword in ['vert', 'vs', 'vertex']):
entry_points[func_name] = Stage.VERTEX
print(f"Inferred vertex shader: {func_name}")
# 片元着色器关键词
elif any(keyword in func_lower for keyword in ['frag', 'pixel', 'ps', 'fs', 'fragment']):
entry_points[func_name] = Stage.FRAGMENT
print(f"Inferred fragment shader: {func_name}")
# 计算着色器关键词
elif any(keyword in func_lower for keyword in ['comp', 'cs', 'compute']):
entry_points[func_name] = Stage.COMPUTE
print(f"Inferred compute shader: {func_name}")
# 4. 查找带有HLSL语义的函数
semantic_patterns = [
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', Stage.VERTEX),
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', Stage.FRAGMENT),
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', Stage.VERTEX),
]
for pattern, stage in semantic_patterns:
matches = re.finditer(pattern, source, re.IGNORECASE)
for match in matches:
func_name = match.group(1)
if func_name not in entry_points:
entry_points[func_name] = stage
print(f"Found semantic-based entry point: {func_name} -> {stage.value}")
# 5. 如果没找到任何入口点查找main函数
if not entry_points:
if re.search(r'\bmain\s*\(', source, re.IGNORECASE):
entry_points['main'] = Stage.VERTEX
print("Found main function, assuming vertex shader")
print(f"Total entry points found: {len(entry_points)}")
return entry_points
def _compile_and_reflect_entry_point(self, entry_name: str, stage: Stage) -> Optional[ShaderReflection]:
"""为单个入口点进行完整编译和反射"""
# 创建临时输出文件
with tempfile.NamedTemporaryFile(suffix='.shader.tmp', delete=False) as tmp_output:
temp_output_path = tmp_output.name
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
reflection_path = tmp_reflection.name
cmd = make_cmd(global_vars.source_file, global_vars.target, stage, entry_name, temp_output_path)
cmd.extend([
'-reflection-json', reflection_path, # 反射输出
])
print(f"Command: {' '.join(cmd)}")
print(f"**Compiling** {entry_name}...")
result = subprocess.run(
cmd,
capture_output=True,
text=True,
timeout=60
)
print(f"Return code: {result.returncode}")
if result.stdout:
print(f"Stdout: {result.stdout}")
if result.stderr:
print(f"Stderr: {result.stderr}")
if result.returncode == 0:
# 读取反射数据
if os.path.exists(reflection_path):
with open(reflection_path, 'r', encoding='utf-8') as f:
reflection_json = f.read()
print(f"Reflection JSON length: {len(reflection_json)}")
if reflection_json.strip():
try:
reflection = json.loads(reflection_json)
shader_reflection = ShaderReflection.from_dict(reflection)
# 读取编译后的二进制数据
with open(temp_output_path, 'rb') as f:
shader_binary = f.read()
# 写入shader_info.blob
shader_reflection.blob = shader_binary
return shader_reflection
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw JSON: {reflection_json[:500]}...")
else:
print(f"Reflection file not created: {reflection_path}")
else:
print(f"Compilation failed for entry point {entry_name}")
# 如果反射失败,尝试手动解析
print(f"Falling back to manual parsing for {entry_name}")
# 清理临时文件
for temp_file in [temp_output_path, reflection_path]:
if os.path.exists(temp_file):
os.unlink(temp_file)
return None