Added DATA_DIR hyperparameter

This commit is contained in:
Harry Stuart 2020-01-01 15:49:17 +11:00
parent 05cf4c4dac
commit b4f62df3fa

View File

@ -26,6 +26,8 @@ EPOCHS_PER_SAMPLE = 2
BATCH_SIZE = 16 BATCH_SIZE = 16
Fs = 16000 Fs = 16000
DATA_DIR = r"D:\ML_Datasets\mancini_piano\piano\train"
# Define class that contains GAN infrastructure # Define class that contains GAN infrastructure
class GAN: class GAN:
def __init__(self, model_dims=MODEL_DIMS, num_samples=NUM_SAMPLES, def __init__(self, model_dims=MODEL_DIMS, num_samples=NUM_SAMPLES,
@ -132,8 +134,8 @@ gan = GAN()
# Create training data # Create training data
X_train = [] X_train = []
for file in os.listdir(r"D:\ML_Datasets\mancini_piano\piano\train"): for file in os.listdir(DATA_DIR): ### Modify for your data directory
with open(r"D:\ML_Datasets\mancini_piano\piano\train" + fr"\{file}", "rb") as f: with open(DATA_DIR + fr"\{file}", "rb") as f:
samples, _ = librosa.load(f, Fs) samples, _ = librosa.load(f, Fs)
# Pad short audio files to NUM_SAMPLES duration # Pad short audio files to NUM_SAMPLES duration
if len(samples) < NUM_SAMPLES: if len(samples) < NUM_SAMPLES: