644 lines
22 KiB
Python
644 lines
22 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
代码生成模块
|
||
|
||
负责生成C++代码,包括结构体定义、SPIR-V数组、头文件和绑定信息。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import struct
|
||
from typing import Dict, List
|
||
|
||
from .type_mapping import (
|
||
calculate_std430_alignment,
|
||
calculate_std430_size,
|
||
spirv_type_to_cpp,
|
||
)
|
||
from .types import (
|
||
ArrayTypeInfo,
|
||
BufferInfo,
|
||
CompilationResult,
|
||
MemberInfo,
|
||
ShaderMetadata,
|
||
SPIRVReflection,
|
||
ToolError,
|
||
TypeInfo,
|
||
)
|
||
|
||
|
||
# ============ Structure Generation ============
|
||
|
||
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
|
||
"""生成所有buffer/uniform结构体定义"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ Buffer Structures ============")
|
||
lines.append("// Auto-generated from SPIR-V reflection")
|
||
lines.append("")
|
||
|
||
for buffer in reflection.buffers:
|
||
struct_lines = generate_single_struct(buffer, reflection.types)
|
||
lines.extend(struct_lines)
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成单个结构体定义"""
|
||
lines = []
|
||
|
||
struct = buffer.struct_type
|
||
struct_name = buffer.name
|
||
|
||
# 注释
|
||
lines.append(f"// Generated from SPIR-V reflection")
|
||
lines.append(f"// Binding: {buffer.binding}, Descriptor Set: {buffer.descriptor_set}")
|
||
lines.append(f"// Type: {buffer.descriptor_type}")
|
||
|
||
# 计算对齐
|
||
alignment = calculate_std430_alignment(struct, type_map)
|
||
|
||
# 结构体声明
|
||
lines.append(f"struct alignas({alignment}) {struct_name} {{")
|
||
|
||
# 成员定义
|
||
for member in struct.members:
|
||
member_lines = generate_member_definition(member, type_map)
|
||
for line in member_lines:
|
||
lines.append(f" {line}")
|
||
|
||
lines.append("};")
|
||
|
||
return lines
|
||
|
||
|
||
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成结构体成员定义"""
|
||
lines = []
|
||
|
||
member_type = member.resolved_type
|
||
if member_type is None:
|
||
member_type = type_map.get(member.type_id)
|
||
|
||
cpp_type = spirv_type_to_cpp(member_type, type_map)
|
||
alignment = calculate_std430_alignment(member_type, type_map) if member_type else 4
|
||
size = calculate_std430_size(member_type, type_map) if member_type else 4
|
||
|
||
# 检查是否是运行时数组
|
||
is_runtime_array = isinstance(member_type, ArrayTypeInfo) and member_type.is_runtime
|
||
|
||
# 生成注释
|
||
if is_runtime_array:
|
||
lines.append(f"// Offset: {member.offset}, Size: runtime, Alignment: {alignment}")
|
||
else:
|
||
lines.append(f"// Offset: {member.offset}, Size: {size}, Alignment: {alignment}")
|
||
|
||
# 生成成员声明
|
||
if is_runtime_array:
|
||
# 运行时数组使用灵活数组成员
|
||
lines.append(f"{cpp_type} {member.name}[]; // Runtime-sized array")
|
||
elif alignment > 4:
|
||
# 需要显式对齐
|
||
lines.append(f"alignas({alignment}) {cpp_type} {member.name};")
|
||
else:
|
||
lines.append(f"{cpp_type} {member.name};")
|
||
|
||
return lines
|
||
|
||
|
||
# ============ SPIR-V Array Generation ============
|
||
|
||
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
|
||
"""将SPIR-V二进制转换为C++ uint32_t数组"""
|
||
if len(spirv_data) % 4 != 0:
|
||
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
|
||
|
||
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
|
||
|
||
lines = []
|
||
lines.append(f"static constexpr uint32_t {symbol_name}[] = {{")
|
||
|
||
for i in range(0, len(uint32_array), 12):
|
||
chunk = uint32_array[i : i + 12]
|
||
hex_values = ", ".join(f"0x{val:08x}" for val in chunk)
|
||
lines.append(f" {hex_values},")
|
||
|
||
if lines[-1].endswith(","):
|
||
lines[-1] = lines[-1][:-1]
|
||
|
||
lines.append("};")
|
||
return "\n".join(lines)
|
||
|
||
|
||
# ============ Buffer Helper Functions Generation ============
|
||
|
||
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 数据操作的辅助函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ Buffer Helper Functions ============")
|
||
lines.append("// Auto-generated convenience functions for buffer data management")
|
||
lines.append("")
|
||
|
||
for buffer in reflection.buffers:
|
||
# 生成创建函数
|
||
lines.extend(_generate_create_buffer_function(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
# 生成上传函数
|
||
lines.extend(_generate_upload_function(buffer))
|
||
lines.append("")
|
||
|
||
# 生成下载函数
|
||
lines.extend(_generate_download_function(buffer))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""生成创建 buffer 的便捷函数"""
|
||
lines = []
|
||
|
||
func_name = f"create_{buffer.variable_name}_buffer"
|
||
struct_name = buffer.name
|
||
|
||
lines.append(f"// Create buffer for {buffer.variable_name} (binding {buffer.binding})")
|
||
lines.append(f"inline auto {func_name}(")
|
||
lines.append(f" mirage::render::vulkan::resource_manager& rm,")
|
||
lines.append(f" size_t element_count")
|
||
lines.append(f") -> mirage::render::vulkan::expected<mirage::render::vulkan::resource_manager::buffer_resource>")
|
||
lines.append(f"{{")
|
||
|
||
# 确定 buffer 用途标志
|
||
usage_flags = "vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferDst | vk::BufferUsageFlagBits::eTransferSrc"
|
||
|
||
# 确定内存属性标志
|
||
memory_flags = "vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent"
|
||
|
||
lines.append(f" return rm.create_buffer(")
|
||
lines.append(f" element_count * sizeof({struct_name}),")
|
||
lines.append(f" {usage_flags},")
|
||
lines.append(f" {memory_flags}")
|
||
lines.append(f" );")
|
||
lines.append(f"}}")
|
||
|
||
return lines
|
||
|
||
|
||
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成上传数据到 buffer 的函数"""
|
||
lines = []
|
||
|
||
func_name = f"upload_{buffer.variable_name}"
|
||
struct_name = buffer.name
|
||
|
||
lines.append(f"// Upload data to {buffer.variable_name} buffer")
|
||
lines.append(f"inline void {func_name}(")
|
||
lines.append(f" vk::Device device,")
|
||
lines.append(f" const mirage::render::vulkan::resource_manager::buffer_resource& buffer,")
|
||
lines.append(f" std::span<const {struct_name}> data")
|
||
lines.append(f")")
|
||
lines.append(f"{{")
|
||
lines.append(f" if (data.size_bytes() > buffer.size) {{")
|
||
lines.append(f" throw std::runtime_error(\"Data size exceeds buffer capacity\");")
|
||
lines.append(f" }}")
|
||
lines.append(f" ")
|
||
lines.append(f" void* mapped = device.mapMemory(buffer.memory, 0, buffer.size);")
|
||
lines.append(f" std::memcpy(mapped, data.data(), data.size_bytes());")
|
||
lines.append(f" device.unmapMemory(buffer.memory);")
|
||
lines.append(f"}}")
|
||
|
||
return lines
|
||
|
||
|
||
def _generate_download_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成从 buffer 下载数据的函数"""
|
||
lines = []
|
||
|
||
func_name = f"download_{buffer.variable_name}"
|
||
struct_name = buffer.name
|
||
|
||
lines.append(f"// Download data from {buffer.variable_name} buffer")
|
||
lines.append(f"inline void {func_name}(")
|
||
lines.append(f" vk::Device device,")
|
||
lines.append(f" const mirage::render::vulkan::resource_manager::buffer_resource& buffer,")
|
||
lines.append(f" std::span<{struct_name}> out_data")
|
||
lines.append(f")")
|
||
lines.append(f"{{")
|
||
lines.append(f" if (out_data.size_bytes() > buffer.size) {{")
|
||
lines.append(f" throw std::runtime_error(\"Output buffer size exceeds buffer capacity\");")
|
||
lines.append(f" }}")
|
||
lines.append(f" ")
|
||
lines.append(f" void* mapped = device.mapMemory(buffer.memory, 0, buffer.size);")
|
||
lines.append(f" std::memcpy(out_data.data(), mapped, out_data.size_bytes());")
|
||
lines.append(f" device.unmapMemory(buffer.memory);")
|
||
lines.append(f"}}")
|
||
|
||
return lines
|
||
|
||
|
||
# ============ Typed Buffer Aliases Generation ============
|
||
|
||
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成类型安全的 buffer 类型别名和工厂函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ Typed Buffer Aliases ============")
|
||
lines.append("// Type aliases and factory functions for type-safe buffer wrappers")
|
||
lines.append("")
|
||
|
||
for buffer in reflection.buffers:
|
||
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""为特定 buffer 生成类型别名和工厂函数"""
|
||
lines = []
|
||
|
||
struct_name = buffer.name
|
||
var_name = buffer.variable_name
|
||
alias_name = f"{var_name}_buffer"
|
||
factory_name = f"create_{var_name}_typed"
|
||
|
||
# 类型别名
|
||
lines.append(f"// Type alias for {var_name} typed buffer")
|
||
lines.append(f"using {alias_name} = mirage::render::vulkan::typed_buffer<{struct_name}>;")
|
||
lines.append("")
|
||
|
||
# 工厂函数
|
||
lines.append(f"// Create typed buffer for {var_name} (binding {buffer.binding})")
|
||
lines.append(f"inline auto {factory_name}(")
|
||
lines.append(f" vk::Device device,")
|
||
lines.append(f" mirage::render::vulkan::resource_manager& rm,")
|
||
lines.append(f" size_t count")
|
||
lines.append(f") -> mirage::render::vulkan::expected<{alias_name}>")
|
||
lines.append(f"{{")
|
||
lines.append(f" auto buffer_result = create_{var_name}_buffer(rm, count);")
|
||
lines.append(f" if (!buffer_result) {{")
|
||
lines.append(f" return std::unexpected(buffer_result.error());")
|
||
lines.append(f" }}")
|
||
lines.append(f" ")
|
||
lines.append(f" return {alias_name}(device, std::move(*buffer_result), count);")
|
||
lines.append(f"}}")
|
||
|
||
return lines
|
||
|
||
|
||
# ============ Buffer Manager Class Generation ============
|
||
|
||
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 管理器类"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ Buffer Manager Class ============")
|
||
lines.append("// Automatic lifecycle management for all shader buffers")
|
||
lines.append("")
|
||
|
||
class_name = f"{shader_name}_buffer_manager"
|
||
|
||
# 类定义开始
|
||
lines.append(f"class {class_name} {{")
|
||
lines.append("public:")
|
||
|
||
# buffers 结构体
|
||
lines.extend(_generate_buffers_struct(reflection.buffers))
|
||
lines.append("")
|
||
|
||
# 静态工厂方法
|
||
lines.extend(_generate_manager_factory(reflection.buffers, class_name))
|
||
lines.append("")
|
||
|
||
# 初始化方法
|
||
lines.extend(_generate_initialize_method(reflection.buffers))
|
||
lines.append("")
|
||
|
||
# 绑定到 descriptor set 的方法
|
||
lines.extend(_generate_bind_method(reflection.buffers))
|
||
lines.append("")
|
||
|
||
# 访问器
|
||
lines.append(" // Access buffers")
|
||
lines.append(" auto& get_buffers() { return buffers_; }")
|
||
lines.append(" const auto& get_buffers() const { return buffers_; }")
|
||
lines.append("")
|
||
|
||
# 私有部分
|
||
lines.append("private:")
|
||
lines.extend(_generate_manager_constructor(reflection.buffers, class_name))
|
||
lines.append("")
|
||
|
||
# 成员变量
|
||
lines.append(" vk::Device device_;")
|
||
lines.append(" buffers buffers_;")
|
||
|
||
lines.append("};")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成 buffers 结构体"""
|
||
lines = []
|
||
|
||
lines.append(" // Container for all buffers")
|
||
lines.append(" struct buffers {")
|
||
|
||
for buffer in buffers:
|
||
var_name = buffer.variable_name
|
||
alias_name = f"{var_name}_buffer"
|
||
lines.append(f" {alias_name} {var_name};")
|
||
|
||
lines.append(" };")
|
||
|
||
return lines
|
||
|
||
|
||
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的工厂方法"""
|
||
lines = []
|
||
|
||
lines.append(" // Factory method to create buffer manager")
|
||
lines.append(" static auto create(")
|
||
lines.append(" vk::Device device,")
|
||
lines.append(" mirage::render::vulkan::resource_manager& rm,")
|
||
lines.append(" size_t element_count")
|
||
lines.append(f" ) -> mirage::render::vulkan::expected<{class_name}>")
|
||
lines.append(" {")
|
||
|
||
# 创建所有 buffers
|
||
for i, buffer in enumerate(buffers):
|
||
var_name = buffer.variable_name
|
||
factory_name = f"create_{var_name}_typed"
|
||
result_var = f"{var_name}_result"
|
||
|
||
lines.append(f" auto {result_var} = {factory_name}(device, rm, element_count);")
|
||
lines.append(f" if (!{result_var}) {{")
|
||
lines.append(f" return std::unexpected({result_var}.error());")
|
||
lines.append(f" }}")
|
||
if i < len(buffers) - 1:
|
||
lines.append("")
|
||
|
||
lines.append("")
|
||
lines.append(f" return {class_name}(")
|
||
lines.append(" device,")
|
||
|
||
for i, buffer in enumerate(buffers):
|
||
var_name = buffer.variable_name
|
||
result_var = f"{var_name}_result"
|
||
comma = "," if i < len(buffers) - 1 else ""
|
||
lines.append(f" std::move(*{result_var}){comma}")
|
||
|
||
lines.append(" );")
|
||
lines.append(" }")
|
||
|
||
return lines
|
||
|
||
|
||
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成初始化方法"""
|
||
lines = []
|
||
|
||
lines.append(" // Initialize all buffers with data")
|
||
lines.append(" void initialize(")
|
||
|
||
for i, buffer in enumerate(buffers):
|
||
var_name = buffer.variable_name
|
||
struct_name = buffer.name
|
||
comma = "," if i < len(buffers) - 1 else ""
|
||
lines.append(f" std::span<const {struct_name}> {var_name}_data{comma}")
|
||
|
||
lines.append(" ) {")
|
||
|
||
for buffer in buffers:
|
||
var_name = buffer.variable_name
|
||
lines.append(f" buffers_.{var_name}.upload({var_name}_data);")
|
||
|
||
lines.append(" }")
|
||
|
||
return lines
|
||
|
||
|
||
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成绑定到 descriptor set 的方法"""
|
||
lines = []
|
||
|
||
lines.append(" // Bind all buffers to a descriptor set")
|
||
lines.append(" void bind_to_descriptor_set(")
|
||
lines.append(" const mirage::render::vulkan::compute_pipeline& pipeline,")
|
||
lines.append(" vk::DescriptorSet descriptor_set")
|
||
lines.append(" ) const {")
|
||
|
||
for buffer in buffers:
|
||
var_name = buffer.variable_name
|
||
binding = buffer.binding
|
||
descriptor_type = "vk::DescriptorType::eStorageBuffer" if buffer.descriptor_type == "storage_buffer" else "vk::DescriptorType::eUniformBuffer"
|
||
|
||
lines.append(f" pipeline.bind_buffer(")
|
||
lines.append(f" descriptor_set,")
|
||
lines.append(f" {binding},")
|
||
lines.append(f" buffers_.{var_name}.get_buffer(),")
|
||
lines.append(f" {descriptor_type}")
|
||
lines.append(f" );")
|
||
|
||
lines.append(" }")
|
||
|
||
return lines
|
||
|
||
|
||
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的私有构造函数"""
|
||
lines = []
|
||
|
||
lines.append(f" {class_name}(")
|
||
lines.append(" vk::Device device,")
|
||
|
||
for i, buffer in enumerate(buffers):
|
||
var_name = buffer.variable_name
|
||
alias_name = f"{var_name}_buffer"
|
||
comma = "," if i < len(buffers) - 1 else ""
|
||
lines.append(f" {alias_name} {var_name}{comma}")
|
||
|
||
lines.append(" ) : device_(device),")
|
||
lines.append(" buffers_{")
|
||
|
||
for i, buffer in enumerate(buffers):
|
||
var_name = buffer.variable_name
|
||
comma = "," if i < len(buffers) - 1 else ""
|
||
lines.append(f" .{var_name} = std::move({var_name}){comma}")
|
||
|
||
lines.append(" } {}")
|
||
|
||
return lines
|
||
|
||
|
||
# ============ Vulkan Type Conversion ============
|
||
|
||
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
|
||
"""将描述符类型转换为Vulkan C++ enum"""
|
||
type_map = {
|
||
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
|
||
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
|
||
"sampled_image": "vk::DescriptorType::eSampledImage",
|
||
"storage_image": "vk::DescriptorType::eStorageImage",
|
||
"sampler": "vk::DescriptorType::eSampler",
|
||
}
|
||
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
|
||
|
||
|
||
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
|
||
"""将着色器阶段列表转换为Vulkan C++ flags"""
|
||
stage_map = {
|
||
"vertex": "vk::ShaderStageFlagBits::eVertex",
|
||
"fragment": "vk::ShaderStageFlagBits::eFragment",
|
||
"compute": "vk::ShaderStageFlagBits::eCompute",
|
||
"geometry": "vk::ShaderStageFlagBits::eGeometry",
|
||
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
|
||
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
|
||
}
|
||
|
||
if not stages:
|
||
stages = ["compute"]
|
||
|
||
flags = [stage_map[stage] for stage in stages if stage in stage_map]
|
||
if not flags:
|
||
flags = ["vk::ShaderStageFlagBits::eCompute"]
|
||
|
||
return " | ".join(flags)
|
||
|
||
|
||
# ============ Header File Generation ============
|
||
|
||
def generate_header(
|
||
metadata: ShaderMetadata,
|
||
compilation_results: List[CompilationResult],
|
||
) -> str:
|
||
"""生成完整的C++头文件"""
|
||
lines = []
|
||
|
||
lines.append("#pragma once")
|
||
lines.append("")
|
||
lines.append("#include <cstdint>")
|
||
lines.append("#include <span>")
|
||
lines.append("#include <array>")
|
||
lines.append("#include <vulkan/vulkan.hpp>")
|
||
lines.append("#include <Eigen/Eigen>")
|
||
|
||
# 如果启用了 typed buffer 功能,添加相关头文件
|
||
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
|
||
lines.append("#include \"vulkan/typed_buffer.h\"")
|
||
|
||
lines.append("")
|
||
|
||
if metadata.namespace:
|
||
lines.append(f"namespace {metadata.namespace} {{")
|
||
lines.append("")
|
||
|
||
# 生成结构体定义(如果有反射信息)
|
||
if metadata.reflection and metadata.reflection.buffers:
|
||
struct_code = generate_buffer_structures(metadata.reflection)
|
||
lines.append(struct_code)
|
||
lines.append("")
|
||
|
||
# 生成 buffer 辅助函数
|
||
if metadata.generate_buffer_helpers:
|
||
helper_code = generate_buffer_helper_functions(metadata.reflection, metadata.name)
|
||
if helper_code:
|
||
lines.append(helper_code)
|
||
lines.append("")
|
||
|
||
# 生成 typed buffer 类型别名和工厂函数
|
||
if metadata.generate_typed_buffers:
|
||
typed_buffer_code = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
|
||
if typed_buffer_code:
|
||
lines.append(typed_buffer_code)
|
||
lines.append("")
|
||
|
||
# 生成 buffer 管理器类
|
||
if metadata.generate_buffer_manager:
|
||
manager_code = generate_buffer_manager_class(metadata.reflection, metadata.name)
|
||
if manager_code:
|
||
lines.append(manager_code)
|
||
lines.append("")
|
||
|
||
# SPIR-V二进制数组
|
||
for result in compilation_results:
|
||
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
|
||
lines.append(f"// {result.shader_type.capitalize()} shader: {result.source_file.name}")
|
||
lines.append("// Compiled with: glslc ... --target-env=vulkan1.2")
|
||
lines.append(spirv_to_cpp_array(result.spirv_data, symbol_name))
|
||
lines.append("")
|
||
|
||
# 元数据
|
||
metadata_struct_name = f"{metadata.name}_metadata"
|
||
lines.append(f"struct {metadata_struct_name} {{")
|
||
lines.append(f' static constexpr std::string_view name = "{metadata.name}";')
|
||
lines.append(f' static constexpr std::string_view entry_point = "{compilation_results[0].entry_point}";')
|
||
lines.append("};")
|
||
lines.append("")
|
||
|
||
# 绑定信息
|
||
if metadata.bindings:
|
||
bindings_name = f"{metadata.name}_bindings"
|
||
lines.append(
|
||
f"inline constexpr std::array<mirage::render::vulkan::compute_pipeline::binding_info, {len(metadata.bindings)}>"
|
||
)
|
||
lines.append(f"{bindings_name} = {{{{")
|
||
|
||
for binding in metadata.bindings:
|
||
descriptor_type = vk_descriptor_type_to_cpp(binding.descriptor_type)
|
||
stage_flags = vk_shader_stage_to_cpp(binding.stages)
|
||
lines.append(" mirage::render::vulkan::compute_pipeline::binding_info{")
|
||
lines.append(f" .binding = {binding.binding},")
|
||
lines.append(f" .type = {descriptor_type},")
|
||
lines.append(f" .count = {binding.count},")
|
||
lines.append(f" .stage = {stage_flags},")
|
||
lines.append(" },")
|
||
|
||
lines.append("}};")
|
||
lines.append("")
|
||
|
||
# 工厂函数
|
||
if metadata.shader_type == "compute":
|
||
result = compilation_results[0]
|
||
spirv_array = f"{metadata.name}_{result.shader_type}_spirv"
|
||
|
||
lines.append("// ============ Factory Function ============")
|
||
lines.append(
|
||
f"inline auto create_{metadata.name}_pipeline(const mirage::render::vulkan::logical_device& device)"
|
||
)
|
||
lines.append(
|
||
" -> mirage::render::vulkan::expected<mirage::render::vulkan::compute_pipeline>"
|
||
)
|
||
lines.append("{")
|
||
lines.append(f" return mirage::render::vulkan::compute_pipeline::create_from_spv(")
|
||
lines.append(f" device,")
|
||
lines.append(f" std::span<const uint32_t>({spirv_array}),")
|
||
|
||
if metadata.bindings:
|
||
lines.append(f" {metadata.name}_bindings")
|
||
else:
|
||
lines.append(" std::span<const mirage::render::vulkan::compute_pipeline::binding_info>{{}}")
|
||
|
||
lines.append(f" );")
|
||
lines.append("}")
|
||
lines.append("")
|
||
|
||
if metadata.namespace:
|
||
lines.append(f"}} // namespace {metadata.namespace}")
|
||
lines.append("")
|
||
|
||
return "\n".join(lines) |