sketch_220421a

sketch.py

# -*- coding: utf-8 -*-
import numpy as np
from scipy import stats
import q5


def hex2color_tuple(h):
    v = int(h[1:], 16)
    r = v >> 16 & 0xFF
    g = v >> 8 & 0xFF
    b = v & 0xFF
    return (r, g, b)


FIG_SCALE = 50.0
COLORS = [
    hex2color_tuple('#FF5126'),
    hex2color_tuple('#37DC94'),
    hex2color_tuple('#162C9B'),
]

mean_t = np.array([[0., 5.], [-5., 2.], [5., -2.]])
cov_t = np.array([[[1., 0.],
                   [0., 2.]],
                  [[.5, .75],
                   [.75, 2.]],
                  [[.5, .1],
                   [.1, .5]]])

pi_t = np.array([0.4, 0.35, 0.25])
gmm_t = [stats.multivariate_normal(mean, cov)
         for mean, cov in zip(mean_t, cov_t)]


def generate_observations(N=1000):
    K = len(pi_t)

    cat = stats.multinomial(1, pi_t)

    S = cat.rvs(N)
    S_idx = np.argmax(S, axis=1)
    X = np.empty((N, 2))

    for k in range(K):
        indices = np.nonzero(S_idx == k)[0]
        samples = gmm_t[k].rvs(len(indices))
        X[indices] = samples

    return X, S_idx


def plot_gaussian_params(mean, cov):

    t = np.linspace(0, 2*np.pi, 1000)
    x = 2.0 * np.cos(t)
    y = 2.0 * np.sin(t)
    pts = np.vstack([x, y]).T
    pts = mean + pts @ cov
    pts = pts * FIG_SCALE

    q5.polygon(pts)


class App(q5.BaseApp):
    def setup(self):
        q5.title('app')
        # q5.loop_ntimes(900)

        self.init_data()
        self.init_model()

    def init_data(self):
        self.N = 1000
        self.X_train, self.S_train = generate_observations(N=self.N)

    def init_model(self):
        # Hyper Params
        self.K = 3
        self.D = 2

        self.alpha0 = np.ones(self.K)
        self.mu0 = np.zeros((self.K, self.D))
        self.beta0 = np.ones(self.K)
        self.nu0 = np.ones(self.K) * self.D
        self.W0 = np.empty((self.K, self.D, self.D))

        for k in range(self.K):
            self.W0[k] = np.identity(self.D)

        # Local Params
        self.pi = stats.dirichlet.rvs(self.alpha0)[0]
        self.mu = np.empty((self.K, self.D))
        self.Lambda = np.empty((self.K, self.D, self.D))
        self.Sigma = np.empty((self.K, self.D, self.D))

        for k in range(self.K):
            self.Lambda[k] = stats.wishart.rvs(
                df=self.nu0[k], scale=self.W0[k]
            )
            self.Sigma[k] = np.linalg.inv(self.Lambda[k])
            self.mu[k] = stats.multivariate_normal.rvs(
                mean=self.mu0[k], cov=self.Sigma[k] / self.beta0[k]
            )

        # Latent variables
        self.hidden_state = np.zeros(self.N, dtype=int)

        # Temporary variables
        self.p_xi = np.empty((self.N, self.K))

    def update(self):
        for k in range(self.K):
            self.p_xi[:, k] = stats.multivariate_normal.pdf(
                self.X_train, mean=self.mu[k], cov=self.Sigma[k]
            )
        self.p_xi *= self.pi

        # Resampling latent states

        for n in range(self.N):
            self.pi_ast = self.p_xi[n] / self.p_xi[n].sum()
            self.hidden_state[n] = np.random.choice(self.K, p=self.pi_ast)

        m = np.bincount(self.hidden_state, minlength=self.K)

        # Resampling pi

        self.alpha_ast = self.alpha0 + m
        self.pi[:] = stats.dirichlet.rvs(self.alpha_ast)

        # Resampling mu_k and Sigma_k

        for k in range(self.K):

            if m[k] == 0:
                data_k = np.matrix(np.zeros(self.D))
            else:
                data_k = np.matrix(self.X_train[self.hidden_state == k])

            self.beta_ast = m[k] + self.beta0[k]
            self.mu_ast = (data_k.sum(axis=0) + self.beta0[k] * self.mu0[k]) \
                / self.beta_ast

            # Resampling Lambda_k

            S_k = np.dot(data_k.T, data_k)

            self.nu_ast = m[k] + self.nu0[k]
            self.W_ast = S_k + self.beta0[k] * self.mu0[k, np.newaxis].T \
                @ self.mu0[k, np.newaxis] \
                - self.beta_ast * self.mu_ast.T \
                @ self.mu_ast + np.linalg.inv(self.W0[k])
            self.W_ast = np.linalg.inv(self.W_ast)

            self.Lambda[k] = stats.wishart.rvs(self.nu_ast, self.W_ast)
            self.Sigma[k] = np.linalg.inv(self.Lambda[k])

            # Resampling mu_k
            self.mu[k] = stats.multivariate_normal.rvs(
                np.array(self.mu_ast)[0], self.Sigma[k] / self.beta_ast
            )

    def draw(self):
        q5.background(220)
        q5.scale(0.75, 0.75)

        q5.stroke(0)
        q5.stroke_weight(1.0)
        for n in range(self.N):
            x_n = self.X_train[n] * FIG_SCALE
            q5.fill(*COLORS[self.hidden_state[n]])
            q5.circle(x_n[0], x_n[1], 8.0)

        q5.no_fill()
        q5.stroke_weight(5.0)
        for k in range(self.K):
            q5.stroke(*COLORS[k])
            plot_gaussian_params(self.mu[k], self.Sigma[k])

        # q5.save_frame('frames/{:04d}.png'.format(q5.frame_count))

    def mouse_pressed(self):
        self.init_data()
        self.init_model()


if __name__ == '__main__':
    app = App()
    app.run()