import React, { useMemo, useCallback, useRef } from 'react';
import { Group } from '@visx/group';
import { Grid } from '@visx/grid';
import { Circle } from '@visx/shape';
import { scaleLinear, scaleOrdinal } from '@visx/scale';
import { withTooltip, TooltipWithBounds } from '@visx/tooltip';
import { WithTooltipProvidedProps } from '@visx/tooltip/lib/enhancers/withTooltip';
import { voronoi } from '@visx/voronoi';
import { localPoint } from '@visx/event';
import { GradientPinkBlue } from '@visx/gradient';
import { AxisLeft, AxisBottom } from '@visx/axis';
import { LegendItem, LegendLabel, LegendOrdinal } from '@visx/legend';

export type PointsRange = [number, number, number];

export type DataGroup = {
  data: PointsRange[];
  color: string;
  label: string;
};

export type DotsProps = {
  data: PointsRange[];
  dataGroups: DataGroup[];
  xDomain: [number, number];
  yDomain: [number, number];
  width: number;
  height: number;
  yLabel?: string;
  xLabel?: string;
  onDataSelect?: (data: PointsRange) => void;
};

const x = (d: PointsRange) => d[0];
const y = (d: PointsRange) => d[1];

let tooltipTimeout: number;

const margin = { top: 16, right: 16, bottom: 48, left: 48 };

export default withTooltip<DotsProps, PointsRange>(
  ({
    data,
    dataGroups,
    width,
    height,
    xDomain,
    yDomain,
    xLabel,
    yLabel,
    onDataSelect,
    hideTooltip,
    showTooltip,
    tooltipOpen,
    tooltipData,
    tooltipLeft,
    tooltipTop,
  }: DotsProps & WithTooltipProvidedProps<PointsRange>) => {
    if (width < 10) return null;

    // bounds
    const innerWidth = width - margin.left - margin.right;
    const innerHeight = height - margin.top - margin.bottom;

    const svgRef = useRef<SVGSVGElement>(null);
    const xScale = useMemo(
      () =>
        scaleLinear<number>({
          domain: xDomain,
          range: [0, innerWidth],
          clamp: true,
        }),
      [innerWidth, xDomain],
    );
    const yScale = useMemo(
      () =>
        scaleLinear<number>({
          domain: yDomain,
          range: [innerHeight, 0],
          clamp: true,
        }),
      [innerHeight, yDomain],
    );
    const voronoiLayout = useMemo(
      () =>
        voronoi<PointsRange>({
          x: (d) => xScale(x(d)) ?? 0,
          y: (d) => yScale(y(d)) ?? 0,
          width: innerWidth,
          height: innerHeight,
        })(data),
      [innerWidth, innerHeight, data, xScale, yScale],
    );
    const ordinalColorScale = useMemo(
      () =>
        scaleOrdinal({
          domain: dataGroups.map((group) => group.label),
          range: dataGroups.map((group) => group.color),
        }),
      [dataGroups],
    );

    // event handlers
    const handleMouseMove = useCallback(
      (event: React.MouseEvent | React.TouchEvent) => {
        if (tooltipTimeout) clearTimeout(tooltipTimeout);
        if (!svgRef.current) return;

        // find the nearest polygon to the current mouse position
        const point = localPoint(svgRef.current, event);
        if (!point) return;
        const neighborRadius = 100;
        const closest = voronoiLayout.find(point.x - margin.left, point.y - margin.top, neighborRadius);
        if (closest) {
          showTooltip({
            tooltipLeft: closest[0] + margin.left,
            tooltipTop: closest[1] + margin.top,
            tooltipData: closest.data,
          });
          if (svgRef.current) svgRef.current.style.cursor = 'pointer';
        } else {
          hideTooltip();
          if (svgRef.current) svgRef.current.style.cursor = 'auto';
        }
      },
      [showTooltip, voronoiLayout, hideTooltip],
    );

    const handleMouseLeave = useCallback(() => {
      tooltipTimeout = window.setTimeout(() => {
        hideTooltip();
      }, 300);
    }, [hideTooltip]);

    const handleMouseDown = useCallback(
      (event: React.MouseEvent) => {
        if (!svgRef.current) return;
        if (!onDataSelect) return;

        // find the nearest polygon to the current mouse position
        const point = localPoint(svgRef.current, event);
        if (!point) return;
        const neighborRadius = 100;
        const closest = voronoiLayout.find(point.x - margin.left, point.y - margin.top, neighborRadius);
        if (closest) onDataSelect(closest.data);
      },
      [onDataSelect, voronoiLayout],
    );

    return (
      <div>
        <svg width={width} height={height} ref={svgRef}>
          <GradientPinkBlue id="dots-pink" />
          {/** capture all mouse events with a rect */}
          <rect
            width={width}
            height={height}
            rx={14}
            fill="url(#dots-pink)"
            onMouseMove={handleMouseMove}
            onMouseLeave={handleMouseLeave}
            onMouseDown={handleMouseDown}
            onTouchMove={handleMouseMove}
            onTouchEnd={handleMouseLeave}
          />
          <Grid
            top={margin.top}
            left={margin.left}
            xScale={xScale}
            yScale={yScale}
            width={innerWidth}
            height={innerHeight}
            stroke="rgba(0,0,0,0.1)"
          />
          <AxisBottom
            label={xLabel}
            labelOffset={12}
            top={height - margin.bottom}
            left={margin.left}
            scale={xScale}
            stroke="rgba(0,0,0,0.1)"
            tickStroke="rgba(0,0,0,0.1)"
            tickLength={4}
          />
          <AxisLeft
            label={yLabel}
            labelOffset={24}
            top={margin.top}
            left={margin.left}
            scale={yScale}
            stroke="rgba(0,0,0,0.1)"
            tickStroke="rgba(0,0,0,0.1)"
            tickLength={2}
          />
          {dataGroups.map((group, i) => (
            <Group key={`group-${i}`} pointerEvents="none" left={margin.left} top={margin.top}>
              {group.data.map((point, j) => (
                <Circle
                  key={`point-${x(point)}-${i}-${j}`}
                  className="dot"
                  cx={xScale(x(point))}
                  cy={yScale(y(point))}
                  r={tooltipData === point ? 6 : 4}
                  fill={group.color}
                />
              ))}
            </Group>
          ))}
        </svg>

        <LegendOrdinal scale={ordinalColorScale}>
          {(labels) => (
            <div style={{ display: 'flex', flexDirection: 'row', justifyContent: 'center' }}>
              {labels.map((label, i) => (
                <LegendItem key={`legend-quantile-${i}`} margin="0 5px">
                  <svg width={16} height={16}>
                    <Circle className="dot" cx={8} cy={8} r={4} fill={label.value} stroke="gray" />
                  </svg>
                  <LegendLabel align="left" margin="0 0 0 4px">
                    {label.text}
                  </LegendLabel>
                </LegendItem>
              ))}
            </div>
          )}
        </LegendOrdinal>

        {tooltipOpen && tooltipData && tooltipLeft != null && tooltipTop != null && (
          <TooltipWithBounds left={tooltipLeft + 10} top={tooltipTop + 10}>
            <div>
              <strong>{xLabel}:</strong> {x(tooltipData)}
            </div>
            <div>
              <strong>{yLabel}:</strong> {y(tooltipData)}
            </div>
          </TooltipWithBounds>
        )}
      </div>
    );
  },
);
