Files
mirage/tools/type_mapping.py
2025-11-24 22:35:44 +08:00

224 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
类型映射模块
负责将SPIR-V类型映射到C++类型使用Eigen数学库以及计算std430内存布局。
Eigen类型映射说明
- 向量Eigen::Vector{N}{type} (例如 Vector3f, Vector4i)
- 方阵Eigen::Matrix{N}{type} (例如 Matrix3f, Matrix4f)
- 非方阵Eigen::Matrix<type, rows, cols>
- 对齐Eigen 自动处理对齐要求(使用 EIGEN_MAKE_ALIGNED_OPERATOR_NEW
"""
from __future__ import annotations
from typing import Dict
from .types import (
ArrayTypeInfo,
BaseType,
MatrixTypeInfo,
ScalarTypeInfo,
StructTypeInfo,
TypeInfo,
VectorTypeInfo,
)
# ============ Type Conversion ============
def spirv_type_to_cpp(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> str:
"""
将SPIR-V类型转换为C++类型字符串使用Eigen数学库
类型映射规则:
- 标量float, double, int32_t, uint32_t
- 向量Eigen::Vector2f/3f/4f, Vector2i/3i/4i 等
- 矩阵Eigen::Matrix2f/3f/4f 或 Matrix<float, rows, cols>
- 数组std::array<T, N> 或 T[](运行时数组)
"""
if type_info is None:
return "uint32_t"
if isinstance(type_info, ScalarTypeInfo):
type_map_scalar = {
(BaseType.FLOAT, 32): "float",
(BaseType.FLOAT, 64): "double",
(BaseType.DOUBLE, 64): "double",
(BaseType.INT, 32): "int32_t",
(BaseType.UINT, 32): "uint32_t",
(BaseType.BOOL, 32): "uint32_t",
}
return type_map_scalar.get((type_info.base_type, type_info.bit_width), "uint32_t")
elif isinstance(type_info, VectorTypeInfo):
component_type = type_info.resolved_component_type
if not isinstance(component_type, ScalarTypeInfo):
component_type = type_map.get(type_info.component_type_id)
if not isinstance(component_type, ScalarTypeInfo):
return "Eigen::Vector4f"
count = type_info.component_count
# 根据组件类型选择合适的Eigen向量类型
if component_type.base_type == BaseType.FLOAT:
vector_types = {2: "Eigen::Vector2f", 3: "Eigen::Vector3f", 4: "Eigen::Vector4f"}
elif component_type.base_type == BaseType.INT:
vector_types = {2: "Eigen::Vector2i", 3: "Eigen::Vector3i", 4: "Eigen::Vector4i"}
elif component_type.base_type == BaseType.UINT:
# Eigen 没有预定义的无符号整数向量,使用模板形式
vector_types = {2: "Eigen::Vector2<uint32_t>", 3: "Eigen::Vector3<uint32_t>", 4: "Eigen::Vector4<uint32_t>"}
elif component_type.base_type == BaseType.DOUBLE:
vector_types = {2: "Eigen::Vector2d", 3: "Eigen::Vector3d", 4: "Eigen::Vector4d"}
else:
return "Eigen::Vector4f"
return vector_types.get(count, "Eigen::Vector4f")
elif isinstance(type_info, MatrixTypeInfo):
cols = type_info.column_count
rows = type_info.row_count
# Eigen使用行优先存储但GLSL使用列优先
# 生成的矩阵类型需要注意这个差异
if cols == rows:
# 方阵:使用预定义类型
matrix_types = {2: "Eigen::Matrix2f", 3: "Eigen::Matrix3f", 4: "Eigen::Matrix4f"}
return matrix_types.get(cols, "Eigen::Matrix4f")
else:
# 非方阵:使用通用模板形式
# 注意Eigen::Matrix<Scalar, RowsAtCompileTime, ColsAtCompileTime>
return f"Eigen::Matrix<float, {rows}, {cols}>"
elif isinstance(type_info, ArrayTypeInfo):
element_type = type_info.resolved_element_type
if element_type is None:
element_type = type_map.get(type_info.element_type_id)
element_cpp = spirv_type_to_cpp(element_type, type_map)
if type_info.is_runtime:
return element_cpp # 运行时数组,返回元素类型
else:
return f"std::array<{element_cpp}, {type_info.length}>"
elif isinstance(type_info, StructTypeInfo):
return type_info.name or f"Struct_{type_info.id}"
return "uint32_t"
# ============ Layout Calculation ============
def align_up(value: int, alignment: int) -> int:
"""向上对齐到指定边界"""
if alignment == 0:
return value
return ((value + alignment - 1) // alignment) * alignment
def calculate_std430_alignment(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> int:
"""计算std430布局的对齐要求字节"""
if type_info is None:
return 4
if isinstance(type_info, ScalarTypeInfo):
return type_info.bit_width // 8
elif isinstance(type_info, VectorTypeInfo):
count = type_info.component_count
component_type = type_info.resolved_component_type
if component_type is None:
component_type = type_map.get(type_info.component_type_id)
if component_type:
component_align = calculate_std430_alignment(component_type, type_map)
if count == 2:
return 2 * component_align
elif count in (3, 4):
return 4 * component_align
return 16
elif isinstance(type_info, MatrixTypeInfo):
col_type = type_info.resolved_column_type
if col_type is None:
col_type = type_map.get(type_info.column_type_id)
if col_type:
return calculate_std430_alignment(col_type, type_map)
return 16
elif isinstance(type_info, ArrayTypeInfo):
element_type = type_info.resolved_element_type
if element_type is None:
element_type = type_map.get(type_info.element_type_id)
if element_type:
return calculate_std430_alignment(element_type, type_map)
return 16
elif isinstance(type_info, StructTypeInfo):
max_align = 4
for member in type_info.members:
member_type = member.resolved_type
if member_type is None:
member_type = type_map.get(member.type_id)
if member_type:
member_align = calculate_std430_alignment(member_type, type_map)
max_align = max(max_align, member_align)
return max_align
return 4
def calculate_std430_size(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> int:
"""计算std430布局的大小字节"""
if type_info is None:
return 4
if isinstance(type_info, ScalarTypeInfo):
return type_info.bit_width // 8
elif isinstance(type_info, VectorTypeInfo):
count = type_info.component_count
component_type = type_info.resolved_component_type
if component_type is None:
component_type = type_map.get(type_info.component_type_id)
if component_type:
component_size = calculate_std430_size(component_type, type_map)
return count * component_size
return 16
elif isinstance(type_info, MatrixTypeInfo):
col_type = type_info.resolved_column_type
if col_type is None:
col_type = type_map.get(type_info.column_type_id)
if col_type:
col_size = calculate_std430_size(col_type, type_map)
col_align = calculate_std430_alignment(col_type, type_map)
aligned_col_size = align_up(col_size, col_align)
return type_info.column_count * aligned_col_size
return 64
elif isinstance(type_info, ArrayTypeInfo):
if type_info.is_runtime:
return 0 # 运行时数组大小未知
element_type = type_info.resolved_element_type
if element_type is None:
element_type = type_map.get(type_info.element_type_id)
if element_type and type_info.length:
element_size = calculate_std430_size(element_type, type_map)
element_align = calculate_std430_alignment(element_type, type_map)
stride = align_up(element_size, element_align)
return stride * type_info.length
return 0
elif isinstance(type_info, StructTypeInfo):
if not type_info.members:
return 0
last_member = type_info.members[-1]
last_type = last_member.resolved_type
if last_type is None:
last_type = type_map.get(last_member.type_id)
last_size = calculate_std430_size(last_type, type_map) if last_type else 4
total_size = last_member.offset + last_size
struct_align = calculate_std430_alignment(type_info, type_map)
return align_up(total_size, struct_align)
return 4