import { Series } from "data-forge";

/*
  Compute number of sites in stacked EC table
  (defined by tuples (segment_i, i))
*/
export const extractNumSites = pairs => {
  const pairsFwd = pairs.subset(["i", "segment_i"]);
  const pairsBwd = pairs
    .subset(["j", "segment_j"])
    .renameSeries({ j: "i", segment_j: "segment_i" });

  const allPairs = pairsFwd.concat(pairsBwd);
  return allPairs.distinct(row => row.segment_i + "_" + row.i).count();
};

/*
  Filter EC table based on predefined filter criteria.
  Assumes that segment_i and segment_j are specified
  and that table is sorted
*/
export const filterCouplings = (pairs, filters) => {
  let {
    minSeqDist,
    minScore,
    minProbability,
    maxRank,
    maxRankFraction
  } = filters;

  // console.log("OUTPUT\n", pairs.head(10).toString());
  // first filter based on sequence distance, since this
  // influences the rank of couplings
  let pairsDistFilt;
  if (minSeqDist) {
    // reset index after filtering to get consecutive count from 0
    pairsDistFilt = pairs
      .where(
        row =>
          Math.abs(row.i - row.j) >= minSeqDist ||
          row.segment_i !== row.segment_j
      )
      .resetIndex();
  } else {
    pairsDistFilt = pairs.resetIndex();
  }

  // select based on rank - since we can't easily get the index into a column
  // without Chrome blowing up, use index-based selection instead
  if (maxRank != null) {
    // last index is inclusive
    pairsDistFilt = pairsDistFilt.between(0, maxRank - 1);
  }

  // filter by relative fraction of ECs (e.g. 1.0 for top L)
  if (maxRankFraction != null) {
    // compute number of sites in model
    const numSites = extractNumSites(pairs);

    // obtain absolute number of pairs from relative fraction
    const absoluteRank = Math.round(maxRankFraction * numSites);

    pairsDistFilt = pairsDistFilt.between(0, absoluteRank - 1);
  }

  const remaining = pairsDistFilt.where(
    row =>
      (minScore != null ? row.score >= minScore : true) &&
      (minProbability != null ? row.probability >= minProbability : true)
  );

  return remaining;
};

/*
  Swap i and j and columns ending in _i/_j in EC dataframe
*/
export const invertPairs = pairs => {
  const replaceColumnNames = (df, src, dest) => {
    const columnNames = df
      .getColumnNames()
      .filter(c => c === src || c.endsWith("_" + src));

    const renameMap = Object.fromEntries(
      columnNames.map(x => [
        x,
        x === src ? dest : x.replace("_" + src, "_" + dest)
      ])
    );

    return df.renameSeries(renameMap);
  };

  // clumsy three-way exchange of column names, but it looks
  // like df.renameSeries can't do joint rename
  const kj = replaceColumnNames(pairs, "i", "k");
  const ki = replaceColumnNames(kj, "j", "i");
  const ji = replaceColumnNames(ki, "k", "j");

  // return dataframe with inverted column suffixes/positions
  return ji;
};

/*
  Compute cumulative coupling strength per site

  Note: scores < 0 will be replaced by 0 in calculation
*/
export const computeCumulativeCoupling = pairs => {
  // make sure there are no negative scores left
  // const pairsFilt = pairs.where(row => row.score >= 0);
  const pairsFilt = pairs;

  // create swapped version of dataframe
  const pairsFiltInv = invertPairs(pairsFilt);
  // console.log("DF INV\n", dfFiltInv.head(10).toString());

  // concatenate forward and backward pairs
  const pairsFwdBwd = pairsFilt.concat(pairsFiltInv);

  // compute enrichment per position;
  // substitute scores < 0 with 0
  const cumulativeCoupling = pairsFwdBwd
    .groupBy(row => row.i + "_" + row.segment_i)
    .select(group => ({
      i: group.first().i,
      segment_i: group.first().segment_i,
      A_i: group.first().A_i,
      pair_count: group.count(),
      cumulative_score: group.deflate(row => Math.max(0, row.score)).sum()
    }))
    .inflate()
    .orderByDescending(row => row.cumulative_score);

  return cumulativeCoupling;
};

