Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions config/whackamole.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
[base]
env_name = whackamole

[env]
num_envs = 4096

[policy]
hidden_size = 64
num_layers = 1

[train]
learning_rate = 0.001
total_timesteps = 500_000_000
22 changes: 22 additions & 0 deletions ocean/whackamole/binding.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#include "whackamole.h"

#define OBS_SIZE TOTAL_CELLS
#define NUM_ATNS 1
#define ACT_SIZES {TOTAL_CELLS}
#define OBS_TENSOR_T FloatTensor

#define Env Whackamole
#include "vecenv.h"

void my_init(Env* env, Dict* kwargs) {
env->num_agents = 1;
env->hits = 0;
env->tick = 0;
}

void my_log(Log* log, Dict* out) {
dict_set(out, "perf", log->perf);
dict_set(out, "score", log->score);
dict_set(out, "episode_return", log->episode_return);
dict_set(out, "episode_length", log->episode_length);
}
Binary file added ocean/whackamole/pufferfish.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
50 changes: 50 additions & 0 deletions ocean/whackamole/whackamole.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include "whackamole.h"

int main() {
Whackamole env = {0};
env.num_agents = 1;
env.rng = (unsigned int)time(NULL);

env.observations = (float*)calloc(TOTAL_CELLS, sizeof(float));
env.actions = (float*)calloc(1, sizeof(float));
env.rewards = (float*)calloc(1, sizeof(float));
env.terminals = (float*)calloc(1, sizeof(float));

c_reset(&env);
c_render(&env);

int frame = 0;
while (1) {
frame += 1;

if (IsKeyDown(KEY_LEFT_SHIFT)) {
env.actions[0] = NOOP;
if (IsMouseButtonPressed(MOUSE_LEFT_BUTTON)) {
Vector2 mouse = GetMousePosition();
int c = (int)(mouse.x / CELL_SIZE);
int r = (int)(mouse.y / CELL_SIZE);
if (r >= 0 && r < GRID_SIZE && c >= 0 && c < GRID_SIZE) {
env.actions[0] = (float)(r * GRID_SIZE + c);
}
}
if (IsKeyPressed(KEY_R)) c_reset(&env);
} else {
if (frame % 10 == 0) {
env.actions[0] = (float)(rand_r(&env.rng) % TOTAL_CELLS);
} else {
env.actions[0] = NOOP;
}
}

c_step(&env);
c_render(&env);
}

free(env.observations);
free(env.actions);
free(env.rewards);
free(env.terminals);
c_close(&env);

return 0;
}
197 changes: 197 additions & 0 deletions ocean/whackamole/whackamole.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>
#include "raylib.h"

#define GRID_SIZE 5
#define TOTAL_CELLS (GRID_SIZE * GRID_SIZE)
#define CELL_SIZE 128
#define NOOP -1.0f
#define ATTEMPTS_PER_EPISODE 3

typedef struct {
float perf;
float score;
float episode_return;
float episode_length;
float n;
} Log;

typedef struct {
Texture2D puffer;
int flash_timer;
Color flash_color;
int last_action_r;
int last_action_c;
bool show_flash;
} Client;

typedef struct {
Log log;
float* observations;
float* actions;
float* rewards;
float* terminals;
int num_agents;
unsigned int rng;
int mole_r;
int mole_c;
int hits;
int tick;
Client* client;
} Whackamole;

void add_log(Whackamole* env) {
env->log.perf += (env->rewards[0] > 0) ? 1.0f : 0.0f;
env->log.score += env->rewards[0];
env->log.episode_length += env->tick;
env->log.episode_return += env->rewards[0];
env->log.n += 1.0f;
}

void c_reset(Whackamole* env) {
memset(env->observations, 0, sizeof(float) * TOTAL_CELLS);

int mole_idx = rand_r(&env->rng) % TOTAL_CELLS;
env->observations[mole_idx] = 1.0f;
env->mole_r = mole_idx / GRID_SIZE;
env->mole_c = mole_idx % GRID_SIZE;

env->tick = 0;
env->rewards[0] = 0.0f;
env->terminals[0] = 0.0f;

if (env->client != NULL) {
env->client->show_flash = false;
env->client->flash_timer = 0;
}
}

