import { useState, useEffect, createRef } from "react";
import {
  Dialog,
  DialogActions,
  DialogContent,
  DialogTitle,
  Grid,
  Typography,
  Alert,
  AlertTitle,
  FormControlLabel,
  Checkbox,
  Box,
  Chip,
  Select,
  MenuItem,
  Skeleton,
  Button,
  CircularProgress,
  useTheme,
  Stack,
} from "@mui/material";
import { useMutation, useQuery, useQueryClient } from "react-query";
import WhatifAnalyseDialog from "../whatif/WhatifAnalyseDialog";
import ApiClient from "src/axios";
import {
  DialogMode,
  GET_COUNTERFATUAL_LOGS_QUERY_KEY,
  GET_EXPLANATIONS_QUERY_KEY,
  GET_EXPLANATION_QUERY_KEY,
} from "src/utils/types";
import {
  resultsCountsValues,
  DEFAULT_RESULT_COUNTS,
  EXPLANATION_STATUS_CHECK_INTERVAL,
  EXPLANATION_STATUS_RETRY_COUNT,
  ExplanationState,
} from "./utils";
import ChangeExpressionDialog from "../common/expression-table/ChangeExpressionDialog";
import SelectedExpressionTable from "../common/expression-table/SelectedExpressionTable";
import SuggestionsTable from "./SuggestionsTable";
import CounterfactualLogs from "./CounterfactualLogs";
import SelectOutputRange from "./SelectOutputRange";
import SelectedInstanceTable from "./SelectedInstanceTable";

