from typing import Tuple from shader_reflection_type import Field, TypeKind, ScalarType format_mapping = { ('int32', 1): 'SG_VERTEXFORMAT_INVALID', ('int32', 2): 'SG_VERTEXFORMAT_INVALID', ('int32', 3): 'SG_VERTEXFORMAT_INVALID', ('int32', 4): 'SG_VERTEXFORMAT_INVALID', ('uint32', 1): 'SG_VERTEXFORMAT_INVALID', ('uint32', 2): 'SG_VERTEXFORMAT_INVALID', ('uint32', 3): 'SG_VERTEXFORMAT_INVALID', ('uint32', 4): 'SG_VERTEXFORMAT_INVALID', ('float32', 1): 'SG_VERTEXFORMAT_FLOAT', ('float32', 2): 'SG_VERTEXFORMAT_FLOAT2', ('float32', 3): 'SG_VERTEXFORMAT_FLOAT3', ('float32', 4): 'SG_VERTEXFORMAT_FLOAT4', ('int8', 2): 'SG_VERTEXFORMAT_INVALID', ('int8', 4): 'SG_VERTEXFORMAT_BYTE4', ('uint8', 2): 'SG_VERTEXFORMAT_INVALID', ('uint8', 4): 'SG_VERTEXFORMAT_UBYTE4', ('int16', 2): 'SG_VERTEXFORMAT_SHORT2', ('int16', 4): 'SG_VERTEXFORMAT_SHORT4', ('uint16', 2): 'SG_VERTEXFORMAT_USHORT2N', ('uint16', 4): 'SG_VERTEXFORMAT_SHORT4', ('float16', 2): 'SG_VERTEXFORMAT_HALF2', ('float16', 4): 'SG_VERTEXFORMAT_HALF4', } # 类型大小映射 type_sizes = { ScalarType.INT32: 4, ScalarType.UINT32: 4, ScalarType.FLOAT32: 4, ScalarType.INT8: 1, ScalarType.UINT8: 1, ScalarType.INT16: 2, ScalarType.UINT16: 2, ScalarType.FLOAT16: 2, # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现 } # C类型映射表 c_type_mapping = { ScalarType.INT32: 'int32_t', ScalarType.UINT32: 'uint32_t', ScalarType.FLOAT32: 'float', ScalarType.INT8: 'int8_t', ScalarType.UINT8: 'uint8_t', ScalarType.INT16: 'int16_t', ScalarType.UINT16: 'uint16_t', ScalarType.FLOAT16: 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现 } # 对齐要求映射表 alignment_mapping = { ScalarType.INT32: 4, ScalarType.UINT32: 4, ScalarType.FLOAT32: 4, ScalarType.INT8: 1, ScalarType.UINT8: 1, ScalarType.INT16: 2, ScalarType.UINT16: 2, ScalarType.FLOAT16: 2, } def get_type_info(field_type: Field) -> Tuple[str, int, int]: """获取类型信息:C类型名称、元素数量、字节大小 Returns: Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小) """ kind = field_type.type.kind scalar_type = field_type.type.scalarType element_count = field_type.type.elementCount c_type = c_type_mapping.get(scalar_type, 'float') if kind == TypeKind.SCALAR: size = type_sizes.get(scalar_type, 4) return c_type, 1, size elif kind == TypeKind.VECTOR: # 验证向量元素数量的合法性 if element_count not in [1, 2, 3, 4]: raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.") size = type_sizes.get(scalar_type, 4) * element_count # 对于vec3,通常需要按vec4对齐 if element_count == 3: alignment = alignment_mapping.get(scalar_type, 4) * 4 size = ((size + alignment - 1) // alignment) * alignment return c_type, element_count, size elif kind == TypeKind.MATRIX: row_count = field_type.type.rowCount or 4 column_count = field_type.type.columnCount or 4 # 验证矩阵维度的合法性 if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]: raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. " f"Rows and columns must be 2, 3, or 4.") # 矩阵通常按列主序存储,每列都需要对齐 element_size = type_sizes.get(scalar_type, 4) # 每列的大小(考虑对齐) if row_count == 3: # mat3xN 的每列通常按 vec4 对齐 column_size = element_size * 4 else: column_size = element_size * row_count total_size = column_size * column_count total_elements = row_count * column_count return c_type, total_elements, total_size else: # 未知类型,返回默认值并记录警告 print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.") return 'float', 1, 4 def get_c_type_declaration(field_type: Field, field_name: str) -> str: """生成C类型声明 Args: field_type: 字段类型信息 field_name: 字段名称 Returns: str: 完整的C类型声明 """ c_type, count, _ = get_type_info(field_type) kind = field_type.type.kind if kind == TypeKind.SCALAR: return f"{c_type} {field_name}" elif kind == TypeKind.VECTOR: # 对于向量,可以选择使用数组或专门的向量类型 if count == 1: return f"{c_type} {field_name}" else: # 选项1: 使用数组 return f"{c_type} {field_name}[{count}]" # 选项2: 使用对齐的结构体(如果需要的话) # return f"struct {{ {c_type} data[{count}]; }} {field_name}" elif kind == TypeKind.MATRIX: rows = field_type.type.rowCount or 4 cols = field_type.type.columnCount or 4 # 矩阵使用二维数组(列主序) return f"{c_type} {field_name}[{cols}][{rows}]" else: return f"{c_type} {field_name}" def get_aligned_size(size: int, alignment: int) -> int: """计算对齐后的大小 Args: size: 原始大小 alignment: 对齐要求 Returns: int: 对齐后的大小 """ return ((size + alignment - 1) // alignment) * alignment