/*
  Add conservation info to cumulative couplings table
*/
export const addConservation = (cumulativeCouplings, conservation) => {
  if (!cumulativeCouplings || !conservation) {
    return cumulativeCouplings;
  }

  // check if segment info is available for merging the two tables
  const useSegments =
    cumulativeCouplings.hasSeries("segment_i") &&
    conservation.hasSeries("segment_i");

  // create map from segment/position to conservation (do not use not-so-great join function of dataforge);
  // if this ever fails, check if position column is really still called "i" or not renamed to "i.1" internally
  // by dataforge (as toString() method displays it)
  const consMap = conservation.toObject(
    row => (useSegments ? row.segment_i + "_" + row.i : row.i),
    row => row.conservation
  );

  // assign new series with conservation per row in cumulative couplings table; all
  // values there must be present in alignment table by definition
  const cons = cumulativeCouplings.deflate(
    row => consMap[useSegments ? row.segment_i + "_" + row.i : row.i]
  );

  // add new series to dataframe and return
  return cumulativeCouplings.withSeries("conservation", cons);
};

/*
  compute precision of ECs given a residue-residue distance threshold
*/
export const computePrecision = (pairs, distanceThreshold) => {
  // if no dataframe or no distances, nothing to do here...
  if (!pairs || !pairs.hasSeries("dist")) {
    return pairs;
  }

  // extract distances from dataframe
  // note: use deflate instead of getSeries(), since the latter
  // drops null values, leading to a shorter array than the original series
  const dists = pairs.deflate(row => row.dist);

  // compute precision
  const precision = [];
  let tpSum = 0;
  let pairSum = 0;

  dists.forEach((dist, i) => {
    let curPPV;
    // explicitly check if !== null since NaN may be encoded as null
    if (dist !== null && Number.isFinite(dist)) {
      tpSum += dist <= distanceThreshold;
      pairSum += 1;
      curPPV = tpSum / pairSum;
    } else {
      curPPV = null;
    }

    precision.push(curPPV);
  });

  // return updated dataframe (withSeries replaces original precision column if present)
  return pairs.withSeries({
    precision: new Series(precision)
  });
};

/*
  Dynamically add residue distances for pairs to EC table
*/
export const addDistances = (
  pairs,
  distanceInfo,
  isDistanceMap,
  useMultimerDistances
) => {
  if (!pairs) {
    return null;
  }

  // store final distance annotation in here
  let dists;

  // check if this is distance map calculated in frontend, or table of i/j/dist from backend
  if (isDistanceMap) {
    // no multimer info in this case for now, so can safely ignore the flag

    // distanceInfo may be null (e.g. if still computing or not yet loaded)
    if (distanceInfo) {
      dists = pairs.deflate(row => {
        const dist = distanceInfo.get(row.i + "_" + row.j);
        // note missing values need to be set to NaN or somehow dataforge shortens series...
        return dist ? dist : NaN;
      });
    } else {
      // if no distance info available, replace the distance column with missing values
      // and remove precision column that may be present in original dataframe
      // (correspond to experimental structures used by pipeline)
      dists = pairs.deflate(row => NaN);
    }
  } else {
    // note that in this case (distance data from backend) we need to deal with multimer setting

    // check if structure info is currently available (or reloading)
    if (distanceInfo.data) {
      // check if pair distances explicitly given in experimental structure data
      // (this means a subset of structures is selected and we need to default override
      // distance annotation in couplings table)
      if (distanceInfo.pairDistances) {
        // note that this should ideally be done with a join, however, dataforge join is very slow
        // so create own join-like approach using dictionary

        // if multimer distance is available but deselected, ignore it
        const valueFunc =
          !useMultimerDistances &&
          distanceInfo.pairDistances.hasSeries("dist_intra")
            ? row => row.dist_intra
            : row => row.dist;

        // create map i_j -> distance
        const distMap = distanceInfo.pairDistances.toObject(
          row => row.i + "_" + row.j,
          valueFunc
        );

        dists = pairs.deflate(row => {
          const dist = distMap[row.i + "_" + row.j];
          return dist ? dist : NaN;
        });
      } else {
        // no override, but we still need to distinguish if multimer distances should be used or not
        dists =
          !useMultimerDistances && pairs.hasSeries("dist_intra")
            ? pairs.deflate(row => row.dist_intra)
            : pairs.deflate(row => row.dist);
      }
    } else {
      // if no distance info currently available/loaded, show as missing data
      dists = pairs.deflate(row => NaN);
    }
  }
  return pairs.dropSeries("precision").withSeries("dist", dists);
};
