import { useCallback, useEffect, useState } from 'react';
import { useQuery } from '@tanstack/react-query';

import { logError, useFetchCore, useMutation } from '../support/Fetch';
import { endpoints } from '../generated/endpoints';
import type {
  ServerONNXGraphNode,
  ServerONNXModelUpdate,
  ServerONNXModel,
  ServerUser,
  ServerONNXGraphNodeUpdate,
  ServerONNXGraphNodeCreate,
  ServerONNXGraphInitializer,
} from '../types/Api';
import type { ONNXModelCreateInput } from '../types/ONNXModel';
import { useAuth } from '../support/Auth';
import mockedModelVersions from '../generated/mocked_model_versions.json';
import { TensorRTSupportedNodes } from '../constants/Integrations/TensorRTSupportedNodes';
import type { ServerONNXGraphNodeOpType } from '../types/Models/ONNX';
import type { TensorRTPrecisionType } from '../types/Integrations/TensorRT';

import { resolve } from './helpers/resolve';
import { useAuthFetch } from './helpers/useAuthFetch';
import type { FetchState } from './helpers/types';
import { fromServerUser } from './userApi';

export const itemsPerPageLimit = 10;
/**
 * Fetches the ONNX models.
 * @returns list of ONNX models.
 */
export function useFetchONNXModels() {
  const fetch = useAuthFetch();

  const fetchONNXModels = async () => {
    const data = await fetch<Array<ServerONNXModel>>(endpoints.get_onnx_models);
    return fromServerONNXModels(data);
  };

  const fetchUser = async (id: string) => {
    const url = resolve(endpoints.get_user, { id });
    const data = await fetch<ServerUser>(url);
    return fromServerUser(data);
  };

  const fetchAll = async () => {
    const models = await fetchONNXModels();
    const pagesQuantity = Math.floor(models.length / itemsPerPageLimit) + 1;
    const userIds = new Set<string>();
    for (const model of models) {
      userIds.add(model.owner_id);
    }
    const promises = Array.from(userIds).map(async (userId) => {
      const user = await fetchUser(userId);
      return [userId, user] as const;
    });
    const users = await Promise.all(promises);
    return { models, modelUsers: new Map(users), pagesQuantity };
  };

  return useQuery(['models'], fetchAll);
}

export function useFetchONNXModel(id: string) {
  type Result = {
    model: ServerONNXModel;
    graphNodes: Map<string, ServerONNXGraphNode>;
    supportedTRTPrecisionTypes: Map<string, Set<TensorRTPrecisionType>>;
  };
  const [fetchState, setFetchState] = useState<FetchState<Result>>({
    isLoading: true,
    data: null,
    error: null,
  });
  const fetchONNXModel = useFetchCore({
    url: resolve(endpoints.get_onnx_model, { model_id: id }),
    responseTransform: (data) => fromServerONNXModel(data as ServerONNXModel),
  });
  const fetchONNXGraphNodes = useFetchCore({
    url: resolve(endpoints.get_onnx_graph_nodes, { model_id: id }),
    responseTransform: (data) =>
      fromServerONNXGraphNodes(data as Array<ServerONNXGraphNode>),
  });
  const fetchAll = async () => {
    const result = await fetchONNXModel();
    if (result.isError || !result.ok || !result.data) {
      throw result.error || new Error(`Status: ${result.status}`);
    }

    const model = result.data;
    const graphNodes = new Map<string, ServerONNXGraphNode>();
    const supportedTRTPrecisionTypes = new Map<
      string,
      Set<TensorRTPrecisionType>
    >();

    // Error so return early.
    if (!model.graph) {
      return { model, graphNodes, supportedTRTPrecisionTypes };
    }

    const graphNodesResult = await fetchONNXGraphNodes();
    if (
      graphNodesResult.isError ||
      !graphNodesResult.ok ||
      !graphNodesResult.data
    ) {
      throw (
        graphNodesResult.error ||
        new Error(`Status: ${graphNodesResult.status}`)
      );
    }

    const graphNodeList = graphNodesResult.data;
    for (const graphNode of graphNodeList) {
      graphNodes.set(graphNode.id, graphNode);

      // TODO: this is a temporary solution, we need to get this information
      // from the backend
      supportedTRTPrecisionTypes.set(
        graphNode.id,
        TensorRTSupportedNodes[
          graphNode.op_type as ServerONNXGraphNodeOpType
        ] ?? new Set<TensorRTPrecisionType>(),
      );
    }
    return {
      model,
      graphNodes,
      supportedTRTPrecisionTypes,
    };
  };

  const fetchAndUpdateState = async () => {
    try {
      const data = await fetchAll();
      setFetchState({ isLoading: false, data, error: null });
    } catch (error) {
      logError(error);
      setFetchState({ isLoading: false, data: null, error: String(error) });
    }
  };
  useEffect(
    () => {
      fetchAndUpdateState();
    },
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [],
  );
  const { data, isLoading, error } = fetchState;
  // Returning in this way to match the way we do it in useFetch()
  return [data, { isLoading, error, refetch: fetchAndUpdateState }] as const;
}

