import streamlit as st
import numpy as np
from PIL import Image
from pydub import AudioSegment
import wave
import moviepy.editor as mp
import os
import io

# Function to encode text into an image with delimiter
def encode_image(image_path, payload_text, num_lsb):
    image = Image.open(image_path)
    img_data = np.array(image)

    # Convert payload to binary and add delimiter
    payload_bin = ''.join(format(ord(char), '08b') for char in payload_text) + '1111111111111110'
    
    # Calculate maximum bytes to encode
    max_bytes = img_data.size * num_lsb // 8
    if len(payload_bin) > max_bytes:
        raise ValueError("Payload size exceeds cover image capacity.")

    data_index = 0
    for row in img_data:
        for pixel in row:
            for channel in range(3):  # R, G, B channels
                if data_index < len(payload_bin):
                    pixel[channel] = (pixel[channel] & ~((1 << num_lsb) - 1)) | int(payload_bin[data_index:data_index + num_lsb], 2)
                    data_index += num_lsb
                else:
                    break

    stego_image = Image.fromarray(img_data)
    return stego_image

# Function to decode text from an image with delimiter
def decode_image(stego_image_path, num_lsb):
    image = Image.open(stego_image_path)
    img_data = np.array(image)

    payload_bin = ""
    for row in img_data:
        for pixel in row:
            for channel in range(3):  # R, G, B channels
                payload_bin += format(pixel[channel] & ((1 << num_lsb) - 1), '0' + str(num_lsb) + 'b')

    # Convert binary to ASCII and check for delimiter
    payload = ""
    for i in range(0, len(payload_bin), 8):
        byte = payload_bin[i:i + 8]
        if byte == '11111111':  # End of data (delimiter)
            break
        payload += chr(int(byte, 2))

    return payload

# Function to encode text into text with LSB
def encode_text_to_text(payload_text, num_lsb):
    # Convert text to binary and add delimiter
    payload_bin = ''.join(format(ord(char), f'0{num_lsb}b') for char in payload_text) + '1111111111111110'
    return payload_bin

# Function to decode text from text binary
def decode_text_from_text(payload_bin, num_lsb):
    # Convert binary to ASCII and check for delimiter
    payload = ""
    for i in range(0, len(payload_bin), num_lsb):
        byte = payload_bin[i:i + 8]
        if byte == '11111111':  # End of data (delimiter)
            break
        payload += chr(int(byte, 2))
    return payload

# Function to encode text into an audio file (MP3 and WAV supported)
def encode_audio(audio_path, text, num_lsb):
    if audio_path.lower().endswith('.mp3'):
        audio = AudioSegment.from_mp3(audio_path)
        frames = bytearray(audio.raw_data)
    else:
        audio = wave.open(audio_path, 'rb')
        frames = bytearray(list(audio.readframes(audio.getnframes())))

    binary_text = ''.join(format(ord(char), '08b') for char in text) + '1111111111111110'  # Add delimiter

    binary_index = 0
    for i in range(len(frames)):
        if binary_index < len(binary_text):
            frames[i] = (frames[i] & ~1) | int(binary_text[binary_index])
            binary_index += 1

    new_audio_path = os.path.splitext(audio_path)[0] + "_stego.wav"
    if audio_path.lower().endswith('.mp3'):
        new_audio = AudioSegment(
            data=bytes(frames),
            sample_width=audio.sample_width,
            frame_rate=audio.frame_rate,
            channels=audio.channels
        )
        new_audio.export(new_audio_path, format="wav")
    else:
        new_audio = wave.open(new_audio_path, 'wb')
        new_audio.setparams(audio.getparams())
        new_audio.writeframes(frames)
        new_audio.close()
    return new_audio_path

# Function to decode text from an audio file
def decode_audio(audio_path, num_lsb):
    if audio_path.lower().endswith('.mp3'):
        audio = AudioSegment.from_mp3(audio_path)
        frames = bytearray(audio.raw_data)
    else:
        audio = wave.open(audio_path, 'rb')
        frames = bytearray(list(audio.readframes(audio.getnframes())))

    binary_text = ''
    for frame in frames:
        binary_text += str(frame & 1)

    text = ''
    for i in range(0, len(binary_text), 8):
        byte = binary_text[i:i + 8]
        if byte == '11111111':  # End of data (delimiter)
            break
        text += chr(int(byte, 2))

    return text

# Function to encode text into an MP4 video file
def encode_video(video_path, text, num_lsb):
    video = mp.VideoFileClip(video_path)
    frames = [frame.copy() for frame in video.iter_frames()]  # Make frames writable
    
    # Convert text to binary
    binary_text = ''.join(format(ord(char), '08b') for char in text) + '1111111111111110'
    
    binary_index = 0
    for frame in frames:
        for row in frame:
            for pixel in row:
                for channel in range(3):  # R, G, B channels
                    if binary_index < len(binary_text):
                        pixel[channel] = (pixel[channel] & ~((1 << num_lsb) - 1)) | int(binary_text[binary_index:binary_index + num_lsb], 2)
                        binary_index += num_lsb
                    else:
                        break

    # Save the modified frames back to a video
    new_video_path = os.path.splitext(video_path)[0] + "_stego.mp4"
    new_video = mp.ImageSequenceClip(frames, fps=video.fps)
    new_video = new_video.set_audio(video.audio)  # Add audio to the stego video
    new_video.write_videofile(new_video_path, codec='libx264')

    return new_video_path

