Files
mirage/tools/code_generator.py
daiqingshuang 358162415b feat: Add post-processing effects and canvas widget
- Implemented a fullscreen quad vertex shader for post-processing effects.
- Added a noise fragment shader to apply noise effects to textures.
- Created a vignette fragment shader for darkening corners of the screen.
- Developed a canvas widget to manage child widgets with anchor points and flexible sizing.
- Introduced an effect chain widget to apply multiple post-processing effects in sequence.
- Added overlay widget to position child widgets with alignment and padding options.
- Implemented a post effect widget to apply effects like blur, vignette, and color adjustments to child widgets.
- Provided convenience functions for easily applying common effects to widgets.
2025-11-27 09:52:00 +08:00

354 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
代码生成模块
负责生成C++代码包括结构体定义、SPIR-V数组、头文件和绑定信息。
使用Jinja2模板引擎代替字符串拼接。
"""
from __future__ import annotations
import struct
from typing import Dict, List
from .template_loader import get_renderer
from .type_mapping import (
calculate_std430_alignment,
calculate_std430_size,
spirv_type_to_cpp,
)
from .types import (
ArrayTypeInfo,
BufferInfo,
CompilationResult,
MemberInfo,
ShaderMetadata,
SPIRVReflection,
ToolError,
TypeInfo,
)
# ============ Structure Generation ============
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
"""生成所有buffer/uniform结构体定义"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Buffer Structures ============")
lines.append("// Auto-generated from SPIR-V reflection")
lines.append("")
for buffer in reflection.buffers:
struct_lines = generate_single_struct(buffer, reflection.types)
lines.extend(struct_lines)
lines.append("")
return "\n".join(lines)
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成单个结构体定义"""
renderer = get_renderer()
context = {
'buffer': buffer,
'type_map': type_map,
'alignment': calculate_std430_alignment(buffer.struct_type, type_map),
}
result = renderer.render('base/struct.jinja2', context)
return result.splitlines()
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成结构体成员定义"""
renderer = get_renderer()
context = {
'member': member,
'type_map': type_map,
}
result = renderer.render('base/struct_member.jinja2', context)
return result.splitlines()
# ============ SPIR-V Array Generation ============
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
"""将SPIR-V二进制转换为C++ uint32_t数组"""
if len(spirv_data) % 4 != 0:
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
# 将数组分成每行12个元素的块
chunks = []
for i in range(0, len(uint32_array), 12):
chunk = uint32_array[i : i + 12]
hex_values = [f"0x{val:08x}" for val in chunk]
chunks.append(hex_values)
renderer = get_renderer()
context = {
'symbol_name': symbol_name,
'chunks': chunks,
}
return renderer.render('base/spirv_array.jinja2', context)
# ============ 缓冲区辅助函数生成 ============
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 数据操作的辅助函数"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ 缓冲区辅助函数 ============")
lines.append("// 自动生成缓冲区数据管理的便利函数")
lines.append("")
for buffer in reflection.buffers:
# 生成创建函数
lines.extend(_generate_create_buffer_function(buffer, shader_name))
lines.append("")
# 生成上传函数
lines.extend(_generate_upload_function(buffer))
lines.append("")
# 生成下载函数
lines.extend(_generate_download_function(buffer))
lines.append("")
return "\n".join(lines)
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
"""生成创建 buffer 的便捷函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('buffer_helpers/create_buffer.jinja2', context)
return result.splitlines()
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
"""生成上传数据到 buffer 的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/upload.jinja2', context)
return result.splitlines()
def _generate_download_function(buffer: BufferInfo) -> List[str]:
"""生成从 buffer 下载数据的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/download.jinja2', context)
return result.splitlines()
# ============ Typed Buffer Aliases Generation ============
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成类型安全的 buffer 类型别名和工厂函数"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Typed Buffer Aliases ============")
lines.append("// Type aliases and factory functions for type-safe buffer wrappers")
lines.append("")
for buffer in reflection.buffers:
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
lines.append("")
return "\n".join(lines)
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
"""为特定 buffer 生成类型别名和工厂函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('typed_buffer/alias_and_factory.jinja2', context)
return result.splitlines()
# ============ Buffer Manager Class Generation ============
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 管理器类"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Buffer Manager Class ============")
lines.append("// Automatic lifecycle management for all shader buffers")
lines.append("")
class_name = f"{shader_name}_buffer_manager"
# 判断是否是计算着色器
is_compute = reflection.shader_stage == "compute"
renderer = get_renderer()
context = {
'class_name': class_name,
'buffers': reflection.buffers,
'is_compute': is_compute,
}
result = renderer.render('buffer_manager/manager_class.jinja2', context)
lines.append(result)
return "\n".join(lines)
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
"""生成 buffers 结构体(已被模板替代,保留用于兼容性)"""
lines = []
lines.append(" // Container for all buffers")
lines.append(" struct buffers {")
for buffer in buffers:
var_name = buffer.variable_name
alias_name = f"{var_name}_buffer"
lines.append(f" {alias_name} {var_name};")
lines.append(" };")
return lines
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的工厂方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
"""生成初始化方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
"""生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的私有构造函数(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
# ============ Vulkan Type Conversion ============
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
"""将描述符类型转换为Vulkan C++ enum"""
type_map = {
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
"sampled_image": "vk::DescriptorType::eSampledImage",
"storage_image": "vk::DescriptorType::eStorageImage",
"sampler": "vk::DescriptorType::eSampler",
}
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
"""将着色器阶段列表转换为Vulkan C++ flags"""
stage_map = {
"vertex": "vk::ShaderStageFlagBits::eVertex",
"fragment": "vk::ShaderStageFlagBits::eFragment",
"compute": "vk::ShaderStageFlagBits::eCompute",
"geometry": "vk::ShaderStageFlagBits::eGeometry",
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
}
if not stages:
stages = ["compute"]
flags = [stage_map[stage] for stage in stages if stage in stage_map]
if not flags:
flags = ["vk::ShaderStageFlagBits::eCompute"]
return " | ".join(flags)
# ============ Header File Generation ============
def generate_header(
metadata: ShaderMetadata,
compilation_results: List[CompilationResult],
) -> str:
"""生成完整的C++头文件"""
# 准备buffer结构体代码
buffer_structures = ""
if metadata.reflection and metadata.reflection.buffers:
buffer_structures = generate_buffer_structures(metadata.reflection)
# 准备buffer辅助函数代码
buffer_helpers = ""
if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers:
buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name)
# 准备typed buffer代码
typed_buffers = ""
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
# 准备buffer管理器代码
buffer_manager = ""
if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers:
buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name)
# 准备SPIR-V数组
compilation_results_with_arrays = []
for result in compilation_results:
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name)
compilation_results_with_arrays.append({
'shader_type': result.shader_type,
'source_file': result.source_file,
'spirv_array': spirv_array,
})
# 准备绑定信息
bindings_data = []
if metadata.bindings:
for binding in metadata.bindings:
bindings_data.append({
'binding': binding.binding,
'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type),
'count': binding.count,
'stage_flags': vk_shader_stage_to_cpp(binding.stages),
})
# 使用模板渲染
renderer = get_renderer()
context = {
'namespace': metadata.namespace,
'shader_name': metadata.name,
'metadata_struct_name': f"{metadata.name}_metadata",
'bindings_name': f"{metadata.name}_bindings",
'entry_point': compilation_results[0].entry_point if compilation_results else "main",
'is_compute': metadata.shader_type == "compute",
'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "",
'has_buffers': metadata.reflection and metadata.reflection.buffers,
'generate_typed_buffers': metadata.generate_typed_buffers,
'buffer_structures': buffer_structures,
'buffer_helpers': buffer_helpers,
'typed_buffers': typed_buffers,
'buffer_manager': buffer_manager,
'compilation_results': compilation_results_with_arrays,
'bindings': bindings_data,
}
return renderer.render('base/header.jinja2', context)