void c_step(Whackamole* env) {
env->tick += 1;

int action = (int)env->actions[0];
int mole_idx = env->mole_r * GRID_SIZE + env->mole_c;

if (env->client != NULL) {
env->client->show_flash = false;
}

if (action == (int)NOOP || action < 0 || action >= TOTAL_CELLS) {
env->rewards[0] = 0.0f;
} else if (action == mole_idx) {
env->rewards[0] = 1.0f;
env->hits += 1;
// Flash GREEN for hit
if (env->client != NULL) {
env->client->flash_color = (Color){0, 255, 0, 180};
env->client->flash_timer = 15;
env->client->show_flash = true;
env->client->last_action_r = action / GRID_SIZE;
env->client->last_action_c = action % GRID_SIZE;
}
} else {
int action_r = action / GRID_SIZE;
int action_c = action % GRID_SIZE;
int dist = abs(action_r - env->mole_r) + abs(action_c - env->mole_c);
env->rewards[0] = fmaxf(0.0f, 1.0f - dist * 0.25f);
// Flash RED for miss
if (env->client != NULL) {
env->client->flash_color = (Color){255, 0, 0, 180};
env->client->flash_timer = 15;
env->client->show_flash = true;
env->client->last_action_r = action_r;
env->client->last_action_c = action_c;
}
}

if (env->tick >= ATTEMPTS_PER_EPISODE) {
env->terminals[0] = 1.0f;
add_log(env);
c_reset(env);
} else {
env->terminals[0] = 0.0f;
// Move puffre for next attempt
int new_idx = rand_r(&env->rng) % TOTAL_CELLS;
env->observations[mole_idx] = 0.0f;
env->observations[new_idx] = 1.0f;
env->mole_r = new_idx / GRID_SIZE;
env->mole_c = new_idx % GRID_SIZE;
}
}

void c_render(Whackamole* env) {
if (!IsWindowReady()) {
InitWindow(CELL_SIZE * GRID_SIZE, CELL_SIZE * GRID_SIZE, "PufferLib WhacKe-a-PUFFER");
SetTargetFPS(60);
env->client = (Client*)calloc(1, sizeof(Client));
env->client->puffer = LoadTexture("pufferfish.png");
if (env->client->puffer.id == 0) {
env->client->puffer = LoadTexture("ocean/whackamole/pufferfish.png");
}
env->client->show_flash = false;
env->client->flash_timer = 0;
}

if (IsKeyDown(KEY_ESCAPE)) {
exit(0);
}

BeginDrawing();
ClearBackground((Color){34, 139, 34, 255});

for (int r = 0; r < GRID_SIZE; r++) {
for (int c = 0; c < GRID_SIZE; c++) {
int cx = c * CELL_SIZE + CELL_SIZE / 2;
int cy = r * CELL_SIZE + CELL_SIZE / 2;
DrawCircle(cx, cy, CELL_SIZE / 3, DARKGRAY);
}
}

for (int i = 1; i < GRID_SIZE; i++) {
int pos = i * CELL_SIZE;
DrawLine(pos, 0, pos, CELL_SIZE * GRID_SIZE, BLACK);
DrawLine(0, pos, CELL_SIZE * GRID_SIZE, pos, BLACK);
}

if (env->client->show_flash && env->client->flash_timer > 0) {
int fx = env->client->last_action_c * CELL_SIZE;
int fy = env->client->last_action_r * CELL_SIZE;
DrawRectangle(fx, fy, CELL_SIZE, CELL_SIZE, env->client->flash_color);
env->client->flash_timer--;
if (env->client->flash_timer <= 0) {
env->client->show_flash = false;
}
}

int x = env->mole_c * CELL_SIZE;
int y = env->mole_r * CELL_SIZE;

if (env->client->puffer.id > 0) {
float scale = (float)CELL_SIZE / env->client->puffer.width;
DrawTextureEx(env->client->puffer, (Vector2){x, y}, 0.0f, scale, WHITE);
} else {
DrawCircle(x + CELL_SIZE/2, y + CELL_SIZE/2, CELL_SIZE/3, RED);
DrawCircle(x + CELL_SIZE/2, y + CELL_SIZE/2, CELL_SIZE/4, YELLOW);
}

DrawText(TextFormat("Hits: %i", env->hits), 10, 10, 24, WHITE);
DrawText(TextFormat("Return: %.1f", env->log.episode_return), 10, 40, 20, WHITE);
DrawText(TextFormat("Attempt: %i/%i", env->tick + 1, ATTEMPTS_PER_EPISODE), 10, 70, 20, WHITE);

EndDrawing();
}

void c_close(Whackamole* env) {
if (env->client != NULL) {
if (env->client->puffer.id > 0) {
UnloadTexture(env->client->puffer);
}
free(env->client);
env->client = NULL;
}
if (IsWindowReady()) {
CloseWindow();
}
}
Loading