import type { ServerONNXTensor } from '../../types/Api';
import type { ServerONNXTensorDataType } from '../../types/Models/ONNX';
import { ServerONNXTensorDataTypeEnum } from '../../types/Models/ONNX';

/**
 * This function takes in an attribute and based on the TensorProto
 * DataType enum, formats and returns the right text representation
 * of the numerals.
 * @param tensor
 * @returns
 */
export function formatTensorDataTypeVal(
  tensor: ServerONNXTensor,
): string | undefined {
  const {
    data_type,
    float_data,
    int32_data,
    string_data,
    int64_data,
    double_data,
    uint64_data,
  } = tensor;

  const DataTypeToValueFields: Record<ServerONNXTensorDataType, any> = {
    UNDEFINED: undefined,
    FLOAT: float_data,
    UINT8: int32_data,
    INT8: int32_data,
    UINT16: int32_data,
    INT16: int32_data,
    INT32: int32_data,
    INT64: int64_data,
    STRING: string_data,
    BOOL: int32_data,
    FLOAT16: int32_data,
    DOUBLE: double_data,
    UINT32: uint64_data,
    UINT64: uint64_data,
    COMPLEX64: float_data,
    // TODO(CELL-100): COMPLEX128 has a different encoding method when
    // stored in double_data. Refer to onnx.proto3 for more details.
    COMPLEX128: double_data,
    BFLOAT16: int32_data,
    FLOAT8E4M3FN: int32_data,
    FLOAT8E4M3FNUZ: int32_data,
    FLOAT8E5M2: int32_data,
    FLOAT8E5M2FNUZ: int32_data,
  };

  const key: ServerONNXTensorDataType = data_type
    ? ServerONNXTensorDataTypeEnum[data_type] ?? 'UNDEFINED'
    : 'UNDEFINED';
  return DataTypeToValueFields[key];
}
