import type { TensorRTPrecisionType } from '../../types/Integrations/TensorRT';
import type { ServerONNXGraphNodeOpType } from '../../types/Models/ONNX';
import { ServerONNXGraphNodeOps } from '../../types/Models/ONNX';

// Export a Record with ServerONNXGraphNodeOpType as key and a set of
// TensorRTPrecisionTypes as the value.
export const TensorRTSupportedNodes: Partial<
  Record<ServerONNXGraphNodeOpType, Set<TensorRTPrecisionType>>
> = {};
// For each ServerONNXGraphNodeOpType, initialize an empty set.
ServerONNXGraphNodeOps.forEach((op: ServerONNXGraphNodeOpType) => {
  TensorRTSupportedNodes[op] = new Set<TensorRTPrecisionType>();
});

// Taken from here:
// https://github.com/onnx/onnx-tensorrt/blob/main/docs/operators.md
TensorRTSupportedNodes.Abs = new Set(['FP32', 'FP16', 'INT8']);
TensorRTSupportedNodes.Acos = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Acosh = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Add = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.And = new Set(['BOOL']);
TensorRTSupportedNodes.ArgMax = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ArgMin = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Asin = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Asinh = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Atan = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Atanh = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.AveragePool = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.BatchNormalization = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.BitShift = new Set(['INT32']);
TensorRTSupportedNodes.Cast = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Ceil = new Set(['FP32', 'FP16', 'INT8']);
TensorRTSupportedNodes.Celu = new Set(['FP32', 'FP16', 'INT8']);
TensorRTSupportedNodes.Clip = new Set(['FP32', 'FP16', 'INT32', 'INT8']);
TensorRTSupportedNodes.Concat = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Constant = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.ConstantOfShape = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Conv = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.ConvTranspose = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Cos = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Cosh = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.CumSum = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.DepthToSpace = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Div = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Dropout = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Einsum = new Set(['FP32', 'FP16', 'INT32']);
// TensorRTSupportedNodes.Else = new Set(['FP32', 'FP16', 'INT32', 'INT8', 'BOOL']);
TensorRTSupportedNodes.Equal = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Erf = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Exp = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Expand = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.EyeLike = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Flatten = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Floor = new Set(['FP32', 'FP16', 'INT8']);
TensorRTSupportedNodes.Gather = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.GatherElements = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.GatherND = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Gemm = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Greater = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.HardSigmoid = new Set(['FP32', 'FP16', 'INT8']);
TensorRTSupportedNodes.Hardmax = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Identity = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.If = new Set(['FP32', 'FP16', 'INT32', 'INT8', 'BOOL']);
TensorRTSupportedNodes.InstanceNormalization = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.IsInf = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.IsNaN = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Less = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Log = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.LogSoftmax = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Loop = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.LpNormalization = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.LpPool = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.MatMul = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Max = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.MaxPool = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.MaxRoiPool = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.MaxUnpool = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Mean = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Min = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Mod = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Mul = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Multinomial = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Neg = new Set(['FP32', 'FP16', 'INT32', 'INT8']);
TensorRTSupportedNodes.NonMaxSuppression = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.NonZero = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Not = new Set(['BOOL']);
TensorRTSupportedNodes.OneHot = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Or = new Set(['BOOL']);
TensorRTSupportedNodes.Pad = new Set(['FP32', 'FP16', 'INT32', 'INT8', 'BOOL']);
TensorRTSupportedNodes.Pow = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.PRelu = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Range = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Reciprocal = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceL1 = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceL2 = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceLogSum = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceLogSumExp = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceMax = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceMean = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceMin = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceProd = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceSum = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.ReduceSumSquare = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Relu = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Reshape = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Resize = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.ReverseSequence = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.RNN = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Round = new Set(['FP32', 'FP16', 'INT32', 'INT8']);
TensorRTSupportedNodes.Scan = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Scatter = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.ScatterElements = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.ScatterND = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Selu = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Shape = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
// TensorRTSupportedNodes.Shift = new Set(['FP32', 'FP16', 'INT32', 'INT8', 'BOOL']);
TensorRTSupportedNodes.Shrink = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Sigmoid = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Sin = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Size = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Slice = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Softmax = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Softplus = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Softsign = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.SpaceToDepth = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Split = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Sqrt = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Squeeze = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
// TensorRTSupportedNodes.StridedSlice = new Set(['FP32', 'FP16', 'INT32', 'INT8', 'BOOL']);
TensorRTSupportedNodes.Sub = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Sum = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Tan = new Set(['FP32', 'FP16', 'INT32']);
TensorRTSupportedNodes.Tanh = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Tile = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.TopK = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Transpose = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
// TensorRTSupportedNodes.Unary = new Set(['FP32', 'FP16', 'INT32', 'INT8', 'BOOL']);
TensorRTSupportedNodes.Upsample = new Set(['FP32', 'FP16']);
TensorRTSupportedNodes.Where = new Set([
  'FP32',
  'FP16',
  'INT32',
  'INT8',
  'BOOL',
]);
TensorRTSupportedNodes.Xor = new Set(['BOOL']);
