import React, { useState, useReducer, useMemo, useEffect } from "react";
import { Button, ButtonGroup, Divider, Icon, Spinner } from "@blueprintjs/core";
import { interpolateRdBu } from "d3";
import Heatmap from "./../common/Heatmap";
import {
  StructureViewer,
  registerColorSchemeFromResidueMap
} from "./StructureViewer";
import { ErrorRefetcher } from "../common/Helpers";
import { Legend, makeColorbar } from "../common/Legend";
import { createColorMapper } from "./../../utils/Helpers";
import { createSiteId, DEFAULT_SEGMENT_ID } from "../../utils/Segments";
import { RESET_SELECTION_ICON } from "../../utils/Constants";
import { createPanel } from "./Panel";

const AA_LIST_PROPERTY = "WFYPMILVAGCSTQNDEHRK";
const MIN_MATIRX_CELL_SIZE = 10;
const MAX_MATRIX_CELL_SIZE = 20;
const DEFAULT_STRUCTURE_COLOR = "#FFFFFF";

export const reorderValues = (x, targetColumn, order, reverse) => {
  // const ordered = Array(x.length);
  // note this needs to be based on order rather than data, since self-mutant not in data table
  const ordered = Array(order.length).fill(0);
  x.forEach(r => {
    // note that position mapping is inefficient and should rather be done with a map
    // console.log(r.subs, order.indexOf(r.subs));
    const idx = reverse
      ? order.length - order.indexOf(r.subs) - 1
      : order.indexOf(r.subs);
    ordered[idx] = r[targetColumn];
  });

  return ordered;
};

const extractMutationData = (
  mutationTable,
  substitutionOrder,
  reverseOrder
) => {
  let segments = null;
  let positions = [];
  let wildtype = [];
  let epistaticMatrix = [];
  let independentMatrix = [];

  // set segments variable to list if segment is present, otherwise leave as null
  if (mutationTable.getColumnNames().includes("segment")) {
    segments = [];
  }

  // TODO: make sure order of groupby is guaranteed
  // TODO: remove head
  const dataByPos = mutationTable.groupBy(row =>
    row.segment ? row.segment + "__" + row.pos : row.pos
  );

  // go through sub-dataframe for each position
  dataByPos.forEach(x => {
    const xArr = x.toArray();

    // add position and segment (if present)
    if (xArr.segment) {
      segments.push(xArr[0].segment);
    }
    positions.push(xArr[0].pos);

    // store wildtype residue/base
    wildtype.push(xArr[0].wt);

    // store current data column in right order of entries
    // reverse order so rendered in same way as by pipeline
    epistaticMatrix.push(
      reorderValues(
        xArr,
        "prediction_epistatic",
        substitutionOrder,
        reverseOrder
      )
    );
    independentMatrix.push(
      reorderValues(
        xArr,
        "prediction_independent",
        substitutionOrder,
        reverseOrder
      )
    );
  });

  return {
    segments: segments,
    positions: positions,
    wildtype: wildtype,
    epistaticMatrix: epistaticMatrix,
    independentMatrix: independentMatrix
  };
};

// TODO: replace with global helper function
/*
export const createColorMapper = (data) => {
  const minVal = data.min();
  const maxVal = data.max();
  const range = Math.max(Math.abs(minVal), Math.abs(maxVal));

  return v => color(interpolateRdBu(1 - (v + range) / (2 * range))).hex();
};
*/

export const selectionReducer = (state, action) => {
  switch (action.action) {
    case "setAll":
      return {
        sites: action.sites != null ? action.sites : state.sites,
        pairs: action.pairs != null ? action.pairs : state.pairs,
        substitutions:
          action.substitutions != null
            ? action.substitutions
            : state.substitutions,
        siteSubstitutions:
          action.siteSubstitutions != null
            ? action.siteSubstitutions
            : state.siteSubstitutions
      };

    case "addSites":
      // dispatched change to site selection
      const selectedSiteIds = action.sites;

      // compute updated selection (remove any already selected sites if selected again)
      const newSelection = state.sites
        .filter(
          // keep all already selected sites if they are not in new one
          s => !selectedSiteIds.includes(s)
        )
        .concat(
          // vice versa, keep all newly selected sites if they are not in old selection
          selectedSiteIds.filter(s => !state.sites.includes(s))
        );

      // adding a site in the use case of this panel eliminates siteSubstitutions
      return {
        ...state,
        sites: newSelection,
        siteSubstitutions: {}
      };

    case "reset":
      return {
        sites: [],
        pairs: [],
        substitutions: [],
        siteSubstitutions: {}
      };

    default:
      return {
        sites: [],
        pairs: [],
        substitutions: [],
        siteSubstitutions: {}
      };
  }
};

