import { useState, useEffect, createRef } from "react";
import {
  Dialog,
  DialogActions,
  DialogContent,
  DialogTitle,
  Grid,
  Typography,
  Alert,
  AlertTitle,
  FormControlLabel,
  Checkbox,
  Box,
  Select,
  MenuItem,
  Skeleton,
  Button,
  CircularProgress,
  useTheme,
  Stack,
} from "@mui/material";
import { useMutation, useQuery, useQueryClient } from "react-query";
import WhatifAnalyseDialog from "../whatif/WhatifAnalyseDialog";
import ApiClient from "src/axios";
import {
  DialogMode,
  GET_COUNTERFATUAL_LOGS_QUERY_KEY,
  GET_EXPLANATIONS_QUERY_KEY,
  GET_EXPLANATION_QUERY_KEY,
} from "src/utils/types";
import {
  DEFAULT_RESULT_COUNTS,
  EXPLANATION_STATUS_CHECK_INTERVAL,
  EXPLANATION_STATUS_RETRY_COUNT,
  ExplanationState,
  CounterfactualAlgorithmType,
  ModelType,
  DEFAULT_CODICE_MAX_ITER,
  ContinuousDistanceType,
  LossType,
  DEFAULT_CODICE_WEIGHT,
} from "./utils";
import ChangeExpressionDialog from "../common/expression-table/ChangeExpressionDialog";
import SelectedExpressionTable from "../common/expression-table/SelectedExpressionTable";
import SuggestionsTable from "./SuggestionsTable";
import CounterfactualLogs from "./CounterfactualLogs";
import SelectedInstanceTable from "./SelectedInstanceTable";
import useNotifier, { NotificationType } from "src/hooks/use-notify";
import DiceFormSection from "./DiceFormSection";
import CodiceFormSection from "./CodiceFormSection";
import { waiter } from "src/utils/Utils";
import { GitBranch } from "phosphor-react";

