import { AtomContext, Keyframe, TimeSignature } from "../@types";
import { last, uniq } from "../base/utils/ramdaEquivalents.utils";
import { KnownKeyframeControlPath } from "../constants/keyframe.constants";
import { MusicKey } from "../constants/musicKeys.constants";
import { MusicScaleName } from "../constants/musicScales.constants";
import {
  DefaultXpm,
  convertAbstractXToDurationInSeconds,
  convertDurationInSecondsToAbstractX,
  getBpxFromTimeSignature,
  getXpmFromBpmBpx,
} from "./beats.utils";

export type ValueAtXMap<T = unknown> = [number, T][];

export type BpmMap = ValueAtXMap<number>;
export type TimeMap = ValueAtXMap<number>;
export type ScaledXMap = ValueAtXMap<number>;
export type SpeedScalarMap = ValueAtXMap<number>;
export type TimeSignatureMap = ValueAtXMap<TimeSignature>;
export type XpmMap = ValueAtXMap<number>;
export type ScaledTimeMap = ValueAtXMap<number>;
export type MusicKeyMap = ValueAtXMap<MusicKey>;
export type MusicScaleMap = ValueAtXMap<MusicScaleName>;

export const getStopInValueAtXMap = <T = unknown>(
  x: number,
  map: ValueAtXMap<T>
) => {
  const i = map.findIndex(m => m[0] > x);
  if (i === 0) return map[0];
  if (i === -1) return last(map)!;
  return map[i - 1];
};
export const getStopUpTillInValueAtXMap = <T = unknown>(
  x: number,
  map: ValueAtXMap<T>
) => {
  const nextIndex = map.findIndex(m => m[0] > x);
  const currIndex = nextIndex - 1;
  if (currIndex === 0) return map[0];
  if (currIndex <= -1) return last(map)!;
  return map[currIndex - 1];
};

export const getValueAtXFromMap = <T = unknown>(
  x: number,
  map: ValueAtXMap<T>
) => {
  return getStopInValueAtXMap<T>(x, map)[1];
};

export const getValueUpTillXFromMap = <T = unknown>(
  x: number,
  map: ValueAtXMap<T>
) => {
  return getStopUpTillInValueAtXMap<T>(x, map)[1];
};

export const makeValueAtXGeneratorFromKeyframePath =
  <T>(
    controlPath: KnownKeyframeControlPath,
    initialValueGetter: (ac: AtomContext) => T
  ) =>
  (ac: AtomContext) => {
    const initialValue = initialValueGetter(ac);
    const initialPoint: [number, T] = [ac.startX, initialValue];
    const map = new Map<number, T>();
    const processed = new Set<Keyframe>();
    map.set(...initialPoint);
    const processKeyframe = (
      k: Keyframe,
      i: number,
      arr: Keyframe[],
      ...levels: Keyframe[]
    ) => {
      // add starting point
      const stops = [[k.startX, k.value] as [number, T]];
      // loop through all following keyframes that overlaps with this one
      let j = 1;
      let nextLevel: Keyframe | null;
      let nextLevelsCumulativeEndX = 0; // keep track of those following keyframe's largest endX
      do {
        const next = arr[i + j];
        nextLevel = next && next.startX! < k.endX! ? next : null;
        if (nextLevel) {
          if (nextLevel.endX! > nextLevelsCumulativeEndX)
            nextLevelsCumulativeEndX = nextLevel.endX!;
          stops.push(...processKeyframe(nextLevel, i + j, arr, k, ...levels));
        }
        j++;
      } while (nextLevel);
      // if the overlapping keyframes ends later than the current keyframe,
      // the current keyframe's endpoint can be disregarded as the last-ending overlapping keyframes would take precedence.
      const shouldProcessEndPoint = nextLevelsCumulativeEndX < k.endX!;
      if (shouldProcessEndPoint) {
        const jumpDownTo = levels.find(l => l.endX! >= k.endX!);
        if (jumpDownTo) stops.push([k.endX, jumpDownTo.value] as [number, T]);
        else stops.push([k.endX!, initialValue]);
      }
      // we've processed this keyframe. should not check it again.
      processed.add(k);
      return stops;
    };
    const allStops = ac.keyframesCategorized[controlPath]
      .map((k, i, arr) => {
        if (processed.has(k)) return [];
        if (k.startX === null || k.endX === null) return [];
        if (!k.width) return [[k.startX, k.value]] as [number, T][];
        else return processKeyframe(k, i, arr);
      })
      .flat(1);
    // later stops would override earlier stops.
    allStops.sort((a, b) => a[0] - b[0]).forEach(stop => map.set(...stop));
    const result = Array.from(map.entries());
    if (last(result)![0] !== ac.width)
      result.push([ac.width, last(result)![1]] as [number, T]);
    return result;
  };

