自动生成绑定句柄对象

This commit is contained in:
daiqingshuang
2025-06-10 15:51:08 +08:00
parent 274af12589
commit 8e65684dd1
6 changed files with 50 additions and 16 deletions

View File

@@ -256,7 +256,14 @@ class BindingManager:
binding_info = {
'stage': shader_info.stage.value,
'entry_point': shader_info.entry_point,
'resources': []
'resources': [],
'num_resources': dict[ResourceType, int]({
ResourceType.SAMPLED_TEXTURE: 0,
ResourceType.STORAGE_TEXTURE: 0,
ResourceType.STORAGE_BUFFER: 0,
ResourceType.UNIFORM_BUFFER: 0,
ResourceType.SAMPLER: 0
})
}
# 读取output_path中的二进制数据
@@ -281,5 +288,6 @@ class BindingManager:
res_info['index'] = resource.metal_index
binding_info['resources'].append(res_info)
binding_info['num_resources'][resource.type] += 1
return binding_info

View File

@@ -8,6 +8,9 @@ from typing import List, Dict, TextIO, Optional
from pathlib import Path
from contextlib import contextmanager
from global_vars import global_vars
from shader_types import *
class IndentManager:
"""RAII风格的缩进管理器"""
@@ -65,11 +68,11 @@ class CodeGenerator:
'sampler': 'Resource::Sampler'
}
def generate_binding_functions(self, source_file_pathname, binding_infos: List[Dict], output_path: str) -> None:
def generate_binding_functions(self, binding_infos: List[Dict], output_path: str) -> None:
"""生成C++绑定函数的入口方法"""
self._generate_cpp_bindings(source_file_pathname, binding_infos, output_path)
self._generate_cpp_bindings(binding_infos, output_path)
def _generate_cpp_bindings(self, source_file_pathname, binding_infos: List[Dict], output_path: str) -> None:
def _generate_cpp_bindings(self, binding_infos: List[Dict], output_path: str) -> None:
"""生成C++绑定函数"""
output_file = Path(output_path)
# 尝试创建输出目录
@@ -77,16 +80,20 @@ class CodeGenerator:
with open(output_file, 'w', encoding='utf-8') as file:
writer = IndentManager(file)
self._write_complete_file(writer, source_file_pathname, binding_infos)
self._write_complete_file(writer, binding_infos)
def _write_complete_file(self, writer: IndentManager, source_file_pathname, binding_infos: List[Dict]) -> None:
def _write_complete_file(self, writer: IndentManager, binding_infos: List[Dict]) -> None:
"""写入完整的文件内容"""
self._write_header(writer)
with writer.block('namespace SDL3GPU', '} // namespace SDL3GPU'):
writer.write()
self._write_shader_bindings_class(writer, source_file_pathname, binding_infos)
self._write_shader_bindings_class(writer, binding_infos)
writer.write()
if len(binding_infos) == 2:
self._write_handle_class(writer, 'pixel_shader_handle_t', binding_infos)
else:
self._write_handle_class(writer, 'compute_shader_handle_t', binding_infos)
def _write_header(self, writer: IndentManager) -> None:
"""写入文件头部"""
@@ -97,18 +104,18 @@ class CodeGenerator:
'#include <vector>',
'#include <unordered_map>',
'#include <cstdint>',
'#include "shader_handle.h"',
]
for header in headers:
writer.write(header)
writer.write()
def _write_shader_bindings_class(self, writer: IndentManager, source_file_pathname: str, binding_infos: List[Dict]) -> None:
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: List[Dict]) -> None:
"""写入ShaderBindings类"""
# 获取源文件名(不带路径和拓展名)
source_file_name = Path(source_file_pathname).stem
with writer.block(f'class {source_file_name}ShaderBindings','};'):
with writer.block(f'class {global_vars.source_file_name}_shader_bindings','};'):
self._write_public_section(writer, binding_infos)
self._write_private_section(writer)
self._write_constructor(writer, binding_infos)
@@ -165,7 +172,7 @@ class CodeGenerator:
writer.write('public:')
with writer.indent():
with writer.block('ShaderBindings()'):
with writer.block(f'{global_vars.source_file_name}_shader_bindings()'):
for info in binding_infos:
self._write_shader_initialization(writer, info)
@@ -271,3 +278,22 @@ class CodeGenerator:
with writer.block('const ShaderInfo* getShaderInfo(const std::string& shaderName) const'):
writer.write('auto it = m_shaderInfos.find(shaderName);')
writer.write('return it != m_shaderInfos.end() ? &it->second : nullptr;')
def _write_handle_class(self, writer: IndentManager, parent_class: str, shader_info: List[Dict]) -> None:
"""写入ShaderHandle类"""
# class test_shader_handle_t: public pixel_shader_handle_t
with writer.block(f'class {global_vars.source_file_name}_shader_handle_t : public {parent_class}', '};'):
writer.write('protected:')
for info in shader_info:
with writer.indent(f'virtual SDL_GPUShader* create_{info['stage'].lower()}_shader(SDL_GPUDevice* in_gpu_device) override ''{', '}'):
writer.write('SDL_GPUShaderCreateInfo info{};')
writer.write(f'info.code = {global_vars.source_file_name}_shader_bindings::{info["entry_point"]}_blob;')
writer.write(f'info.code_size = sizeof({global_vars.source_file_name}_shader_bindings::{info["entry_point"]}_blob);')
writer.write(f'info.entrypoint = "{info["entry_point"]}";')
writer.write(f'info.format = SDL_GPU_SHADERFORMAT_{global_vars.target.value.upper()};')
writer.write(f'info.stage = SDL_GPU_SHADERSTAGE_{info['stage'].upper()};')
writer.write(f'info.num_samplers = {info['num_resources'][ResourceType.SAMPLER]};')
writer.write(f'info.num_storage_textures = {info['num_resources'][ResourceType.STORAGE_TEXTURE]};')
writer.write(f'info.num_storage_buffers = {info['num_resources'][ResourceType.STORAGE_BUFFER]};')
writer.write(f'info.num_uniform_buffers = {info['num_resources'][ResourceType.UNIFORM_BUFFER]};')
writer.write('return SDL_CreateGPUShader(in_gpu_device, &info);')

View File

@@ -64,6 +64,6 @@ class SDL3GPUSlangCompiler:
os.unlink(output_path)
def generate_binding_functions(self, source_file_pathname, binding_infos: List[Dict], output_path: str):
def generate_binding_functions(self, binding_infos: List[Dict], output_path: str):
"""生成C/C++绑定函数"""
self.code_generator.generate_binding_functions(source_file_pathname, binding_infos, output_path)
self.code_generator.generate_binding_functions(binding_infos, output_path)

View File

@@ -62,7 +62,7 @@ def main():
# 生成绑定代码
print(f"\n**Generating** binding code to {binding_output_file_pathname}...")
compiler.generate_binding_functions(os.path.abspath(args.input), binding_infos, binding_output_file_pathname)
compiler.generate_binding_functions(binding_infos, binding_output_file_pathname)
print("**Done!**")
if __name__ == '__main__':

View File

@@ -225,7 +225,7 @@ class ShaderParser:
stage=stage,
entry_point=entry_name,
resources=[],
source_code=source
source_code=source,
)
# Slang反射数据的可能结构

View File

@@ -6,7 +6,7 @@ SDL3_GPU Slang Compiler - Type Definitions
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Mapping
class ShaderStage(Enum):