Add get_metrics.py

This commit is contained in:
Kacper Donat 2020-05-26 22:10:33 +02:00
parent 7c2485ce9b
commit dad63aa496
4 changed files with 49 additions and 2 deletions

3
.gitignore vendored
View File

@ -145,4 +145,5 @@ cython_debug/
/dataset/processed/
!/dataset/processed/.gitkeep
/dataset/manifest.csv
/eval/
/eval/
metrics.csv

View File

@ -3,3 +3,13 @@ image: Dockerfile
raport.pdf: raport.md
pandoc -f markdown-implicit_figures -V geometry:margin=1in $^ -o $@
manifest: dataset/midi/*.mid dataset/wav/*.wav
python prepare_dataset.py
dataset: manifest
python /opt/conda/envs/magenta/lib/python3.7/site-packages/magenta/models/onsets_frames_transcription/onsets_frames_transcription_create_tfrecords.py --csv="./dataset/manifest.csv" --output_directory="./dataset/processed" --wav_dir="./dataset/wav" --midi_dir="./dataset/midi" --expected_splits="test"
test: dataset
onsets_frames_transcription_infer --model_dir="${MODEL_DIR}" --output_dir="./eval/" --examples_path=./dataset/processed/test.tfrecord* --hparams="use_cudnn=false" --preprocess_examples=True

36
get_metrics.py Normal file
View File

@ -0,0 +1,36 @@
import tensorflow as tf
import pandas as pd
from glob import glob, fnmatch
from argparse import ArgumentParser
parser = ArgumentParser(description="Extract metrics from TF logs.")
parser.add_argument('--log', '-d', dest="pattern", default="./eval/event*", help="Glob pattern for TF event logs.")
parser.add_argument('--output', '-o', dest="output", default="./metrics.csv", help="Name of output csv file.")
parser.add_argument('metric', default=['*'], nargs='*', type=str, help="Metrics to extract");
args = parser.parse_args()
def extract_metrics(path):
runlog = pd.DataFrame(columns=['metric', 'value'])
for e in tf.train.summary_iterator(path):
for v in e.summary.value:
if any([fnmatch.fnmatch(v.tag, pattern) for pattern in args.metric]):
runlog = runlog.append({
'metric': v.tag,
'value':v.simple_value
}, ignore_index=True)
return runlog
metrics = pd.DataFrame()
for path in glob(args.pattern):
try:
metrics = metrics.append(extract_metrics(path))
except:
print(f'Event file corrupted: {path}, skipping...')
print(metrics)
metrics.to_csv(args.output)

View File

@ -20,7 +20,7 @@ if not args.no_convert:
wav = midi.replace('.mid', '.wav').replace('/midi/', '/wav/')
processed = processed + 1
if 'n' in args and processed > args.n:
if args.n is not None and processed > args.n:
break
if path.isfile(wav):