Generate music with an RNN

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial shows you how to generate musical notes using a simple recurrent neural network (RNN). You will train a model using a collection of piano MIDI files from the MAESTRO dataset. Given a sequence of notes, your model will learn to predict the next note in the sequence. You can generate longer sequences of notes by calling the model repeatedly.

This tutorial contains complete code to parse and create MIDI files. You can learn more about how RNNs work by visiting the Text generation with an RNN tutorial.

Setup

This tutorial uses the pretty_midi library to create and parse MIDI files, and pyfluidsynth for generating audio playback in Colab.

sudo apt install -y fluidsynth
The following packages were automatically installed and are no longer required:
  libatasmart4 libblockdev-fs2 libblockdev-loop2 libblockdev-part-err2
  libblockdev-part2 libblockdev-swap2 libblockdev-utils2 libblockdev2
  libparted-fs-resize0 libxmlb2
Use 'sudo apt autoremove' to remove them.
The following additional packages will be installed:
  fluid-soundfont-gm libdouble-conversion3 libfluidsynth2 libinstpatch-1.0-2
  libpcre2-16-0 libqt5core5a libqt5dbus5 libqt5gui5 libqt5network5 libqt5svg5
  libqt5widgets5 libsdl2-2.0-0 qsynth qt5-gtk-platformtheme
  qttranslations5-l10n timgm6mb-soundfont
Suggested packages:
  fluid-soundfont-gs timidity qt5-image-formats-plugins qtwayland5 jackd
  musescore
The following NEW packages will be installed:
  fluid-soundfont-gm fluidsynth libdouble-conversion3 libfluidsynth2
  libinstpatch-1.0-2 libpcre2-16-0 libqt5core5a libqt5dbus5 libqt5gui5
  libqt5network5 libqt5svg5 libqt5widgets5 libsdl2-2.0-0 qsynth
  qt5-gtk-platformtheme qttranslations5-l10n timgm6mb-soundfont
0 upgraded, 17 newly installed, 0 to remove and 93 not upgraded.
Need to get 136 MB of archives.
After this operation, 202 MB of additional disk space will be used.
Get:1 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libdouble-conversion3 amd64 3.1.5-4ubuntu1 [37.9 kB]
Get:2 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/main amd64 libpcre2-16-0 amd64 10.34-7ubuntu0.1 [181 kB]
Get:3 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5core5a amd64 5.12.8+dfsg-0ubuntu2.1 [2006 kB]
Get:4 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5dbus5 amd64 5.12.8+dfsg-0ubuntu2.1 [208 kB]
Get:5 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5network5 amd64 5.12.8+dfsg-0ubuntu2.1 [673 kB]
Get:6 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5gui5 amd64 5.12.8+dfsg-0ubuntu2.1 [2971 kB]
Get:7 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 libqt5widgets5 amd64 5.12.8+dfsg-0ubuntu2.1 [2295 kB]
Get:8 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libqt5svg5 amd64 5.12.8-0ubuntu1 [131 kB]
Get:9 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 fluid-soundfont-gm all 3.1-5.1 [119 MB]
Get:10 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libinstpatch-1.0-2 amd64 1.1.2-2build1 [238 kB]
Get:11 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libsdl2-2.0-0 amd64 2.0.10+dfsg1-3 [407 kB]
Get:12 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 timgm6mb-soundfont all 1.3-3 [5420 kB]
Get:13 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 libfluidsynth2 amd64 2.1.1-2 [198 kB]
Get:14 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 fluidsynth amd64 2.1.1-2 [25.6 kB]
Get:15 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 qsynth amd64 0.6.1-1build1 [245 kB]
Get:16 http://us-central1.gce.archive.ubuntu.com/ubuntu focal-updates/universe amd64 qt5-gtk-platformtheme amd64 5.12.8+dfsg-0ubuntu2.1 [124 kB]
Get:17 http://us-central1.gce.archive.ubuntu.com/ubuntu focal/universe amd64 qttranslations5-l10n all 5.12.8-0ubuntu1 [1486 kB]
Fetched 136 MB in 3s (46.5 MB/s)

