/**
 * This should match the AttributeProto AttributeType enum in:
 * https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
 **/
export type ServerONNXGraphNodeAttributeType =
  | 'UNDEFINED'
  | 'FLOAT'
  | 'INT'
  | 'STRING'
  | 'TENSOR'
  | 'GRAPH'
  | 'SPARSE_TENSOR'
  | 'TYPE_PROTO'
  | 'FLOATS'
  | 'INTS'
  | 'STRINGS'
  | 'TENSORS'
  | 'GRAPHS'
  | 'SPARSE_TENSORS'
  | 'TYPE_PROTOS';

export type ServerONNXTensorDataType =
  | 'UNDEFINED'
  | 'FLOAT'
  | 'UINT8'
  | 'INT8'
  | 'UINT16'
  | 'INT16'
  | 'INT32'
  | 'INT64'
  | 'STRING'
  | 'BOOL'
  | 'FLOAT16'
  | 'DOUBLE'
  | 'UINT32'
  | 'UINT64'
  | 'COMPLEX64'
  | 'COMPLEX128'
  | 'BFLOAT16'
  | 'FLOAT8E4M3FN'
  | 'FLOAT8E4M3FNUZ'
  | 'FLOAT8E5M2'
  | 'FLOAT8E5M2FNUZ';

/**
 * This should match the TensorProto DataType enum in:
 * https://github.com/onnx/onnx/blob/main/onnx/onnx.proto3
 **/
export const ServerONNXTensorDataTypeEnum: Record<
  number,
  ServerONNXTensorDataType
> = {
  0: 'UNDEFINED',
  // Basic types.
  1: 'FLOAT', // float
  2: 'UINT8', // uint8_t
  3: 'INT8', // int8_t
  4: 'UINT16', // uint16_t
  5: 'INT16', // int16_t
  6: 'INT32', // int32_t
  7: 'INT64', // int64_t
  8: 'STRING', // string
  9: 'BOOL', // bool

  // IEEE754 half-precision floating-point format (16 bits wide).
  // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits.
  10: 'FLOAT16',

  11: 'DOUBLE',
  12: 'UINT32',
  13: 'UINT64',
  14: 'COMPLEX64', // complex with float32 real and imaginary components
  15: 'COMPLEX128', // complex with float64 real and imaginary components

  // Non-IEEE floating-point format based on IEEE754 single-precision
  // floating-point number truncated to 16 bits.
  // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits.
  16: 'BFLOAT16',

  // Non-IEEE floating-point format based on papers
  // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433,
  // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf.
  // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear.
  // The computation usually happens inside a block quantize / dequantize
  // fused by the runtime.
  17: 'FLOAT8E4M3FN', // float 8, mostly used for coefficients, supports nan, not inf
  18: 'FLOAT8E4M3FNUZ', // float 8, mostly used for coefficients, supports nan, not inf, no negative zero
  19: 'FLOAT8E5M2', // follows IEEE 754, supports nan, inf, mostly used for gradients
  20: 'FLOAT8E5M2FNUZ', // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero
};

export const ServerONNXGraphNodeOps = [
  'Abs',
  'Acos',
  'Acosh',
  'Add',
  'And',
  'ArgMax',
  'ArgMin',
  'Asin',
  'Asinh',
  'Atan',
  'Atanh',
  'AveragePool',
  'BatchNormalization',
  'BitShift',
  'BitwiseAnd',
  'BitwiseNot',
  'BitwiseOr',
  'BitwiseXor',
  'Cast',
  'Ceil',
  'Col2Im',
  'Compress',
  'Concat',
  'ConcatFromSequence',
  'Constant',
  'ConstantOfShape',
  'Conv',
  'ConvInteger',
  'ConvTranspose',
  'Cos',
  'Cosh',
  'CumSum',
  'DFT',
  'DeformConv',
  'DepthToSpace',
  'DequantizeLinear',
  'Det',
  'Div',
  'Dropout',
  'Einsum',
  'Equal',
  'Erf',
  'Exp',
  'Expand',
  'EyeLike',
  'Flatten',
  'Floor',
  'GRU',
  'Gather',
  'GatherElements',
  'GatherND',
  'Gemm',
  'GlobalAveragePool',
  'GlobalLpPool',
  'GlobalMaxPool',
  'Greater',
  'GridSample',
  'Hardmax',
  'Identity',
  'If',
  'InstanceNormalization',
  'IsInf',
  'IsNaN',
  'LRN',
  'LSTM',
  'Less',
  'Log',
  'Loop',
  'LpNormalization',
  'LpPool',
  'MatMul',
  'MatMulInteger',
  'Max',
  'MaxPool',
  'MaxRoiPool',
  'MaxUnpool',
  'Mean',
  'MelWeightMatrix',
  'Min',
  'Mod',
  'Mul',
  'Multinomial',
  'Neg',
  'NonMaxSuppression',
  'NonZero',
  'Not',
  'OneHot',
  'Optional',
  'OptionalGetElement',
  'OptionalHasElement',
  'Or',
  'Pad',
  'Pow',
  'QLinearConv',
  'QLinearMatMul',
  'QuantizeLinear',
  'RNN',
  'RandomNormal',
  'RandomNormalLike',
  'RandomUniform',
  'RandomUniformLike',
  'Reciprocal',
  'ReduceMax',
  'ReduceMean',
  'ReduceMin',
  'ReduceProd',
  'ReduceSum',
  'Reshape',
  'Resize',
  'ReverseSequence',
  'RoiAlign',
  'Round',
  'STFT',
  'Scan',
  'Scatter',
  'ScatterElements',
  'ScatterND',
  'SequenceAt',
  'SequenceConstruct',
  'SequenceEmpty',
  'SequenceErase',
  'SequenceInsert',
  'SequenceLength',
  'Shape',
  'Sigmoid',
  'Sign',
  'Sin',
  'Sinh',
  'Size',
  'Slice',
  'SpaceToDepth',
  'Split',
  'SplitToSequence',
  'Sqrt',
  'Squeeze',
  'StringNormalizer',
  'Sub',
  'Sum',
  'Tan',
  'Tanh',
  'TfIdfVectorizer',
  'Tile',
  'TopK',
  'Transpose',
  'Trilu',
  'Unique',
  'Unsqueeze',
  'Upsample',
  'Where',
  'Xor',
  'Bernoulli',
  'BlackmanWindow',
  'CastLike',
  'Celu',
  'CenterCropPad',
  'Clip',
  'DynamicQuantizeLinear',
  'Elu',
  'GreaterOrEqual',
  'GroupNormalization',
  'HammingWindow',
  'HannWindow',
  'HardSigmoid',
  'HardSwish',
  'LayerNormalization',
  'LeakyRelu',
  'LessOrEqual',
  'LogSoftmax',
  'MeanVarianceNormalization',
  'Mish',
  'NegativeLogLikelihoodLoss',
  'PRelu',
  'Range',
  'ReduceL1',
  'ReduceL2',
  'ReduceLogSum',
  'ReduceLogSumExp',
  'ReduceSumSquare',
  'Relu',
  'Selu',
  'SequenceMap',
  'Shrink',
  'Softmax',
  'SoftmaxCrossEntropyLoss',
  'Softplus',
  'Softsign',
  'ThresholdedRelu',
] as const;

export type ServerONNXGraphNodeOpType = (typeof ServerONNXGraphNodeOps)[number];
