Source code for psf_generator.utils.handle_data

"""
A collection of functions to handle loading and saving of data and image.

- `image` uses common image formats, e.g., `.tif`
- `npy` uses numpy data format `.npy` for images
- `csv` uses `.csv` for statistical data

Notes
-----
    `save_image` follows convention (spatial dimensions, channels), i.e. it changes the axes of the input image.
    For tests, we save images in `.npy` format to avoid this inconvenience.

"""
import csv
import os
import typing as tp

import numpy as np
import skimage.io as skio
import torch

from psf_generator.utils.misc import convert_tensor_to_array


[docs] def load_image(filepath: str): """ Load data from filepath. Parameters ---------- filepath : str Path to the file. """ if not os.path.isfile(filepath): raise FileNotFoundError(f'{filepath} does not exist') return skio.imread(filepath)
[docs] def save_image(filepath: str, image: tp.Union[torch.Tensor, np.ndarray]): """ Save image in specified format to specified location. Parameters ---------- filepath : str Path to save the file. image : torch.Tensor or np.ndarray Image to be saved. Notes ----- Scikit-image and tifffile both follow the convention of putting the channel dimension after x and y. The saved tif image thus has dimension (z, x, y, channels) instead of (z, channels, x, y). """ image = convert_tensor_to_array(image) os.makedirs(os.path.dirname(filepath), exist_ok=True) skio.imsave(filepath, image, check_contrast=False)
[docs] def save_as_npy(filepath: str, input_data: tp.Union[torch.Tensor, np.ndarray]): """ Save data as a numpy array in a .npy file. Parameters ---------- filepath : str Path to save the file. input_data : torch.Tensor or np.ndarray Data to be saved """ input_data = convert_tensor_to_array(input_data) os.makedirs(os.path.dirname(filepath), exist_ok=True) np.save(filepath, input_data)
[docs] def load_from_npy(filepath: str) -> np.ndarray: """ Load numpy array from a file. Parameters ---------- filepath : str Path to file. Returns ------- output : np.ndarray Loaded array. """ return np.load(filepath)
[docs] def save_stats_as_csv(filepath: str, data: list): """ Save statistical data to a csv file for further analysis or plotting. Statistical data such as the runtime values is saved as a list of tuples (index, value). Parameters ---------- filepath : str Path to the file to store the statistics. data : list Statistics to be saved. """ os.makedirs(os.path.dirname(filepath), exist_ok=True) with open(filepath, 'w', newline='') as csv_file: writer = csv.writer(csv_file) for row in data: writer.writerow(row)
[docs] def load_stats_from_csv(filepath: str): """ Load data from a csv file. Parameters ---------- filepath: str Path to the csv file. """ if not os.path.isfile(filepath): raise FileNotFoundError(f'File {filepath} does not exist') with open(filepath, newline='') as csv_file: reader = csv.reader(csv_file, delimiter=',') data = [] for row in reader: data.append((int(row[0]), float(row[1]))) return data