import { Fragment, useMemo, useState } from "react";
import {
  Chart as ChartJS,
  CategoryScale,
  LinearScale,
  PointElement,
  LineElement,
  Title,
  Tooltip,
  Legend,
} from "chart.js";
import { Line } from "react-chartjs-2";
import { useMultiTreeSession } from "src/hooks/use-multitree-session";
import {
  Alert,
  AlertTitle,
  Box,
  Checkbox,
  Stack,
  Typography,
} from "@mui/material";

ChartJS.register(
  CategoryScale,
  LinearScale,
  PointElement,
  LineElement,
  Title,
  Tooltip,
  Legend
);

function MetricsEvolution() {
  const { trainGenerationsData: solutions } = useMultiTreeSession();

  // first array represents generation
  // second array represents models
  // models have multiple trees
  const treeCountForEachModel =
    solutions.length > 0
      ? solutions[0].length > 0
        ? solutions[0][0].models.length
        : 0
      : 0;
  const [treeFitnessMetrics, setTreeFitnessMetrics] = useState(
    new Array(treeCountForEachModel)
      .fill("")
      .map((_, index) => ({ index, checked: false }))
  );
  const [treeSizeMetrics, setTreeSizeMetrics] = useState(
    new Array(treeCountForEachModel)
      .fill("")
      .map((_, index) => ({ index, checked: false }))
  );

  const modelDatasets = useMemo(() => {
    if (solutions.length === 0) {
      return [];
    }

    let fields = new Set(["fitness", "size"]);
    solutions.flatMap((item) => {
      item.forEach((model) => {
        Object.keys(model.other || {}).forEach((key) => {
          if (model.other[key] && typeof model.other[key] == "number") {
            fields.add(key);
          }
        });
      });
    });

    return Array.from(fields);
  }, [solutions]);

  const data = {
    labels: solutions.flatMap((solution) =>
      solution.map((item, index) => index)
    ),
    datasets: [
      ...modelDatasets.map((label, i) => {
        return {
          label: `${label}`,
          data: solutions.flatMap((solution) =>
            solution.map((item) =>
              item.hasOwnProperty(label)
                ? item[label]
                : item.other.hasOwnProperty(label)
                  ? item.other[label]
                  : 0
            )
          ),
          borderColor: colors[i % colors.length],
          backgroundColor: colors[i % colors.length],
        };
      }),
      ...treeFitnessMetrics
        .filter((item) => item.checked)
        .map((val, i) => {
          return {
            label: `tree-${val.index + 1}-fitness`,
            data: solutions.flatMap((solution) => {
              return solution.map((item) => {
                return item.objectives[val.index];
              });
            }),
            borderColor: colors[(modelDatasets.length + i) % colors.length],
            backgroundColor: colors[(modelDatasets.length + i) % colors.length],
          };
        }),
      ...treeSizeMetrics
        .filter((item) => item.checked)
        .map((val, i) => {
          return {
            label: `tree-${val.index + 1}-size`,
            data: solutions.flatMap((solution) => {
              return solution.map((item) => {
                return item.objectives[val.index];
              });
            }),
            borderColor: colors[i % colors.length],
            backgroundColor: colors[i % colors.length],
          };
        }),
    ],
  };

  const toggleMetric = (e, index, type) => {
    const isFitness = type === "fitness";
    const status = e.target.checked;
    let oldMetrics = isFitness ? [...treeFitnessMetrics] : [...treeSizeMetrics];
    oldMetrics = oldMetrics.map((item) => {
      return item.index === index ? { ...item, checked: status } : item;
    });
    if (isFitness) {
      setTreeFitnessMetrics(oldMetrics);
    } else {
      setTreeSizeMetrics(oldMetrics);
    }
  };

  return (
    <div id="metrics-line-container">
      <Line options={options} data={data} />
      {treeCountForEachModel > 0 && (
        <Fragment>
          <Alert sx={{ margin: "1rem" }} severity="info">
            Model metrics are default. If you want to see tree specific changes,
            check the tree below.
          </Alert>
          <Box
            sx={{
              display: "flex",
              flexDirection: "column",
              gap: ".625rem",
              paddingX: ".5rem",
              paddingBottom: "3rem",
            }}
          >
            {new Array(treeCountForEachModel).fill("").map((_, index) => {
              return (
                <Stack
                  key={index}
                  direction={"row"}
                  alignItems={"center"}
                  justifyContent={"center"}
                  gap={"1rem"}
                  sx={{
                    borderBottom:
                      index + 1 !== treeCountForEachModel
                        ? (t) => `1px solid ${t.palette.divider}`
                        : "none",
                  }}
                >
                  <Typography sx={{ marginRight: "1rem" }}>
                    <b>tree-{index + 1}:</b>
                  </Typography>
                  <Stack direction={"row"} alignItems={"center"}>
                    <Checkbox
                      checked={treeFitnessMetrics[index].checked}
                      onChange={(e) => toggleMetric(e, index, "fitness")}
                    />
                    <Typography>fitness</Typography>
                  </Stack>
                  <Stack direction={"row"} alignItems={"center"}>
                    <Checkbox
                      checked={treeSizeMetrics[index].checked}
                      onChange={(e) => toggleMetric(e, index, "size")}
                    />
                    <Typography>size</Typography>
                  </Stack>
                </Stack>
              );
            })}
          </Box>
        </Fragment>
      )}
    </div>
  );
}

export default MetricsEvolution;

const options = {
  responsive: true,
  plugins: {
    legend: {
      position: "bottom",
      align: "start",
    },
    zoom: {
      zoom: {
        wheel: {
          enabled: true,
        },
        pinch: {
          enabled: true,
        },
        mode: "xy",
      },
      pan: {
        enabled: true,
        mode: "xy",
      },
    },
  },
  scales: {
    y: {
      title: {
        display: true,
        text: "value",
      },
    },
    x: {
      title: {
        display: true,
        text: "generation",
      },
    },
  },
};

const colors = [
  "#0101ff",
  "#2ee09a",
  "#ff684d",
  "#E8E288",
  "#7DCE82",
  "#3CDBD3",
  "#3CDBD3",
  "#E3D8F1",
  "#726DA8",
  "#A0D2DB",
];