const CounterFactualDialog = ({
  open,
  selectedExpression: initialExpression,
  mode,
  desiredOutcomeRange,
  explanationID: initialExplanationID,
  selectedInstanceIndex,
  existingFeatureVars,
  instance,
  sessionID,
  featureSet,
  allPopulations,
  targetLabel,
  onClose,
}) => {
  const theme = useTheme();
  const queryClient = useQueryClient();
  const logsRef = createRef();
  const suggestionsRef = createRef();

  const [selectedPopulation, setSelectedPopulation] = useState(null);
  const [logs, setLogs] = useState(null);
  const [showLogs, setShowLogs] = useState(false);
  const [loading, setLoading] = useState({
    analyse: false,
    checkExplanationStatus: false,
    checkLogs: false,
    checkResults: false,
  });
  const [error, setError] = useState({
    analyse: null,
    explanation: null,
    logs: null,
  });
  const [result, setResult] = useState(null);
  const [allFeatures, setAllFeatures] = useState(null);
  const [resultsCount, setResultsCount] = useState(DEFAULT_RESULT_COUNTS);
  const [outputRange, setOutputRange] = useState({
    start: null,
    finish: null,
  });
  const [explanationID, setExplanationID] = useState(null);
  const [explanationStatus, setExplanationStatus] = useState(
    ExplanationState.UNKNOWN
  );
  const [selectedSuggestionIndex, setSelectedSugesstionIndex] = useState(-1);
  const [whatifDialogOpen, setWhatifDialogOpen] = useState(false);
  const [changeExpressionOpen, setChangeExpressionOpen] = useState(false);

  useEffect(() => {
    let initialPopulation;
    (allPopulations || []).forEach((item, index) => {
      if (item.model == initialExpression) {
        initialPopulation = {
          ...item,
          index: index,
        };
        return;
      }
    });
    setSelectedPopulation(initialPopulation);

    if (mode === DialogMode.UPDATE) {
      setOutputRange({
        start: desiredOutcomeRange[0],
        finish: desiredOutcomeRange[1],
      });
      setExplanationID(initialExplanationID);
    }
  }, []);

  useEffect(() => {
    if (showLogs) {
      logsRef?.current?.scrollIntoView({ behavior: "smooth", block: "start" });
    }
  }, [showLogs]);

  useEffect(() => {
    if (result) {
      suggestionsRef?.current?.scrollIntoView({
        behavior: "smooth",
        block: "start",
      });
    }
  }, [result]);

  const { mutate: analyseMutation } = useMutation({
    mutationFn: () =>
      ApiClient.post(
        `/api/counterfactuals?explanationId=${explanationID ?? ""}`,
        {
          queryInstanceIdx: selectedInstanceIndex,
          expr: selectedPopulation?.model,
          featuresToVary: existingFeatureVars,
          originalInstance: instance,
          desiredRange: [Number(outputRange.start), Number(outputRange.finish)],
          cfcount: Number(resultsCount),
          sessionId: sessionID,
          allFeatures: featureSet,
        }
      ),
    onSuccess: (res) => {
      setExplanationStatus(res.data.state);
      setExplanationID(res.data.id);
      queryClient.invalidateQueries([GET_EXPLANATIONS_QUERY_KEY, sessionID]);
    },
    onError: (err) => {
      setError({
        ...error,
        analyse: err?.response?.data?.errors || [
          { message: "Error occured. Please try later!" },
        ],
      });
    },
    onSettled: () => {
      setLoading({ ...loading, analyse: false, checkExplanationStatus: true });
    },
  });

  useQuery({
    queryKey: [GET_EXPLANATION_QUERY_KEY],
    queryFn: () =>
      ApiClient.get(`/api/explanation/${explanationID}?type=counterfactuals`),
    enabled:
      Boolean(explanationID) &&
      (mode === DialogMode.CREATE
        ? explanationStatus !== ExplanationState.COMPLETED
        : true) &&
      explanationStatus !== ExplanationState.FAILED,
    refetchInterval: EXPLANATION_STATUS_CHECK_INTERVAL,
    retry: EXPLANATION_STATUS_RETRY_COUNT,
    onSuccess: (res) => {
      const currentState = res.data.state;
      setExplanationStatus(currentState);
      if (
        currentState === ExplanationState.COMPLETED ||
        currentState === ExplanationState.FAILED
      ) {
        setResult(res.data.explanations || []);
        setAllFeatures(res.data.configuration.allFeatures || []);
        setLoading({
          ...loading,
          checkExplanationStatus: false,
        });
      }
    },
    onError: (err) => {
      setError({
        ...error,
        explanation: err?.response?.data?.errors || [
          { message: "Error occured. Please try later!" },
        ],
      });
      setLoading({
        ...loading,
        checkExplanationStatus: false,
      });
    },
    onSettled: () => {
      setLoading({
        ...loading,
        checkLogs: true,
        checkExplanationStatus: false,
      });
    },
  });

  useQuery({
    queryKey: [GET_COUNTERFATUAL_LOGS_QUERY_KEY],
    queryFn: () =>
      ApiClient.get(
        `/api/explanations/${explanationID}/logs/data?type=counterfactuals`
      ),
    enabled: loading.checkLogs,
    refetchInterval: EXPLANATION_STATUS_CHECK_INTERVAL,
    onSuccess: (res) => {
      setLogs(res.data);
      setLoading({ ...loading, checkLogs: false });
    },
    onError: (err) => {
      setError({
        ...error,
        logs: err?.response?.data?.errors || [
          { message: "Error occured. Please try later!" },
        ],
      });
    },
  });

  const handleAnalyse = () => {
    setLoading({ ...loading, analyse: true });
    setResult(null);
    analyseMutation();
  };

  const handleChangeShowLogs = (event) => {
    const isChecked = event.target.checked;
    setShowLogs(isChecked);
  };

  const explanationStatusLoading =
    explanationStatus === ExplanationState.CREATED ||
    explanationStatus === ExplanationState.PENDING ||
    explanationStatus === ExplanationState.RUNNING;

  return (
    <Dialog open={open} onClose={onClose} maxWidth="lg" fullWidth>
      <DialogTitle sx={{ m: 0, p: 2 }} variant="h6" fontWeight={600}>
        {`Counterfactual Analyse`}
      </DialogTitle>
      <DialogContent dividers>
        <Grid container spacing={3}>
          <Grid item xs={12} mt={2}>
            <SelectedExpressionTable
              selectedPopulation={selectedPopulation}
              onChange={() => setChangeExpressionOpen(true)}
            />
          </Grid>
          <Grid item xs={12} sm={6}>
            <Typography
              sx={{
                fontSize: "14px",
                fontWeight: "600",
                mb: ".3rem",
              }}
            >
              Results Count
            </Typography>
            <Select
              id="resultcount-select"
              fullWidth
              value={resultsCount}
              onChange={(e) => setResultsCount(e.target.value)}
            >
              {resultsCountsValues.map((item) => (
                <MenuItem value={item}>{item}</MenuItem>
              ))}
            </Select>
          </Grid>
          <Grid item xs={12} sm={6}>
            <Typography
              sx={{
                fontSize: "14px",
                fontWeight: "600",
                mb: ".3rem",
              }}
            >
              Algorithm
            </Typography>
            <Select id="algorithm-select" fullWidth defaultValue={"dice"}>
              <MenuItem value={"dice"}>Dice</MenuItem>
            </Select>
          </Grid>
          <Grid item xs={12}>
            <SelectOutputRange
              outputRange={outputRange}
              onSetOutputRange={setOutputRange}
            />
          </Grid>
          <Grid item xs={12}>
            <SelectedInstanceTable
              targetLabel={targetLabel}
              featureSet={featureSet}
              instance={instance}
            />
          </Grid>
          {result && (
            <Grid item xs={12}>
              <Box
                sx={{
                  mb: 1,
                  mt: 2,
                  display: "flex",
                  alignItems: "center",
                  gap: "8px",
                }}
              >
                <div ref={suggestionsRef}>
                  <Typography
                    sx={{
                      fontSize: "14px",
                      fontWeight: "600",
                      fontFamily: "Titillium Web",
                    }}
                  >
                    Suggestions
                  </Typography>
                </div>
                {selectedSuggestionIndex > -1 &&
                  result &&
                  result.length > 0 && (
                    <Chip
                      label="Use instance at whatif analyse"
                      size="small"
                      sx={{ background: "#dbeafe", color: "#1e3a8a" }}
                      onClick={() => setWhatifDialogOpen(true)}
                    />
                  )}
              </Box>
              {loading.checkExplanationStatus ? (
                <>
                  {new Array(5).fill().map(() => (
                    <Skeleton
                      variant="rectangular"
                      width="100%"
                      height={30}
                      style={{ marginBottom: ".25rem" }}
                    />
                  ))}
                </>
              ) : result.length === 0 ? (
                <Alert severity="info">No suggestions found</Alert>
              ) : (
                <SuggestionsTable
                  allFeatures={allFeatures}
                  result={result}
                  instance={instance}
                  selectedSuggestionIndex={selectedSuggestionIndex}
                  onSelectedSugesstionIndex={setSelectedSugesstionIndex}
                />
              )}
            </Grid>
          )}
          {(explanationStatus === ExplanationState.COMPLETED ||
            explanationStatus === ExplanationState.FAILED) &&
            !result &&
            mode !== DialogMode.UPDATE && (
              <Grid item xs={12} mt={2}>
                <Alert severity="info">Counterfactual results not found.</Alert>
              </Grid>
            )}
          {error.analyse && (
            <Grid item xs={12}>
              {error.analyse.map((item) => (
                <Stack mt={2} width={"100%"}>
                  <Alert severity="error">
                    <AlertTitle>Counterfactual analyse error.</AlertTitle>
                    {item.message}
                  </Alert>
                </Stack>
              ))}
            </Grid>
          )}
          {showLogs && (
            <Grid item xs={12}>
              <CounterfactualLogs
                ref={logsRef}
                logs={logs}
                errorLogs={error.logs}
              />
            </Grid>
          )}
        </Grid>
      </DialogContent>
      <DialogActions>
        <Box
          sx={{
            width: "100%",
            display: "flex",
            justifyContent: "space-between",
            alignItems: "center",
            paddingLeft: "20px",
            paddingRight: "20px",
          }}
        >
          <FormControlLabel
            control={
              <Checkbox
                checked={showLogs}
                onChange={handleChangeShowLogs}
                name="checkedA"
              />
            }
            label="Show logs"
          />
          <Box
            sx={{ display: "flex", alignItems: "center", marginLeft: "auto" }}
          >
            <Button
              color={"error"}
              onClick={onClose}
              sx={{ marginRight: ".25rem" }}
            >
              Close
            </Button>
            <Button
              disabled={
                !resultsCount ||
                !selectedPopulation ||
                !outputRange.start ||
                !outputRange.finish
              }
              variant="contained"
              startIcon={
                explanationStatusLoading ? (
                  <CircularProgress
                    size={16}
                    sx={{ color: theme.palette.primary.contrastText }}
                  />
                ) : undefined
              }
              onClick={
                explanationStatusLoading ? undefined : () => handleAnalyse()
              }
            >
              {explanationStatusLoading ? "Evaluating..." : "Evaluate"}
            </Button>
          </Box>
        </Box>
      </DialogActions>

      {whatifDialogOpen && (
        <WhatifAnalyseDialog
          open={whatifDialogOpen}
          sessionId={sessionID}
          instance={result[selectedSuggestionIndex]}
          dataHeaders={allFeatures}
          expression={selectedPopulation.model}
          allPopulations={allPopulations}
          targetLabel={targetLabel}
          onClose={() => setWhatifDialogOpen(false)}
        />
      )}

      {changeExpressionOpen && (
        <ChangeExpressionDialog
          allPopulations={allPopulations || []}
          onSelect={setSelectedPopulation}
          onClose={() => setChangeExpressionOpen(false)}
        />
      )}
    </Dialog>
  );
};

export default CounterFactualDialog;