/**
 * Fetches ONNX Model versions based on a given ONNX model id.
 */
export function useFetchONNXModelVersions(id: string) {
  type Result = {
    name: string | undefined;
    description: string | undefined;
    // TODO: need to change the type here...
    modelVersions: Array<any>;
    users: Map<string, ServerUser>;
  };

  const [fetchState, setFetchState] = useState<FetchState<Result>>({
    isLoading: true,
    data: null,
    error: null,
  });
  const fetchONNXModel = useFetchCore({
    url: resolve(endpoints.get_onnx_model, { model_id: id }),
    responseTransform: (data) => fromServerONNXModel(data as ServerONNXModel),
  });
  const fetchUser = useFetchCore((id: string) => {
    return {
      url: resolve(endpoints.get_user, { id }),
      responseTransform: (data) => fromServerUser(data as ServerUser),
    };
  });

  const fetchAll = async () => {
    const parentONNXModelResult = await fetchONNXModel();
    if (
      parentONNXModelResult.isError ||
      !parentONNXModelResult.ok ||
      !parentONNXModelResult.data
    ) {
      throw (
        parentONNXModelResult.error ||
        new Error(`Status: ${parentONNXModelResult.status}`)
      );
    }
    const parentONNXModel = parentONNXModelResult.data;
    const { name, description } = parentONNXModel;
    const modelVersions: Array<any> = mockedModelVersions;

    const users = new Map<string, ServerUser>();

    for (const modelVersion of modelVersions) {
      const userResult = await fetchUser(modelVersion.created_by);
      if (userResult.isError || !userResult.ok || !userResult.data) {
        throw userResult.error || new Error(`Status: ${userResult.status}`);
      }
      const user = userResult.data;
      users.set(modelVersion.created_by, user);
    }

    return { name, description, modelVersions, users };
  };
  const fetchAndUpdateState = async () => {
    try {
      const data = await fetchAll();
      setFetchState({ isLoading: false, data, error: null });
    } catch (error) {
      logError(error);
      setFetchState({ isLoading: false, data: null, error: String(error) });
    }
  };
  useEffect(
    () => {
      fetchAndUpdateState();
    },
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [],
  );
  const { data, isLoading, error } = fetchState;
  // Returning in this way to match the way we do it in useFetch()
  return [data, { isLoading, error, refetch: fetchAndUpdateState }] as const;
}

/**
 * TODO: Fetches fused ONNX node's children.
 */
export function useFetchFusedONNXGraphNode(node: ServerONNXGraphNode) {
  type Result = {
    parentNode: ServerONNXGraphNode;
    childNodes: Array<ServerONNXGraphNode>;
  };

  const [fetchState, setFetchState] = useState<FetchState<Result>>({
    isLoading: true,
    data: null,
    error: null,
  });
  const fetchChildNodes = useFetchCore({
    url: resolve(endpoints.get_onnx_graph_nodes, { model_id: node.id }),
    responseTransform: (data) =>
      fromServerONNXGraphNodes(data as Array<ServerONNXGraphNode>),
  });

  const fetchAll = async () => {
    const childNodesResult = await fetchChildNodes();
    if (
      childNodesResult.isError ||
      !childNodesResult.ok ||
      !childNodesResult.data
    ) {
      throw (
        childNodesResult.error ||
        new Error(`Status: ${childNodesResult.status}`)
      );
    }
    const childNodes = childNodesResult.data;

    return { parentNode: node, childNodes };
  };
  const fetchAndUpdateState = async () => {
    try {
      const data = await fetchAll();
      setFetchState({ isLoading: false, data, error: null });
    } catch (error) {
      logError(error);
      setFetchState({ isLoading: false, data: null, error: String(error) });
    }
  };
  useEffect(
    () => {
      fetchAndUpdateState();
    },
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [],
  );
  const { data, isLoading, error } = fetchState;
  // Returning in this way to match the way we do it in useFetch()
  return [data, { isLoading, error, refetch: fetchAndUpdateState }] as const;
}

/**
 * Fetches ONNX Model versions based on a given ONNX model id.
 */
export function useCreateONNXModel() {
  return useMutation((input: ONNXModelCreateInput) => {
    return {
      url: endpoints.create_onnx_model,
      method: 'post',
      body: { ...input },
    };
  });
}

export function useUpdateONNXModel() {
  return useMutation(
    (input: { id: string; updates: ServerONNXModelUpdate }) => {
      const { id, updates } = input;
      return {
        url: resolve(endpoints.update_onnx_model, { model_id: id }),
        method: 'put',
        body: updates,
        responseTransform: (data) =>
          fromServerONNXModel(data as ServerONNXModel),
      };
    },
  );
}

export function useDeleteONNXModel() {
  return useMutation((model: ServerONNXModel) => {
    const { id } = model;
    return {
      url: resolve(endpoints.delete_onnx_model, { model_id: id }),
      method: 'delete',
    };
  });
}

