Source code for pylorenzmie.lmtool.FitWidget

from datetime import datetime
from pathlib import Path
import warnings
import numpy as np
from numpy.typing import NDArray
import pandas as pd
import pyqtgraph as pg
from pyqtgraph.Qt.QtCore import (pyqtProperty, pyqtSignal, pyqtSlot,
                                  QObject, QRectF, QThread)
from pyqtgraph.Qt.QtWidgets import QFileDialog
from pylorenzmie.analysis import Optimizer
from pylorenzmie.theory import LorenzMie


class _OptimizeWorker(QObject):
    '''Runs scipy optimization in a background thread.'''

    finished = pyqtSignal(object)  # pd.Series
    error = pyqtSignal(str)

    def __init__(self, optimizer: Optimizer) -> None:
        super().__init__()
        self._optimizer = optimizer

    @pyqtSlot()
    def run(self) -> None:
        try:
            self.finished.emit(self._optimizer.optimize())
        except Exception as e:
            self.error.emit(str(e))


[docs] class FitWidget(pg.GraphicsLayoutWidget): '''Three-panel widget showing the ROI, model fit, and normalized residuals.''' #: Emitted when an optimization thread is launched. optimizationStarted = pyqtSignal() #: Emitted with the ``pd.Series`` result when optimization succeeds. optimizationFinished = pyqtSignal(object) #: Emitted with an error message string when optimization fails. optimizationError = pyqtSignal(str) def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._configurePlot() self.optimizer = Optimizer() self.fraction = 0.25 self.datafile = None self.result = None self._data = None self.rect = None self._thread = None def _configurePlot(self) -> None: self.ci.layout.setContentsMargins(0, 0, 0, 0) self.setBackground('w') pen = pg.mkPen('k', width=2) plots = [self.addPlot(row=0, column=0), self.addPlot(row=0, column=1), self.addPlot(row=0, column=2)] for plot in plots: plot.getAxis('bottom').setPen(pen) plot.getAxis('left').setPen(pen) plot.setAspectLocked() plot.enableAutoRange(axis='xy', enable=True) options = dict(border=pen, axisOrder='row-major') self.region = pg.ImageItem(**options) self.fit = pg.ImageItem(**options) self.residuals = pg.ImageItem(**options) plots[0].addItem(self.region) plots[1].addItem(self.fit) plots[2].addItem(self.residuals) plots[1].setXLink(plots[0]) plots[2].setXLink(plots[0]) plots[1].setYLink(plots[0]) plots[2].setYLink(plots[0]) self._regionPlot = plots[0] cm = pg.colormap.get('CET-D1') self.residuals.setColorMap(cm) self.residuals.setLevels((-10, 10)) cb = pg.ColorBarItem(values=(-10, 10), limits=(-10, 10), interactive=False, colorMap=cm, pen=pen) self.addItem(cb)
[docs] def mask(self, data: NDArray[float]) -> NDArray[bool]: '''Return a random boolean mask selecting pixels for fitting. A fresh random subset is drawn on every call, which prevents the optimizer from overfitting to a fixed pixel pattern. Saturated pixels (equal to the image maximum) are always excluded. Parameters ---------- data : ndarray Hologram pixel values (any shape). Returns ------- mask : ndarray of bool, shape (data.size,) ''' data = data.flatten() mask = np.random.choice([True, False], data.size, p=[self.fraction, 1-self.fraction]) mask[data == np.max(data)] = False return mask
[docs] def optimize(self, data: NDArray[float], coordinates: NDArray[float]) -> pd.Series: '''Fit the model to data and update the display. Parameters ---------- data : ndarray Normalized hologram crop. coordinates : ndarray, shape (2, npts) Pixel coordinates for the crop. Returns ------- result : pandas.Series Fitted parameters and uncertainties. ''' mask = self.mask(data) coordinates = coordinates.reshape((2, -1)) self.optimizer.data = data.flatten()[mask] self.optimizer.model.coordinates = coordinates[:, mask] self.result = self.optimizer.optimize() self.optimizer.model.coordinates = coordinates self._data = data self._updateFitDisplay() return self.result
[docs] def optimizeAsync(self, data: NDArray[float], coordinates: NDArray[float]) -> None: '''Start optimization in a background thread. Returns immediately. Emits :attr:`optimizationStarted` on entry, then :attr:`optimizationFinished` (or :attr:`optimizationError`) when the thread completes. Use :meth:`optimize` for synchronous (blocking) operation. Parameters ---------- data : ndarray Normalized hologram crop. coordinates : ndarray, shape (2, npts) Pixel coordinates for the crop. ''' if self._thread is not None and self._thread.isRunning(): return mask = self.mask(data) coordinates = coordinates.reshape((2, -1)) self.optimizer.data = data.flatten()[mask] self.optimizer.model.coordinates = coordinates[:, mask] self._full_coordinates = coordinates self._data = data worker = _OptimizeWorker(self.optimizer) thread = QThread() worker.moveToThread(thread) thread.started.connect(worker.run) worker.finished.connect(self._onWorkerFinished) worker.error.connect(self._onWorkerError) worker.finished.connect(thread.quit) worker.error.connect(thread.quit) thread.finished.connect(self._onThreadFinished) self._worker = worker self._thread = thread thread.start() self.optimizationStarted.emit()
@pyqtSlot(object) def _onWorkerFinished(self, result: pd.Series) -> None: self.result = result self.optimizer.model.coordinates = self._full_coordinates self._updateFitDisplay() self.optimizationFinished.emit(result) @pyqtSlot(str) def _onWorkerError(self, message: str) -> None: self.optimizationError.emit(message) @pyqtSlot() def _onThreadFinished(self) -> None: self._worker = None self._thread = None
[docs] def shutdown(self) -> None: '''Stop any running optimization thread and wait for it to finish.''' if self._thread is not None and self._thread.isRunning(): self._thread.quit() self._thread.wait()
[docs] def setData(self, data: NDArray[float], rect: QRectF, coordinates: NDArray[float]) -> None: '''Display data, compute and show the current model prediction. Parameters ---------- data : ndarray Cropped hologram region. rect : QRectF Screen rectangle for positioning the images. coordinates : ndarray, shape (2, npts) Pixel coordinates for the cropped region. ''' self._data = data self.rect = rect self.optimizer.model.coordinates = coordinates.reshape(2, -1) self.region.setImage(data) self.region.setRect(rect) self._updateFitDisplay()
[docs] def showEvent(self, event) -> None: super().showEvent(event) self._updateFitDisplay()
def _updateFitDisplay(self) -> None: '''Recompute the model hologram and refresh all three display panels. No-op when data has not been loaded or the widget is not visible. ''' if self._data is None or self.rect is None: return if not self.isVisible(): return hologram = self.optimizer.model.hologram().reshape(self._data.shape) hologram = np.clip(hologram, np.min(self._data), np.max(self._data)) self.fit.setImage(hologram) noise = self.optimizer.model.instrument.noise self.residuals.setImage((self._data - hologram) / noise) self.fit.setRect(self.rect) self.residuals.setRect(self.rect) self._regionPlot.autoRange()
[docs] def refreshPreview(self, properties: dict | None = None) -> None: '''Recompute and redisplay the model prediction without re-optimizing. Parameters ---------- properties : dict, optional Model properties to apply before recomputing the hologram. ''' if self._data is None: return if properties: self.optimizer.model.properties = properties self._updateFitDisplay()
@pyqtProperty(LorenzMie) def model(self) -> LorenzMie: return self.optimizer.model @model.setter def model(self, model: LorenzMie) -> None: self.optimizer.model = model @pyqtProperty(dict) def properties(self) -> LorenzMie.Properties: return self.optimizer.model.properties @properties.setter def properties(self, properties: LorenzMie.Properties) -> None: self.optimizer.model.properties = properties @pyqtSlot(str, object) def setSetting(self, name: str, value: LorenzMie.Property) -> None: if name == 'fraction': self.fraction = value else: self.optimizer.settings[name] = value
[docs] def filename(self) -> str: directory = Path('~/data/lmtool').expanduser() directory.mkdir(exist_ok=True) timestamp = datetime.now().strftime('%m_%d_%Y-%H_%M_%S') return str(directory / f'result_{timestamp}.h5')
@pyqtSlot() def saveResult(self, filename: str | None = None) -> None: if self.result is None: return filename = filename or self.filename() with warnings.catch_warnings(): warnings.filterwarnings( 'ignore', category=pd.errors.PerformanceWarning) self.result.to_hdf(filename, 'result', mode='w') metadata = self.optimizer.metadata metadata['datafile'] = self.datafile metadata.to_hdf(filename, 'metadata')
[docs] def saveJson(self, filename: str) -> None: if self.result is None: return s = pd.concat([self.result, self.optimizer.metadata]) s['datafile'] = self.datafile s.to_json(filename, indent=4)
@pyqtSlot() def saveResultAs(self) -> None: get = QFileDialog.getSaveFileName filename, _ = get(self, 'Save Results', self.filename(), 'HDF5 (*.h5);;JSON (*.json)') if not filename: return if '.h5' in filename: self.saveResult(filename) elif '.json' in filename: self.saveJson(filename)