Add get_metrics.py
This commit is contained in:
parent
7c2485ce9b
commit
dad63aa496
3
.gitignore
vendored
3
.gitignore
vendored
@ -145,4 +145,5 @@ cython_debug/
|
|||||||
/dataset/processed/
|
/dataset/processed/
|
||||||
!/dataset/processed/.gitkeep
|
!/dataset/processed/.gitkeep
|
||||||
/dataset/manifest.csv
|
/dataset/manifest.csv
|
||||||
/eval/
|
/eval/
|
||||||
|
metrics.csv
|
10
Makefile
10
Makefile
@ -3,3 +3,13 @@ image: Dockerfile
|
|||||||
|
|
||||||
raport.pdf: raport.md
|
raport.pdf: raport.md
|
||||||
pandoc -f markdown-implicit_figures -V geometry:margin=1in $^ -o $@
|
pandoc -f markdown-implicit_figures -V geometry:margin=1in $^ -o $@
|
||||||
|
|
||||||
|
manifest: dataset/midi/*.mid dataset/wav/*.wav
|
||||||
|
python prepare_dataset.py
|
||||||
|
|
||||||
|
dataset: manifest
|
||||||
|
python /opt/conda/envs/magenta/lib/python3.7/site-packages/magenta/models/onsets_frames_transcription/onsets_frames_transcription_create_tfrecords.py --csv="./dataset/manifest.csv" --output_directory="./dataset/processed" --wav_dir="./dataset/wav" --midi_dir="./dataset/midi" --expected_splits="test"
|
||||||
|
|
||||||
|
test: dataset
|
||||||
|
onsets_frames_transcription_infer --model_dir="${MODEL_DIR}" --output_dir="./eval/" --examples_path=./dataset/processed/test.tfrecord* --hparams="use_cudnn=false" --preprocess_examples=True
|
||||||
|
|
36
get_metrics.py
Normal file
36
get_metrics.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
import tensorflow as tf
|
||||||
|
import pandas as pd
|
||||||
|
from glob import glob, fnmatch
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
parser = ArgumentParser(description="Extract metrics from TF logs.")
|
||||||
|
parser.add_argument('--log', '-d', dest="pattern", default="./eval/event*", help="Glob pattern for TF event logs.")
|
||||||
|
parser.add_argument('--output', '-o', dest="output", default="./metrics.csv", help="Name of output csv file.")
|
||||||
|
parser.add_argument('metric', default=['*'], nargs='*', type=str, help="Metrics to extract");
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_metrics(path):
|
||||||
|
runlog = pd.DataFrame(columns=['metric', 'value'])
|
||||||
|
|
||||||
|
for e in tf.train.summary_iterator(path):
|
||||||
|
for v in e.summary.value:
|
||||||
|
if any([fnmatch.fnmatch(v.tag, pattern) for pattern in args.metric]):
|
||||||
|
runlog = runlog.append({
|
||||||
|
'metric': v.tag,
|
||||||
|
'value':v.simple_value
|
||||||
|
}, ignore_index=True)
|
||||||
|
|
||||||
|
return runlog
|
||||||
|
|
||||||
|
metrics = pd.DataFrame()
|
||||||
|
for path in glob(args.pattern):
|
||||||
|
try:
|
||||||
|
metrics = metrics.append(extract_metrics(path))
|
||||||
|
except:
|
||||||
|
print(f'Event file corrupted: {path}, skipping...')
|
||||||
|
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
metrics.to_csv(args.output)
|
@ -20,7 +20,7 @@ if not args.no_convert:
|
|||||||
wav = midi.replace('.mid', '.wav').replace('/midi/', '/wav/')
|
wav = midi.replace('.mid', '.wav').replace('/midi/', '/wav/')
|
||||||
|
|
||||||
processed = processed + 1
|
processed = processed + 1
|
||||||
if 'n' in args and processed > args.n:
|
if args.n is not None and processed > args.n:
|
||||||
break
|
break
|
||||||
|
|
||||||
if path.isfile(wav):
|
if path.isfile(wav):
|
||||||
|
Loading…
Reference in New Issue
Block a user