// tSNE class
// Based on: https://github.com/flynnwang/tsnejs/tree/opt which is forked from https://github.com/karpathy/tsnejs

import { MathRandom, Randomizer } from 'util/random';

interface Options {
  perplexity?: number;
  dim?: number;
  epsilon?: number;
  randomizer?: Randomizer;
}

type FloatArray = number[] | Float64Array;
type Matrix = number[][];

export default class TSNE {
  randomizer: Randomizer;
  perplexity: number;
  epsilon: number;
  dim: number = 2;
  steps = 0;
  private N = 0;
  private P: FloatArray = [];
  private Y: Matrix = [];
  private gains: Matrix = [];
  private ystep: Matrix = [];
  private return_v = false;
  private v_val = 0.0;

  constructor(options?: Options) {
    options = options || {};
    this.perplexity = options?.perplexity || 30; // effective number of nearest neighbors
    this.epsilon = options?.epsilon || 15; // learning rate
    this.randomizer = options?.randomizer || new MathRandom();
  }

  // this function takes a set of high-dimensional points
  // and creates matrix P from them using gaussian kernel
  initDataRaw(X: Matrix) {
    this.randomizer.reset();
    const N = X.length;
    const D = X[0].length;
    assert(N > 0, ' X is empty? You must have some data!');
    assert(D > 0, ' X[0] is empty? Where is the data?');
    const dists = xtod(X); // convert X to distances using gaussian kernel
    this.N = N; // back up the size of the dataset
    this.P = d2p(dists, this.perplexity, 1e-4, N); // attach to object
    this.initSolution(); // refresh this
  }

  // this function takes a given distance matrix and creates
  // matrix P from them.
  // D is assumed to be provided as a list of lists, and should be symmetric
  initDataDist(D: Matrix) {
    this.randomizer.reset();
    const N = D.length;
    assert(N > 0, ' X is empty? You must have some data!');
    // convert D to a (fast) typed array version
    const dists = zeros((N * (N + 1)) / 2); // allocate contiguous array

    for (let i = 0; i < N; i++) {
      for (let j = i + 1; j < N; j++) {
        dists[idx(i, j, N)] = D[i][j];
      }
    }
    this.N = N;
    this.P = d2p(dists, this.perplexity, 1e-4, N);
    this.initSolution(); // refresh this
  }

  // (re)initializes the solution to random
  initSolution() {
    // generate random solution to t-SNE
    this.Y = this.randn2d(this.N, this.dim); // the solution
    this.gains = this.randn2d(this.N, this.dim, 1.0); // step gains to accelerate progress in unchanging directions
    this.ystep = this.randn2d(this.N, this.dim, 0.0); // momentum accumulator
    this.steps = 0;
  }

  // return pointer to current solution
  getSolution() {
    return this.Y;
  }

  // perform a single step of optimization to improve the embedding
  step() {
    this.steps += 1;
    const N = this.N;

    const cg = this.costGrad(this.Y); // evaluate gradient
    let cost = cg.cost;
    const grad = cg.grad;

    // perform gradient step
    const ymean = zeros(this.dim);
    for (let i = 0; i < N; i++) {
      for (let d = 0; d < this.dim; d++) {
        const gid = grad[i][d];
        const sid = this.ystep[i][d];
        const gainid = this.gains[i][d];

        // compute gain update
        let newgain = sign(gid) === sign(sid) ? gainid * 0.8 : gainid + 0.2;
        if (newgain < 0.01) newgain = 0.01; // clamp
        this.gains[i][d] = newgain; // store for next turn

        // compute momentum step direction
        const momval = this.steps < 250 ? 0.5 : 0.8;
        const newsid = momval * sid - this.epsilon * newgain * grad[i][d];
        this.ystep[i][d] = newsid; // remember the step we took

        // step!
        this.Y[i][d] += newsid;

        ymean[d] += this.Y[i][d]; // accumulate mean so that we can center later
      }
    }

    // reproject Y to be zero mean
    for (let i = 0; i < N; i++) {
      for (let d = 0; d < this.dim; d++) {
        this.Y[i][d] -= ymean[d] / N;
      }
    }

    // if (this.steps % 50 === 0) console.debug("steps " + this.steps + ", cost: " + cost);
    return cost; // return current cost
  }