/*
  EVmutation results panel

  TODO:
  - allow to toggle panels
  - implement all settings (also in parent component)
  - allow to select which matrix to display (epistatic or independent)
*/
export const MutationPanel = ({
  mutationData,
  refetchMutations,
  jobGroup,
  job,
  showSegments,
  showMatrix,
  showViewer,
  mutationFileDownloadLink,
  structures,
  availableStructures,
  dispatchStructureSelection,
  refetchFailedStructures
}) => {
  const [cellSize, setCellSize] = useState(10);
  const [showEpistaticModel, setShowEpistaticModel] = useState(true);

  const [selection, dispatchSelection] = useReducer(selectionReducer, {
    sites: [],
    pairs: [],
    substitutions: [],
    siteSubstitutions: {}
  });

  // Dispatch reset of selections if job/subjob changes
  // Note that exclusion of selection from useEffect condition is on purpose
  useEffect(() => {
    dispatchSelection({ action: "reset" });
  }, [jobGroup, job]);

  const substitutionOrder = AA_LIST_PROPERTY;
  const reverseSubstitutionOrder = true;

  let yLabels = Array.from(AA_LIST_PROPERTY);
  if (reverseSubstitutionOrder) {
    yLabels.reverse();
  }

  let extractedData;
  let siteIds;

  // prepare data for heatmap
  if (mutationData.table) {
    extractedData = extractMutationData(
      mutationData.table,
      substitutionOrder,
      reverseSubstitutionOrder
    );

    // create site IDs for all positions in matrix for selection handling
    siteIds = extractedData.positions.map((pos, i) =>
      createSiteId(
        extractedData.segments ? extractedData.segments[i] : DEFAULT_SEGMENT_ID,
        pos
      )
    );
  }

  // colormap for heatmap rendering;
  // memoize for structure colormap update below
  const curColorMapper = useMemo(() => {
    return mutationData.table
      ? createColorMapper(
          mutationData.table.getSeries(
            showEpistaticModel
              ? "prediction_epistatic"
              : "prediction_independent"
          ),
          interpolateRdBu,
          true,
          true
        )
      : null;
  }, [mutationData.table, showEpistaticModel]);

  const colorScheme = useMemo(() => {
    let colorMap = {};
    if (extractedData) {
      // get respective 2D array with experimental data
      const expMatrix = showEpistaticModel
        ? extractedData.epistaticMatrix
        : extractedData.independentMatrix;

      // compute coloring on a per-residue basis
      extractedData.positions.forEach((pos, i) => {
        let posValue;

        // if substitution selected, color entire structure
        // by substitution effect to this residue
        if (selection.substitutions.length > 0) {
          posValue = expMatrix[i][yLabels.indexOf(selection.substitutions[0])];
        } else {
          // either show average effect per position or selected mutation:
          // if particular substitution selected, show single substitution effect
          // (need to look at actual site ID here rather than residue number alone)
          if (siteIds[i] in selection.siteSubstitutions) {
            // only one residue selected by definition so take 0-th element
            const selectedRes = selection.siteSubstitutions[siteIds[i]][0];
            posValue = expMatrix[i][yLabels.indexOf(selectedRes)];
          } else {
            // show average
            posValue =
              expMatrix[i].reduce((a, b) => a + b) / expMatrix[i].length;
          }
        }

        // put position value through colormapper and assign to residue color map
        colorMap[pos] = curColorMapper(posValue);
      });
    }

    return registerColorSchemeFromResidueMap(colorMap, DEFAULT_STRUCTURE_COLOR);
  }, [
    extractedData,
    showEpistaticModel,
    selection,
    curColorMapper,
    yLabels,
    siteIds
  ]);

  // console.log(extractedData ? extractedData.epistaticMatrix : null);

  /*
    Legend renderer shared between mutation heatmap and 3D viewer panels
  */
  const renderLegend = showAverageNote => {
    let epistaticData;
    let independentData;

    if (mutationData && mutationData.table) {
      epistaticData = mutationData.table.getSeries("prediction_epistatic");
      independentData = mutationData.table.getSeries("prediction_independent");
    }

    const makeColorbarRow = currentData => {
      let colorbarContent;
      if (currentData) {
        const colorbar = makeColorbar(
          curColorMapper,
          currentData.min(),
          currentData.max(),
          "right",
          {
            width: "150px",
            height: "20px",
            marginLeft: "0.5em",
            marginRight: "0.5em"
          }
        );

        colorbarContent = (
          <div
            style={{
              display: "flex",
              flexDirection: "row",
              alignItems: "center"
            }}
          >
            {currentData.min().toFixed(1)}
            <div>{colorbar}</div>
            {currentData.max().toFixed(1)}
          </div>
        );

        return (
          <tr>
            <td>{colorbarContent}</td>
          </tr>
        );
      } else {
        return null;
      }
    };

    const colWidth = "40px";

    const legendContent = (
      <>
        <h6 className="bp3-heading">Mutation effect strength</h6>
        <table
          className="bp3-html-table bp3-html-table-condensed"
          style={{
            verticalAlign: "top",
            width: "100%",
            tableLayout: "fixed",
            marginBottom: "1em"
          }}
        >
          <tbody>
            <tr>
              <td>
                <b>Epistatic model</b>
              </td>
            </tr>
            {makeColorbarRow(epistaticData)}
            <tr>
              <td>
                <b>Independent model</b>
              </td>
            </tr>
            {makeColorbarRow(independentData)}
          </tbody>
        </table>

        <h6 className="bp3-heading">Effect strength interpretation</h6>

        <table
          className="bp3-html-table bp3-html-table-condensed"
          style={{
            verticalAlign: "top",
            width: "100%",
            tableLayout: "fixed",
            marginBottom: "1em"
          }}
        >
          <tbody>
            <tr>
              <td style={{ width: colWidth, verticalAlign: "middle" }}>
                &lt;&nbsp;0
              </td>
              <td>Damaging substitution</td>
            </tr>
            <tr>
              <td style={{ width: colWidth, verticalAlign: "middle" }}>
                =&nbsp;0
              </td>
              <td>Neutral substitution</td>
            </tr>
            <tr>
              <td style={{ width: colWidth, verticalAlign: "middle" }}>
                &gt;&nbsp;0
              </td>
              <td>Beneficial substitution</td>
            </tr>
          </tbody>
        </table>

        {showAverageNote ? (
          <div style={{ display: "flex" }}>
            <Icon icon="error" style={{ marginRight: "0.5em" }} />
            Average effect per position shown on structure unless particular
            substitution selected in heatmap for that position or across
            positions
          </div>
        ) : null}
      </>
    );
    return <Legend content={legendContent} />;
  };

  const renderHeatmap = () => {
    // TODO: cache component to avoid expensive rerenders with lots of divs?
    let fullContent;

    // need different alignment if showing heatmap or spinner/error message
    let justifyContent;

    if (extractedData) {
      // TODO: add segment display (if selected)
      const xLabels = extractedData.positions.map(
        (pos, i) => extractedData.wildtype[i] + " " + pos
      );

      const dataMatrix = showEpistaticModel
        ? extractedData.epistaticMatrix
        : extractedData.independentMatrix;

      // determine indices of selected sites
      let selectedSiteIdx = [];
      siteIds.forEach((siteId, i) => {
        if (selection.sites.includes(siteId)) {
          selectedSiteIdx.push(i);
        }
      });

      // determine indices of selected rows
      let selectedRowIdx = [];
      yLabels.forEach((substitution, j) => {
        if (selection.substitutions.includes(substitution)) {
          selectedRowIdx.push(j);
        }
      });

      // map selected sites + substitutions into heatmap 0-based numbering
      let selectedCellIdx = {};
      for (let siteId in selection.siteSubstitutions) {
        // note this is a list of substitutions (even though we will only have 1 value here)
        selectedCellIdx[siteIds.indexOf(siteId)] = selection.siteSubstitutions[
          siteId
        ].map(id => yLabels.indexOf(id));
      }

      // TODO: probably nicer to have all selection logic below in reducer
      const heatmapContent = (
        <Heatmap
          data={dataMatrix}
          xLabels={xLabels}
          yLabels={yLabels}
          selectedRows={selectedRowIdx}
          selectedCols={selectedSiteIdx}
          selectedCells={selectedCellIdx}
          cellWidth={cellSize + "px"}
          cellHeight={cellSize + "px"}
          colorMap={curColorMapper}
          labelMap={({ i, j }) => {
            const subs =
              substitutionOrder[
                reverseSubstitutionOrder ? AA_LIST_PROPERTY.length - j - 1 : j
              ];
            return (
              <span>
                mutant: <b>{xLabels[i] + " " + subs}</b>
                <br />
                effect: <b>{dataMatrix[i][j]}</b>
              </span>
            );
          }}
          handleColumnClick={(i, label) => {
            const selectedSiteId = siteIds[i];
            // if clicked site is already selected, deselect, otherwise add to selection
            if (selection.sites.includes(selectedSiteId)) {
              dispatchSelection({
                action: "setAll",
                sites: selection.sites.filter(elem => elem !== selectedSiteId),
                siteSubstitutions: {}
              });
            } else {
              dispatchSelection({
                action: "setAll",
                sites: selection.sites.concat([selectedSiteId]),
                siteSubstitutions: {}
              });
            }
          }}
          handleRowClick={(j, label) => {
            const selectedSubstitution = yLabels[j];
            // only allow to select at most one substitution, if already
            // selected, deselect again
            if (selection.substitutions.includes(selectedSubstitution)) {
              dispatchSelection({
                action: "setAll",
                substitutions: [],
                siteSubstitutions: {}
              });
            } else {
              dispatchSelection({
                action: "setAll",
                substitutions: [selectedSubstitution],
                siteSubstitutions: {}
              });
            }
          }}
          handleCellClick={(i, j, cellValue) => {
            let siteSubstitutions;

            // check if cell is already select, if so, deselect...
            if (
              siteIds[i] in selection.siteSubstitutions &&
              selection.siteSubstitutions[siteIds[i]][0] === yLabels[j]
            ) {
              // copy and remove entry for clicked cell
              siteSubstitutions = { ...selection.siteSubstitutions };
              delete siteSubstitutions[siteIds[i]];
            } else {
              const newCellSelection = {};
              newCellSelection[siteIds[i]] = [yLabels[j]];

              // only one cell can be selected per column, solve this
              // using dictionary merging
              siteSubstitutions = {
                ...selection.siteSubstitutions,
                ...newCellSelection
              };
            }

            dispatchSelection({
              action: "setAll",
              sites: [],
              substitutions: [],
              siteSubstitutions: siteSubstitutions
            });
          }}
        />
      );

      const controlContent = (
        <ButtonGroup minimal={true}>
          <Button
            icon="zoom-in"
            title="Zoom in"
            onClick={() => {
              setCellSize(cellSize + 1);
            }}
            disabled={cellSize >= MAX_MATRIX_CELL_SIZE}
          />
          <Button
            icon="zoom-out"
            title="Zoom out"
            onClick={() => {
              setCellSize(cellSize - 1);
            }}
            disabled={cellSize <= MIN_MATIRX_CELL_SIZE}
          />
          <Button
            title="Clear selection"
            icon={RESET_SELECTION_ICON}
            onClick={() => dispatchSelection({ action: "reset" })}
          />
          <Divider />
          <Button
            icon="exchange"
            title="Switch probability model used for mutation effect calculation"
            /* active={showEpistaticModel} */
            onClick={() => {
              setShowEpistaticModel(!showEpistaticModel);
            }}
          >
            {showEpistaticModel ? "Epistatic" : "Independent"} model
          </Button>
          {mutationFileDownloadLink ? (
            <>
              <Divider />
              <a
                download
                href={mutationFileDownloadLink}
                className="bp3-button bp3-minimal"
                role="button"
                title="Export image"
              >
                <Icon icon="import" />
              </a>
            </>
          ) : null}
          <Divider />
          {renderLegend()}
        </ButtonGroup>
      );

      fullContent = (
        <>
          <div
            style={{
              overflow: "auto",
              overflowX: "hidden"
              // need to have overflowY if heatmap is too big too display on viewport
            }}
          >
            {heatmapContent}
          </div>
          {controlContent}
        </>
      );
      justifyContent = "space-between";
    } else {
      justifyContent = "center";
      if (mutationData.error) {
        fullContent = <ErrorRefetcher refetcher={refetchMutations} />;
      } else {
        fullContent = <Spinner />;
      }
    }

    // TODO: clean up CSS mess, reuse this globally for all components
    return createPanel(fullContent, { justifyContent: justifyContent });
  };

  const renderViewer = () => {
    // if selection mode is based on individual substitutions,
    // transform selected sites into new "fake" selection object
    let selectionTransformed = selection;

    if (
      selection.siteSubstitutions &&
      Object.keys(selection.siteSubstitutions).length > 0
    ) {
      selectionTransformed = {
        sites: Object.keys(selection.siteSubstitutions),
        pairs: []
      };
    }

    return (
      <StructureViewer
        selection={selectionTransformed}
        // if selecting sites, reset mutations
        dispatchSelectionOld={event => {
          dispatchSelection({ ...event, siteSubstitutions: {} });
        }}
        colorScheme={colorScheme}
        imageExportFileName={`EVcouplings_3Dstructure_${jobGroup}_${job}_mutations.png`}
        structures={structures}
        availableStructures={availableStructures}
        dispatchStructureSelection={dispatchStructureSelection}
        refetchFailedStructures={refetchFailedStructures}
        legend={renderLegend(true)}
      />
    );
  };

  return (
    <div
      style={{
        display: "flex",
        flexDirection: "row",
        flexWrap: "wrap"
      }}
    >
      {showMatrix ? renderHeatmap() : null}
      {showViewer ? renderViewer() : null}
    </div>
  );
};

export default MutationPanel;
