69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
SDL3_GPU Slang Compiler - Main Compiler
|
||
主编译器类,整合所有功能模块
|
||
"""
|
||
|
||
import os
|
||
import subprocess
|
||
import tempfile
|
||
from typing import List, Dict
|
||
|
||
from binding_manager import BindingManager
|
||
from code_generator import CodeGenerator
|
||
from compiler_cmd import make_cmd
|
||
from shader_parser import ShaderParser
|
||
from shader_types import ShaderInfo, TargetFormat
|
||
|
||
|
||
class SDL3GPUSlangCompiler:
|
||
def __init__(self):
|
||
self.parser = ShaderParser()
|
||
self.binding_manager = BindingManager()
|
||
self.code_generator = CodeGenerator()
|
||
|
||
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||
"""解析Slang着色器源码,提取资源信息"""
|
||
return self.parser.parse_slang_shader()
|
||
|
||
def compile_shader(self, shader_info: ShaderInfo, target: TargetFormat) -> tuple[str, dict]:
|
||
"""编译着色器并返回二进制路径和绑定信息"""
|
||
output_path = tempfile.mktemp()
|
||
# 根据目标格式分配绑定点
|
||
if target == TargetFormat.SPIRV:
|
||
self.binding_manager.assign_bindings_spirv(shader_info)
|
||
elif target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
||
self.binding_manager.assign_bindings_dxil(shader_info)
|
||
elif target == TargetFormat.MSL:
|
||
self.binding_manager.assign_bindings_msl(shader_info)
|
||
|
||
# 生成带绑定信息的着色器代码
|
||
modified_source = self.binding_manager.inject_bindings(shader_info, target)
|
||
|
||
# 写入临时文件
|
||
with tempfile.NamedTemporaryFile(mode='w', suffix='.slang', delete=False, encoding='utf8') as tmp:
|
||
tmp.write(modified_source)
|
||
tmp_path = tmp.name
|
||
|
||
try:
|
||
# 编译着色器
|
||
cmd = make_cmd(tmp_path, target, shader_info.stage, shader_info.entry_point, output_path)
|
||
print(f"Compiling shader with command: {' '.join(cmd)}")
|
||
|
||
subprocess.run(cmd, check=True)
|
||
print(f"Shader compiled successfully")
|
||
|
||
# 生成绑定信息
|
||
binding_info = self.binding_manager.generate_binding_info(shader_info, target, output_path)
|
||
|
||
return binding_info
|
||
|
||
finally:
|
||
# 清理临时文件
|
||
os.unlink(tmp_path)
|
||
os.unlink(output_path)
|
||
|
||
|
||
def generate_binding_functions(self, binding_infos: List[Dict], output_path: str):
|
||
"""生成C/C++绑定函数"""
|
||
self.code_generator.generate_binding_functions(binding_infos, output_path) |