216 lines
8.5 KiB
Python
216 lines
8.5 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Slang Compiler - Shader Parser
|
||
着色器源码解析功能
|
||
"""
|
||
|
||
import json
|
||
import os
|
||
import re
|
||
import subprocess
|
||
import tempfile
|
||
from typing import List, Dict, Optional
|
||
|
||
from compiler_cmd import make_cmd
|
||
from global_vars import global_vars
|
||
from shader_reflection_type import *
|
||
|
||
|
||
class ShaderParser:
|
||
def parse_slang_shader(self) -> List[ShaderReflection]:
|
||
"""解析Slang着色器源码,提取资源信息"""
|
||
|
||
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
|
||
source = f.read()
|
||
|
||
# 首先分析源码找到所有入口点
|
||
entry_points = self._find_entry_points(source)
|
||
if not entry_points:
|
||
print("No entry points found in the shader source.")
|
||
return []
|
||
print(f"Found potential entry points: {entry_points}")
|
||
|
||
shaders = []
|
||
|
||
# 为每个入口点单独进行完整编译和反射
|
||
for entry_name, stage in entry_points.items():
|
||
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
|
||
|
||
shader_info = self._compile_and_reflect_entry_point(
|
||
entry_name, stage
|
||
)
|
||
|
||
if shader_info:
|
||
shaders.append(shader_info)
|
||
else:
|
||
print(f"Failed to process entry point: {entry_name}")
|
||
|
||
return shaders
|
||
|
||
def _find_entry_points(self, source: str) -> Dict[str, Stage]:
|
||
"""在源码中查找入口点函数"""
|
||
entry_points = {}
|
||
|
||
# 1. 查找带有Slang属性的函数
|
||
attribute_patterns = [
|
||
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.VERTEX),
|
||
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
|
||
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
|
||
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.COMPUTE),
|
||
]
|
||
|
||
for pattern, stage in attribute_patterns:
|
||
matches = re.finditer(pattern, source, re.DOTALL | re.IGNORECASE)
|
||
for match in matches:
|
||
entry_points[match.group(1)] = stage
|
||
print(f"Found attributed entry point: {match.group(1)} -> {stage.value}")
|
||
|
||
# 2. 查找常见的命名约定
|
||
common_entry_points = {
|
||
# Vertex shaders
|
||
'vertex_main': Stage.VERTEX,
|
||
'vertexMain': Stage.VERTEX,
|
||
'vert_main': Stage.VERTEX,
|
||
'vs_main': Stage.VERTEX,
|
||
'vertex_shader': Stage.VERTEX,
|
||
'VS': Stage.VERTEX,
|
||
|
||
# Fragment shaders
|
||
'fragment_main': Stage.FRAGMENT,
|
||
'fragmentMain': Stage.FRAGMENT,
|
||
'frag_main': Stage.FRAGMENT,
|
||
'pixel_main': Stage.FRAGMENT,
|
||
'ps_main': Stage.FRAGMENT,
|
||
'fragment_shader': Stage.FRAGMENT,
|
||
'PS': Stage.FRAGMENT,
|
||
|
||
# Compute shaders
|
||
'compute_main': Stage.COMPUTE,
|
||
'computeMain': Stage.COMPUTE,
|
||
'comp_main': Stage.COMPUTE,
|
||
'cs_main': Stage.COMPUTE,
|
||
'compute_shader': Stage.COMPUTE,
|
||
'CS': Stage.COMPUTE,
|
||
}
|
||
|
||
# 查找这些函数名是否在源码中存在
|
||
for func_name, stage in common_entry_points.items():
|
||
if func_name not in entry_points:
|
||
# 使用词边界确保完整匹配函数名
|
||
pattern = r'\b' + re.escape(func_name) + r'\s*\('
|
||
if re.search(pattern, source, re.IGNORECASE):
|
||
entry_points[func_name] = stage
|
||
print(f"Found common entry point: {func_name} -> {stage.value}")
|
||
|
||
# 3. 查找所有函数声明,根据返回类型和参数推断
|
||
function_pattern = r'(?:float4|void|float3|float2|float)\s+(\w+)\s*$[^)]*$(?:\s*:\s*\w+)?'
|
||
matches = re.finditer(function_pattern, source)
|
||
|
||
for match in matches:
|
||
func_name = match.group(1)
|
||
if func_name not in entry_points:
|
||
# 根据函数名模式推断
|
||
func_lower = func_name.lower()
|
||
|
||
# 顶点着色器关键词
|
||
if any(keyword in func_lower for keyword in ['vert', 'vs', 'vertex']):
|
||
entry_points[func_name] = Stage.VERTEX
|
||
print(f"Inferred vertex shader: {func_name}")
|
||
# 片元着色器关键词
|
||
elif any(keyword in func_lower for keyword in ['frag', 'pixel', 'ps', 'fs', 'fragment']):
|
||
entry_points[func_name] = Stage.FRAGMENT
|
||
print(f"Inferred fragment shader: {func_name}")
|
||
# 计算着色器关键词
|
||
elif any(keyword in func_lower for keyword in ['comp', 'cs', 'compute']):
|
||
entry_points[func_name] = Stage.COMPUTE
|
||
print(f"Inferred compute shader: {func_name}")
|
||
|
||
# 4. 查找带有HLSL语义的函数
|
||
semantic_patterns = [
|
||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', Stage.VERTEX),
|
||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', Stage.FRAGMENT),
|
||
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', Stage.VERTEX),
|
||
]
|
||
|
||
for pattern, stage in semantic_patterns:
|
||
matches = re.finditer(pattern, source, re.IGNORECASE)
|
||
for match in matches:
|
||
func_name = match.group(1)
|
||
if func_name not in entry_points:
|
||
entry_points[func_name] = stage
|
||
print(f"Found semantic-based entry point: {func_name} -> {stage.value}")
|
||
|
||
# 5. 如果没找到任何入口点,查找main函数
|
||
if not entry_points:
|
||
if re.search(r'\bmain\s*\(', source, re.IGNORECASE):
|
||
entry_points['main'] = Stage.VERTEX
|
||
print("Found main function, assuming vertex shader")
|
||
|
||
print(f"Total entry points found: {len(entry_points)}")
|
||
return entry_points
|
||
|
||
def _compile_and_reflect_entry_point(self, entry_name: str, stage: Stage) -> Optional[ShaderReflection]:
|
||
"""为单个入口点进行完整编译和反射"""
|
||
|
||
# 创建临时输出文件
|
||
with tempfile.NamedTemporaryFile(suffix='.shader.tmp', delete=False) as tmp_output:
|
||
temp_output_path = tmp_output.name
|
||
|
||
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
|
||
reflection_path = tmp_reflection.name
|
||
|
||
cmd = make_cmd(global_vars.source_file, global_vars.target, stage, entry_name, temp_output_path)
|
||
cmd.extend([
|
||
'-reflection-json', reflection_path, # 反射输出
|
||
])
|
||
|
||
print(f"Command: {' '.join(cmd)}")
|
||
print(f"**Compiling** {entry_name}...")
|
||
result = subprocess.run(
|
||
cmd,
|
||
capture_output=True,
|
||
text=True,
|
||
timeout=60
|
||
)
|
||
|
||
print(f"Return code: {result.returncode}")
|
||
if result.stdout:
|
||
print(f"Stdout: {result.stdout}")
|
||
if result.stderr:
|
||
print(f"Stderr: {result.stderr}")
|
||
|
||
if result.returncode == 0:
|
||
# 读取反射数据
|
||
if os.path.exists(reflection_path):
|
||
with open(reflection_path, 'r', encoding='utf-8') as f:
|
||
reflection_json = f.read()
|
||
print(f"Reflection JSON length: {len(reflection_json)}")
|
||
|
||
if reflection_json.strip():
|
||
try:
|
||
reflection = json.loads(reflection_json)
|
||
shader_reflection = ShaderReflection.from_dict(reflection)
|
||
|
||
# 读取编译后的二进制数据
|
||
with open(temp_output_path, 'rb') as f:
|
||
shader_binary = f.read()
|
||
# 写入shader_info.blob
|
||
shader_reflection.blob = shader_binary
|
||
|
||
return shader_reflection
|
||
except json.JSONDecodeError as e:
|
||
print(f"JSON parsing error: {e}")
|
||
print(f"Raw JSON: {reflection_json[:500]}...")
|
||
else:
|
||
print(f"Reflection file not created: {reflection_path}")
|
||
else:
|
||
print(f"Compilation failed for entry point {entry_name}")
|
||
|
||
# 如果反射失败,尝试手动解析
|
||
print(f"Falling back to manual parsing for {entry_name}")
|
||
# 清理临时文件
|
||
for temp_file in [temp_output_path, reflection_path]:
|
||
if os.path.exists(temp_file):
|
||
os.unlink(temp_file)
|
||
|
||
return None |