%pylab inline
import torch
import sys, os
import pystk
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device = ', device)
# Init supertuxkart
config = pystk.GraphicsConfig.hd()
config.screen_width = 100
config.screen_height = 80
pystk.init(config)
class ActionNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.network = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, 5, stride=2),
torch.nn.ReLU(),
torch.nn.Conv2d(32, 64, 5, stride=2),
torch.nn.ReLU(),
torch.nn.Conv2d(64, 96, 5, stride=2),
torch.nn.ReLU(),
torch.nn.Conv2d(96, 128, 5),
torch.nn.ReLU()
)
self.classifier = torch.nn.Linear(128, 3)
def forward(self, x):
f = self.network(x)
return self.classifier(f.mean(dim=(2,3)))
def rollout_agent(agent, n_step=200):
config = pystk.RaceConfig()
config.track = 'lighthouse'
# config.players[0].controller = pystk.PlayerConfig.Controller.AI_CONTROL
k = pystk.Race(config)
state = pystk.WorldState()
k.start()
k.step()
state.update()
data = []
try:
for i in range(n_step):
x = torch.as_tensor(np.array(k.render_data[0].image))[None].permute(0,3,1,2).float()/255. - 0.5
a = actor(x.to(device))[0]
k.step(pystk.Action(steer=float(a[0]), acceleration=float(a[1]), brake=float(a[2])>0.5))
state.update()
data.append( (np.array(k.render_data[0].image), (state.karts[0].distance_down_track)) )
finally:
k.stop()
del k
return data
actor = ActionNet().to(device)
data = rollout_agent(actor, 200)
print( data[-1][1] )
for i in range(5):
figure()
imshow(data[i*20+10][0])
axis('off')
n_epochs = 10
n_step = 10
def get_flat_params(actor):
return np.concatenate([p.data.cpu().numpy().flatten() for p in actor.parameters()])
def set_flat_params(actor, v):
i = 0
for p in actor.parameters():
n = np.prod(p.data.size())
p.data[:] = torch.as_tensor(v[i:i+n].reshape(p.data.size()), dtype=p.dtype, device=p.device)
for epoch in range(n_epochs):
eps = 1e-2
p = get_flat_params(actor)
grads = []
for i in range(n_step):
dp = np.random.normal(0, eps, p.shape)
# Try positive
set_flat_params(actor, p+dp)
data = rollout_agent(actor, 200)
score_1 = data[-1][1]
# Try negative
set_flat_params(actor, p-dp)
data = rollout_agent(actor, 200)
score_2 = data[-1][1]
grads.append((score_1, dp))
grads.append((score_2, -dp))
# Update the parameters
best_grad = grads[np.argmax([s for s,g in grads])]
dp = p + best_grad[1]
set_flat_params(actor, dp)
print( best_grad[0] )