import { memo, useCallback } from 'react';
import { EdgeProps, internalsSymbol, Node, Position, useStore } from 'reactflow';
import { colors } from 'styles/colors';
import shallow from 'zustand/shallow';

import useGraphStore from 'views/Graph/state';

export interface GetBezierPathParams {
  sourceX: number;
  sourceY: number;
  sourcePosition?: Position;
  targetX: number;
  targetY: number;
  targetPosition?: Position;
  curvature?: number;
}

interface GetControlWithCurvatureParams {
  pos: Position;
  x1: number;
  y1: number;
  x2: number;
  y2: number;
  c: number;
}

export function getBezierEdgeCenter({
  sourceX,
  sourceY,
  targetX,
  targetY,
  sourceControlX,
  sourceControlY,
  targetControlX,
  targetControlY,
}: {
  sourceX: number;
  sourceY: number;
  targetX: number;
  targetY: number;
  sourceControlX: number;
  sourceControlY: number;
  targetControlX: number;
  targetControlY: number;
}): [number, number, number, number] {
  // cubic bezier t=0.5 mid point, not the actual mid point, but easy to calculate
  // https://stackoverflow.com/questions/67516101/how-to-find-distance-mid-point-of-bezier-curve
  const centerX =
    sourceX * 0.125 + sourceControlX * 0.375 + targetControlX * 0.375 + targetX * 0.125;
  const centerY =
    sourceY * 0.125 + sourceControlY * 0.375 + targetControlY * 0.375 + targetY * 0.125;
  const offsetX = Math.abs(centerX - sourceX);
  const offsetY = Math.abs(centerY - sourceY);

  return [centerX, centerY, offsetX, offsetY];
}

function calculateControlOffset(distance: number, curvature: number): number {
  if (distance >= 0) {
    // Custom offset will adds a little more artificial curvature to the paths. Helps make arrow alignment more intuitive
    const customOffset = 15;
    return 0.5 * distance + customOffset;
  }

  return curvature * 25 * Math.sqrt(-distance);
}

function getControlWithCurvature({
  pos,
  x1,
  y1,
  x2,
  y2,
  c,
}: GetControlWithCurvatureParams): [number, number] {
  // Dividing the xy diffs by the denominator will help our path ends align with the arrows when nodes are laterally spaced apart
  const yDiff = y1 - y2;
  const xDiff = x1 - x2;
  // 8 is just a magic number that looked aesthetically pleasing
  const denominator = 8;

  switch (pos) {
    case Position.Left:
      return [x1 - calculateControlOffset(x1 - x2, c), y1 - yDiff / denominator];
    case Position.Right:
      return [x1 + calculateControlOffset(x2 - x1, c), y1 - yDiff / denominator];
    case Position.Top:
      return [x1 - xDiff / denominator, y1 - calculateControlOffset(y1 - y2, c)];
    case Position.Bottom:
      return [x1 - xDiff / denominator, y1 + calculateControlOffset(y2 - y1, c)];
  }
}

function getBezierPath({
  sourceX,
  sourceY,
  sourcePosition = Position.Bottom,
  targetX,
  targetY,
  targetPosition = Position.Top,
  curvature = 0.25,
}: GetBezierPathParams): [
  path: string,
  labelX: number,
  labelY: number,
  offsetX: number,
  offsetY: number
] {
  const [sourceControlX, sourceControlY] = getControlWithCurvature({
    pos: sourcePosition,
    x1: sourceX,
    y1: sourceY,
    x2: targetX,
    y2: targetY,
    c: curvature,
  });
  const [targetControlX, targetControlY] = getControlWithCurvature({
    pos: targetPosition,
    x1: targetX,
    y1: targetY,
    x2: sourceX,
    y2: sourceY,
    c: curvature,
  });
  const [labelX, labelY, offsetX, offsetY] = getBezierEdgeCenter({
    sourceX,
    sourceY,
    targetX,
    targetY,
    sourceControlX,
    sourceControlY,
    targetControlX,
    targetControlY,
  });

  return [
    `M${sourceX},${sourceY} C${sourceControlX},${sourceControlY} ${targetControlX},${targetControlY} ${targetX},${targetY}`,
    labelX,
    labelY,
    offsetX,
    offsetY,
  ];
}

// React Flow uses the handle positions for start and end points of an edge.
// In this custom edge implementation we check if the target is on the left side of the source.
// If that's the case, we are using the left handle as the start point and the right handle as the end point.

function getEdgeParams(sourceNode: Node, targetNode: Node) {
  const sourceHandles = sourceNode[internalsSymbol]?.handleBounds?.source!;
  const targetHandles = targetNode[internalsSymbol]?.handleBounds?.target!;

  const sx = sourceNode.position.x;
  const sy = sourceNode.position.y;
  const tx = targetNode.position.x;
  const ty = targetNode.position.y;

  // Adding an extra 50px helps compensate for situations when stacking shapes of different sizes (i.e., a Diamond on top of a Square)
  const targetOnTheLeft = tx + 50 < sx;
  let fromHandle = sourceHandles.find((h) => h.position === Position.Right)!;
  let toHandle = targetHandles.find((h) => h.position === Position.Left)!;
  let sourcePos = Position.Right;
  let targetPos = Position.Left;
  let targetHandleOffset = targetOnTheLeft ? toHandle?.width + 2 : -toHandle?.width / 2;

  if (targetOnTheLeft) {
    const sourceLeft = sourceHandles.find((h) => h.position === Position.Left);
    const targetRight = targetHandles.find((h) => h.position === Position.Right);
    fromHandle = sourceLeft || fromHandle;
    toHandle = targetRight || toHandle;
    sourcePos = !!sourceLeft ? Position.Left : Position.Right;
    targetPos = !!targetRight ? Position.Right : Position.Left;
  }

  return {
    sx: sx + fromHandle?.x + fromHandle?.width,
    sy: sy + fromHandle?.y + fromHandle?.height / 2,
    tx: tx + toHandle?.x + targetHandleOffset,
    ty: ty + toHandle?.y + toHandle?.height / 2,
    sourcePos,
    targetPos,
  };
}

function SmartEdge({ id, source, target, markerEnd, style, selected }: EdgeProps) {
  const { sourceNode, targetNode } = useStore(
    useCallback(
      (state) => ({
        sourceNode: state.nodeInternals.get(source),
        targetNode: state.nodeInternals.get(target),
      }),
      [source, target]
    ),
    shallow
  );

  const { willRun } = useGraphStore(
    useCallback(
      (s) => {
        const willRun = s.edgesWillRun.includes(id);
        return { willRun };
      },
      [id]
    ),
    shallow
  );

  if (!sourceNode || !targetNode) {
    return null;
  }

  const { sx, sy, tx, ty, sourcePos, targetPos } = getEdgeParams(sourceNode, targetNode);

  const [d] = getBezierPath({
    sourceX: sx,
    sourceY: sy,
    sourcePosition: sourcePos,
    targetPosition: targetPos,
    targetX: tx,
    targetY: ty,
  });

  return (
    <>
      <path
        id={id}
        style={{
          ...style,
          strokeWidth: selected ? '2px' : '1px',
          stroke: willRun ? colors.light.action : colors.gray[500],
        }}
        className="react-flow__edge-path"
        d={d}
        markerEnd={markerEnd}
      />
      {/* This transparent path is used to capture mouse events */}
      <path
        d={d}
        style={{
          ...style,
          stroke: 'transparent',
          fill: 'transparent',
          strokeWidth: '10px',
          cursor: 'pointer',
        }}
      />
    </>
  );
}

export default memo(SmartEdge);