export const generateBpmMap = makeValueAtXGeneratorFromKeyframePath(
  KnownKeyframeControlPath.bpmChange,
  ac => ac.interpretation?.options.bpm ?? 60
);

export const generateSpeedScalarMap = makeValueAtXGeneratorFromKeyframePath(
  KnownKeyframeControlPath.speedScalar,
  () => 1
);

export const generateMusicKeyMap = makeValueAtXGeneratorFromKeyframePath(
  KnownKeyframeControlPath.musicKeyChange,
  ac => ac.composition?.options.musicKey || MusicKey.C
);

export const generateMusicScaleNameMap = makeValueAtXGeneratorFromKeyframePath(
  KnownKeyframeControlPath.musicScaleChange,
  ac => ac.composition?.options.musicScaleName || MusicScaleName.Ionian
);

export const generateScaledXMap = (ac: AtomContext) => {
  const result = [] as ScaledXMap;
  ac.valueAtXMaps.speedScalar.forEach((curr, i, arr) => {
    if (i === 0) {
      result.push([curr[0], 0] as [number, number]);
      return;
    }
    const prev = arr[i - 1];
    const prevWidth = curr[0] - prev[0];
    const prevTimePoint = last(result)!;
    result.push([curr[0], prevTimePoint[1] + prevWidth * (1 / prev[1])] as [
      number,
      number
    ]);
  });
  if (result.length === 0) {
    result.push([0, 0]);
    result.push([ac.width, ac.width]);
    return result;
  }
  if (ac.width !== last(result)![0])
    result.push([ac.width, last(result)![1]] as [number, number]);
  return result;
};

export const generateTimeSignatureMap = (ac: AtomContext) => {
  const result = [] as [number, TimeSignature][];
  if (ac.bars.length === 0) {
    return [[0, ac.timeSignature]] as [number, TimeSignature][];
  }
  ac.bars.forEach(b => {
    const prev = last(result);
    if (prev && prev[1].join("/") === b.timeSignature.join("/")) return;
    result.push([b.x, b.timeSignature]);
  });
  if (ac.width !== last(result)![0])
    result.push([ac.width, last(result)![1]] as [number, TimeSignature]);
  return result;
};

export const generateBpxMap = (ac: AtomContext) => {
  const result = [] as [number, number][];
  ac.valueAtXMaps.timeSignature.forEach(tsStop => {
    const stop = [tsStop[0], getBpxFromTimeSignature(tsStop[1])] as [
      number,
      number
    ];
    const lastStop = last(result);
    if (lastStop && stop[1] === lastStop[1]) return;
    result.push(stop);
  });
  if (ac.width !== last(result)![0])
    result.push([ac.width, last(result)![1]] as [number, number]);
  return result;
};

export const generateXpmMap = (ac: AtomContext) => {
  const xArray = uniq([
    ...ac.valueAtXMaps.bpm.map(s => s[0]),
    ...ac.valueAtXMaps.bpx.map(s => s[0]),
  ]).sort((a, b) => a - b);
  const result = xArray.map(x => {
    const bpm = getValueAtXFromMap(x, ac.valueAtXMaps.bpm);
    const bpx = getValueAtXFromMap(x, ac.valueAtXMaps.bpx);
    return [x, getXpmFromBpmBpx(bpm, bpx)] as [number, number];
  });
  if (ac.width !== last(result)![0])
    result.push([ac.width, last(result)![1]] as [number, number]);
  return result;
};

