Files
mirage_slang/code_generator_helper.py

177 lines
5.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from typing import Tuple
from shader_types import FieldType
# SDL GPU 格式映射
format_mapping = {
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
}
# 类型大小映射
type_sizes = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
}
def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
"""获取类型信息C类型名称、元素数量、字节大小
Returns:
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
"""
# C类型映射表
c_type_mapping = {
'int32': 'int32_t',
'uint32': 'uint32_t',
'float32': 'float',
'int8': 'int8_t',
'uint8': 'uint8_t',
'int16': 'int16_t',
'uint16': 'uint16_t',
'float16': 'mirage_float16', # 注意C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
}
# 对齐要求映射表
alignment_mapping = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
}
if field_type.kind == 'scalar':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
size = type_sizes.get(scalar_type, 4)
return c_type, 1, size
elif field_type.kind == 'vector':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
element_count = field_type.element_count or 4
# 验证向量元素数量的合法性
if element_count not in [1, 2, 3, 4]:
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
size = type_sizes.get(scalar_type, 4) * element_count
# 对于vec3通常需要按vec4对齐
if element_count == 3:
alignment = alignment_mapping.get(scalar_type, 4) * 4
size = ((size + alignment - 1) // alignment) * alignment
return c_type, element_count, size
elif field_type.kind == 'matrix':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
row_count = field_type.row_count or 4
column_count = field_type.column_count or 4
# 验证矩阵维度的合法性
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
f"Rows and columns must be 2, 3, or 4.")
# 矩阵通常按列主序存储,每列都需要对齐
element_size = type_sizes.get(scalar_type, 4)
# 每列的大小(考虑对齐)
if row_count == 3:
# mat3xN 的每列通常按 vec4 对齐
column_size = element_size * 4
else:
column_size = element_size * row_count
total_size = column_size * column_count
total_elements = row_count * column_count
return c_type, total_elements, total_size
else:
# 未知类型,返回默认值并记录警告
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
return 'float', 1, 4
def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
"""生成C类型声明
Args:
field_type: 字段类型信息
field_name: 字段名称
Returns:
str: 完整的C类型声明
"""
c_type, count, _ = get_type_info(field_type)
if field_type.kind == 'scalar':
return f"{c_type} {field_name}"
elif field_type.kind == 'vector':
# 对于向量,可以选择使用数组或专门的向量类型
if count == 1:
return f"{c_type} {field_name}"
else:
# 选项1: 使用数组
return f"{c_type} {field_name}[{count}]"
# 选项2: 使用对齐的结构体(如果需要的话)
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
elif field_type.kind == 'matrix':
rows = field_type.row_count or 4
cols = field_type.column_count or 4
# 矩阵使用二维数组(列主序)
return f"{c_type} {field_name}[{cols}][{rows}]"
else:
return f"{c_type} {field_name}"
def get_aligned_size(size: int, alignment: int) -> int:
"""计算对齐后的大小
Args:
size: 原始大小
alignment: 对齐要求
Returns:
int: 对齐后的大小
"""
return ((size + alignment - 1) // alignment) * alignment