# Function to decode text from an MP4 video file
def decode_video(video_path, num_lsb):
    video = mp.VideoFileClip(video_path)
    frames = [frame for frame in video.iter_frames()]
    
    binary_text = ""
    for frame in frames:
        for row in frame:
            for pixel in row:
                for channel in range(3):  # R, G, B channels
                    binary_text += format(pixel[channel] & ((1 << num_lsb) - 1), '0' + str(num_lsb) + 'b')

    # Convert binary to ASCII and check for delimiter
    text = ""
    for i in range(0, len(binary_text), 8):
        byte = binary_text[i:i + 8]
        if byte == '11111111':  # End of data (delimiter)
            break
        text += chr(int(byte, 2))

    return text

# Main Streamlit interface
def main():
    st.title("Steganography with Streamlit")
    st.write("Encode and decode text, images, audio, and videos using LSB steganography.")

    # Sidebar for encoding/decoding
    st.sidebar.title("Options")
    operation = st.sidebar.radio("Choose operation", ("Encode", "Decode"))

    # Payload type selection for encoding
    if operation == "Encode":
        payload_type = st.sidebar.radio("Select payload type", ("Text", "Image", "Audio", "Video"))
        cover_type = st.sidebar.radio("Select cover object type", ("Text", "Image", "Audio", "Video"))

        # File type filters for each category
        image_types = ["jpeg", "jpg", "png", "bmp", "gif"]
        audio_types = ["mp3", "wav"]
        video_types = ["mp4"]
        text_types = ["txt"]

        # Upload Payload File
        if payload_type == "Text":
            payload_file = st.file_uploader("Upload a text file payload", type=text_types)
            payload_text = st.text_area("Or enter the text payload:")
        elif payload_type == "Image":
            payload_file = st.file_uploader("Upload an image file payload", type=image_types)
        elif payload_type == "Audio":
            payload_file = st.file_uploader("Upload an audio file payload", type=audio_types)
        elif payload_type == "Video":
            payload_file = st.file_uploader("Upload a video file payload", type=video_types)

        # Upload Cover Object File
        if cover_type == "Text":
            cover_file = st.file_uploader("Upload a text file cover object", type=text_types)
        elif cover_type == "Image":
            cover_file = st.file_uploader("Upload an image file cover object", type=image_types)
        elif cover_type == "Audio":
            cover_file = st.file_uploader("Upload an audio file cover object", type=audio_types)
        elif cover_type == "Video":
            cover_file = st.file_uploader("Upload a video file cover object", type=video_types)

        # Select number of LSBs
        num_lsb = st.slider("Select number of LSBs to use", 1, 8, 1)

        # Encode Button
        if payload_file and cover_file:
            if st.button("Encode Payload into Cover Object"):
                try:
                    # Depending on the payload and cover types, call appropriate encoding functions
                    if payload_type == "Text" and cover_type == "Image":
                        stego_image = encode_image(cover_file, payload_text, num_lsb)
                        st.image(stego_image, caption="Stego Image with Hidden Text")
                        buf = io.BytesIO()
                        stego_image.save(buf, format='PNG')
                        byte_im = buf.getvalue()
                        st.download_button("Download Stego Image", data=byte_im, file_name="stego_image.png", mime="image/png")

                    elif payload_type == "Video" and cover_type == "Image":
                        stego_image = encode_video(cover_file, payload_file, num_lsb)
                        st.image(stego_image, caption="Stego Image with Hidden Video")
                        buf = io.BytesIO()
                        stego_image.save(buf, format='PNG')
                        byte_im = buf.getvalue()
                        st.download_button("Download Stego Image", data=byte_im, file_name="stego_image.png", mime="image/png")

                    # Add other encoding combinations as needed
                except Exception as e:
                    st.error(f"Error: {e}")

    # Stego object type selection for decoding
    elif operation == "Decode":
        # Select the type of stego object that contains the hidden payload
        stego_type = st.sidebar.radio("Select stego object type", ("Text", "Image", "Audio", "Video"))

        # File type filters for each category
        image_types = ["jpeg", "jpg", "png", "bmp", "gif"]
        audio_types = ["mp3", "wav"]
        video_types = ["mp4"]
        text_types = ["txt"]

        # Upload Stego File (the file containing hidden data)
        if stego_type == "Text":
            stego_file = st.file_uploader("Upload a stego text file", type=text_types)
        elif stego_type == "Image":
            stego_file = st.file_uploader("Upload a stego image file", type=image_types)
        elif stego_type == "Audio":
            stego_file = st.file_uploader("Upload a stego audio file", type=audio_types)
        elif stego_type == "Video":
            stego_file = st.file_uploader("Upload a stego video file", type=video_types)

        # Select number of LSBs used during encoding
        num_lsb = st.slider("Select number of LSBs used during encoding", 1, 8, 1)

        # Decode Button
        if stego_file:
            if st.button("Decode Payload from Stego Object"):
                try:
                    # Depending on the stego type, call appropriate decoding functions
                    if stego_type == "Text":
                        decoded_text = decode_text_from_text(stego_file.read().decode("utf-8"), num_lsb)
                        st.text_area("Decoded Text", value=decoded_text)
                    elif stego_type == "Image":
                        decoded_text = decode_image(stego_file, num_lsb)
                        st.image(stego_file, caption="Stego Image with Hidden Data")
                    elif stego_type == "Audio":
                        decoded_text = decode_audio(stego_file, num_lsb)
                        st.audio(stego_file, format="audio/wav")
                    elif stego_type == "Video":
                        decoded_video = decode_video(stego_file, num_lsb)
                        st.video(decoded_video)
                    # Add other decoding combinations as needed
                except Exception as e:
                    st.error(f"Error: {e}")

if __name__ == "__main__":
    main()