  // for debugging: gradient check
  debugGrad() {
    const N = this.N;

    const cg = this.costGrad(this.Y); // evaluate gradient
    // let cost = cg.cost;
    const grad = cg.grad;

    const e = 1e-5;
    for (let i = 0; i < N; i++) {
      for (let d = 0; d < this.dim; d++) {
        const yold = this.Y[i][d];

        this.Y[i][d] = yold + e;
        const cg0 = this.costGrad(this.Y);

        this.Y[i][d] = yold - e;
        const cg1 = this.costGrad(this.Y);

        const analytic = grad[i][d];
        const numerical = (cg0.cost - cg1.cost) / (2 * e);
        console.debug(
          i +
            ',' +
            d +
            ': gradcheck analytic: ' +
            analytic +
            ' vs. numerical: ' +
            numerical
        );

        this.Y[i][d] = yold;
      }
    }
  }

  // return cost and gradient, given an arrangement
  costGrad(Y: Matrix) {
    const N = this.N;
    const dim = this.dim; // dim of output space
    const P = this.P;

    const pmul = this.steps < 100 ? 1 : 1; // trick that helps with local optima

    // compute current Q distribution, unnormalized first
    const Qu = zeros((N * (N + 1)) / 2);
    let qsum = 0.0;
    for (let i = 0; i < N; i++) {
      for (let j = i + 1; j < N; j++) {
        let dsum = 0.0;
        for (let d = 0; d < dim; d++) {
          const dhere = Y[i][d] - Y[j][d];
          dsum += dhere * dhere;
        }
        const qu = 1.0 / (1.0 + dsum); // Student t-distribution
        Qu[idx(i, j, N)] = qu;
        qsum += 2 * qu;
      }
    }

    let cost = 0.0;
    const grad = [];
    for (let i = 0; i < N; i++) {
      const gsum = new Array(dim); // init grad for point i
      for (let d = 0; d < dim; d++) {
        gsum[d] = 0.0;
      }
      for (let j = 0; j < N; j++) {
        const qu = Qu[idx(i, j, N)];
        const q = Math.max(qu / qsum, 1e-100);
        cost += -P[idx(i, j, N)] * Math.log(q); // accumulate cost (the non-constant portion at least...)
        const premult = 4 * (pmul * P[idx(i, j, N)] - q) * qu;
        for (let d = 0; d < dim; d++) {
          gsum[d] += premult * (Y[i][d] - Y[j][d]);
        }
      }
      grad.push(gsum);
    }

    return { cost: cost, grad: grad };
  }

  // return 0 mean unit standard deviation random number
  gaussRandom(): number {
    if (this.return_v) {
      this.return_v = false;
      return this.v_val;
    }
    const u = 2 * this.randomizer.get() - 1;
    const v = 2 * this.randomizer.get() - 1;
    const r = u * u + v * v;
    if (r === 0 || r > 1) return this.gaussRandom();
    const c = Math.sqrt((-2 * Math.log(r)) / r);
    this.v_val = v * c; // cache this for next function call for efficiency
    this.return_v = true;
    return u * c;
  }

  // return random normal number
  randn(mu: number, std: number) {
    return mu + this.gaussRandom() * std;
  }

  // utility that returns 2d array filled with random numbers
  // or with value s, if provided
  randn2d(n: number, d: number, s?: number) {
    const uses = typeof s !== 'undefined';
    const x = [];
    for (let i = 0; i < n; i++) {
      const xhere = [];
      for (let j = 0; j < d; j++) {
        if (uses) {
          xhere.push(s);
        } else {
          xhere.push(this.randn(0.0, 1e-4));
        }
      }
      x.push(xhere);
    }
    return x;
  }
}

// utility function
const assert = (condition: boolean, message?: string) => {
  if (!condition) {
    throw message || 'Assertion failed';
  }
};