export const generateTimeMap = (ac: AtomContext) => {
  const leadingSeconds = ac.leadingBeatsInSeconds;
  const xArray = uniq([...ac.valueAtXMaps.xpm.map(s => s[0])]).sort(
    (a, b) => a - b
  );
  const _map: { x: number; time: number }[] = [];
  for (let i = 0; i < xArray.length; i++) {
    const x = xArray[i];
    const prev = _map[i - 1];
    const xpm = getValueUpTillXFromMap(x, ac.valueAtXMaps.xpm);
    const prevTime = prev?.time ?? leadingSeconds;
    const xDiffFromPrev = x - (prev?.x ?? 0);
    _map.push({
      x,
      time: prevTime + convertAbstractXToDurationInSeconds(xDiffFromPrev, xpm),
    });
  }
  const map = _map.map(step => [step.x, step.time]);
  if (leadingSeconds > 0) map.unshift([-ac.leadingBeatsWidth, 0]);
  if (ac.width !== last(map)![0])
    map.push([ac.width, last(map)![1]] as [number, number]);
  return map as [number, number][];
};

export const generateScaledTimeMap = (ac: AtomContext) => {
  const leadingSeconds = ac.leadingBeatsInSeconds;
  const xArray = uniq([
    ...ac.valueAtXMaps.scaledX.map(s => s[0]),
    ...ac.valueAtXMaps.xpm.map(s => s[0]),
  ]).sort((a, b) => a - b);
  const _map: { x: number; scaledX: number; scaledTime: number }[] = [];
  for (let i = 0; i < xArray.length; i++) {
    const x = xArray[i];
    const prev = _map[i - 1];
    const scaledX = getScaledXValueFromMap(x, ac.valueAtXMaps.scaledX, 1);
    const xpm = getValueUpTillXFromMap(x, ac.valueAtXMaps.xpm);
    const prevTime = prev?.scaledTime ?? leadingSeconds;
    const scaledXDiffFromPrev = scaledX - (prev?.scaledX ?? 0);
    _map.push({
      x,
      scaledX,
      scaledTime:
        prevTime +
        convertAbstractXToDurationInSeconds(scaledXDiffFromPrev, xpm),
    });
  }
  const map = _map.map(step => [step.x, step.scaledTime]);
  if (leadingSeconds > 0) map.unshift([-ac.leadingBeatsWidth, 0]);
  if (ac.width !== last(map)![0])
    map.push([ac.width, last(map)![1]] as [number, number]);
  return map as [number, number][];
};

export const getScaledXValueFromMap = (
  x: number,
  map: ScaledXMap,
  scalarForInputLargerThanMap: number
) => {
  let stop: [number, number];
  let next: [number, number];
  const i = map.findIndex(m => m[0] > x);
  if (i === 0) {
    // x value is smaller than the left end of the map. no transforms applicable
    return x;
  } else if (i === -1) {
    // x value is larger than the right end of the map. The additional span has no transforms applicable.
    const lastStop = last(map) ?? [0, 0];
    const diff = x - lastStop[0];
    return lastStop[1] + diff * scalarForInputLargerThanMap;
  } else {
    stop = map[i - 1];
    next = map[i];
  }
  const stopWidth = next ? next[0] - stop[0] : 0;
  const stopScaledWidth = next ? next[1] - stop[1] : 0;
  const scaleRatio = stopScaledWidth / stopWidth;
  if (stopWidth === 0) return stop[1];
  return stop[1] + (x - stop[0]) * scaleRatio;
};

export const getAbstractXFromScaledTime = (
  t: number,
  map: ScaledTimeMap,
  ac?: AtomContext
) => {
  let stop: [number, number];
  let next: [number, number];
  const i = map.findIndex(m => m[1] > t);
  if (i === 0) {
    // x value is smaller than the left end of the map. no transforms applicable
    return convertDurationInSecondsToAbstractX(t, ac?.xpm ?? DefaultXpm);
  } else if (i === -1) {
    // x value is larger than the right end of the map.
    // excess time is calculated using the last bar's xpm.
    const lastStop = last(map) ?? [0, 0];
    const excessTime = t - lastStop[1];
    const excessTimeInAbstractX = convertDurationInSecondsToAbstractX(
      excessTime,
      ac?.lastBar?.xpm ?? DefaultXpm
    );
    return excessTimeInAbstractX + lastStop[0];
  } else {
    stop = map[i - 1];
    next = map[i];
  }
  const stopWidth = next ? next[0] - stop[0] : 0;
  const stopScaledTime = next ? next[1] - stop[1] : 0;
  const result = stop[0] + ((t - stop[1]) / stopScaledTime) * stopWidth;
  return result;
};
