In [ ]:
%pylab inline
import torch
import sys, os
import pystk
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('device = ', device)
In [ ]:
# Init supertuxkart
config = pystk.GraphicsConfig.hd()
config.screen_width = 100
config.screen_height = 80
pystk.init(config)
In [ ]:
class ActionNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 5, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, 5, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 96, 5, stride=2),
            torch.nn.ReLU(),
            torch.nn.Conv2d(96, 128, 5),
            torch.nn.ReLU()
        )
        self.classifier = torch.nn.Linear(128, 3)
    
    def forward(self, x):
        f = self.network(x)
        return self.classifier(f.mean(dim=(2,3)))
In [ ]:
def rollout_agent(agent, n_step=200):
    config = pystk.RaceConfig()
    config.track = 'lighthouse'
#     config.players[0].controller = pystk.PlayerConfig.Controller.AI_CONTROL
    
    k = pystk.Race(config)
    state = pystk.WorldState()
    k.start()
    k.step()
    state.update()
    data = []
    try:
        for i in range(n_step):
            x = torch.as_tensor(np.array(k.render_data[0].image))[None].permute(0,3,1,2).float()/255. - 0.5
            a = actor(x.to(device))[0]
            k.step(pystk.Action(steer=float(a[0]), acceleration=float(a[1]), brake=float(a[2])>0.5))
            state.update()
            data.append( (np.array(k.render_data[0].image), (state.karts[0].distance_down_track)) )
    finally:
        k.stop()
        del k
    return data

actor = ActionNet().to(device)
data = rollout_agent(actor, 200)
print( data[-1][1] )
for i in range(5):
    figure()
    imshow(data[i*20+10][0])
    axis('off')
In [ ]:
n_epochs = 10
n_step = 10

def get_flat_params(actor):
    return np.concatenate([p.data.cpu().numpy().flatten() for p in actor.parameters()])

def set_flat_params(actor, v):
    i = 0
    for p in actor.parameters():
        n = np.prod(p.data.size())
        p.data[:] = torch.as_tensor(v[i:i+n].reshape(p.data.size()), dtype=p.dtype, device=p.device)


for epoch in range(n_epochs):
    eps = 1e-2
    
    p = get_flat_params(actor)
    grads = []
    for i in range(n_step):
        dp = np.random.normal(0, eps, p.shape)
        # Try positive
        set_flat_params(actor, p+dp)
        data = rollout_agent(actor, 200)
        score_1 = data[-1][1]
        
        # Try negative
        set_flat_params(actor, p-dp)
        data = rollout_agent(actor, 200)
        score_2 = data[-1][1]
        
        grads.append((score_1, dp))
        grads.append((score_2, -dp))
    
    # Update the parameters
    best_grad = grads[np.argmax([s for s,g in grads])]
    dp = p + best_grad[1]
    set_flat_params(actor, dp)
    print( best_grad[0] )
    
In [ ]: