sketch_220516a

sketch.py

# -*- coding: utf-8 -*-
import q5
import torch
import numpy as np

COLORS = [
    (224, 221, 0), (112, 22, 224), (224, 72, 11), (22, 224, 192)
]


class App(q5.BaseApp):
    def setup(self):
        q5.title('sketch_220516a')
        # q5.loop_ntimes(60 * 10)

        self.N = 500
        self.eta = 5*1e-2
        self.rbf_w = 10.0
        self.r_margin = 1.05
        self.reg_weight = 10.0

        (self.centers, self.radii) = self.init_circles(self.N)
        self.centers.requires_grad = True
        self.color_idxs = np.random.randint(0, len(COLORS), self.N)

    def init_circles(self, N):
        x = torch.rand(N) * q5.width - q5.width / 2.0
        y = torch.rand(N) * q5.height - q5.height / 2.0
        r = torch.rand(N) * 128.0 + 8.0
        return torch.vstack([x, y]).T, r

    def step(self):
        with torch.no_grad():
            r_mat = self.radii + self.radii[:, None]

        dist_mat = torch.cdist(self.centers, self.centers, p=2)
        loss = (
                (torch.exp(-dist_mat ** 2 / (r_mat * self.rbf_w))
                * (dist_mat - r_mat * self.r_margin) ** 2).sum()
                # + self.reg_weight * (self.centers.norm(dim=1)
                #                      - q5.width).sum()
        )

        loss.backward()

        with torch.no_grad():
            self.centers = self.centers - self.eta * self.centers.grad
        self.centers.requires_grad = True

    def update(self):
        for k in range(5):
            self.step()

    def draw(self):
        q5.background(64)

        centers = self.centers.detach().numpy()
        radii = self.radii.detach().numpy()
        circles = np.hstack([centers, radii[:, np.newaxis]])

        q5.stroke_weight(5.0)
        for ((x, y, r), i) in zip(circles, self.color_idxs):
            q5.fill(*COLORS[i])
            q5.circle(x, y, r)

        # q5.save_frame('frames/{:05d}.png'.format(q5.frame_count))

    def mouse_pressed(self):
        if q5.mouse_button == q5.MOUSE_LEFT:
            (self.centers, self.radii) = self.init_circles(self.N)
            self.centers.requires_grad = True
        elif q5.mouse_button == q5.MOUSE_RIGHT:
            q5.save_frame('frames/{:05d}.png'.format(q5.frame_count))


if __name__ == '__main__':
    app = App()
    app.run()