Giter Club home page Giter Club logo

baller2vec's Introduction

baller2vec

This is the repository for the paper:

Michael A. Alcorn and Anh Nguyen. baller2vec: A Multi-Entity Transformer For Multi-Agent Spatiotemporal Modeling. arXiv. 2021.

The input for baller2vec at each time step t is an unordered set of feature vectors containing information about the identities and locations of NBA players on the court, along with the ball. The model uses these inputs to classify either the binned trajectory for each player (left) or the ball (right).
By exclusively learning to predict the trajectory of the ball, baller2vec was able to infer idiosyncratic player attributes.
Further, nearest neighbors in baller2vec's embedding space are plausible doppelgängers. Credit for the images: Erik Drost, Keith Allison, Jose Garcia, Keith Allison, Verse Photography, and Joe Glorioso.
Additionally, several attention heads in baller2vec appear to perform different basketball-relevant functions, such as anticipating passes. Code to generate the GIF was adapted from @linouk23's NBA Player Movement's repository.
Here, a baller2vec model trained to simultaneously predict the trajectories of all the players on the court uses both the historical and current context to forecast the target player's trajectory at each time step. The left grid shows the target player's true trajectory at each time step while the right grid shows baller2vec's forecast distribution. The blue-bordered center cell is the "stationary" trajectory.

Citation

If you use this code for your own research, please cite:

@article{alcorn2021baller2vec,
   title={\texttt{baller2vec}: A Multi-Entity Transformer For Multi-Agent Spatiotemporal Modeling},
   author={Alcorn, Michael A. and Nguyen, Anh},
   journal={arXiv preprint arXiv:2102.03291},
   year={2021}
}

Training baller2vec

Setting up .basketball_profile

After you've cloned the repository to your desired location, create a file called .basketball_profile in your home directory:

nano ~/.basketball_profile

and copy and paste in the contents of .basketball_profile, replacing each of the variable values with paths relevant to your environment. Next, add the following line to the end of your ~/.bashrc:

source ~/.basketball_profile

and either log out and log back in again or run:

source ~/.bashrc

You should now be able to copy and paste all of the commands in the various instructions sections. For example:

echo ${PROJECT_DIR}

should print the path you set for PROJECT_DIR in .basketball_profile.

Installing the necessary Python packages

cd ${PROJECT_DIR}
pip3 install --upgrade -r requirements.txt

