import { useState, useRef, useMemo, useEffect, Fragment } from "react";
import {
  Chart as ChartJS,
  LinearScale,
  PointElement,
  LineElement,
  Tooltip,
  Legend,
} from "chart.js";
import { Scatter } from "react-chartjs-2";
import zoomPlugin from "chartjs-plugin-zoom";
import { useMultiTreeSession } from "src/hooks/use-multitree-session";
import { useGetMultiTreeComparisonChartData } from "src/hooks/sessions";
import {
  defaultPopulationModelKeys,
  MultiTreeFilterType,
  SessionResultsPlotType,
} from "src/utils/types";
import { Alert, Box, CircularProgress, Stack } from "@mui/material";
import { getErrorMsg } from "src/utils/Utils";
import Dropdown from "src/components/Dropdown";

ChartJS.register(
  LinearScale,
  PointElement,
  LineElement,
  Tooltip,
  Legend,
  zoomPlugin
);

function ComparisonScatter() {
  const {
    selectedSession,
    comparisonFilterFields,
    comparisonFilters,
    sessionId,
  } = useMultiTreeSession();
  const chartRef = useRef(null);
  const [chartData, setChartData] = useState(undefined);

  const [scatterplotY, setscatterplotY] = useState({
    value: "",
    label: "select y axis field",
  });

  const [scatterplotX, setscatterplotX] = useState({
    value: "",
    label: "select x axis field",
  });

  const {
    mutateAsync: getChartData,
    isLoading,
    error,
  } = useGetMultiTreeComparisonChartData({
    sessionId,
  });

  useEffect(async () => {
    try {
      if (!scatterplotX.value || !scatterplotY.value) {
        return;
      }

      let allFields = [...defaultPopulationModelKeys];
      allFields.push(scatterplotX.value);
      allFields.push(scatterplotY.value);

      const { datasets } = await getChartData({
        filterInfo: comparisonFilters,
        fieldsInfo: allFields,
        prepPlotData: SessionResultsPlotType.SCATTER,
      });

      const data = { datasets: datasets || [] };
      setChartData(data);
    } catch (error) {
      console.error("Error fetching chart data:", error);
    }
  }, [
    scatterplotX.value,
    scatterplotY.value,
    JSON.stringify(comparisonFilters),
  ]);

  const labels = useMemo(() => {
    if ((comparisonFilterFields || []).length === 0) {
      return [];
    }

    const filteredFields = comparisonFilterFields.filter(
      (item) =>
        !defaultPopulationModelKeys.includes(item.key) &&
        item.metricType != MultiTreeFilterType.TREE
    );

    if (!scatterplotX.value && filteredFields.length > 0) {
      const firstFilterObj = filteredFields[0];
      setscatterplotX({
        value: firstFilterObj.isOtherField
          ? `other.${firstFilterObj.key}`
          : firstFilterObj.key,
        label: firstFilterObj.key,
      });
    }
    if (!scatterplotY.value && filteredFields.length > 1) {
      const secondFilterObj = filteredFields[1];
      setscatterplotY({
        value: secondFilterObj.isOtherField
          ? `other.${secondFilterObj.key}`
          : secondFilterObj.key,
        label: secondFilterObj.key,
      });
    }

    return filteredFields.map((item) => ({
      value: item.isOtherField ? `other.${item.key}` : `${item.key}`,
      label: item.key,
    }));
  }, [JSON.stringify(comparisonFilterFields)]);

  if (!selectedSession || (comparisonFilterFields || []).length === 0)
    return <div className="tile-no-solutions">No solutions found</div>;

  const options = {
    responsive: true,
    plugins: {
      legend: {
        display: false,
      },
      tooltip: {
        callbacks: {
          label: function (context) {
            return `${context.raw.identifier}: (${context.raw.x}, ${context.raw.y})`;
          },
        },
      },
      zoom: {
        zoom: {
          wheel: {
            enabled: true,
          },
          pinch: {
            enabled: true,
          },
          mode: "xy",
        },
        pan: {
          enabled: true,
          mode: "xy",
        },
      },
    },
    scales: {
      y: {
        beginAtZero: true,
        title: {
          display: true,
          text: scatterplotY.label,
        },
      },
      x: {
        beginAtZero: true,
        title: {
          display: true,
          text: scatterplotX.label,
        },
      },
    },
  };

  return (
    <div id="scatter-container">
      {error ? (
        <Box sx={{ width: "100%" }}>
          <Alert severity="error" sx={{ m: "1rem" }}>
            {getErrorMsg(error)}
          </Alert>
        </Box>
      ) : (
        <Fragment>
          <Box
            sx={{
              display: "flex",
              justifyContent: "flex-end",
              flexWrap: "wrap",
              alignItems: "center",
              gap: ".5rem",
              mt: ".25rem",
            }}
          >
            {scatterplotY && (
              <Dropdown
                value={{ value: scatterplotY.value, label: scatterplotY.label }}
                onChange={setscatterplotY}
                options={labels.map((item) => {
                  return { value: item.value, label: item.label };
                })}
                isSearchable={false}
                p
              />
            )}
            {scatterplotX && (
              <Dropdown
                value={{ value: scatterplotX.value, label: scatterplotX.label }}
                onChange={setscatterplotX}
                options={labels.map((item) => {
                  return { value: item.value, label: item.label };
                })}
                isSearchable={false}
              />
            )}
          </Box>
          {isLoading ? (
            <Box
              sx={{
                width: "100%",
                display: "flex",
                justifyContent: "center",
                mt: "3rem",
              }}
            >
              <CircularProgress />
            </Box>
          ) : chartData ? (
            <Fragment>
              {chartData.datasets.length == 0 ? (
                <Alert severity="warning" sx={{ margin: "1rem" }}>
                  No data found for selected fields.
                </Alert>
              ) : (
                <Stack gap={"1.5rem"} mt={"-2rem"}>
                  <Scatter options={options} data={chartData} ref={chartRef} />
                </Stack>
              )}
            </Fragment>
          ) : (
            <Box sx={{ padding: ".5rem" }}>
              <Alert severity="info">
                Select axis fields to fill chart data.
              </Alert>
            </Box>
          )}
        </Fragment>
      )}
    </div>
  );
}

export default ComparisonScatter;