// helper function, map upper triangle index to one dimension array index
const idx = (i: number, j: number, n: number): number => {
  if (j >= i) {
    return ((n + n - (i - 1)) * i) / 2 + (j - i);
  }
  return idx(j, i, n);
};

// utilitity that creates contiguous vector of zeros of size n
const zeros = (n: number): FloatArray => {
  if (typeof ArrayBuffer === 'undefined') {
    // lacking browser support
    const arr: number[] = new Array(n);
    for (let i = 0; i < n; i++) {
      arr[i] = 0;
    }
    return arr;
  } else {
    return new Float64Array(n); // typed arrays are faster
  }
};

// compute L2 distance between two vectors
const L2 = (x1: number[], x2: number[]) => {
  const D = x1.length;
  let d = 0;
  for (let i = 0; i < D; i++) {
    const x1i = x1[i];
    const x2i = x2[i];
    d += (x1i - x2i) * (x1i - x2i);
  }
  return d;
};

// compute pairwise distance in all vectors in X
const xtod = (X: Matrix) => {
  const N = X.length;
  const dist = zeros((N * (N + 1)) / 2); // allocate contiguous array
  for (let i = 0; i < N; i++) {
    for (let j = i + 1; j < N; j++) {
      const d = L2(X[i], X[j]);
      dist[idx(i, j, N)] = d;
    }
  }
  return dist;
};

// compute (p_{i|j} + p_{j|i})/(2n)
const d2p = (D: FloatArray, perplexity: number, tol: number, N: number) => {
  const Htarget = Math.log(perplexity); // target entropy of distribution
  const P = zeros((N * (N + 1)) / 2); // temporary probability matrix

  const prow = zeros(N); // a temporary storage compartment
  for (let i = 0; i < N; i++) {
    let betamin = -Infinity;
    let betamax = Infinity;
    let beta = 1; // initial value of precision
    let done = false;
    const maxtries = 50;

    // perform binary search to find a suitable precision beta
    // so that the entropy of the distribution is appropriate
    let num = 0;
    while (!done) {
      //debugger;

      // compute entropy and kernel row with beta precision
      let psum = 0.0;
      let pj;
      for (let j = 0; j < N; j++) {
        pj = Math.exp(-D[idx(i, j, N)] * beta);
        if (i === j) {
          pj = 0;
        } // we dont care about diagonals
        prow[j] = pj;
        psum += pj;
      }

      // normalize p and compute entropy
      let Hhere = 0.0;
      for (let j = 0; j < N; j++) {
        if (psum === 0) {
          pj = 0;
        } else {
          pj = prow[j] / psum;
        }
        prow[j] = pj;
        if (pj > 1e-7) Hhere -= pj * Math.log(pj);
      }

      // adjust beta based on result
      if (Hhere > Htarget) {
        // entropy was too high (distribution too diffuse)
        // so we need to increase the precision for more peaky distribution
        betamin = beta; // move up the bounds
        if (betamax === Infinity) {
          beta = beta * 2;
        } else {
          beta = (beta + betamax) / 2;
        }
      } else {
        // converse case. make distrubtion less peaky
        betamax = beta;
        if (betamin === -Infinity) {
          beta = beta / 2;
        } else {
          beta = (beta + betamin) / 2;
        }
      }

      // stopping conditions: too many tries or got a good precision
      num++;
      if (Math.abs(Hhere - Htarget) < tol) {
        done = true;
      }
      if (num >= maxtries) {
        done = true;
      }
    }

    // console.debug('data point ' + i + ' gets precision ' + beta + ' after ' + num + ' binary search steps.');
    // copy over the final prow to P at row i
    for (let j = 0; j < N; j++) {
      P[idx(i, j, N)] += prow[j];
    }
  } // end loop over examples i

  // symmetrize P and normalize it to sum to 1 over all ij
  const N2 = N * 2;
  for (let i = 0; i < N; i++) {
    for (let j = i + 1; j < N; j++) {
      P[idx(i, j, N)] = Math.max(P[idx(i, j, N)] / N2, 1e-100);
    }
  }
  return P;
};

// helper function
const sign = (x: number) => {
  return x > 0 ? 1 : x < 0 ? -1 : 0;
};
