import * as d3 from 'd3';
import {
  mean,
  median,
  mode,
  quantile,
  standardDeviation,
} from 'simple-statistics';
import type { CrossPlotDataPoint } from '~/utils/chart';
import {
  crossPlotRegressionFormula,
  crossPlotRSquared,
  zeroInterceptRegressionFormula,
} from '~/utils/chart';

type StatisticName =
  | 'Minimum'
  | 'Maximum'
  | 'P10'
  | 'P50'
  | 'P90'
  | 'Mean'
  | 'Median'
  | 'Mode'
  | 'Std Deviation'
  | 'n';

type Statistic = {
  x: string;
  y: string | null;
};

function formatValue(value: number | string): string;
function formatValue(value?: number | null): string | null;
function formatValue(value: number | string | null | undefined) {
  if (typeof value === 'undefined' || value === null) return null;
  if (typeof value === 'string') return value;
  return value.toFixed(2);
}

export function useMeasurementStatistics(
  dataX: number[],
  dataY?: number[] | null,
) {
  const calcPercentile = (values: number[]) => (percentile: number) =>
    quantile(values, percentile);

  const [minX, maxX] = d3.extent(dataX);
  const [minY, maxY] = d3.extent(dataY ?? []);

  const stats: Partial<Record<StatisticName, Statistic>> = {};

  if (minX) {
    stats['Minimum'] = { x: formatValue(minX), y: formatValue(minY ?? null) };
  }
  if (maxX) {
    stats['Maximum'] = { x: formatValue(maxX), y: formatValue(maxY ?? null) };
  }
  stats['P10'] = {
    x: formatValue(calcPercentile(dataX)(0.1)),
    y: formatValue(dataY ? calcPercentile(dataY)(0.1) : null),
  };
  stats['P50'] = {
    x: formatValue(calcPercentile(dataX)(0.5)),
    y: formatValue(dataY ? calcPercentile(dataY)(0.5) : null),
  };
  stats['P90'] = {
    x: formatValue(calcPercentile(dataX)(0.9)),
    y: formatValue(dataY ? calcPercentile(dataY)(0.9) : null),
  };
  stats['Mean'] = {
    x: formatValue(mean(dataX)),
    y: formatValue(dataY ? mean(dataY) : null),
  };
  stats['Median'] = {
    x: formatValue(median(dataX)),
    y: formatValue(dataY ? median(dataY) : null),
  };
  stats['Mode'] = {
    x: formatValue(mode(dataX)),
    y: formatValue(dataY ? mode(dataY) : null),
  };
  stats['Std Deviation'] = {
    x: formatValue(standardDeviation(dataX)),
    y: formatValue(dataY ? standardDeviation(dataY) : null),
  };
  stats['n'] = { x: String(dataX.length), y: null };

  return stats;
}

const SCALE = 4;

function regressionFormulaText(
  slope: number,
  intercept: number,
  isLogScaleX: boolean,
  isLogScaleY: boolean,
) {
  const x = isLogScaleX ? ' · ln(x)' : 'X';
  const y = isLogScaleY ? 'ln(ŷ)' : 'ŷ';
  const m = slope.toFixed(SCALE);
  const sign = intercept >= 0 ? '+' : '-';
  const b =
    intercept === 0 ? '' : ` ${sign} ${Math.abs(intercept).toFixed(SCALE)}`;

  return `${y} = ${m}${x}${b}`;
}

export function useRegressionStatistics(
  data: CrossPlotDataPoint[],
  isLogScaleX = false,
  isLogScaleY = false,
) {
  const { m, b } = crossPlotRegressionFormula(data, isLogScaleX, isLogScaleY);
  const regressionFormula = regressionFormulaText(
    m,
    b,
    isLogScaleX,
    isLogScaleY,
  );

  const r2RegressionVars = crossPlotRegressionFormula(data, false, false);
  const r2 = crossPlotRSquared(data, r2RegressionVars).toFixed(SCALE);

  const { m: ziM, b: ziB } = zeroInterceptRegressionFormula(data);
  const ziRegressionFormula = regressionFormulaText(ziM, ziB, false, false);
  const ziR2 = crossPlotRSquared(data, { m: ziM, b: ziB }).toFixed(SCALE);

  return { regressionFormula, r2, ziRegressionFormula, ziR2 };
}
