Давайте рассмотрим, как использовать линейную алгебру и тензорные операции, чтобы создать всем известную игру в 12 строк.
И у вас сразу точно возникает несколько вопросов:
- Насколько длинные эти 12 строк? Не волнуйтесь, все они соответствуют стандарту PEP 8.
- Зачем это вообще делать? Иногда надо писать код просто ради фана. Кроме того, это отличный способ познакомиться с PyTorch и возможностями, которые предоставляют тензоры.
- Но этом же нет никакой практической пользы? Напротив. Методы, используемые в этой материале, на самом деле являются фундаментальными. И они лежат в основе модуля TensorSnake, который может эмулировать параллельно 100 миллионов игр "Змейка" на карте NVIDIA A6000 с задержкой 20 миллисекунд.
Мы напишем версию "Змейки", в которой она может перетекать за границу поля и выходить с другой стороны. Тем не менее, можно будет изменить 2 строки, чтобы реализовать стандартную версию.
Будем использовать PyTorch и NumPy. Можно было использовать даже какую-то одну из библиотек, но у PyTorch прекрасное Tensor API, а в NumPy есть хорошая функция под названием unravel_index, которую мы и будем использовать.
И договоримся, что в подсчёт строк не будут входить импорты и строка с определением функции 😉.
def do(snake: t.Tensor, action: int):
positions = snake.flatten().topk(2)[1]
[pos_cur, pos_prev] = [T(unravel(x, snake.shape)) for x in positions]
rotation = T([[0, -1], [1, 0]]).matrix_power(3 + action)
pos_next = (pos_cur + (pos_cur - pos_prev) @ rotation) % T(snake.shape)
if (snake[tuple(pos_next)] > 0).any():
return (snake[tuple(pos_cur)] - 2).item()
if snake[tuple(pos_next)] == -1:
pos_food = (snake == 0).flatten().to(t.float).multinomial(1)[0]
snake[unravel(pos_food, snake.shape)] = -1
else:
snake[snake > 0] -= 1
snake[tuple(pos_next)] = snake[tuple(pos_cur)] + 1
snake = t.zeros((32, 32), dtype=t.int)
snake[0, :3] = T([1, 2, -1])
fig, ax = plt.subplots(1, 1)
img = ax.imshow(snake)
action = {'val': 1}
action_dict = {'a': 0, 'd': 2}
fig.canvas.mpl_connect('key_press_event',
lambda e: action.__setitem__('val', action_dict[e.key]))
score = None
while score is None:
img.set_data(snake)
fig.canvas.draw_idle()
plt.pause(0.1)
score = do(snake, action['val'])
action['val'] = 1
print('Score:', score)
Источник: Medium