const CounterFactualDialog = ({
  open,
  selectedExplanation,
  selectedExpression: initialExpression,
  mode,
  desiredOutput,
  explanationID: initialExplanationID,
  selectedInstanceIndex,
  existingFeatureVars,
  instance,
  sessionID,
  featureSet,
  allPopulations,
  targetLabel,
  counterfactualAlgorithm,
  modelType: cfModelType,
  onClose,
}) => {
  const theme = useTheme();
  const { notify } = useNotifier();
  const queryClient = useQueryClient();
  const logsRef = createRef();
  const suggestionsRef = createRef();

  const [algorithmType, setAlgoruthmType] = useState(
    counterfactualAlgorithm || CounterfactualAlgorithmType.DICE
  );
  const [modelType, setModelType] = useState(
    cfModelType || ModelType.REGRESSION
  );
  const [selectedPopulation, setSelectedPopulation] = useState(null);
  const [logs, setLogs] = useState(null);
  const [showLogs, setShowLogs] = useState(false);
  const [loading, setLoading] = useState({
    analyse: false,
    checkExplanationStatus: false,
    checkLogs: false,
    checkResults: false,
  });
  const [error, setError] = useState({
    analyse: null,
    explanation: null,
    logs: null,
  });
  const [result, setResult] = useState(null);
  const [allFeatures, setAllFeatures] = useState(null);
  const [resultsCount, setResultsCount] = useState(DEFAULT_RESULT_COUNTS);
  const [maxIter, setMaxIter] = useState(DEFAULT_CODICE_MAX_ITER);
  const [continuousDistanceType, setContinuosDistanceType] = useState(
    selectedExplanation
      ? selectedExplanation.configuration?.cfsearch?.continuous_distance?.type
      : ContinuousDistanceType.Weighted_l1
  );
  const [lossType, setLossType] = useState(
    selectedExplanation
      ? selectedExplanation.configuration?.cfsearch?.loss_type
      : LossType.Regression.MSE
  );
  const functionWeights =
    selectedExplanation?.configuration?.cfsearch?.objective_function_weights ||
    [];
  const [distance, setDistance] = useState({
    status: true,
    value:
      functionWeights.length > 0 ? functionWeights[0] : DEFAULT_CODICE_WEIGHT,
  });
  const [sparsity, setSparsity] = useState({
    status: selectedExplanation
      ? Boolean(selectedExplanation?.configuration?.cfsearch?.sparsity)
      : true,
    value:
      functionWeights.length > 1 ? functionWeights[1] : DEFAULT_CODICE_WEIGHT,
  });
  const [coherence, setCoherence] = useState({
    status: selectedExplanation
      ? Boolean(selectedExplanation?.configuration?.cfsearch?.coherence)
      : true,
    value:
      functionWeights.length > 2 ? functionWeights[2] : DEFAULT_CODICE_WEIGHT,
  });
  const [outputRange, setOutputRange] = useState({
    start: null,
    finish: null,
  });
  const [classificationOutput, setClassificationOutput] = useState("");
  const [codiceConfigs, setCodiceConfigs] = useState(null);
  const [explanationID, setExplanationID] = useState(null);
  const [explanationStatus, setExplanationStatus] = useState(
    ExplanationState.UNKNOWN
  );
  const [selectedSuggestionIndex, setSelectedSugesstionIndex] = useState(-1);
  const [whatifDialogOpen, setWhatifDialogOpen] = useState(false);
  const [changeExpressionOpen, setChangeExpressionOpen] = useState(false);

  const explanationStatusLoading =
    loading.analyse ||
    explanationStatus === ExplanationState.CREATED ||
    explanationStatus === ExplanationState.PENDING ||
    explanationStatus === ExplanationState.RUNNING;

  const isDice = Boolean(algorithmType == CounterfactualAlgorithmType.DICE);
  const isCodice = Boolean(
    algorithmType == CounterfactualAlgorithmType.CO_DICE
  );

  useEffect(() => {
    let initialPopulation;
    const createCondition = (item) =>
      String(item.model) == String(initialExpression);
    const updateCondition = (item) =>
      String(item.id) == String(selectedExplanation?.exprId);
    (allPopulations || []).forEach((item, index) => {
      if (
        mode == DialogMode.CREATE
          ? createCondition(item)
          : updateCondition(item)
      ) {
        initialPopulation = {
          ...item,
          index: index,
        };
        return;
      }
    });
    setSelectedPopulation(initialPopulation);

    if (mode === DialogMode.UPDATE) {
      if (cfModelType == ModelType.CLASSIFICATION) {
        setClassificationOutput(desiredOutput);
      } else {
        setOutputRange({
          start: desiredOutput[0],
          finish: desiredOutput[1],
        });
      }
      setExplanationID(initialExplanationID);
    }
  }, []);

  useEffect(() => {
    if (result) {
      suggestionsRef?.current?.scrollIntoView({
        behavior: "smooth",
        block: "start",
      });
    }
  }, [result]);

  const { mutateAsync: analyseMutation } = useMutation({
    mutationFn: ({ algorithmType, data }) =>
      ApiClient.post(
        `/api/${algorithmType}?explanationId=${explanationID ?? ""}`,
        data
      ),
  });

  useQuery({
    queryKey: [GET_EXPLANATION_QUERY_KEY],
    queryFn: () =>
      ApiClient.get(`/api/explanation/${explanationID}?type=counterfactuals`),
    enabled:
      Boolean(explanationID) &&
      (mode === DialogMode.CREATE
        ? explanationStatus !== ExplanationState.COMPLETED
        : true) &&
      explanationStatus !== ExplanationState.FAILED,
    refetchInterval: EXPLANATION_STATUS_CHECK_INTERVAL,
    retry: EXPLANATION_STATUS_RETRY_COUNT,
    onSuccess: (res) => {
      const currentState = res.data.state;
      setExplanationStatus(currentState);
      if (
        currentState === ExplanationState.COMPLETED ||
        currentState === ExplanationState.FAILED
      ) {
        setResult(res.data.explanations || []);
        setAllFeatures(res.data.configuration?.allFeatures || []);
        setLoading({
          ...loading,
          checkExplanationStatus: false,
        });
      }
    },
    onError: (err) => {
      setError({
        ...error,
        explanation: err?.response?.data?.errors || [
          { message: "Error occured. Please try later!" },
        ],
      });
      setLoading({
        ...loading,
        checkExplanationStatus: false,
      });
    },
    onSettled: () => {
      setLoading({
        ...loading,
        checkLogs: true,
        checkExplanationStatus: false,
      });
    },
  });

  useQuery({
    queryKey: [GET_COUNTERFATUAL_LOGS_QUERY_KEY],
    queryFn: () =>
      ApiClient.get(
        `/api/explanations/${explanationID}/logs/data?type=counterfactuals`
      ),
    enabled: loading.checkLogs,
    refetchInterval: EXPLANATION_STATUS_CHECK_INTERVAL,
    onSuccess: (res) => {
      setLogs(res.data);
      setLoading({ ...loading, checkLogs: false });
    },
    onError: (err) => {
      setError({
        ...error,
        logs: err?.response?.data?.errors || [
          { message: "Error occured. Please try later!" },
        ],
      });
    },
  });

  const handleAnalyse = () => {
    setLoading({ ...loading, analyse: true });
    setResult(null);

    let body;
    if (isDice) {
      body = {
        queryInstanceIdx: selectedInstanceIndex,
        expr: selectedPopulation?.model,
        exprId: selectedPopulation?.id,
        featuresToVary: existingFeatureVars,
        originalInstance: instance,
        desiredRange: [Number(outputRange.start), Number(outputRange.finish)],
        cfCount: Number(resultsCount),
        sessionId: sessionID,
        allFeatures: featureSet,
      };
    }
    if (isCodice) {
      body = {
        config: {
          params: {
            ...codiceConfigs.params,
            expr: selectedPopulation?.model,
            exprId: selectedPopulation?.id,
            desiredOutput:
              modelType == ModelType.REGRESSION
                ? [Number(outputRange.start), Number(outputRange.finish)]
                : classificationOutput,
            cfCount: Number(resultsCount),
            maxIter: Number(maxIter),
            target: targetLabel,
            queryInstanceIdx: selectedInstanceIndex,
          },
          model: {
            ...codiceConfigs.model,
            model_type: modelType,
          },
          cfsearch: {
            ...codiceConfigs.cfsearch,
            loss_type: lossType,
            continuous_distance: {
              ...codiceConfigs.cfsearch.continuous_distance,
              type: continuousDistanceType,
            },
            sparsity: sparsity.status,
            coherence: coherence.status,
            objective_function_weights: [
              distance.value,
              sparsity.value,
              coherence.value,
            ],
          },
          allFeatures: featureSet,
          originalInstance: instance,
        },
        sessionId: sessionID,
        expr: selectedPopulation?.model,
        exprId: selectedPopulation?.id,
      };
    }

    analyseMutation({
      algorithmType: isDice
        ? CounterfactualAlgorithmType.DICE
        : CounterfactualAlgorithmType.CO_DICE,
      data: body,
    })
      .then((res) => {
        setExplanationStatus(res.data.state);
        setExplanationID(res.data.id);
        queryClient.invalidateQueries([GET_EXPLANATIONS_QUERY_KEY, sessionID]);
      })
      .catch((err) => {
        setError({
          ...error,
          analyse: err?.response?.data?.errors || [
            { message: "Error occured. Please try later!" },
          ],
        });
      })
      .finally(async () => {
        await waiter(1000); // wait 1 sec to show loading
        setLoading({
          ...loading,
          analyse: false,
          checkExplanationStatus: true,
        });
      });
  };

  const handleChangeShowLogs = (event) => {
    const isChecked = event.target.checked;
    setShowLogs(isChecked);
    if (isChecked) {
      logsRef?.current?.scrollIntoView({ behavior: "smooth", block: "start" });
    }
  };

  const handleOpenWhatifDialog = () => {
    if (!selectedPopulation) {
      notify(
        NotificationType.ERROR,
        "Select an expression to open what-if analyse"
      );
      return;
    } else {
      setWhatifDialogOpen(true);
    }
  };

  return (
    <Dialog open={open} onClose={onClose} maxWidth="lg" fullWidth>
      <DialogTitle sx={{ m: 0, p: 2 }} variant="h6" fontWeight={600}>
        {`Counterfactual Analyse`}
      </DialogTitle>
      <DialogContent dividers>
        <Box sx={{ display: "flex", flexDirection: "column", gap: "1.5rem" }}>
          {mode == DialogMode.UPDATE && !selectedPopulation && (
            <Alert severity="warning">
              <AlertTitle>Expression not found</AlertTitle>
              No expression found. You may have run your session again. Select
              new expression to continue.
            </Alert>
          )}
          <SelectedExpressionTable
            selectedPopulation={selectedPopulation}
            onChange={() => setChangeExpressionOpen(true)}
          />
        </Box>
        <Box
          sx={{
            display: "flex",
            alignItems: "center",
            gap: "1rem",
            marginTop: "2.5rem",
            marginBottom: "2rem",
          }}
        >
          <Typography
            sx={{
              fontWeight: "600",
            }}
          >
            Algorithm Type:
          </Typography>
          <Select
            id="algorithm-select"
            size="small"
            value={algorithmType}
            onChange={(e) => setAlgoruthmType(e.target.value)}
            sx={{ width: "12rem" }}
          >
            <MenuItem
              key={CounterfactualAlgorithmType.DICE}
              value={CounterfactualAlgorithmType.DICE}
            >
              Dice
            </MenuItem>
            <MenuItem
              key={CounterfactualAlgorithmType.CO_DICE}
              value={CounterfactualAlgorithmType.CO_DICE}
            >
              Co-Dice
            </MenuItem>
          </Select>
        </Box>
        <Grid container spacing={3}>
          {isDice && (
            <DiceFormSection
              outputRange={outputRange}
              resultsCount={resultsCount}
              onChangeResultCount={setResultsCount}
              onChangeOutputRange={setOutputRange}
            />
          )}
          {isCodice && (
            <CodiceFormSection
              modelType={modelType}
              resultsCount={resultsCount}
              outputRange={outputRange}
              maxIter={maxIter}
              classificationOutput={classificationOutput}
              continuousDistanceType={continuousDistanceType}
              lossType={lossType}
              distance={distance}
              sparsity={sparsity}
              coherence={coherence}
              onChangeCodiceConfigs={setCodiceConfigs}
              onChangeMaxIter={setMaxIter}
              onChangeClassificationOutput={setClassificationOutput}
              onChangeOutputRange={setOutputRange}
              onChangeResultCount={setResultsCount}
              onChangeModelType={setModelType}
              onChangeDistanceType={setContinuosDistanceType}
              onChangeLossType={setLossType}
              onChangeDistance={setDistance}
              onChangeSparsity={setSparsity}
              onChangeCoherence={setCoherence}
            />
          )}
          <Grid item xs={12}>
            <SelectedInstanceTable
              targetLabel={targetLabel}
              featureSet={featureSet}
              instance={instance}
            />
          </Grid>
          {result && (
            <Grid item xs={12}>
              <Box
                sx={{
                  mb: 1,
                  mt: 1,
                  display: "flex",
                  alignItems: "center",
                  justifyContent: "space-between",
                  gap: "1.5rem",
                }}
              >
                <div ref={suggestionsRef}>
                  <Box>
                    <Typography variant="body2" fontWeight={600}>
                      Suggestions
                    </Typography>
                    <Typography variant="caption">
                      (The values ​​on the yellow colored fields are different
                      from the original values.)
                    </Typography>
                  </Box>
                </div>
                {selectedSuggestionIndex > -1 &&
                  result &&
                  result.length > 0 && (
                    <Button
                      variant="outlined"
                      size="small"
                      sx={{
                        textTransform: "none",
                        fontSize: ".875rem",
                        padding: "0.3rem .75rem",
                        height: "fit-content",
                      }}
                      startIcon={<GitBranch />}
                      onClick={handleOpenWhatifDialog}
                    >
                      Use instance at whatif analyse
                    </Button>
                  )}
              </Box>
              {loading.checkExplanationStatus ? (
                <>
                  {new Array(5).fill().map(() => (
                    <Skeleton
                      variant="rectangular"
                      width="100%"
                      height={30}
                      style={{ marginBottom: ".25rem" }}
                    />
                  ))}
                </>
              ) : result.length === 0 ? (
                <Alert severity="info">No suggestions found</Alert>
              ) : (
                <SuggestionsTable
                  allFeatures={featureSet || allFeatures}
                  result={result}
                  instance={instance}
                  selectedSuggestionIndex={selectedSuggestionIndex}
                  onSelectedSugesstionIndex={setSelectedSugesstionIndex}
                />
              )}
            </Grid>
          )}
          {(explanationStatus === ExplanationState.COMPLETED ||
            explanationStatus === ExplanationState.FAILED) &&
            !result &&
            mode !== DialogMode.UPDATE && (
              <Grid item xs={12} mt={2}>
                <Alert severity="info">Counterfactual results not found.</Alert>
              </Grid>
            )}
          {error.analyse && (
            <Grid item xs={12}>
              {error.analyse.map((item) => (
                <Stack mt={1} width={"100%"}>
                  <Alert severity="error">
                    <AlertTitle>Counterfactual analyse error.</AlertTitle>
                    {item.message}
                  </Alert>
                </Stack>
              ))}
            </Grid>
          )}
          {showLogs && (
            <Grid item xs={12}>
              <CounterfactualLogs
                ref={logsRef}
                logs={logs}
                errorLogs={error.logs}
              />
            </Grid>
          )}
        </Grid>
      </DialogContent>
      <DialogActions>
        <Box
          sx={{
            width: "100%",
            display: "flex",
            justifyContent: "space-between",
            alignItems: "center",
            paddingLeft: "20px",
            paddingRight: "20px",
          }}
        >
          <FormControlLabel
            control={
              <Checkbox
                checked={showLogs}
                onChange={handleChangeShowLogs}
                name="checkedA"
              />
            }
            label="Show logs"
          />
          <Box
            sx={{ display: "flex", alignItems: "center", marginLeft: "auto" }}
          >
            <Button
              color={"error"}
              onClick={onClose}
              sx={{ marginRight: ".25rem" }}
            >
              Close
            </Button>
            <Button
              disabled={
                !resultsCount ||
                !selectedPopulation ||
                algorithmType == CounterfactualAlgorithmType.CO_DICE
                  ? !codiceConfigs
                  : modelType == ModelType.CLASSIFICATION
                    ? !classificationOutput
                    : !outputRange.start || !outputRange.finish
              }
              variant="contained"
              startIcon={
                explanationStatusLoading ? (
                  <CircularProgress
                    size={16}
                    sx={{ color: theme.palette.primary.contrastText }}
                  />
                ) : undefined
              }
              onClick={
                explanationStatusLoading ? undefined : () => handleAnalyse()
              }
            >
              {explanationStatusLoading ? "Evaluating..." : "Evaluate"}
            </Button>
          </Box>
        </Box>
      </DialogActions>

      {whatifDialogOpen && selectedPopulation && (
        <WhatifAnalyseDialog
          open={whatifDialogOpen}
          sessionId={sessionID}
          instance={result[selectedSuggestionIndex]}
          dataHeaders={allFeatures}
          expression={selectedPopulation.model}
          allPopulations={allPopulations}
          targetLabel={targetLabel}
          onClose={() => setWhatifDialogOpen(false)}
        />
      )}

      {changeExpressionOpen && (
        <ChangeExpressionDialog
          allPopulations={allPopulations || []}
          onSelect={setSelectedPopulation}
          onClose={() => setChangeExpressionOpen(false)}
        />
      )}
    </Dialog>
  );
};

export default CounterFactualDialog;
