Add get_metrics.py
This commit is contained in:
parent
7c2485ce9b
commit
dad63aa496
3
.gitignore
vendored
3
.gitignore
vendored
@ -145,4 +145,5 @@ cython_debug/
|
||||
/dataset/processed/
|
||||
!/dataset/processed/.gitkeep
|
||||
/dataset/manifest.csv
|
||||
/eval/
|
||||
/eval/
|
||||
metrics.csv
|
10
Makefile
10
Makefile
@ -3,3 +3,13 @@ image: Dockerfile
|
||||
|
||||
raport.pdf: raport.md
|
||||
pandoc -f markdown-implicit_figures -V geometry:margin=1in $^ -o $@
|
||||
|
||||
manifest: dataset/midi/*.mid dataset/wav/*.wav
|
||||
python prepare_dataset.py
|
||||
|
||||
dataset: manifest
|
||||
python /opt/conda/envs/magenta/lib/python3.7/site-packages/magenta/models/onsets_frames_transcription/onsets_frames_transcription_create_tfrecords.py --csv="./dataset/manifest.csv" --output_directory="./dataset/processed" --wav_dir="./dataset/wav" --midi_dir="./dataset/midi" --expected_splits="test"
|
||||
|
||||
test: dataset
|
||||
onsets_frames_transcription_infer --model_dir="${MODEL_DIR}" --output_dir="./eval/" --examples_path=./dataset/processed/test.tfrecord* --hparams="use_cudnn=false" --preprocess_examples=True
|
||||
|
36
get_metrics.py
Normal file
36
get_metrics.py
Normal file
@ -0,0 +1,36 @@
|
||||
import tensorflow as tf
|
||||
import pandas as pd
|
||||
from glob import glob, fnmatch
|
||||
from argparse import ArgumentParser
|
||||
|
||||
parser = ArgumentParser(description="Extract metrics from TF logs.")
|
||||
parser.add_argument('--log', '-d', dest="pattern", default="./eval/event*", help="Glob pattern for TF event logs.")
|
||||
parser.add_argument('--output', '-o', dest="output", default="./metrics.csv", help="Name of output csv file.")
|
||||
parser.add_argument('metric', default=['*'], nargs='*', type=str, help="Metrics to extract");
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
def extract_metrics(path):
|
||||
runlog = pd.DataFrame(columns=['metric', 'value'])
|
||||
|
||||
for e in tf.train.summary_iterator(path):
|
||||
for v in e.summary.value:
|
||||
if any([fnmatch.fnmatch(v.tag, pattern) for pattern in args.metric]):
|
||||
runlog = runlog.append({
|
||||
'metric': v.tag,
|
||||
'value':v.simple_value
|
||||
}, ignore_index=True)
|
||||
|
||||
return runlog
|
||||
|
||||
metrics = pd.DataFrame()
|
||||
for path in glob(args.pattern):
|
||||
try:
|
||||
metrics = metrics.append(extract_metrics(path))
|
||||
except:
|
||||
print(f'Event file corrupted: {path}, skipping...')
|
||||
|
||||
print(metrics)
|
||||
|
||||
metrics.to_csv(args.output)
|
@ -20,7 +20,7 @@ if not args.no_convert:
|
||||
wav = midi.replace('.mid', '.wav').replace('/midi/', '/wav/')
|
||||
|
||||
processed = processed + 1
|
||||
if 'n' in args and processed > args.n:
|
||||
if args.n is not None and processed > args.n:
|
||||
break
|
||||
|
||||
if path.isfile(wav):
|
||||
|
Loading…
Reference in New Issue
Block a user