Autonomous NBA Shot-Tracking Engine

Building a physics-based engine to detect bucket getting.

Final Result: Production Logic (Generalization Test)

Table of Contents

Introduction: The Problem

The core question: Was there a made basket in this NBA clip?

In Machine Learning terms, this is a Binary Classification problem (Yes/No). Standard approaches might try to retrain a generic model (like ResNet) to "look" for a made basket by feeding it tons of video data. But that is computationally expensive and prone to overfitting.

If we want to be smarter, we need to decompose the problem.

Initial Assumption

Let's define what a made basket is based on physics: A made basket is when the basketball goes from above the ring of a net and goes through it.

This changes our task. We don't need Binary Classification; we need Object Detection. We need to figure out:

  • What a Basketball looks like (identify its coordinates).
  • What a Basketball Net looks like (identify its coordinates).

I sourced a standard YOLOv11n model and datasets from Roboflow. I had to tune the confidence levels carefully based on the F1 Scores.

F1 Rim

FIG 1.1: RIM F1 SCORE

F1 Ball

FIG 1.2: BALL F1 SCORE

train_initial.py
from ultralytics import YOLO

def train_local():
    # 1. Load the SOTA Model (YOLOv11 Nano)
    # It will download automatically if you don't have it
    model = YOLO('yolo11n.pt') 

    # 2. Train on your LOCAL dataset
    # Point 'data' to your edited data.yaml file
    print("Starting Local Training...")
    results = model.train(
        data='data.yaml',  # Assumes data.yaml is in the same folder
        epochs=50,         # 50 epochs is plenty for a single class
        imgsz=640,
        batch=16,
        name='nba_hoop_v11',
        device=0           # Set to 'cpu' if you don't have a GPU
    )
    
    print(f"Success! Weights saved at: runs/detect/nba_hoop_v11/weights/best.pt")

if __name__ == "__main__":
    train_local()

The Failure Loop

I applied this first dataset, but it failed immediately. The basketball detection wasn't working due to motion blur, although the rim was okay.

Fail 1: Baselinefail_1.mp4

I switched to another dataset, but it still struggled.

Fail 2: Dataset Swapfail_2.mp4

I tried a third dataset. It became clear that the differences in camera angles, court colors, and lighting were breaking the model.

Fail 3: Angle Issuesfail_3.mp4

The Pivot: Data Mining

I needed a better way to get data. I built a custom Data Miner to scrape specific "hard negatives" from my video files.

smart_miner.py
import cv2
import os
import numpy as np

# --- CONFIG ---
VIDEO_PATH = 'jaquez_fadeaway_make_1.mp4'  # CHANGE THIS to your video file
OUTPUT_DIR = 'nba_smart_data'
CLASS_ID = 0  # 0 for Ball (since we are training a single-class model)

# Setup Folders
img_dir = os.path.join(OUTPUT_DIR, "images")
lbl_dir = os.path.join(OUTPUT_DIR, "labels")
os.makedirs(img_dir, exist_ok=True)
os.makedirs(lbl_dir, exist_ok=True)

# Global Variables
drawing = False
ix, iy = -1, -1
bbox = None  # (x1, y1, x2, y2)
current_frame_idx = 0
cap = cv2.VideoCapture(VIDEO_PATH)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

def on_trackbar(val):
    global current_frame_idx, bbox
    current_frame_idx = val
    cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_idx)
    bbox = None # Reset box when moving frames

def draw_rect(event, x, y, flags, param):
    global ix, iy, drawing, bbox

    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True
        ix, iy = x, y
        bbox = None

    elif event == cv2.EVENT_MOUSEMOVE:
        if drawing:
            # Update temporary box for visualization
            bbox = (ix, iy, x, y)

    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False
        # Finalize box (handle dragging backwards)
        x1, y1 = min(ix, x), min(iy, y)
        x2, y2 = max(ix, x), max(iy, y)
        bbox = (x1, y1, x2, y2)