7[0;23r8[1ASelecting previously unselected package libdouble-conversion3:amd64.
(Reading database ... 143583 files and directories currently installed.)
Preparing to unpack .../00-libdouble-conversion3_3.1.5-4ubuntu1_amd64.deb ...
7[24;0fProgress: [  0%] [..........................................................] 87[24;0fProgress: [  1%] [..........................................................] 8Unpacking libdouble-conversion3:amd64 (3.1.5-4ubuntu1) ...
7[24;0fProgress: [  3%] [#.........................................................] 8Selecting previously unselected package libpcre2-16-0:amd64.
Preparing to unpack .../01-libpcre2-16-0_10.34-7ubuntu0.1_amd64.deb ...
7[24;0fProgress: [  4%] [##........................................................] 8Unpacking libpcre2-16-0:amd64 (10.34-7ubuntu0.1) ...
7[24;0fProgress: [  6%] [###.......................................................] 8Selecting previously unselected package libqt5core5a:amd64.
Preparing to unpack .../02-libqt5core5a_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7[24;0fProgress: [  7%] [####......................................................] 8Unpacking libqt5core5a:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [  9%] [#####.....................................................] 8Selecting previously unselected package libqt5dbus5:amd64.
Preparing to unpack .../03-libqt5dbus5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7[24;0fProgress: [ 10%] [#####.....................................................] 8Unpacking libqt5dbus5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 12%] [######....................................................] 8Selecting previously unselected package libqt5network5:amd64.
Preparing to unpack .../04-libqt5network5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7[24;0fProgress: [ 13%] [#######...................................................] 8Unpacking libqt5network5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 14%] [########..................................................] 8Selecting previously unselected package libqt5gui5:amd64.
Preparing to unpack .../05-libqt5gui5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7[24;0fProgress: [ 16%] [#########.................................................] 8Unpacking libqt5gui5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 17%] [##########................................................] 8Selecting previously unselected package libqt5widgets5:amd64.
Preparing to unpack .../06-libqt5widgets5_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7[24;0fProgress: [ 19%] [##########................................................] 8Unpacking libqt5widgets5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 20%] [###########...............................................] 8Selecting previously unselected package libqt5svg5:amd64.
Preparing to unpack .../07-libqt5svg5_5.12.8-0ubuntu1_amd64.deb ...
7[24;0fProgress: [ 22%] [############..............................................] 8Unpacking libqt5svg5:amd64 (5.12.8-0ubuntu1) ...
7[24;0fProgress: [ 23%] [#############.............................................] 8Selecting previously unselected package fluid-soundfont-gm.
Preparing to unpack .../08-fluid-soundfont-gm_3.1-5.1_all.deb ...
7[24;0fProgress: [ 25%] [##############............................................] 8Unpacking fluid-soundfont-gm (3.1-5.1) ...
7[24;0fProgress: [ 26%] [###############...........................................] 8Selecting previously unselected package libinstpatch-1.0-2:amd64.
Preparing to unpack .../09-libinstpatch-1.0-2_1.1.2-2build1_amd64.deb ...
7[24;0fProgress: [ 28%] [###############...........................................] 8Unpacking libinstpatch-1.0-2:amd64 (1.1.2-2build1) ...
7[24;0fProgress: [ 29%] [################..........................................] 8Selecting previously unselected package libsdl2-2.0-0:amd64.
Preparing to unpack .../10-libsdl2-2.0-0_2.0.10+dfsg1-3_amd64.deb ...
7[24;0fProgress: [ 30%] [#################.........................................] 8Unpacking libsdl2-2.0-0:amd64 (2.0.10+dfsg1-3) ...
7[24;0fProgress: [ 32%] [##################........................................] 8Selecting previously unselected package timgm6mb-soundfont.
Preparing to unpack .../11-timgm6mb-soundfont_1.3-3_all.deb ...
7[24;0fProgress: [ 33%] [###################.......................................] 8Unpacking timgm6mb-soundfont (1.3-3) ...
7[24;0fProgress: [ 35%] [####################......................................] 8Selecting previously unselected package libfluidsynth2:amd64.
Preparing to unpack .../12-libfluidsynth2_2.1.1-2_amd64.deb ...
7[24;0fProgress: [ 36%] [#####################.....................................] 8Unpacking libfluidsynth2:amd64 (2.1.1-2) ...
7[24;0fProgress: [ 38%] [#####################.....................................] 8Selecting previously unselected package fluidsynth.
Preparing to unpack .../13-fluidsynth_2.1.1-2_amd64.deb ...
7[24;0fProgress: [ 39%] [######################....................................] 8Unpacking fluidsynth (2.1.1-2) ...
7[24;0fProgress: [ 41%] [#######################...................................] 8Selecting previously unselected package qsynth.
Preparing to unpack .../14-qsynth_0.6.1-1build1_amd64.deb ...
7[24;0fProgress: [ 42%] [########################..................................] 8Unpacking qsynth (0.6.1-1build1) ...
7[24;0fProgress: [ 43%] [#########################.................................] 8Selecting previously unselected package qt5-gtk-platformtheme:amd64.
Preparing to unpack .../15-qt5-gtk-platformtheme_5.12.8+dfsg-0ubuntu2.1_amd64.deb ...
7[24;0fProgress: [ 45%] [##########################................................] 8Unpacking qt5-gtk-platformtheme:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 46%] [##########################................................] 8Selecting previously unselected package qttranslations5-l10n.
Preparing to unpack .../16-qttranslations5-l10n_5.12.8-0ubuntu1_all.deb ...
7[24;0fProgress: [ 48%] [###########################...............................] 8Unpacking qttranslations5-l10n (5.12.8-0ubuntu1) ...
7[24;0fProgress: [ 49%] [############################..............................] 8Setting up libdouble-conversion3:amd64 (3.1.5-4ubuntu1) ...
7[24;0fProgress: [ 51%] [#############################.............................] 87[24;0fProgress: [ 52%] [##############################............................] 8Setting up libpcre2-16-0:amd64 (10.34-7ubuntu0.1) ...
7[24;0fProgress: [ 54%] [###############################...........................] 87[24;0fProgress: [ 55%] [###############################...........................] 8Setting up qttranslations5-l10n (5.12.8-0ubuntu1) ...
7[24;0fProgress: [ 57%] [################################..........................] 87[24;0fProgress: [ 58%] [#################################.........................] 8Setting up libqt5core5a:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 59%] [##################################........................] 87[24;0fProgress: [ 61%] [###################################.......................] 8Setting up libqt5dbus5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 62%] [####################################......................] 87[24;0fProgress: [ 64%] [####################################......................] 8Setting up fluid-soundfont-gm (3.1-5.1) ...
7[24;0fProgress: [ 65%] [#####################################.....................] 87[24;0fProgress: [ 67%] [######################################....................] 8Setting up libsdl2-2.0-0:amd64 (2.0.10+dfsg1-3) ...
7[24;0fProgress: [ 68%] [#######################################...................] 87[24;0fProgress: [ 70%] [########################################..................] 8Setting up timgm6mb-soundfont (1.3-3) ...
7[24;0fProgress: [ 71%] [#########################################.................] 8update-alternatives: using /usr/share/sounds/sf2/TimGM6mb.sf2 to provide /usr/share/sounds/sf2/default-GM.sf2 (default-GM.sf2) in auto mode
update-alternatives: using /usr/share/sounds/sf2/TimGM6mb.sf2 to provide /usr/share/sounds/sf3/default-GM.sf3 (default-GM.sf3) in auto mode
7[24;0fProgress: [ 72%] [##########################################................] 8Setting up libinstpatch-1.0-2:amd64 (1.1.2-2build1) ...
7[24;0fProgress: [ 74%] [##########################################................] 87[24;0fProgress: [ 75%] [###########################################...............] 8Setting up libqt5network5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 77%] [############################################..............] 87[24;0fProgress: [ 78%] [#############################################.............] 8Setting up libfluidsynth2:amd64 (2.1.1-2) ...
7[24;0fProgress: [ 80%] [##############################################............] 87[24;0fProgress: [ 81%] [###############################################...........] 8Setting up libqt5gui5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 83%] [###############################################...........] 87[24;0fProgress: [ 84%] [################################################..........] 8Setting up libqt5widgets5:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 86%] [#################################################.........] 87[24;0fProgress: [ 87%] [##################################################........] 8Setting up qt5-gtk-platformtheme:amd64 (5.12.8+dfsg-0ubuntu2.1) ...
7[24;0fProgress: [ 88%] [###################################################.......] 87[24;0fProgress: [ 90%] [####################################################......] 8Setting up fluidsynth (2.1.1-2) ...
7[24;0fProgress: [ 91%] [####################################################......] 8Created symlink /etc/systemd/user/multi-user.target.wants/fluidsynth.service → /usr/lib/systemd/user/fluidsynth.service.
7[24;0fProgress: [ 93%] [#####################################################.....] 8Setting up libqt5svg5:amd64 (5.12.8-0ubuntu1) ...
7[24;0fProgress: [ 94%] [######################################################....] 87[24;0fProgress: [ 96%] [#######################################################...] 8Setting up qsynth (0.6.1-1build1) ...
7[24;0fProgress: [ 97%] [########################################################..] 87[24;0fProgress: [ 99%] [#########################################################.] 8Processing triggers for desktop-file-utils (0.24-1ubuntu3) ...
Processing triggers for mime-support (3.64ubuntu1) ...
Processing triggers for hicolor-icon-theme (0.17-2) ...
Processing triggers for gnome-menus (3.36.0-1ubuntu1) ...
Processing triggers for libc-bin (2.31-0ubuntu9.9) ...
Processing triggers for man-db (2.9.1-1) ...

7[0;24r8[1A[J
pip install --upgrade pyfluidsynth
pip install pretty_midi
import collections
import datetime
import fluidsynth
import glob
import numpy as np
import pathlib
import pandas as pd
import pretty_midi
import seaborn as sns
import tensorflow as tf

from IPython import display
from matplotlib import pyplot as plt
from typing import Optional
2023-10-27 05:49:15.925119: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-27 05:49:15.925168: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-27 05:49:15.926725: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)

# Sampling rate for audio playback
_SAMPLING_RATE = 16000

Download the Maestro dataset

data_dir = pathlib.Path('data/maestro-v2.0.0')
if not data_dir.exists():
  tf.keras.utils.get_file(
      'maestro-v2.0.0-midi.zip',
      origin='https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip',
      extract=True,
      cache_dir='.', cache_subdir='data',
  )
Downloading data from https://storage.googleapis.com/magentadata/datasets/maestro/v2.0.0/maestro-v2.0.0-midi.zip
59243107/59243107 [==============================] - 0s 0us/step

The dataset contains about 1,200 MIDI files.

filenames = glob.glob(str(data_dir/'**/*.mid*'))
print('Number of files:', len(filenames))
Number of files: 1282

Process a MIDI file

First, use pretty_midi to parse a single MIDI file and inspect the format of the notes. If you would like to download the MIDI file below to play on your computer, you can do so in colab by writing files.download(sample_file).

sample_file = filenames[1]
print(sample_file)
data/maestro-v2.0.0/2008/MIDI-Unprocessed_05_R1_2008_01-04_ORIG_MID--AUDIO_05_R1_2008_wav--4.midi

Generate a PrettyMIDI object for the sample MIDI file.

pm = pretty_midi.PrettyMIDI(sample_file)

Play the sample file. The playback widget may take several seconds to load.

def display_audio(pm: pretty_midi.PrettyMIDI, seconds=30):
  waveform = pm.fluidsynth(fs=_SAMPLING_RATE)
  # Take a sample of the generated waveform to mitigate kernel resets
  waveform_short = waveform[:seconds*_SAMPLING_RATE]
  return display.Audio(waveform_short, rate=_SAMPLING_RATE)
display_audio(pm)
fluidsynth: warning: SDL2 not initialized, SDL2 audio driver won't be usable
fluidsynth: error: Unknown integer parameter 'synth.sample-rate'

Do some inspection on the MIDI file. What kinds of instruments are used?

print('Number of instruments:', len(pm.instruments))
instrument = pm.instruments[0]
instrument_name = pretty_midi.program_to_instrument_name(instrument.program)
print('Instrument name:', instrument_name)
Number of instruments: 1
Instrument name: Acoustic Grand Piano

Extract notes

for i, note in enumerate(instrument.notes[:10]):
  note_name = pretty_midi.note_number_to_name(note.pitch)
  duration = note.end - note.start
  print(f'{i}: pitch={note.pitch}, note_name={note_name},'
        f' duration={duration:.4f}')
0: pitch=54, note_name=F#3, duration=0.0612
1: pitch=51, note_name=D#3, duration=0.0781
2: pitch=58, note_name=A#3, duration=0.0898
3: pitch=39, note_name=D#2, duration=0.0703
4: pitch=46, note_name=A#2, duration=0.1029
5: pitch=39, note_name=D#2, duration=0.0495
6: pitch=51, note_name=D#3, duration=0.0599
7: pitch=46, note_name=A#2, duration=0.0443
8: pitch=54, note_name=F#3, duration=0.0651
9: pitch=63, note_name=D#4, duration=0.9219

You will use three variables to represent a note when training the model: pitch, step and duration. The pitch is the perceptual quality of the sound as a MIDI note number. The step is the time elapsed from the previous note or start of the track. The duration is how long the note will be playing in seconds and is the difference between the note end and note start times.

Extract the notes from the sample MIDI file.

def midi_to_notes(midi_file: str) -> pd.DataFrame:
  pm = pretty_midi.PrettyMIDI(midi_file)
  instrument = pm.instruments[0]
  notes = collections.defaultdict(list)

  # Sort the notes by start time
  sorted_notes = sorted(instrument.notes, key=lambda note: note.start)
  prev_start = sorted_notes[0].start

  for note in sorted_notes:
    start = note.start
    end = note.end
    notes['pitch'].append(note.pitch)
    notes['start'].append(start)
    notes['end'].append(end)
    notes['step'].append(start - prev_start)
    notes['duration'].append(end - start)
    prev_start = start

  return pd.DataFrame({name: np.array(value) for name, value in notes.items()})
raw_notes = midi_to_notes(sample_file)
raw_notes.head()

It may be easier to interpret the note names rather than the pitches, so you can use the function below to convert from the numeric pitch values to note names. The note name shows the type of note, accidental and octave number (e.g. C#4).

get_note_names = np.vectorize(pretty_midi.note_number_to_name)
sample_note_names = get_note_names(raw_notes['pitch'])
sample_note_names[:10]
array(['D#4', 'A#3', 'D#3', 'A#2', 'F#3', 'D#2', 'D#2', 'D#3', 'F#3',
       'A#3'], dtype='<U3')

To visualize the musical piece, plot the note pitch, start and end across the length of the track (i.e. piano roll). Start with the first 100 notes

def plot_piano_roll(notes: pd.DataFrame, count: Optional[int] = None):
  if count:
    title = f'First {count} notes'
  else:
    title = f'Whole track'
    count = len(notes['pitch'])
  plt.figure(figsize=(20, 4))
  plot_pitch = np.stack([notes['pitch'], notes['pitch']], axis=0)
  plot_start_stop = np.stack([notes['start'], notes['end']], axis=0)
  plt.plot(
      plot_start_stop[:, :count], plot_pitch[:, :count], color="b", marker=".")
  plt.xlabel('Time [s]')
  plt.ylabel('Pitch')
  _ = plt.title(title)
plot_piano_roll(raw_notes, count=100)

png

Plot the notes for the entire track.

plot_piano_roll(raw_notes)

png

Check the distribution of each note variable.

def plot_distributions(notes: pd.DataFrame, drop_percentile=2.5):
  plt.figure(figsize=[15, 5])
  plt.subplot(1, 3, 1)
  sns.histplot(notes, x="pitch", bins=20)

  plt.subplot(1, 3, 2)
  max_step = np.percentile(notes['step'], 100 - drop_percentile)
  sns.histplot(notes, x="step", bins=np.linspace(0, max_step, 21))

  plt.subplot(1, 3, 3)
  max_duration = np.percentile(notes['duration'], 100 - drop_percentile)
  sns.histplot(notes, x="duration", bins=np.linspace(0, max_duration, 21))
plot_distributions(raw_notes)

png

Create a MIDI file

You can generate your own MIDI file from a list of notes using the function below.

def notes_to_midi(
  notes: pd.DataFrame,
  out_file: str, 
  instrument_name: str,
  velocity: int = 100,  # note loudness
) -> pretty_midi.PrettyMIDI:

  pm = pretty_midi.PrettyMIDI()
  instrument = pretty_midi.Instrument(
      program=pretty_midi.instrument_name_to_program(
          instrument_name))

  prev_start = 0
  for i, note in notes.iterrows():
    start = float(prev_start + note['step'])
    end = float(start + note['duration'])
    note = pretty_midi.Note(
        velocity=velocity,
        pitch=int(note['pitch']),
        start=start,
        end=end,
    )
    instrument.notes.append(note)
    prev_start = start

  pm.instruments.append(instrument)
  pm.write(out_file)
  return pm
example_file = 'example.midi'
example_pm = notes_to_midi(
    raw_notes, out_file=example_file, instrument_name=instrument_name)

Play the generated MIDI file and see if there is any difference.

display_audio(example_pm)
fluidsynth: warning: SDL2 not initialized, SDL2 audio driver won't be usable
fluidsynth: error: Unknown integer parameter 'synth.sample-rate'

As before, you can write files.download(example_file) to download and play this file.

Create the training dataset

Create the training dataset by extracting notes from the MIDI files. You can start by using a small number of files, and experiment later with more. This may take a couple minutes.

num_files = 5
all_notes = []
for f in filenames[:num_files]:
  notes = midi_to_notes(f)
  all_notes.append(notes)

all_notes = pd.concat(all_notes)
n_notes = len(all_notes)
print('Number of notes parsed:', n_notes)
Number of notes parsed: 15315

Next, create a tf.data.Dataset from the parsed notes.

key_order = ['pitch', 'step', 'duration']
train_notes = np.stack([all_notes[key] for key in key_order], axis=1)
notes_ds = tf.data.Dataset.from_tensor_slices(train_notes)
notes_ds.element_spec
TensorSpec(shape=(3,), dtype=tf.float64, name=None)

You will train the model on batches of sequences of notes. Each example will consist of a sequence of notes as the input features, and the next note as the label. In this way, the model will be trained to predict the next note in a sequence. You can find a diagram describing this process (and more details) in Text classification with an RNN.

You can use the handy window function with size seq_length to create the features and labels in this format.

def create_sequences(
    dataset: tf.data.Dataset, 
    seq_length: int,
    vocab_size = 128,
) -> tf.data.Dataset:
  """Returns TF Dataset of sequence and label examples."""
  seq_length = seq_length+1

  # Take 1 extra for the labels
  windows = dataset.window(seq_length, shift=1, stride=1,
                              drop_remainder=True)

  # `flat_map` flattens the" dataset of datasets" into a dataset of tensors
  flatten = lambda x: x.batch(seq_length, drop_remainder=True)
  sequences = windows.flat_map(flatten)

  # Normalize note pitch
  def scale_pitch(x):
    x = x/[vocab_size,1.0,1.0]
    return x

  # Split the labels
  def split_labels(sequences):
    inputs = sequences[:-1]
    labels_dense = sequences[-1]
    labels = {key:labels_dense[i] for i,key in enumerate(key_order)}

    return scale_pitch(inputs), labels

  return sequences.map(split_labels, num_parallel_calls=tf.data.AUTOTUNE)

Set the sequence length for each example. Experiment with different lengths (e.g. 50, 100, 150) to see which one works best for the data, or use hyperparameter tuning. The size of the vocabulary (vocab_size) is set to 128 representing all the pitches supported by pretty_midi.

seq_length = 25
vocab_size = 128
seq_ds = create_sequences(notes_ds, seq_length, vocab_size)
seq_ds.element_spec
(TensorSpec(shape=(25, 3), dtype=tf.float64, name=None),
 {'pitch': TensorSpec(shape=(), dtype=tf.float64, name=None),
  'step': TensorSpec(shape=(), dtype=tf.float64, name=None),
  'duration': TensorSpec(shape=(), dtype=tf.float64, name=None)})

The shape of the dataset is (100,1), meaning that the model will take 100 notes as input, and learn to predict the following note as output.

for seq, target in seq_ds.take(1):
  print('sequence shape:', seq.shape)
  print('sequence elements (first 10):', seq[0: 10])
  print()
  print('target:', target)
sequence shape: (25, 3)
sequence elements (first 10): tf.Tensor(
[[0.625      0.         0.23828125]
 [0.6015625  0.04036458 0.2421875 ]
 [0.5859375  0.22395833 0.06510417]
 [0.5625     0.09505208 0.0703125 ]
 [0.53125    0.11067708 0.1640625 ]
 [0.5859375  0.05598958 0.12760417]
 [0.5625     0.09244792 0.08333333]
 [0.625      0.08333333 0.17317708]
 [0.6015625  0.01822917 0.15104167]
 [0.5859375  0.109375   0.04036458]], shape=(10, 3), dtype=float64)

target: {'pitch': <tf.Tensor: shape=(), dtype=float64, numpy=68.0>, 'step': <tf.Tensor: shape=(), dtype=float64, numpy=0.11588541666666652>, 'duration': <tf.Tensor: shape=(), dtype=float64, numpy=0.15494791666666696>}

Batch the examples, and configure the dataset for performance.

batch_size = 64
buffer_size = n_notes - seq_length  # the number of items in the dataset
train_ds = (seq_ds
            .shuffle(buffer_size)
            .batch(batch_size, drop_remainder=True)
            .cache()
            .prefetch(tf.data.experimental.AUTOTUNE))
train_ds.element_spec
(TensorSpec(shape=(64, 25, 3), dtype=tf.float64, name=None),
 {'pitch': TensorSpec(shape=(64,), dtype=tf.float64, name=None),
  'step': TensorSpec(shape=(64,), dtype=tf.float64, name=None),
  'duration': TensorSpec(shape=(64,), dtype=tf.float64, name=None)})

Create and train the model

The model will have three outputs, one for each note variable. For step and duration, you will use a custom loss function based on mean squared error that encourages the model to output non-negative values.

def mse_with_positive_pressure(y_true: tf.Tensor, y_pred: tf.Tensor):
  mse = (y_true - y_pred) ** 2
  positive_pressure = 10 * tf.maximum(-y_pred, 0.0)
  return tf.reduce_mean(mse + positive_pressure)
input_shape = (seq_length, 3)
learning_rate = 0.005

inputs = tf.keras.Input(input_shape)
x = tf.keras.layers.LSTM(128)(inputs)

outputs = {
  'pitch': tf.keras.layers.Dense(128, name='pitch')(x),
  'step': tf.keras.layers.Dense(1, name='step')(x),
  'duration': tf.keras.layers.Dense(1, name='duration')(x),
}

model = tf.keras.Model(inputs, outputs)

loss = {
      'pitch': tf.keras.losses.SparseCategoricalCrossentropy(
          from_logits=True),
      'step': mse_with_positive_pressure,
      'duration': mse_with_positive_pressure,
}

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

model.compile(loss=loss, optimizer=optimizer)

model.summary()
Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_1 (InputLayer)        [(None, 25, 3)]              0         []                            
                                                                                                  
 lstm (LSTM)                 (None, 128)                  67584     ['input_1[0][0]']             
                                                                                                  
 duration (Dense)            (None, 1)                    129       ['lstm[0][0]']                
                                                                                                  
 pitch (Dense)               (None, 128)                  16512     ['lstm[0][0]']                
                                                                                                  
 step (Dense)                (None, 1)                    129       ['lstm[0][0]']                
                                                                                                  
==================================================================================================
Total params: 84354 (329.51 KB)
Trainable params: 84354 (329.51 KB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________

Testing the model.evaluate function, you can see that the pitch loss is significantly greater than the step and duration losses. Note that loss is the total loss computed by summing all the other losses and is currently dominated by the pitch loss.

losses = model.evaluate(train_ds, return_dict=True)
losses
238/238 [==============================] - 4s 3ms/step - loss: 6.1272 - duration_loss: 0.6039 - pitch_loss: 4.8544 - step_loss: 0.6689
{'loss': 6.127169609069824,
 'duration_loss': 0.603919267654419,
 'pitch_loss': 4.854353427886963,
 'step_loss': 0.6688962578773499}

One way balance this is to use the loss_weights argument to compile:

model.compile(
    loss=loss,
    loss_weights={
        'pitch': 0.05,
        'step': 1.0,
        'duration':1.0,
    },
    optimizer=optimizer,
)

The loss then becomes the weighted sum of the individual losses.

model.evaluate(train_ds, return_dict=True)
238/238 [==============================] - 2s 3ms/step - loss: 1.5155 - duration_loss: 0.6039 - pitch_loss: 4.8544 - step_loss: 0.6689
{'loss': 1.515533447265625,
 'duration_loss': 0.603919267654419,
 'pitch_loss': 4.854353427886963,
 'step_loss': 0.6688962578773499}

Train the model.

callbacks = [
    tf.keras.callbacks.ModelCheckpoint(
        filepath='./training_checkpoints/ckpt_{epoch}',
        save_weights_only=True),
    tf.keras.callbacks.EarlyStopping(
        monitor='loss',
        patience=5,
        verbose=1,
        restore_best_weights=True),
]
%%time
epochs = 50

history = model.fit(
    train_ds,
    epochs=epochs,
    callbacks=callbacks,
)
Epoch 1/50
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1698385783.436050  470386 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
238/238 [==============================] - 4s 5ms/step - loss: 0.4896 - duration_loss: 0.2237 - pitch_loss: 4.2569 - step_loss: 0.0530
Epoch 2/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4510 - duration_loss: 0.2010 - pitch_loss: 4.0747 - step_loss: 0.0463
Epoch 3/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4494 - duration_loss: 0.2006 - pitch_loss: 4.0595 - step_loss: 0.0457
Epoch 4/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4464 - duration_loss: 0.1989 - pitch_loss: 4.0458 - step_loss: 0.0451
Epoch 5/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4413 - duration_loss: 0.1948 - pitch_loss: 4.0342 - step_loss: 0.0448
Epoch 6/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4357 - duration_loss: 0.1896 - pitch_loss: 4.0273 - step_loss: 0.0447
Epoch 7/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4338 - duration_loss: 0.1884 - pitch_loss: 4.0228 - step_loss: 0.0442
Epoch 8/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4249 - duration_loss: 0.1821 - pitch_loss: 3.9918 - step_loss: 0.0432
Epoch 9/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4201 - duration_loss: 0.1790 - pitch_loss: 3.9779 - step_loss: 0.0422
Epoch 10/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4151 - duration_loss: 0.1753 - pitch_loss: 3.9627 - step_loss: 0.0417
Epoch 11/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4153 - duration_loss: 0.1765 - pitch_loss: 3.9280 - step_loss: 0.0424
Epoch 12/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4065 - duration_loss: 0.1705 - pitch_loss: 3.9253 - step_loss: 0.0397
Epoch 13/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4220 - duration_loss: 0.1824 - pitch_loss: 3.9317 - step_loss: 0.0430
Epoch 14/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4051 - duration_loss: 0.1689 - pitch_loss: 3.9061 - step_loss: 0.0409
Epoch 15/50
238/238 [==============================] - 1s 4ms/step - loss: 0.4007 - duration_loss: 0.1681 - pitch_loss: 3.8994 - step_loss: 0.0376
Epoch 16/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3938 - duration_loss: 0.1608 - pitch_loss: 3.8912 - step_loss: 0.0385
Epoch 17/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3726 - duration_loss: 0.1431 - pitch_loss: 3.8987 - step_loss: 0.0346
Epoch 18/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3663 - duration_loss: 0.1380 - pitch_loss: 3.8902 - step_loss: 0.0337
Epoch 19/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3573 - duration_loss: 0.1306 - pitch_loss: 3.8760 - step_loss: 0.0330
Epoch 20/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3678 - duration_loss: 0.1400 - pitch_loss: 3.8820 - step_loss: 0.0337
Epoch 21/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3552 - duration_loss: 0.1311 - pitch_loss: 3.8700 - step_loss: 0.0307
Epoch 22/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3552 - duration_loss: 0.1313 - pitch_loss: 3.8636 - step_loss: 0.0307
Epoch 23/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3446 - duration_loss: 0.1239 - pitch_loss: 3.8685 - step_loss: 0.0272
Epoch 24/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3460 - duration_loss: 0.1255 - pitch_loss: 3.8604 - step_loss: 0.0275
Epoch 25/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3446 - duration_loss: 0.1242 - pitch_loss: 3.8534 - step_loss: 0.0277
Epoch 26/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3326 - duration_loss: 0.1156 - pitch_loss: 3.8457 - step_loss: 0.0247
Epoch 27/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3315 - duration_loss: 0.1149 - pitch_loss: 3.8485 - step_loss: 0.0241
Epoch 28/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3263 - duration_loss: 0.1113 - pitch_loss: 3.8590 - step_loss: 0.0220
Epoch 29/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3319 - duration_loss: 0.1151 - pitch_loss: 3.8535 - step_loss: 0.0241
Epoch 30/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3310 - duration_loss: 0.1139 - pitch_loss: 3.8613 - step_loss: 0.0240
Epoch 31/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3178 - duration_loss: 0.1051 - pitch_loss: 3.8405 - step_loss: 0.0206
Epoch 32/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3314 - duration_loss: 0.1151 - pitch_loss: 3.8618 - step_loss: 0.0232
Epoch 33/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3216 - duration_loss: 0.1074 - pitch_loss: 3.8391 - step_loss: 0.0223
Epoch 34/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3124 - duration_loss: 0.1006 - pitch_loss: 3.8407 - step_loss: 0.0198
Epoch 35/50
238/238 [==============================] - 1s 4ms/step - loss: 0.3081 - duration_loss: 0.0962 - pitch_loss: 3.8260 - step_loss: 0.0206
Epoch 36/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2982 - duration_loss: 0.0891 - pitch_loss: 3.8339 - step_loss: 0.0174
Epoch 37/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2957 - duration_loss: 0.0882 - pitch_loss: 3.8265 - step_loss: 0.0162
Epoch 38/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2986 - duration_loss: 0.0905 - pitch_loss: 3.8198 - step_loss: 0.0171
Epoch 39/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2936 - duration_loss: 0.0862 - pitch_loss: 3.8019 - step_loss: 0.0173
Epoch 40/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2920 - duration_loss: 0.0842 - pitch_loss: 3.8052 - step_loss: 0.0175
Epoch 41/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2907 - duration_loss: 0.0846 - pitch_loss: 3.8012 - step_loss: 0.0161
Epoch 42/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2953 - duration_loss: 0.0895 - pitch_loss: 3.7930 - step_loss: 0.0162
Epoch 43/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2892 - duration_loss: 0.0846 - pitch_loss: 3.7831 - step_loss: 0.0154
Epoch 44/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2830 - duration_loss: 0.0783 - pitch_loss: 3.7743 - step_loss: 0.0161
Epoch 45/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2783 - duration_loss: 0.0749 - pitch_loss: 3.7638 - step_loss: 0.0152
Epoch 46/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2729 - duration_loss: 0.0711 - pitch_loss: 3.7510 - step_loss: 0.0142
Epoch 47/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2750 - duration_loss: 0.0744 - pitch_loss: 3.7449 - step_loss: 0.0134
Epoch 48/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2710 - duration_loss: 0.0717 - pitch_loss: 3.7243 - step_loss: 0.0130
Epoch 49/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2741 - duration_loss: 0.0731 - pitch_loss: 3.7208 - step_loss: 0.0150
Epoch 50/50
238/238 [==============================] - 1s 4ms/step - loss: 0.2775 - duration_loss: 0.0767 - pitch_loss: 3.7284 - step_loss: 0.0144
CPU times: user 1min 13s, sys: 9.49 s, total: 1min 23s
Wall time: 54.1 s
plt.plot(history.epoch, history.history['loss'], label='total loss')
plt.show()

png

Generate notes

To use the model to generate notes, you will first need to provide a starting sequence of notes. The function below generates one note from a sequence of notes.

For note pitch, it draws a sample from the softmax distribution of notes produced by the model, and does not simply pick the note with the highest probability. Always picking the note with the highest probability would lead to repetitive sequences of notes being generated.

The temperature parameter can be used to control the randomness of notes generated. You can find more details on temperature in Text generation with an RNN.

def predict_next_note(
    notes: np.ndarray, 
    model: tf.keras.Model, 
    temperature: float = 1.0) -> tuple[int, float, float]:
  """Generates a note as a tuple of (pitch, step, duration), using a trained sequence model."""

  assert temperature > 0

  # Add batch dimension
  inputs = tf.expand_dims(notes, 0)

  predictions = model.predict(inputs)
  pitch_logits = predictions['pitch']
  step = predictions['step']
  duration = predictions['duration']

  pitch_logits /= temperature
  pitch = tf.random.categorical(pitch_logits, num_samples=1)
  pitch = tf.squeeze(pitch, axis=-1)
  duration = tf.squeeze(duration, axis=-1)
  step = tf.squeeze(step, axis=-1)

  # `step` and `duration` values should be non-negative
  step = tf.maximum(0, step)
  duration = tf.maximum(0, duration)

  return int(pitch), float(step), float(duration)

Now generate some notes. You can play around with temperature and the starting sequence in next_notes and see what happens.

temperature = 2.0
num_predictions = 120

sample_notes = np.stack([raw_notes[key] for key in key_order], axis=1)

# The initial sequence of notes; pitch is normalized similar to training
# sequences
input_notes = (
    sample_notes[:seq_length] / np.array([vocab_size, 1, 1]))

generated_notes = []
prev_start = 0
for _ in range(num_predictions):
  pitch, step, duration = predict_next_note(input_notes, model, temperature)
  start = prev_start + step
  end = start + duration
  input_note = (pitch, step, duration)
  generated_notes.append((*input_note, start, end))
  input_notes = np.delete(input_notes, 0, axis=0)
  input_notes = np.append(input_notes, np.expand_dims(input_note, 0), axis=0)
  prev_start = start

generated_notes = pd.DataFrame(
    generated_notes, columns=(*key_order, 'start', 'end'))
1/1 [==============================] - 0s 386ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 43ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 42ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
1/1 [==============================] - 0s 41ms/step
generated_notes.head(10)
out_file = 'output.mid'
out_pm = notes_to_midi(
    generated_notes, out_file=out_file, instrument_name=instrument_name)
display_audio(out_pm)
fluidsynth: warning: SDL2 not initialized, SDL2 audio driver won't be usable
fluidsynth: error: Unknown integer parameter 'synth.sample-rate'

You can also download the audio file by adding the two lines below:

from google.colab import files
files.download(out_file)

Visualize the generated notes.

plot_piano_roll(generated_notes)

png

Check the distributions of pitch, step and duration.

plot_distributions(generated_notes)

png

In the above plots, you will notice the change in distribution of the note variables. Since there is a feedback loop between the model's outputs and inputs, the model tends to generate similar sequences of outputs to reduce the loss. This is particularly relevant for step and duration, which uses the MSE loss. For pitch, you can increase the randomness by increasing the temperature in predict_next_note.

Next steps

This tutorial demonstrated the mechanics of using an RNN to generate sequences of notes from a dataset of MIDI files. To learn more, you can visit the closely related Text generation with an RNN tutorial, which contains additional diagrams and explanations.

One of the alternatives to using RNNs for music generation is using GANs. Rather than generating audio, a GAN-based approach can generate an entire sequence in parallel. The Magenta team has done impressive work on this approach with GANSynth. You can also find many wonderful music and art projects and open-source code on Magenta project website.