import { SymbolProps } from "@nivo/legends/dist/types/svg/symbols/types";
import { CustomLayer, CustomLayerProps } from "@nivo/line";
import React, { createElement } from "react";
import colours from "styles/colours";
import { DiamondPoint } from "utils/ChartDrawUtils/pointShapes";

interface SerieDumbbellDisplay {
  serieId: string | number;
  colour: string;
  dumbbellShapeFunction: (props: SymbolProps) => JSX.Element;
  opacity?: number;
}

interface PointInfo {
  id: string;
  symbol: (props: SymbolProps) => JSX.Element;
  x: number;
  y: number;
  fill: string;
  borderColor: string;
  pointSize: number;
  opacity: number;
}

interface ConnectingLineInfo {
  id: string;
  initialForecastPoint: PointInfo;
  latestForecastPoint: PointInfo;
}

export const DrawSquare = ({
  x,
  y,
  size,
  fill,
  opacity,
}: SymbolProps): JSX.Element => (
  <rect
    x={x}
    y={y}
    width={size}
    height={size}
    fill={fill}
    opacity={opacity}
    strokeWidth={0}
    stroke="transparent"
    style={{
      pointerEvents: "none",
    }}
  />
);

export const DrawDiamond = ({
  x,
  y,
  size,
  fill,
  opacity,
}: SymbolProps): JSX.Element => (
  <g
    transform={`translate(${x + size / 2},${y + size / 2})`}
    opacity={opacity}
    strokeWidth={0}
    stroke="transparent"
    style={{
      pointerEvents: "none",
    }}
  >
    <DiamondPoint color={fill} size={size} />
  </g>
);

export const DrawCircle = ({
  x,
  y,
  size,
  fill,
  opacity,
}: SymbolProps): JSX.Element => (
  <circle
    cx={x + size / 2}
    cy={y + size / 2}
    r={size / 2}
    fill={fill}
    opacity={opacity}
    strokeWidth={0}
    stroke="transparent"
    style={{
      pointerEvents: "none",
    }}
  />
);

type DumbbellOptions = {
  innerPointSerie?: SerieDumbbellDisplay;
  outerPointSize?: number;
  innerPointSize?: number;
};

export const generateDumbbells = (
  dumbbellStartSerie: SerieDumbbellDisplay,
  dumbbellEndSerie: SerieDumbbellDisplay,
  {
    innerPointSerie,
    outerPointSize = 5,
    innerPointSize = 5,
  }: DumbbellOptions = {}
): CustomLayer =>
  function dumbbells({ points }: CustomLayerProps): React.ReactNode {
    const dumbbellStartSeriePoints = points
      .filter((p) => p.serieId === dumbbellStartSerie.serieId)
      .map<PointInfo>((point) => ({
        id: point.id,
        symbol: dumbbellStartSerie.dumbbellShapeFunction,
        pointSize: outerPointSize,
        opacity: dumbbellStartSerie.opacity ?? 1,
        x: point.x - outerPointSize / 2,
        y: point.y - outerPointSize / 2,
        fill: dumbbellStartSerie.colour,
        borderColor: point.borderColor,
      }));
    const dumbbellEndSeriePoints = points
      .filter((p) => p.serieId === dumbbellEndSerie.serieId)
      .map<PointInfo>((point) => ({
        id: point.id,
        symbol: dumbbellEndSerie.dumbbellShapeFunction,
        pointSize: outerPointSize,
        opacity: dumbbellEndSerie.opacity ?? 1,
        x: point.x - outerPointSize / 2,
        y: point.y - outerPointSize / 2,
        fill: dumbbellEndSerie.colour,
        borderColor: point.borderColor,
      }));

    const innerPoints = innerPointSerie
      ? points
          .filter((p) => p.serieId === innerPointSerie.serieId)
          .map<PointInfo>((point) => ({
            id: point.id,
            symbol: innerPointSerie?.dumbbellShapeFunction,
            pointSize: innerPointSize,
            opacity: innerPointSerie.opacity ?? 1,
            x: point.x - innerPointSize / 2,
            y: point.y - innerPointSize / 2,
            fill: innerPointSerie.colour,
            borderColor: point.borderColor,
          }))
      : [];

    // dumbbellStartSerie points will be plotted on top of dumbbellEndSerie points (and both of those on top of any innerPoints) if they overlap
    const dumbbellPoints = [
      ...innerPoints,
      ...dumbbellEndSeriePoints,
      ...dumbbellStartSeriePoints,
    ];
    const times = new Set(dumbbellPoints.map((p) => p.x));

    const pairsToConnect = [...times].reduce((prev, time) => {
      const dumbbellStartSeriePoint = dumbbellStartSeriePoints.find(
        (i) => i.x === time
      );
      const dumbbellEndSeriePoint = dumbbellEndSeriePoints.find(
        (l) => l.x === time
      );
      return dumbbellStartSeriePoint !== undefined &&
        dumbbellEndSeriePoint !== undefined
        ? [
            ...prev,
            {
              id: time.toString(),
              initialForecastPoint: dumbbellStartSeriePoint,
              latestForecastPoint: dumbbellEndSeriePoint,
            },
          ]
        : prev;
    }, [] as ConnectingLineInfo[]);

    return (
      <g>
        <g transform={`translate(${outerPointSize / 2},0)`}>
          {pairsToConnect.map((p) => (
            <line
              key={p.id}
              x1={p.initialForecastPoint.x}
              y1={p.initialForecastPoint.y}
              x2={p.latestForecastPoint.x}
              y2={p.latestForecastPoint.y}
              stroke={colours.highlightGrey}
            />
          ))}
        </g>
        {dumbbellPoints.map(({ symbol, ...point }) =>
          createElement(symbol, {
            ...point,
            size: point.pointSize,
            key: point.id,
          })
        )}
      </g>
    );
  };
