Source code for bldfm.config_parser

"""
YAML configuration parser and dataclass schema for BLDFM.

Defines the configuration hierarchy:
    BLDFMConfig
    ├── DomainConfig
    ├── List[TowerConfig]
    ├── MetConfig
    ├── SolverConfig
    ├── OutputConfig
    └── ParallelConfig
"""

import yaml
import numpy as np
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Optional, Tuple, Union

# --- Dataclasses ---


[docs] @dataclass class TowerConfig: """Configuration for a single measurement tower.""" name: str lat: float lon: float z_m: float # Local coordinates (computed from ref origin) x: float = 0.0 y: float = 0.0
[docs] def compute_local_xy(self, ref_lat: float, ref_lon: float): """Compute local x/y from lat/lon and reference origin.""" from abltk.plotting.geo import latlon_to_xy as _latlon_to_xy self.x, self.y = _latlon_to_xy(self.lat, self.lon, ref_lat, ref_lon)
[docs] @dataclass class DomainConfig: """Configuration for the computational domain.""" nx: int ny: int xmax: float ymax: float nz: int modes: Tuple[int, int] = (512, 512) halo: Optional[float] = None ref_lat: Optional[float] = None ref_lon: Optional[float] = None output_levels: Optional[List[int]] = None full_output: bool = False
[docs] @dataclass class MetConfig: """Meteorological forcing data (scalar for single-time, list for timeseries).""" ustar: Optional[Union[float, List[float]]] = None mol: Union[float, List[float]] = 1e9 wind_speed: Union[float, List[float]] = 5.0 wind_dir: Union[float, List[float]] = 270.0 z0: Optional[float] = None timestamps: Optional[List[str]] = None @property def n_timesteps(self) -> int: """Number of timesteps in the timeseries.""" if isinstance(self.ustar, list): return len(self.ustar) if isinstance(self.wind_speed, list): return len(self.wind_speed) return 1
[docs] def get_step(self, i: int) -> dict: """Get met parameters for a single timestep. Returns ------- dict with scalar values for ustar, mol, wind_speed, wind_dir, timestamp. """ def _get(val, idx): if val is None: return None return val[idx] if isinstance(val, list) else val result = { "ustar": _get(self.ustar, i), "mol": _get(self.mol, i), "wind_speed": _get(self.wind_speed, i), "wind_dir": _get(self.wind_dir, i), } if self.z0 is not None: result["z0"] = self.z0 if self.timestamps is not None: result["timestamp"] = self.timestamps[i] else: result["timestamp"] = i return result
[docs] def validate(self): """Validate that at least one of ustar/z0 is provided, and list lengths match.""" if self.ustar is None and self.z0 is None: raise ValueError("MetConfig requires at least one of 'ustar' or 'z0'") list_fields = {} for name in ("ustar", "mol", "wind_speed", "wind_dir"): val = getattr(self, name) if isinstance(val, list): list_fields[name] = len(val) if not list_fields: return # all scalars, fine lengths = set(list_fields.values()) if len(lengths) > 1: raise ValueError( f"Met timeseries arrays must all have the same length. " f"Got: {list_fields}" ) n = lengths.pop() if self.timestamps is not None and len(self.timestamps) != n: raise ValueError( f"timestamps length ({len(self.timestamps)}) does not match " f"met array length ({n})" )
[docs] @dataclass class SolverConfig: """Solver configuration.""" closure: str = "MOST" precision: str = "single" footprint: bool = False surface_flux_shape: str = "diamond" analytic: bool = False src_loc: Optional[Tuple[float, float]] = None
[docs] @dataclass class OutputConfig: """Output configuration.""" format: str = "netcdf" directory: str = "./output"
[docs] @dataclass class ParallelConfig: """Parallelism configuration.""" num_threads: int = 1 max_workers: int = 1 use_cache: bool = False
[docs] @dataclass class BLDFMConfig: """Top-level BLDFM configuration.""" domain: DomainConfig towers: List[TowerConfig] met: MetConfig solver: SolverConfig = field(default_factory=SolverConfig) output: OutputConfig = field(default_factory=OutputConfig) parallel: ParallelConfig = field(default_factory=ParallelConfig) def __post_init__(self): """Compute local tower coordinates from reference origin.""" if self.domain.ref_lat is not None and self.domain.ref_lon is not None: for tower in self.towers: tower.compute_local_xy(self.domain.ref_lat, self.domain.ref_lon) self.met.validate()
# --- YAML parsing --- def _parse_tower(d: dict) -> TowerConfig: return TowerConfig( name=d["name"], lat=d["lat"], lon=d["lon"], z_m=d["z_m"], ) def _parse_domain(d: dict) -> DomainConfig: modes = d.get("modes", [512, 512]) output_levels = d.get("output_levels") return DomainConfig( nx=d["nx"], ny=d["ny"], xmax=float(d["xmax"]), ymax=float(d["ymax"]), nz=d["nz"], modes=tuple(modes), halo=d.get("halo"), ref_lat=d.get("ref_lat"), ref_lon=d.get("ref_lon"), output_levels=output_levels, full_output=d.get("full_output", False), ) def _parse_met(d: dict) -> MetConfig: return MetConfig( ustar=d.get("ustar"), mol=d.get("mol", 1e9), wind_speed=d.get("wind_speed", 5.0), wind_dir=d.get("wind_dir", 270.0), z0=d.get("z0"), timestamps=d.get("timestamps"), ) def _parse_solver(d: dict) -> SolverConfig: if d is None: return SolverConfig() src_loc = d.get("src_loc") if src_loc is not None: src_loc = tuple(src_loc) return SolverConfig( closure=d.get("closure", "MOST"), precision=d.get("precision", "single"), footprint=d.get("footprint", False), surface_flux_shape=d.get("surface_flux_shape", "diamond"), analytic=d.get("analytic", False), src_loc=src_loc, ) def _parse_output(d: dict) -> OutputConfig: if d is None: return OutputConfig() return OutputConfig( format=d.get("format", "netcdf"), directory=d.get("directory", "./output"), ) def _parse_parallel(d: dict) -> ParallelConfig: if d is None: return ParallelConfig() return ParallelConfig( num_threads=d.get("num_threads", 1), max_workers=d.get("max_workers", 1), use_cache=d.get("use_cache", False), )
[docs] def load_config(path: Union[str, Path]) -> BLDFMConfig: """Load a BLDFM configuration from a YAML file. Parameters ---------- path : str or Path Path to the YAML configuration file. Returns ------- BLDFMConfig Parsed and validated configuration. """ path = Path(path) if not path.exists(): raise FileNotFoundError(f"Config file not found: {path}") with open(path) as f: raw = yaml.safe_load(f) return parse_config_dict(raw)
[docs] def parse_config_dict(raw: dict) -> BLDFMConfig: """Parse a BLDFM configuration from a dictionary. Parameters ---------- raw : dict Raw configuration dictionary (e.g. from YAML). Returns ------- BLDFMConfig Parsed and validated configuration. """ if "domain" not in raw: raise ValueError("Config must include 'domain' section") if "towers" not in raw: raise ValueError("Config must include 'towers' section") if "met" not in raw: raise ValueError("Config must include 'met' section") domain = _parse_domain(raw["domain"]) towers = [_parse_tower(t) for t in raw["towers"]] met = _parse_met(raw["met"]) solver = _parse_solver(raw.get("solver")) output = _parse_output(raw.get("output")) parallel = _parse_parallel(raw.get("parallel")) return BLDFMConfig( domain=domain, towers=towers, met=met, solver=solver, output=output, parallel=parallel, )