import cv2
import numpy as np
import math

# === CONFIGURATION ===
video_path = r"C:\Users\gqdee\OneDrive\tennis\Tournaments\Match Videos\Nikith.mp4"  # 🛠 Replace with your video file
output_path = "tennis_ball_speed_output.mp4"
fps = 30  # 🛠 Set your video frame rate
pixels_per_meter = 40  # 🛠 Calibrate based on your video (see notes below)

# Tennis ball HSV color range (adjust for lighting/camera)
lower_yellow = np.array([25, 80, 80])
upper_yellow = np.array([40, 255, 255])

# === VIDEO SETUP ===
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

# === SPEED TRACKING ===
prev_center = None
speeds_mph = []

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break

    # Convert to HSV and apply mask for tennis ball color
    hsv = cv2.cvtColor(frame, cv2.COLOR_BGR2HSV)
    mask = cv2.inRange(hsv, lower_yellow, upper_yellow)
    mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, np.ones((5, 5), np.uint8))

    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    ball_center = None

    if contours:
        largest = max(contours, key=cv2.contourArea)
        ((x, y), radius) = cv2.minEnclosingCircle(largest)

        if radius > 3:
            ball_center = (int(x), int(y))
            # 🔴 Draw red circle around tennis ball
            cv2.circle(frame, ball_center, int(radius), (0, 0, 255), 2)

            if prev_center:
                dx = (ball_center[0] - prev_center[0]) / pixels_per_meter
                dy = (ball_center[1] - prev_center[1]) / pixels_per_meter
                distance_m = math.sqrt(dx**2 + dy**2)
                speed_mps = distance_m * fps
                speed_mph = speed_mps * 2.23694  # Convert to MPH
                speeds_mph.append(speed_mph)

            prev_center = ball_center

    # === DISPLAY SPEED INFO ===
    if speeds_mph:
        current_speed = speeds_mph[-1]
        avg_speed = sum(speeds_mph) / len(speeds_mph)
        top_speed = max(speeds_mph)
        bottom_speed = min(speeds_mph)

        def draw_speed(label, value, y_offset):
            cv2.putText(frame, f"{label}: {value:.1f} MPH", (10, y_offset),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)

        draw_speed("Current Speed", current_speed, 30)
        draw_speed("Average Speed", avg_speed, 60)
        draw_speed("Top Speed", top_speed, 90)
        draw_speed("Lowest Speed", bottom_speed, 120)

    out.write(frame)
    cv2.imshow("Tennis Ball Speed Tracker", frame)
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# === CLEANUP ===
cap.release()
out.release()
cv2.destroyAllWindows()
