import wave
import random
import numpy as np

#################################################################################
output_filename1 = 'swapped_example1.wav'
output_filename2 = 'swapped_example2.wav'

wav_file1 = wave.open('example1.wav', 'rb')
wav_file2 = wave.open('example2.wav', 'rb')

num_frames = wav_file1.getnframes()
num_frames2 = wav_file2.getnframes()
if (num_frames > num_frames2):
    num_frames=num_frames2

frame_rate = wav_file1.getframerate()
num_channels = wav_file1.getnchannels()
sample_width = wav_file1.getsampwidth()

for _ in range(1):
    data1 =     np.frombuffer(wav_file1.readframes(num_frames), dtype=np.int16).reshape(-1, num_channels).copy()
    data1copy = np.frombuffer(wav_file1.readframes(num_frames), dtype=np.int16).reshape(-1, num_channels).copy()
    data2 =     np.frombuffer(wav_file2.readframes(num_frames), dtype=np.int16).reshape(-1, num_channels).copy()

    rnum= random.randint(0, num_frames - int(frame_rate))
    start_frame = rnum
    end_frame = start_frame + int(frame_rate/8)
    
    # Extract segment data
    data1[start_frame:end_frame] = data2[start_frame:end_frame];    
    data2[start_frame:end_frame] = data1copy[start_frame:end_frame];

# Write modified data back to new WAV files
with wave.open(output_filename1, 'wb') as output_file1:
    output_file1.setnchannels(num_channels)
    output_file1.setsampwidth(sample_width)
    output_file1.setframerate(frame_rate)
    output_file1.writeframes(data1.astype(np.int16).tobytes())

with wave.open(output_filename2, 'wb') as output_file2:
    output_file2.setnchannels(num_channels)
    output_file2.setsampwidth(sample_width)
    output_file2.setframerate(frame_rate)
    output_file2.writeframes(data2.astype(np.int16).tobytes())

# Close the input files
wav_file1.close()
wav_file2.close()
output_file1.close()
output_file2.close()