Organizing the play-by-play and tracking data

  1. Copy events.zip (which I acquired from here [mirror here] using https://downgit.github.io) to the DATA_DIR directory and unzip it:
mkdir -p ${DATA_DIR}
cp ${PROJECT_DIR}/events.zip ${DATA_DIR}
cd ${DATA_DIR}
unzip -q events.zip
rm events.zip

Descriptions for the various EVENTMSGTYPEs can be found here (mirror here).

  1. Clone the tracking data from here (mirror here) to the DATA_DIR directory:
cd ${DATA_DIR}
git clone [email protected]:linouk23/NBA-Player-Movements.git

A description of the tracking data can be found here.

Generating the training data

cd ${PROJECT_DIR}
nohup python3 generate_game_numpy_arrays.py > data.log &

You can monitor its progress with:

top

or:

ls -U ${GAMES_DIR} | wc -l

There should be 1,262 NumPy arrays (corresponding to 631 X/y pairs) when finished.

Animating a sequence

  1. If you don't have a display hooked up to your GPU machine, you'll need to first clone the repository to your local machine and retrieve certain files from the remote server:
# From your local machine.
mkdir -p ~/scratch
cd ~/scratch

username=michael
server=gpu3.cse.eng.auburn.edu
data_dir=/home/michael/baller2vec_data
scp ${username}@${server}:${data_dir}/baller2vec_config.pydict .

games_dir=${data_dir}/games
gameid=0021500622

scp ${username}@${server}:${games_dir}/\{${gameid}_X.npy,${gameid}_y.npy\} .
  1. You can then run this code in the Python interpreter from within the repository (make sure you source .basketball_profile first if running locally):
import os

from animator import Game
from settings import DATA_DIR, GAMES_DIR

gameid = "0021500622"
try:
    game = Game(DATA_DIR, GAMES_DIR, gameid)
except FileNotFoundError:
    home_dir = os.path.expanduser("~")
    DATA_DIR = f"{home_dir}/scratch"
    GAMES_DIR = f"{home_dir}/scratch"
    game = Game(DATA_DIR, GAMES_DIR, gameid)

# https://youtu.be/FRrh_WkyXko?t=109
start_period = 3
start_time = "1:55"
stop_period = 3
stop_time = "1:51"
game.show_seq(start_period, start_time, stop_period, stop_time)

to generate the following animation:

Running the training script

Run (or copy and paste) the following script, editing the variables as appropriate.

#!/usr/bin/env bash

# Experiment identifier. Output will be saved to ${EXPERIMENTS_DIR}/${JOB}.
JOB=$(date +%Y%m%d%H%M%S)

# Training options.
echo "train:" >> ${JOB}.yaml
task=ball_traj  # "ball_traj" or "player_traj".
echo "  task: ${task}" >> ${JOB}.yaml
echo "  min_playing_time: 0" >> ${JOB}.yaml  # 0/13314/39917/1.0e+6 --> 100%/75%/50%/0%.
echo "  train_valid_prop: 0.95" >> ${JOB}.yaml
echo "  train_prop: 0.95" >> ${JOB}.yaml
echo "  train_samples_per_epoch: 20000" >> ${JOB}.yaml
echo "  valid_samples: 1000" >> ${JOB}.yaml
echo "  workers: 10" >> ${JOB}.yaml
echo "  learning_rate: 1.0e-5" >> ${JOB}.yaml
echo "  patience: 20" >> ${JOB}.yaml
if [[ ("$task" = "event") || ("$task" = "score") ]]
then
    echo "  prev_model: False" >> ${JOB}.yaml
fi

# Dataset options.
echo "dataset:" >> ${JOB}.yaml
echo "  hz: 5" >> ${JOB}.yaml
echo "  secs: 4.2" >> ${JOB}.yaml
echo "  player_traj_n: 11" >> ${JOB}.yaml
echo "  max_player_move: 4.5" >> ${JOB}.yaml
echo "  ball_traj_n: 19" >> ${JOB}.yaml
echo "  max_ball_move: 8.5" >> ${JOB}.yaml
echo "  n_players: 10" >> ${JOB}.yaml
echo "  next_score_change_time_max: 35" >> ${JOB}.yaml
echo "  n_time_to_next_score_change: 36" >> ${JOB}.yaml
echo "  n_ball_loc_x: 95" >> ${JOB}.yaml
echo "  n_ball_loc_y: 51" >> ${JOB}.yaml
echo "  ball_future_secs: 2" >> ${JOB}.yaml

# Model options.
echo "model:" >> ${JOB}.yaml
echo "  embedding_dim: 20" >> ${JOB}.yaml
echo "  sigmoid: none" >> ${JOB}.yaml
echo "  mlp_layers: [128, 256, 512]" >> ${JOB}.yaml
echo "  nhead: 8" >> ${JOB}.yaml
echo "  dim_feedforward: 2048" >> ${JOB}.yaml
echo "  num_layers: 6" >> ${JOB}.yaml
echo "  dropout: 0.0" >> ${JOB}.yaml

if [[ "$task" != "seq2seq" ]]
then
    echo "  use_cls: False" >> ${JOB}.yaml
    echo "  embed_before_mlp: True" >> ${JOB}.yaml
fi

# Save experiment settings.
mkdir -p ${EXPERIMENTS_DIR}/${JOB}
mv ${JOB}.yaml ${EXPERIMENTS_DIR}/${JOB}/

# Start training the model.
gpu=0
cd ${PROJECT_DIR}
nohup python3 train_baller2vec.py ${JOB} ${gpu} > ${EXPERIMENTS_DIR}/${JOB}/train.log &

baller2vec's People

Contributors

airalcorn2 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.