def save_yolo_format(frame, box, frame_num):
    # 1. Save Image
    filename = f"ball_fix_{frame_num}"
    img_path = os.path.join(img_dir, f"{filename}.jpg")
    cv2.imwrite(img_path, frame)
    
    # 2. Convert Box to YOLO (Normalized xywh)
    h, w, _ = frame.shape
    x1, y1, x2, y2 = box
    
    # Clip to image bounds
    x1, x2 = max(0, x1), min(w, x2)
    y1, y2 = max(0, y1), min(h, y2)
    
    # Math: Center X, Center Y, Width, Height (Normalized 0-1)
    bw = x2 - x1
    bh = y2 - y1
    cx = x1 + (bw / 2)
    cy = y1 + (bh / 2)
    
    norm_cx = cx / w
    norm_cy = cy / h
    norm_w = bw / w
    norm_h = bh / h
    
    # 3. Save Label
    txt_path = os.path.join(lbl_dir, f"{filename}.txt")
    with open(txt_path, 'w') as f:
        f.write(f"{CLASS_ID} {norm_cx:.6f} {norm_cy:.6f} {norm_w:.6f} {norm_h:.6f}")
    
    print(f"[SAVED] {filename} | Box: {x1},{y1},{x2},{y2}")

# --- MAIN LOOP ---
cv2.namedWindow('Smart Miner')
cv2.createTrackbar('Frame', 'Smart Miner', 0, total_frames-1, on_trackbar)
cv2.setMouseCallback('Smart Miner', draw_rect)

print("--- SMART MINER CONTROLS ---")
print("MOUSE:   Click & Drag to draw box")
print("SPACE:   Play / Pause")
print("A / D:   Prev / Next Frame (Precise)")
print("S:       SAVE current frame + box")
print("Q:       Quit")
print("-" * 30)

paused = True  # START STOPPED