export function useUploadONNXModel() {
  const { getAuthToken } = useAuth();

  return useCallback(async (onnx_model_file: Blob) => {
    const token = getAuthToken();
    const form = new FormData();
    form.append('onnx_model_file', onnx_model_file);
    const response = await fetch(endpoints.upload_onnx_model_file, {
      method: 'post',
      headers: { Authorization: `Bearer ${token}` },
      body: form,
    });
    return response;
    // eslint-disable-next-line react-hooks/exhaustive-deps
  }, []);
}

export function useFetchONNXGraphInitializer(
  model_id: string,
  node: ServerONNXGraphNode,
) {
  type Result = {
    inputInitializers: Map<string, ServerONNXGraphInitializer>;
    outputInitializers: Map<string, ServerONNXGraphInitializer>;
  };
  const [fetchState, setFetchState] = useState<FetchState<Result>>({
    isLoading: true,
    data: null,
    error: null,
  });
  const fetchONNXGraphInitializerByName = useFetchCore((name: string) => {
    const baseUrl = resolve(endpoints.get_onnx_graph_initializers, {
      model_id,
    });
    return {
      url: baseUrl + '?' + new URLSearchParams({ name: name }),
      method: 'get',
      responseTransform: (data) =>
        fromServerONNXGraphInitializer(data as ServerONNXGraphInitializer),
    };
  });

  const fetchAll = async () => {
    const inputInitializers = new Map<string, ServerONNXGraphInitializer>();
    const outputInitializers = new Map<string, ServerONNXGraphInitializer>();

    // This will do the fetching in parallel and give you an array of promises
    const nodeInputArr = node.inputs ?? [];
    const inputPromises = nodeInputArr.map(async (input) => {
      const result = await fetchONNXGraphInitializerByName(input);
      // Some will not be found anyway (expected), so we just initialize
      // initializers we can find.
      if (result.ok && result.data) {
        const initializer = result.data;
        inputInitializers.set(input, initializer);
      }
    });

    const nodeOutputArr = node.outputs ?? [];
    const outputPromises = nodeOutputArr.map(async (output) => {
      const result = await fetchONNXGraphInitializerByName(output);
      if (result.ok && result.data) {
        const initializer = result.data;
        outputInitializers.set(output, initializer);
      }
    });

    // Wait for all the promises in the array to resolve
    await Promise.all(inputPromises);
    await Promise.all(outputPromises);
    return {
      inputInitializers,
      outputInitializers,
    };
  };

  const fetchAndUpdateState = async () => {
    try {
      const data = await fetchAll();
      setFetchState({ isLoading: false, data, error: null });
    } catch (error) {
      logError(error);
      setFetchState({ isLoading: false, data: null, error: String(error) });
    }
  };
  useEffect(
    () => {
      fetchAndUpdateState();
    },
    // eslint-disable-next-line react-hooks/exhaustive-deps
    [],
  );
  const { data, isLoading, error } = fetchState;
  // Returning in this way to match the way we do it in useFetch()
  return [data, { isLoading, error, refetch: fetchAndUpdateState }] as const;
}

/**
 * Endpoint that verifies if a new model upload can be made.
 * @returns response
 */
export function useVerifyTotalModelUploadsWithinBillingQuota() {
  return useMutation(() => {
    return {
      url: endpoints.verify_total_model_uploads_within_billing_quota,
      method: 'get',
    };
  });
}

// ----------------------------- ONNX Graph Nodes ------------------------------
export function useCreateONNXGraphNode() {
  return useMutation((input: ServerONNXGraphNodeCreate) => {
    return {
      url: endpoints.create_onnx_graph_node,
      method: 'post',
      body: { ...input },
    };
  });
}

export function useUpdateONNXGraphNode() {
  return useMutation(
    (input: {
      model_id: string;
      id: string;
      updates: ServerONNXGraphNodeUpdate;
    }) => {
      const { model_id, id, updates } = input;
      return {
        url: resolve(endpoints.update_onnx_graph_node, { model_id, id }),
        method: 'put',
        body: updates,
        responseTransform: (data) =>
          fromServerONNXGraphNode(data as ServerONNXGraphNode),
      };
    },
  );
}

export function useDeleteONNXGraphNode() {
  return useMutation(
    (input: { model_id: string; graph_node: ServerONNXGraphNode }) => {
      const { model_id, graph_node } = input;
      return {
        url: resolve(endpoints.delete_onnx_graph_node, {
          model_id,
          id: graph_node.id,
        }),
        method: 'delete',
      };
    },
  );
}

function fromServerONNXModels(onnxModels: Array<ServerONNXModel>) {
  return onnxModels;
}

function fromServerONNXModel(onnxModel: ServerONNXModel): ServerONNXModel {
  return onnxModel;
}

function fromServerONNXGraphNode(onnxGraphNode: ServerONNXGraphNode) {
  return onnxGraphNode;
}

function fromServerONNXGraphNodes(onnxGraphNodes: Array<ServerONNXGraphNode>) {
  return onnxGraphNodes;
}

function fromServerONNXGraphInitializer(
  onnxGraphInitializer: ServerONNXGraphInitializer,
): ServerONNXGraphInitializer {
  return onnxGraphInitializer;
}
