sketch.py
# -*- coding: utf-8 -*-
import numpy as np
from scipy import stats
import q5
def hex2color_tuple(h):
v = int(h[1:], 16)
r = v >> 16 & 0xFF
g = v >> 8 & 0xFF
b = v & 0xFF
return (r, g, b)
FIG_SCALE = 50.0
COLORS = [
hex2color_tuple('#FF5126'),
hex2color_tuple('#37DC94'),
hex2color_tuple('#162C9B'),
]
mean_t = np.array([[0., 5.], [-5., 2.], [5., -2.]])
cov_t = np.array([[[1., 0.],
[0., 2.]],
[[.5, .75],
[.75, 2.]],
[[.5, .1],
[.1, .5]]])
pi_t = np.array([0.4, 0.35, 0.25])
gmm_t = [stats.multivariate_normal(mean, cov)
for mean, cov in zip(mean_t, cov_t)]
def generate_observations(N=1000):
K = len(pi_t)
cat = stats.multinomial(1, pi_t)
S = cat.rvs(N)
S_idx = np.argmax(S, axis=1)
X = np.empty((N, 2))
for k in range(K):
indices = np.nonzero(S_idx == k)[0]
samples = gmm_t[k].rvs(len(indices))
X[indices] = samples
return X, S_idx
def plot_gaussian_params(mean, cov):
t = np.linspace(0, 2*np.pi, 1000)
x = 2.0 * np.cos(t)
y = 2.0 * np.sin(t)
pts = np.vstack([x, y]).T
pts = mean + pts @ cov
pts = pts * FIG_SCALE
q5.polygon(pts)
class App(q5.BaseApp):
def setup(self):
q5.title('app')
# q5.loop_ntimes(900)
self.init_data()
self.init_model()
def init_data(self):
self.N = 1000
self.X_train, self.S_train = generate_observations(N=self.N)
def init_model(self):
# Hyper Params
self.K = 3
self.D = 2
self.alpha0 = np.ones(self.K)
self.mu0 = np.zeros((self.K, self.D))
self.beta0 = np.ones(self.K)
self.nu0 = np.ones(self.K) * self.D
self.W0 = np.empty((self.K, self.D, self.D))
for k in range(self.K):
self.W0[k] = np.identity(self.D)
# Local Params
self.pi = stats.dirichlet.rvs(self.alpha0)[0]
self.mu = np.empty((self.K, self.D))
self.Lambda = np.empty((self.K, self.D, self.D))
self.Sigma = np.empty((self.K, self.D, self.D))
for k in range(self.K):
self.Lambda[k] = stats.wishart.rvs(
df=self.nu0[k], scale=self.W0[k]
)
self.Sigma[k] = np.linalg.inv(self.Lambda[k])
self.mu[k] = stats.multivariate_normal.rvs(
mean=self.mu0[k], cov=self.Sigma[k] / self.beta0[k]
)
# Latent variables
self.hidden_state = np.zeros(self.N, dtype=int)
# Temporary variables
self.p_xi = np.empty((self.N, self.K))
def update(self):
for k in range(self.K):
self.p_xi[:, k] = stats.multivariate_normal.pdf(
self.X_train, mean=self.mu[k], cov=self.Sigma[k]
)
self.p_xi *= self.pi
# Resampling latent states
for n in range(self.N):
self.pi_ast = self.p_xi[n] / self.p_xi[n].sum()
self.hidden_state[n] = np.random.choice(self.K, p=self.pi_ast)
m = np.bincount(self.hidden_state, minlength=self.K)
# Resampling pi
self.alpha_ast = self.alpha0 + m
self.pi[:] = stats.dirichlet.rvs(self.alpha_ast)
# Resampling mu_k and Sigma_k
for k in range(self.K):
if m[k] == 0:
data_k = np.matrix(np.zeros(self.D))
else:
data_k = np.matrix(self.X_train[self.hidden_state == k])
self.beta_ast = m[k] + self.beta0[k]
self.mu_ast = (data_k.sum(axis=0) + self.beta0[k] * self.mu0[k]) \
/ self.beta_ast
# Resampling Lambda_k
S_k = np.dot(data_k.T, data_k)
self.nu_ast = m[k] + self.nu0[k]
self.W_ast = S_k + self.beta0[k] * self.mu0[k, np.newaxis].T \
@ self.mu0[k, np.newaxis] \
- self.beta_ast * self.mu_ast.T \
@ self.mu_ast + np.linalg.inv(self.W0[k])
self.W_ast = np.linalg.inv(self.W_ast)
self.Lambda[k] = stats.wishart.rvs(self.nu_ast, self.W_ast)
self.Sigma[k] = np.linalg.inv(self.Lambda[k])
# Resampling mu_k
self.mu[k] = stats.multivariate_normal.rvs(
np.array(self.mu_ast)[0], self.Sigma[k] / self.beta_ast
)
def draw(self):
q5.background(220)
q5.scale(0.75, 0.75)
q5.stroke(0)
q5.stroke_weight(1.0)
for n in range(self.N):
x_n = self.X_train[n] * FIG_SCALE
q5.fill(*COLORS[self.hidden_state[n]])
q5.circle(x_n[0], x_n[1], 8.0)
q5.no_fill()
q5.stroke_weight(5.0)
for k in range(self.K):
q5.stroke(*COLORS[k])
plot_gaussian_params(self.mu[k], self.Sigma[k])
# q5.save_frame('frames/{:04d}.png'.format(q5.frame_count))
def mouse_pressed(self):
self.init_data()
self.init_model()
if __name__ == '__main__':
app = App()
app.run()