172 lines
5.5 KiB
Python
172 lines
5.5 KiB
Python
from typing import Tuple
|
||
|
||
from shader_reflection_type import Field, TypeKind, ScalarType
|
||
|
||
format_mapping = {
|
||
('int32', 1): 'SG_VERTEXFORMAT_INVALID',
|
||
('int32', 2): 'SG_VERTEXFORMAT_INVALID',
|
||
('int32', 3): 'SG_VERTEXFORMAT_INVALID',
|
||
('int32', 4): 'SG_VERTEXFORMAT_INVALID',
|
||
('uint32', 1): 'SG_VERTEXFORMAT_INVALID',
|
||
('uint32', 2): 'SG_VERTEXFORMAT_INVALID',
|
||
('uint32', 3): 'SG_VERTEXFORMAT_INVALID',
|
||
('uint32', 4): 'SG_VERTEXFORMAT_INVALID',
|
||
('float32', 1): 'SG_VERTEXFORMAT_FLOAT',
|
||
('float32', 2): 'SG_VERTEXFORMAT_FLOAT2',
|
||
('float32', 3): 'SG_VERTEXFORMAT_FLOAT3',
|
||
('float32', 4): 'SG_VERTEXFORMAT_FLOAT4',
|
||
('int8', 2): 'SG_VERTEXFORMAT_INVALID',
|
||
('int8', 4): 'SG_VERTEXFORMAT_BYTE4',
|
||
('uint8', 2): 'SG_VERTEXFORMAT_INVALID',
|
||
('uint8', 4): 'SG_VERTEXFORMAT_UBYTE4',
|
||
('int16', 2): 'SG_VERTEXFORMAT_SHORT2',
|
||
('int16', 4): 'SG_VERTEXFORMAT_SHORT4',
|
||
('uint16', 2): 'SG_VERTEXFORMAT_USHORT2N',
|
||
('uint16', 4): 'SG_VERTEXFORMAT_SHORT4',
|
||
('float16', 2): 'SG_VERTEXFORMAT_HALF2',
|
||
('float16', 4): 'SG_VERTEXFORMAT_HALF4',
|
||
}
|
||
|
||
# 类型大小映射
|
||
type_sizes = {
|
||
ScalarType.INT32: 4,
|
||
ScalarType.UINT32: 4,
|
||
ScalarType.FLOAT32: 4,
|
||
ScalarType.INT8: 1,
|
||
ScalarType.UINT8: 1,
|
||
ScalarType.INT16: 2,
|
||
ScalarType.UINT16: 2,
|
||
ScalarType.FLOAT16: 2, # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||
}
|
||
|
||
# C类型映射表
|
||
c_type_mapping = {
|
||
ScalarType.INT32: 'int32_t',
|
||
ScalarType.UINT32: 'uint32_t',
|
||
ScalarType.FLOAT32: 'float',
|
||
ScalarType.INT8: 'int8_t',
|
||
ScalarType.UINT8: 'uint8_t',
|
||
ScalarType.INT16: 'int16_t',
|
||
ScalarType.UINT16: 'uint16_t',
|
||
ScalarType.FLOAT16: 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||
}
|
||
|
||
# 对齐要求映射表
|
||
alignment_mapping = {
|
||
ScalarType.INT32: 4,
|
||
ScalarType.UINT32: 4,
|
||
ScalarType.FLOAT32: 4,
|
||
ScalarType.INT8: 1,
|
||
ScalarType.UINT8: 1,
|
||
ScalarType.INT16: 2,
|
||
ScalarType.UINT16: 2,
|
||
ScalarType.FLOAT16: 2,
|
||
}
|
||
|
||
def get_type_info(field_type: Field) -> Tuple[str, int, int]:
|
||
"""获取类型信息:C类型名称、元素数量、字节大小
|
||
|
||
Returns:
|
||
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
|
||
"""
|
||
kind = field_type.type.kind
|
||
scalar_type = field_type.type.scalarType
|
||
element_count = field_type.type.elementCount
|
||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||
|
||
if kind == TypeKind.SCALAR:
|
||
size = type_sizes.get(scalar_type, 4)
|
||
return c_type, 1, size
|
||
|
||
elif kind == TypeKind.VECTOR:
|
||
# 验证向量元素数量的合法性
|
||
if element_count not in [1, 2, 3, 4]:
|
||
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
|
||
|
||
size = type_sizes.get(scalar_type, 4) * element_count
|
||
|
||
# 对于vec3,通常需要按vec4对齐
|
||
if element_count == 3:
|
||
alignment = alignment_mapping.get(scalar_type, 4) * 4
|
||
size = ((size + alignment - 1) // alignment) * alignment
|
||
|
||
return c_type, element_count, size
|
||
|
||
elif kind == TypeKind.MATRIX:
|
||
row_count = field_type.type.rowCount or 4
|
||
column_count = field_type.type.columnCount or 4
|
||
|
||
# 验证矩阵维度的合法性
|
||
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
|
||
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
|
||
f"Rows and columns must be 2, 3, or 4.")
|
||
|
||
# 矩阵通常按列主序存储,每列都需要对齐
|
||
element_size = type_sizes.get(scalar_type, 4)
|
||
|
||
# 每列的大小(考虑对齐)
|
||
if row_count == 3:
|
||
# mat3xN 的每列通常按 vec4 对齐
|
||
column_size = element_size * 4
|
||
else:
|
||
column_size = element_size * row_count
|
||
|
||
total_size = column_size * column_count
|
||
total_elements = row_count * column_count
|
||
|
||
return c_type, total_elements, total_size
|
||
|
||
else:
|
||
# 未知类型,返回默认值并记录警告
|
||
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
|
||
return 'float', 1, 4
|
||
|
||
def get_c_type_declaration(field_type: Field, field_name: str) -> str:
|
||
"""生成C类型声明
|
||
|
||
Args:
|
||
field_type: 字段类型信息
|
||
field_name: 字段名称
|
||
|
||
Returns:
|
||
str: 完整的C类型声明
|
||
"""
|
||
c_type, count, _ = get_type_info(field_type)
|
||
kind = field_type.type.kind
|
||
|
||
if kind == TypeKind.SCALAR:
|
||
return f"{c_type} {field_name}"
|
||
|
||
elif kind == TypeKind.VECTOR:
|
||
# 对于向量,可以选择使用数组或专门的向量类型
|
||
if count == 1:
|
||
return f"{c_type} {field_name}"
|
||
else:
|
||
# 选项1: 使用数组
|
||
return f"{c_type} {field_name}[{count}]"
|
||
|
||
# 选项2: 使用对齐的结构体(如果需要的话)
|
||
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
|
||
|
||
elif kind == TypeKind.MATRIX:
|
||
rows = field_type.type.rowCount or 4
|
||
cols = field_type.type.columnCount or 4
|
||
|
||
# 矩阵使用二维数组(列主序)
|
||
return f"{c_type} {field_name}[{cols}][{rows}]"
|
||
|
||
else:
|
||
return f"{c_type} {field_name}"
|
||
|
||
def get_aligned_size(size: int, alignment: int) -> int:
|
||
"""计算对齐后的大小
|
||
|
||
Args:
|
||
size: 原始大小
|
||
alignment: 对齐要求
|
||
|
||
Returns:
|
||
int: 对齐后的大小
|
||
"""
|
||
return ((size + alignment - 1) // alignment) * alignment
|