import { RANDOM_SEED } from 'config/constants';
import * as d3 from 'd3';
import { getPixelDeviceRatio } from 'hooks/useDevice';
import TSNE from 'lib/tsne';
import { DistanceKey } from 'models/entities/Episode';
import { SeededRandom } from 'util/random';
import BaseLayout from './BaseLayout';
import Viewport from './Viewport';
//import TSNELayout from 'models/visual/TSNELayout';

const FORCE_TSNE = 'tsne';

type TSNEStoreKeys = DistanceKey | '_last';

export default class EpisodeClusterLayout extends BaseLayout {
  static readonly tsneStore: Partial<Record<TSNEStoreKeys, TSNE>> = {};
  static lastNodeIds = '';

  private viewport: Viewport;
  private radius = 10561;
  private baseRadius = 1210;
  private maxSteps = 200;
  private tsne: TSNE | null = null;
  private config = {
    dim: 2,
    randomizer: new SeededRandom(RANDOM_SEED),
  };
  private distanceKey: DistanceKey = DistanceKey.THEME;

  constructor(viewport: Viewport, distanceKey: DistanceKey) {
    super();
    this.viewport = viewport;
    this.distanceKey = distanceKey;
    this.addForceTSNE();
    this.updateBaseRadius();

    this.viewport.shouldCenterOnMinScale = true;
  }

  updateBaseRadius() {
    this.baseRadius = this.distanceKey === DistanceKey.THEME ? 1310 : 2100;
  }

  onNodesUpdate(allNodes: boolean) {
    this.resetViewport();

    // Disable on small sets
    if (this.nodes.length <= 6) {
      this.addForceCenterGravity();
      this.active = false;
      return;
    }

    // Remove center force
    this.removeForceCenterGravity();

    // Get radius
    const nodeScaleSum = this.nodes.reduce(
      (total: number, node) =>
        total + Math.max(0.8, node.targetNode.episode?.getScale() || 1),
      0
    );
    this.radius = this.getRadius(nodeScaleSum);

    // Create or reuse TSNE
    const nodeIds = allNodes
      ? '_all'
      : this.nodes.reduce(
          (ids, node) => (ids += node.targetNode.episode?.id),
          // Start with distancekey, to force refresh on layout param change
          this.distanceKey + '_'
        );
    const storeKey: TSNEStoreKeys | null = allNodes
      ? this.distanceKey
      : EpisodeClusterLayout.lastNodeIds === nodeIds
      ? '_last'
      : null;
    const tsneInStore = storeKey && storeKey in EpisodeClusterLayout.tsneStore;

    this.maxSteps = 300 + this.nodes.length / 2;

    if (tsneInStore) {
      // Reload TNSE from store
      this.tsne = EpisodeClusterLayout.tsneStore[storeKey] || null;
    } else {
      // Create new tsne
      // Perplexity based on nodes length
      const perplexity = Math.min(Math.max(this.nodes.length * 0.7, 3), 160);

      // Init TSNE
      this.tsne = new TSNE(Object.assign({}, this.config, { perplexity }));

      // Get a list of store indices for lookup
      const storeIndices = this.nodes.map(
        (node, index) => node.targetNode.episode?.storeIndex || index
      );

      // Get distances
      const distances = this.nodes.map((node) => {
        const distance =
          node.targetNode.episode?.distances[this.distanceKey] || [];
        return distance.length
          ? storeIndices.map((index) => distance[index])
          : distance;
      });

      this.tsne.initDataDist(distances);

      EpisodeClusterLayout.lastNodeIds = nodeIds;
      EpisodeClusterLayout.tsneStore[allNodes ? this.distanceKey : '_last'] =
        this.tsne;
    }

    this.active = true;
  }

  resetViewport() {
    // Reset viewport on new nodes
    this.viewport.reset(
      Math.min(
        this.viewport.maxScale,
        ((0.000005 * this.nodes.length * this.nodes.length -
          0.0038 * this.nodes.length +
          0.8) *
          Math.max(
            400 * getPixelDeviceRatio(),
            Math.min(this.viewport.getHeight(), this.viewport.getWidth())
          )) /
          (2400 * getPixelDeviceRatio())
      )
    );

    this.viewport.minScale = Math.max(
      window.innerWidth < 500 ? 0.2 : 0.001,
      this.viewport.getTargetScale() * 0.8
    );
  }

  getRadius(x: number) {
    return -0.02393 * (x * x) + 34.1 * x + this.baseRadius;
  }

  addForceTSNE() {
    const radius = this.viewport.getWidth() / 2 + this.radius;
    const center = d3.scaleLinear().range([-radius / 2, radius / 2]);
    // tSNE force
    this.simulation.force(FORCE_TSNE, (alpha) => {
      if (!this.active || !this.tsne) {
        return;
      }

      // limit amount of iterations
      if (this.tsne.steps < this.maxSteps) {
        this.tsne.step();
      }

      alpha =
        0.5 * alpha +
        0.5 *
          Math.max(
            0,
            this.baseAlpha - this.baseAlpha * (this.tsne.steps / this.maxSteps)
          );

      const pos = this.tsne.getSolution();
      const radius = this.radius;
      center.range([-radius / 2, radius / 2]);

      center.domain(d3.extent(pos.map((d) => d[0])) as number[]);
      center.domain(d3.extent(pos.map((d) => d[1])) as number[]);

      this.nodes.forEach((node, i) => {
        node.x += alpha * (center(pos[i][0]) - node.x);
        node.y += alpha * (center(pos[i][1]) - node.y);
      });
    });
  }
}
