1import csv
2import os
3
4from torchmate.callbacks import Callback
5
6
[docs]
7class CSVLogger(Callback):
8 """Logs training history to a CSV file.
9
10 Args:
11 filename (str): The path to the CSV file to log to.
12 separator (str, optional): The field separator to use in the CSV file. Defaults to ``","``.
13 append (bool, optional): Whether to append to an existing file or create a new one. Defaults to ``False``.
14
15 **Example Usage:**
16
17 .. code-block:: python
18
19 csv_logger = CSVLogger(filename="logs/log.csv")
20 callbacks = [csv_logger]
21
22
23 """
24
25 def __init__(self, filename: str, separator: str = ",", append: bool = False):
26 super().__init__()
27 self.filename = filename
28 self.separator = separator
29 self.append = append
30 self.current_epoch = 0
31
[docs]
32 def on_experiment_begin(self, trainer) -> None:
33 """Initialize the logging process at the beginning of the experiment.
34
35 Args:
36 trainer (Trainer): The trainer object.
37
38 Returns:
39 None
40 """
41
42 file_dir = os.path.dirname(self.filename)
43 if not os.path.exists(file_dir) and file_dir != "":
44 os.makedirs(file_dir)
45
46 open_flag = "a" if self.append else "w"
47 fieldnames = ["Epoch"] + list(trainer.history.keys())
48 self.csvfile = open(self.filename, open_flag, newline="", encoding="utf-8")
49 self.writer = csv.DictWriter(self.csvfile, fieldnames=fieldnames, delimiter=self.separator)
50
51 if not self.append:
52 self.writer.writeheader()
53 self.csvfile.flush()
54
55 return None
56
[docs]
57 def on_epoch_end(self, trainer) -> None:
58 """Log training metrics after each epoch.
59
60 Args:
61 trainer (Trainer): The trainer object.
62
63 Returns:
64 None
65 """
66
67 self.current_epoch += 1
68
69 epoch_logs = {"Epoch": self.current_epoch}
70 for key, value in trainer.history.items():
71 epoch_logs[key] = trainer.history[key][self.current_epoch - 1]
72
73 self.writer.writerow(epoch_logs)
74 self.csvfile.flush()
75 return None
76
[docs]
77 def on_experiment_end(self, trainer) -> None:
78 """Close the CSV file at the end of the experiment.
79
80 Args:
81 trainer (Trainer): The trainer object.
82
83 Returns:
84 None
85 """
86
87 self.csvfile.close()
88
89 return None