while True:
    if not paused:
        ret, frame = cap.read()
        if ret:
            current_frame_idx = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
            cv2.setTrackbarPos('Frame', 'Smart Miner', current_frame_idx)
            bbox = None # Clear box on play
        else:
            paused = True # Stop at end
    else:
        # If paused, keep reading the SAME frame so we can draw on it
        cap.set(cv2.CAP_PROP_POS_FRAMES, current_frame_idx)
        ret, frame = cap.read()
        if not ret: break

    if frame is None: break
    
    display_frame = frame.copy()

    # Draw the box if it exists
    if bbox:
        x1, y1, x2, y2 = bbox
        cv2.rectangle(display_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(display_frame, "READY TO SAVE (Press S)", (x1, y1-10), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)

    # Status Overlay
    status = "PAUSED" if paused else "PLAYING"
    cv2.putText(display_frame, f"Status: {status} | Frame: {current_frame_idx}", (20, 30), 
                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)

    cv2.imshow('Smart Miner', display_frame)

    # KEYS
    key = cv2.waitKey(10) & 0xFF
    
    if key == ord('q'):
        break
    elif key == ord(' '):
        paused = not paused
    elif key == ord('d'): # Next Frame
        paused = True
        current_frame_idx = min(current_frame_idx + 1, total_frames - 1)
        cv2.setTrackbarPos('Frame', 'Smart Miner', current_frame_idx)
        bbox = None
    elif key == ord('a'): # Prev Frame
        paused = True
        current_frame_idx = max(current_frame_idx - 1, 0)
        cv2.setTrackbarPos('Frame', 'Smart Miner', current_frame_idx)
        bbox = None
    elif key == ord('s'): # SAVE
        if bbox:
            save_yolo_format(frame, bbox, current_frame_idx)
            # Visual Flash
            cv2.rectangle(display_frame, (0,0), (display_frame.shape[1], display_frame.shape[0]), (255,255,255), -1)
            cv2.imshow('Smart Miner', display_frame)
            cv2.waitKey(50)
            bbox = None # Reset after save
        else:
            print("⚠️ No box drawn! Cannot save.")

cap.release()
cv2.destroyAllWindows()
Toolingminer.mp4

This fixed the ball, but now the rim was flickering (Fail 4) due to camera pans. So I quickly built a miner for the rim and retrained again.

Fail 4: Rim Flickerfail_4.mp4

The Logic: Red & Green

With stable detection, I implemented the scoring logic. We identify the top of the rim (Red) and the bottom (Green). If the ball doesn't go through the Red zone, it is not a made basket.

Proof of Concept: Zone Identification success_identification_1.mp4

This logic worked for a wide variety of shots:

The Parallax Bug

However, this resulted in a critical failure: Fail 6. A "floater" missed the rim entirely but fell behind the net. Due to the camera angle (parallax), it passed through the Red Zone and looked like a make.

Fail 6: Parallaxfail_6.mp4

The Strict Gate

We accounted for this by adding a Narrow Gate underneath the rim. The ball must exit through this tiny hole to count.

This successfully rejected the floater (Success 6), but it created a new problem (Fail 7). When the ball swished too hard, it pushed the net backward, missing the narrow gate. We had solved the False Positive but created a False Negative.

Success 6: Floater Fixedsuccess_6.mp4
Fail 7: Swish Brokenfail_7.mp4

Final Production: High & Tight

Finally, we moved the gate taller (up to the neck of the net). This way, even hard swishes pass through the gate before the net deforms.

production_logic.py
import streamlit as st
import cv2
import tempfile
import os
import numpy as np
from collections import deque
from ultralytics import YOLO

st.set_page_config(page_title="NBA Vision (Final)", layout="wide", page_icon="🏀")
st.title("🏀 NBA Vision: Production Candidate")

# --- 1. UTILS ---
class RimMemory:
    def __init__(self):
        self.last_rim = None
        self.frames_without_rim = 0
    def update(self, raw_rim):
        if raw_rim is not None:
            self.last_rim = raw_rim
            self.frames_without_rim = 0
            return raw_rim
        elif self.last_rim is not None and self.frames_without_rim < 60:
            self.frames_without_rim += 1
            return self.last_rim
        return None

# --- GRAVITY KILLER (DUMB SMOOTHER) ---
class DumbSmoother:
    def __init__(self, history_size=3):
        self.history = deque(maxlen=history_size)

    def update(self, raw_box, rim_box=None):
        if raw_box is not None:
            self.history.append(raw_box)
            return np.mean(self.history, axis=0).astype(int).tolist()
        return None 

# --- 2. SHOT DETECTOR (HIGH & TIGHT) ---
class ShotDetector:
    def __init__(self):
        self.basket_made = False
        self.shot_in_progress = False 
        self.cooldown = 0
        
    def update(self, ball_box, rim_box):
        # Allow multiple shots in a clip (with cooldown) or stick to binary?
        # Sticking to binary "Did a Make Happen?" for this demo
        if self.basket_made: return True 
        if ball_box is None or rim_box is None: return False

        bx, by, bw, bh = ball_box
        rx, ry, rw, rh = rim_box
        
        ball_cy = int(by + bh // 2)
        ball_cx = int(bx + bw // 2)

        rim_center_x = int(rx + rw // 2)
        
        # --- ZONES (Hidden Logic) ---
        # 1. ENTRY (Wide): 50% width
        entry_width = int(rw * 0.5)
        ex1 = int(rim_center_x - entry_width // 2)
        ex2 = int(rim_center_x + entry_width // 2)
        entry_y = ry
        
        # 2. EXIT (High & Tight): 25% width at 50% depth
        exit_width = int(rw * 0.25)
        fx1 = int(rim_center_x - exit_width // 2)
        fx2 = int(rim_center_x + exit_width // 2)
        finish_line_y = ry + int(rh * 0.50) 

        # --- LOGIC ---
        
        # A. DETECT ENTRY
        if (ex1 < ball_cx < ex2) and (entry_y < ball_cy < entry_y + rh*0.4):
            self.shot_in_progress = True 

        # B. DETECT FINISH
        if self.shot_in_progress:
            if ball_cy > finish_line_y:
                # CHECK: Is it centered in the narrow neck?
                if (fx1 < ball_cx < fx2): 
                    self.basket_made = True
                    self.shot_in_progress = False 
                    return True
                else:
                    # REJECTED: Ball drifted wide (Parallax/Floater Miss)
                    self.shot_in_progress = False
            
            # C. ABORT (Popped up)
            if ball_cy < ry:
                self.shot_in_progress = False
        
        return False

# --- 3. LOAD MODELS ---
@st.cache_resource
def load_models():
    # Use the best weights
    ball_path = 'runs/detect/ball_model_blur_fix/weights/best.pt'
    hoop_path = 'runs/detect/nba_hoop_custom_v1/weights/best.pt'
    
    ball_model = YOLO(ball_path) if os.path.exists(ball_path) else YOLO('yolo11n.pt')
    hoop_model = YOLO(hoop_path) if os.path.exists(hoop_path) else YOLO('yolo11n.pt')
    
    return ball_model, hoop_model

ball_model, hoop_model = load_models()

# --- MAIN UI ---
st.sidebar.header("Configuration")
ball_conf = st.sidebar.slider("Ball Confidence", 0.1, 0.9, 0.25)
hoop_conf = st.sidebar.slider("Hoop Confidence", 0.3, 0.9, 0.5)

uploaded_file = st.file_uploader("Upload Clip", type=["mp4", "mov"])

if uploaded_file is not None:
    tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
    tfile.write(uploaded_file.read())
    
    cap = cv2.VideoCapture(tfile.name)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # OUTPUT
    output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
    try: fourcc = cv2.VideoWriter_fourcc(*'avc1')
    except: fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
    
    progress_bar = st.progress(0)
    frame_cnt = 0
    
    rim_memory = RimMemory()
    smoother = DumbSmoother() # NO GRAVITY
    scorer = ShotDetector()
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break
        
        # 1. AI
        raw_ball, raw_rim = None, None
        
        b_res = ball_model.track(frame, persist=True, verbose=False, conf=ball_conf)
        if b_res[0].boxes:
            b = b_res[0].boxes[0].xywh.cpu().numpy()[0]
            raw_ball = [b[0], b[1], b[2], b[3]]
            
        r_res = hoop_model(frame, verbose=False, conf=hoop_conf)
        if r_res[0].boxes:
            r = r_res[0].boxes[0].xywh.cpu().numpy()[0]
            raw_rim = [r[0]-r[2]/2, r[1]-r[3]/2, r[2], r[3]]

        # 2. PHYSICS
        stable_rim = rim_memory.update(raw_rim)
        final_ball = smoother.update(raw_ball) 
        
        # 3. LOGIC
        ball_box_score = None
        if final_ball:
            bx, by, bw, bh = final_ball
            ball_box_score = [int(bx-bw/2), int(by-bh/2), int(bw), int(bh)]
        
        scorer.update(ball_box_score, stable_rim)
        
        # 4. DRAW
        # Draw Rim (Clean Green)
        if stable_rim:
            rx, ry, rw, rh = map(int, stable_rim)
            cv2.rectangle(frame, (rx, ry), (rx+rw, ry+rh), (0, 255, 0), 2)
            cv2.putText(frame, "RIM", (rx, ry-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)

        # Draw Ball (Clean Yellow)
        if ball_box_score:
            x, y, w, h = ball_box_score
            cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 255), 2)
            cv2.putText(frame, "BASKETBALL", (x, y-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)

        # Draw Status
        if scorer.basket_made:
             # Flash Green Border
             cv2.rectangle(frame, (0,0), (width, height), (0, 255, 0), 10)
             # Center Text
             cv2.putText(frame, "BASKET!", (width//2 - 100, 100), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 4)

        out.write(frame)
        frame_cnt += 1
        progress_bar.progress(min(frame_cnt / total_frames, 1.0))

    cap.release()
    out.release()
    progress_bar.empty()
    
    result = 1 if scorer.basket_made else 0
    if result == 1:
        st.success(f"### Made Basket: {result}")
    else:
        st.error(f"### Made Basket: {result}")

    with open(output_path, 'rb') as f:
        st.video(f.read())

This fixed the swish (Success 7). We checked back with the floater, and it still correctly rejected it (Success 8).

Success 7: Swish Fixedsuccess_7.mp4
Success 8: Floater Still Fixedsuccess_8.mp4

TL;DR

  • > PROBLEM: Binary Classification was too expensive/brittle.
  • > SOLUTION: Object Detection + Physics Logic.
  • > FAILURES: Generic datasets failed on angles/colors.
  • > PIVOT: Built custom Active Learning Miner.
  • > FINAL LOGIC: "High & Tight" Gate (50% depth, 25% width).