177 lines
5.8 KiB
Python
177 lines
5.8 KiB
Python
from typing import Tuple
|
||
|
||
from shader_types import FieldType
|
||
|
||
# SDL GPU 格式映射
|
||
format_mapping = {
|
||
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
|
||
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
|
||
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
|
||
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
|
||
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
|
||
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
|
||
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
|
||
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
|
||
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
|
||
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
|
||
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
|
||
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
|
||
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
|
||
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
|
||
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
|
||
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
|
||
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
|
||
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
|
||
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
|
||
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
|
||
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
|
||
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
|
||
}
|
||
|
||
# 类型大小映射
|
||
type_sizes = {
|
||
'int32': 4,
|
||
'uint32': 4,
|
||
'float32': 4,
|
||
'int8': 1,
|
||
'uint8': 1,
|
||
'int16': 2,
|
||
'uint16': 2,
|
||
'float16': 2,
|
||
}
|
||
|
||
def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
|
||
"""获取类型信息:C类型名称、元素数量、字节大小
|
||
|
||
Returns:
|
||
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
|
||
"""
|
||
# C类型映射表
|
||
c_type_mapping = {
|
||
'int32': 'int32_t',
|
||
'uint32': 'uint32_t',
|
||
'float32': 'float',
|
||
'int8': 'int8_t',
|
||
'uint8': 'uint8_t',
|
||
'int16': 'int16_t',
|
||
'uint16': 'uint16_t',
|
||
'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||
}
|
||
|
||
# 对齐要求映射表
|
||
alignment_mapping = {
|
||
'int32': 4,
|
||
'uint32': 4,
|
||
'float32': 4,
|
||
'int8': 1,
|
||
'uint8': 1,
|
||
'int16': 2,
|
||
'uint16': 2,
|
||
'float16': 2,
|
||
}
|
||
|
||
if field_type.kind == 'scalar':
|
||
scalar_type = field_type.scalar_type or 'float32'
|
||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||
size = type_sizes.get(scalar_type, 4)
|
||
return c_type, 1, size
|
||
|
||
elif field_type.kind == 'vector':
|
||
scalar_type = field_type.scalar_type or 'float32'
|
||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||
element_count = field_type.element_count or 4
|
||
|
||
# 验证向量元素数量的合法性
|
||
if element_count not in [1, 2, 3, 4]:
|
||
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
|
||
|
||
size = type_sizes.get(scalar_type, 4) * element_count
|
||
|
||
# 对于vec3,通常需要按vec4对齐
|
||
if element_count == 3:
|
||
alignment = alignment_mapping.get(scalar_type, 4) * 4
|
||
size = ((size + alignment - 1) // alignment) * alignment
|
||
|
||
return c_type, element_count, size
|
||
|
||
elif field_type.kind == 'matrix':
|
||
scalar_type = field_type.scalar_type or 'float32'
|
||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||
|
||
row_count = field_type.row_count or 4
|
||
column_count = field_type.column_count or 4
|
||
|
||
# 验证矩阵维度的合法性
|
||
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
|
||
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
|
||
f"Rows and columns must be 2, 3, or 4.")
|
||
|
||
# 矩阵通常按列主序存储,每列都需要对齐
|
||
element_size = type_sizes.get(scalar_type, 4)
|
||
|
||
# 每列的大小(考虑对齐)
|
||
if row_count == 3:
|
||
# mat3xN 的每列通常按 vec4 对齐
|
||
column_size = element_size * 4
|
||
else:
|
||
column_size = element_size * row_count
|
||
|
||
total_size = column_size * column_count
|
||
total_elements = row_count * column_count
|
||
|
||
return c_type, total_elements, total_size
|
||
|
||
else:
|
||
# 未知类型,返回默认值并记录警告
|
||
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
|
||
return 'float', 1, 4
|
||
|
||
|
||
def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
|
||
"""生成C类型声明
|
||
|
||
Args:
|
||
field_type: 字段类型信息
|
||
field_name: 字段名称
|
||
|
||
Returns:
|
||
str: 完整的C类型声明
|
||
"""
|
||
c_type, count, _ = get_type_info(field_type)
|
||
|
||
if field_type.kind == 'scalar':
|
||
return f"{c_type} {field_name}"
|
||
|
||
elif field_type.kind == 'vector':
|
||
# 对于向量,可以选择使用数组或专门的向量类型
|
||
if count == 1:
|
||
return f"{c_type} {field_name}"
|
||
else:
|
||
# 选项1: 使用数组
|
||
return f"{c_type} {field_name}[{count}]"
|
||
|
||
# 选项2: 使用对齐的结构体(如果需要的话)
|
||
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
|
||
|
||
elif field_type.kind == 'matrix':
|
||
rows = field_type.row_count or 4
|
||
cols = field_type.column_count or 4
|
||
|
||
# 矩阵使用二维数组(列主序)
|
||
return f"{c_type} {field_name}[{cols}][{rows}]"
|
||
|
||
else:
|
||
return f"{c_type} {field_name}"
|
||
|
||
def get_aligned_size(size: int, alignment: int) -> int:
|
||
"""计算对齐后的大小
|
||
|
||
Args:
|
||
size: 原始大小
|
||
alignment: 对齐要求
|
||
|
||
Returns:
|
||
int: 对齐后的大小
|
||
"""
|
||
return ((size + alignment - 1) // alignment) * alignment
|