sketch.py
# -*- coding: utf-8 -*-
import q5
import torch
import numpy as np
COLORS = [
(224, 221, 0), (112, 22, 224), (224, 72, 11), (22, 224, 192)
]
class App(q5.BaseApp):
def setup(self):
q5.title('sketch_220516a')
# q5.loop_ntimes(60 * 10)
self.N = 500
self.eta = 5*1e-2
self.rbf_w = 10.0
self.r_margin = 1.05
self.reg_weight = 10.0
(self.centers, self.radii) = self.init_circles(self.N)
self.centers.requires_grad = True
self.color_idxs = np.random.randint(0, len(COLORS), self.N)
def init_circles(self, N):
x = torch.rand(N) * q5.width - q5.width / 2.0
y = torch.rand(N) * q5.height - q5.height / 2.0
r = torch.rand(N) * 128.0 + 8.0
return torch.vstack([x, y]).T, r
def step(self):
with torch.no_grad():
r_mat = self.radii + self.radii[:, None]
dist_mat = torch.cdist(self.centers, self.centers, p=2)
loss = (
(torch.exp(-dist_mat ** 2 / (r_mat * self.rbf_w))
* (dist_mat - r_mat * self.r_margin) ** 2).sum()
# + self.reg_weight * (self.centers.norm(dim=1)
# - q5.width).sum()
)
loss.backward()
with torch.no_grad():
self.centers = self.centers - self.eta * self.centers.grad
self.centers.requires_grad = True
def update(self):
for k in range(5):
self.step()
def draw(self):
q5.background(64)
centers = self.centers.detach().numpy()
radii = self.radii.detach().numpy()
circles = np.hstack([centers, radii[:, np.newaxis]])
q5.stroke_weight(5.0)
for ((x, y, r), i) in zip(circles, self.color_idxs):
q5.fill(*COLORS[i])
q5.circle(x, y, r)
# q5.save_frame('frames/{:05d}.png'.format(q5.frame_count))
def mouse_pressed(self):
if q5.mouse_button == q5.MOUSE_LEFT:
(self.centers, self.radii) = self.init_circles(self.N)
self.centers.requires_grad = True
elif q5.mouse_button == q5.MOUSE_RIGHT:
q5.save_frame('frames/{:05d}.png'.format(q5.frame_count))
if __name__ == '__main__':
app = App()
app.run()