import os
import re
import shutil
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
from pathlib import Path
from typing import Literal
import huggingface_hub as hf # type: ignore
import numpy as np
import pandas as pd
import pyvista as pv
from tqdm import tqdm
from cooldata.metadata import SystemParameters, df_row_to_system_parameters
VolumeFieldType = Literal[
"Velocity_0",
"Velocity_1",
"Velocity_2",
"Pressure",
"Temperature",
"TurbulentKineticEnergy",
"TurbulentDissipationRate",
]
SurfaceFieldType = Literal[
"AreaMagnitude",
"HeatTransferCoefficient",
"Normal_0",
"Normal_1",
"Normal_2",
"Temperature",
"Pressure",
"WallShearStressMagnitude",
"WallShearStress_0",
"WallShearStress_1",
"WallShearStress_2",
]
[docs]
class PyvistaSample:
"""
A PyVista sample representing a 3D flow field dataset. The data is stored on disk and can be loaded to a PyVista object when needed
"""
def __init__(
self,
volume_path: str | Path,
surface_path: str | Path,
metadata: SystemParameters | None = None,
):
self.volume_path = Path(volume_path)
self.surface_path = Path(surface_path)
self._volume_data: pv.MultiBlock | None = None
self._surface_data: pv.MultiBlock | None = None
self.metadata = metadata
self._bounds: tuple[float, float, float, float, float, float] | None = None
@property
def surface_data(self):
if self._surface_data is None:
self._surface_data = pv.read(self.surface_path)
# TODO: Check if data is a valid flow field dataset, i.e., has the necessary point data
return self._surface_data
@property
def volume_data(self):
if self._volume_data is None:
self._volume_data = pv.read(self.volume_path)
# TODO: Check if data is a valid flow field dataset, i.e., has the necessary point data
return self._volume_data
[docs]
def plot_surface(self, field: SurfaceFieldType):
self.surface_data.plot(scalars=field)
[docs]
def plot_volume(self, field: VolumeFieldType):
self.volume_data[0][0][0].plot(scalars=field, opacity=0.7)
[docs]
def get_points(self) -> np.ndarray:
"""
Returns the points of the dataset as a numpy array.
Returns:
--------
np.ndarray: The points of the dataset. Shape: (n_points, 3)
"""
# TODO check if same works for non cgns
block = self.volume_data[0][0][0]
return block.cell_centers().points
[docs]
def get_surface_points(self, block_index: int) -> np.ndarray:
"""
Returns the points of the surface dataset as a numpy array.
Returns:
--------
np.ndarray: The points of the surface dataset. Shape: (n_points, 3)
"""
block = self.surface_data[0][block_index]
return block.points
[docs]
def get_labeled_surface_points(self) -> np.ndarray:
"""
Returns the surface points of the dataset with their block index as a numpy array.
Returns:
--------
np.ndarray: The labeled points of the dataset. Shape: (n_points, 4)
"""
labeled_points = []
for i, block in enumerate(self.surface_data[0]):
labeled_points.append(
np.hstack((block.points, np.full((block.n_points, 1), i)))
)
return np.vstack(labeled_points)
@property
def is_loaded(self) -> bool:
"""
Returns True if the volume and surface data are loaded, False otherwise.
"""
return self._volume_data is not None and self._surface_data is not None
[docs]
def load(self):
self.volume_data
self.surface_data
[docs]
def unload(self):
# TODO: Test if this frees the memory
self._volume_data = None
self._surface_data = None
[docs]
def get_bounding_box(self):
"""
Returns the bounding box of the volume data.
The bounding box is a six-tuple (xmin, xmax, ymin, ymax, zmin, zmax).
"""
if self._bounds is not None:
return self._bounds
was_loaded = self.is_loaded
self._bounds = self.volume_data.bounds
if not was_loaded:
self.unload()
return self._bounds
@property
def design_id(self) -> int:
"""
Returns the design ID of the sample, which is extracted from the file name.
The design ID is assumed to be the second last part of the file name, split by underscores.
"""
stem = self.volume_path.stem
id = re.search(r"_(\d+)", stem)
if id:
return int(id.group(1))
raise ValueError("Design ID not found")
def _repr_html_(self) -> str:
if not hasattr(self, "metadata") or self.metadata is None:
return "<p>No metadata available for this sample to generate HTML representation.</p>"
# Ensure metadata object is not None and has quads and cylinders attributes
if not hasattr(self.metadata, "quads") or not hasattr(
self.metadata, "cylinders"
):
return "<p>Metadata object is malformed (missing quads or cylinders attributes).</p>"
_design_id_str: str
try:
_design_id_str = str(self.design_id)
except ValueError:
_design_id_str = "N_A" # Use a safe fallback for ID
try:
bounds = self.get_bounding_box() # (xmin, xmax, ymin, ymax, zmin, zmax)
except Exception as e:
return f"""<p>Error getting bounding box for sample ID {_design_id_str}: {e}.
Make sure volume data is accessible and valid.</p>"""
xmin, xmax, ymin, ymax, _, _ = bounds
world_width = xmax - xmin
world_height = ymax - ymin
if (
world_width <= 1e-6 or world_height <= 1e-6
): # Using a small epsilon for float comparison
return f"""<p>Bounding box for sample ID {_design_id_str} has zero or negligible 2D dimensions
(width: {world_width:.2e}, height: {world_height:.2e}). Cannot render.</p>"""
# SVG viewport dimensions (fixed size for the output image)
svg_width = 300
svg_height = 200
padding = 0.05 * max(world_width, world_height)
view_box_xmin = xmin - padding
view_box_ymin = ymin - padding
view_box_width = world_width + 2 * padding
view_box_height = world_height + 2 * padding
# Determine a reasonable stroke width based on world dimensions
# This aims for a stroke that's roughly 0.5% of the main dimension
dynamic_stroke_width = 0.005 * max(world_width, world_height)
html_parts = [
'<div style="font-family: Arial, sans-serif; margin: 10px; padding: 10px; border: 1px solid #e0e0e0; border-radius: 5px; background-color: #fdfdfd;">',
f' <h4 style="margin-top:0; margin-bottom: 10px; color: #333;">Design ID: {_design_id_str}</h4>',
f' <svg width="{svg_width}" height="{svg_height}" viewBox="{view_box_xmin} {view_box_ymin} {view_box_width} {view_box_height}" style="border:1px solid #ccc; background-color: #ffffff;">',
]
# Unique ID for clipPath, incorporating object id for robustness in some edge cases
clip_path_id = f"clipPath_PyvistaSample_{_design_id_str}_{hex(id(self))[-6:]}"
html_parts.append(" <defs>")
html_parts.append(f' <clipPath id="{clip_path_id}">')
html_parts.append(
f' <rect x="{xmin}" y="{ymin}" width="{world_width}" height="{world_height}" />'
)
html_parts.append(" </clipPath>")
html_parts.append(" </defs>")
html_parts.append(
f' <rect x="{xmin}" y="{ymin}" width="{world_width}" height="{world_height}" fill="none" stroke="#bbbbbb" stroke-width="{dynamic_stroke_width}" stroke-dasharray="{2 * dynamic_stroke_width},{2 * dynamic_stroke_width}" />'
)
html_parts.append(f' <g clip-path="url(#{clip_path_id})">')
quad_colors = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b"]
for i, quader in enumerate(self.metadata.quads):
q_x = quader.position.x - quader.size_x / 2
q_y = quader.position.y - quader.size_y / 2
color = quad_colors[i % len(quad_colors)]
html_parts.append(
f' <rect x="{q_x}" y="{q_y}" width="{quader.size_x}" height="{quader.size_y}" fill="{color}" fill-opacity="0.6" stroke="{color}" stroke-opacity="0.9" stroke-width="{0.5 * dynamic_stroke_width}">'
)
html_parts.append(
f" <title>Quader {i + 1}\\nTemp: {quader.temperature:.1f}\\nPos: ({quader.position.x:.2f}, {quader.position.y:.2f})\\nSize: ({quader.size_x:.2f}, {quader.size_y:.2f})</title>"
)
html_parts.append(" </rect>")
cyl_colors = ["#17becf", "#bcbd22", "#e377c2", "#7f7f7f", "#aec7e8", "#ffbb78"]
for i, cylinder in enumerate(self.metadata.cylinders):
color = cyl_colors[i % len(cyl_colors)]
html_parts.append(
f' <circle cx="{cylinder.position.x}" cy="{cylinder.position.y}" r="{cylinder.radius}" fill="{color}" fill-opacity="0.6" stroke="{color}" stroke-opacity="0.9" stroke-width="{0.5 * dynamic_stroke_width}">'
)
html_parts.append(
f" <title>Cylinder {i + 1}\\nTemp: {cylinder.temperature:.1f}\\nPos: ({cylinder.position.x:.2f}, {cylinder.position.y:.2f})\\nRadius: {cylinder.radius:.2f}</title>"
)
html_parts.append(" </circle>")
html_parts.append(" </g>")
html_parts.append(" </svg>")
html_parts.append(
f' <p style="font-size: 0.8em; color: #666; margin-top: 8px; margin-bottom: 0;">Bounding Box (xmin, xmax, ymin, ymax): ({xmin:.2f}, {xmax:.2f}, {ymin:.2f}, {ymax:.2f})</p>'
)
html_parts.append("</div>")
return "\n".join(html_parts)
[docs]
class PyvistaFlowFieldDataset:
"""
The main class for working with the cooldata dataset.
"""
def __init__(self, samples: list[PyvistaSample]):
self.samples = samples
def __len__(self):
return len(self.samples)
def __getitem__(self, idx: int):
return self.samples[idx]
[docs]
def slice(self, start: int, end: int):
"""
Returns a slice of the dataset.
"""
return PyvistaFlowFieldDataset(self.samples[start:end])
[docs]
def shuffle(self):
"""
Shuffles the dataset in place.
"""
np.random.shuffle(self.samples)
[docs]
@classmethod
def try_from_directory(
cls, data_dir: str | Path, num_samples: int
) -> "None | PyvistaFlowFieldDataset":
data_dir = os.path.abspath(data_dir)
data_dir = Path(data_dir)
volume_dir = data_dir / "volume"
surface_dir = data_dir / "surface"
metadata_file = data_dir / "metadata.parquet"
if not metadata_file.exists():
print(f"Metadata file not found at {metadata_file}.")
return None
os.makedirs(volume_dir, exist_ok=True)
os.makedirs(surface_dir, exist_ok=True)
volume_files = list(volume_dir.glob("*.cgns"))
surface_files = list(surface_dir.glob("*.cgns"))
volume_indices = [int(f.stem.split("_")[-2]) for f in volume_files]
surface_indices = [int(f.stem.split("_")[-2]) for f in surface_files]
volume_indices.sort()
surface_indices.sort()
if volume_indices != surface_indices:
return None
volume_files = sorted(volume_files, key=lambda x: int(x.stem.split("_")[-2]))
surface_files = sorted(surface_files, key=lambda x: int(x.stem.split("_")[-2]))
samples = [PyvistaSample(v, s) for v, s in zip(volume_files, surface_files)]
if len(samples) < num_samples:
return None
samples = samples[:num_samples]
ds = cls(samples)
metadata_df = pd.read_parquet(metadata_file)
ds.add_metadata(metadata_df)
print(f"Loaded {len(ds)} samples from '{data_dir}'.")
return ds
[docs]
@classmethod
def load_from_huggingface(
cls, data_dir: str | Path, num_samples=3
) -> "PyvistaFlowFieldDataset":
"""
Download the given number of samples from huggingface to data_dir.
Args:
- data_dir: The directory to download the data to.
- num_samples: The number of samples to download
"""
import warnings
warnings.warn(
"load_from_huggingface() is deprecated. "
"Use MetadataFilter from cooldata.metadata instead:\n"
" from cooldata.metadata import MetadataFilter\n"
" ds = MetadataFilter('path/to/metadata.parquet').load_random(n=num_samples)",
FutureWarning,
stacklevel=2,
)
loaded = cls.try_from_directory(data_dir, num_samples)
if loaded is not None:
print(f"Loaded {len(loaded)} samples from '{data_dir}'.")
return loaded
data_dir = os.path.abspath(data_dir)
data_dir = Path(data_dir)
volume_dir = data_dir / "volume"
surface_dir = data_dir / "surface"
tmp_dir = data_dir / "tmp"
os.makedirs(tmp_dir, exist_ok=True)
# remove existing files
if os.path.exists(volume_dir):
shutil.rmtree(volume_dir)
if os.path.exists(surface_dir):
shutil.rmtree(surface_dir)
os.makedirs(volume_dir, exist_ok=True)
os.makedirs(surface_dir, exist_ok=True)
repo_id = "datasets/bgce/cooldata-v2"
fs = hf.HfFileSystem()
# download metadata file
metadata_file = f"{repo_id}/metadata.parquet"
local_metadata_path = os.path.join(data_dir, "metadata.parquet")
fs.download(metadata_file, local_metadata_path)
metadata_df = pd.read_parquet(local_metadata_path)
runs = fs.glob(f"{repo_id}/runs/run_*", detail=False)
samples: list[PyvistaSample] = []
runs = sorted(runs, key=lambda x: int(x.split("/")[-1].removeprefix("run_")))
for run in runs:
run_name = os.path.basename(run)
zip_files_in_run = list(fs.glob(f"{run}/batch_*.zip", detail=False))
zip_files_in_run = sorted(
zip_files_in_run,
key=lambda x: int(
x.split("/")[-1].removeprefix("batch_").removesuffix(".zip")
),
)
for zip_file in zip_files_in_run:
local_path = os.path.join(tmp_dir, run_name, os.path.basename(zip_file))
try:
fs.download(zip_file, local_path)
except Exception as e:
print(f"Failed to download {zip_file} for run {run_name}: {e}")
continue
# Extract the zip file
unzip_dir = os.path.join(
tmp_dir, run_name, os.path.basename(zip_file).removesuffix(".zip")
)
os.makedirs(unzip_dir, exist_ok=True)
shutil.unpack_archive(local_path, unzip_dir)
# Match indices
volume_files = list(Path(unzip_dir).glob("volume_design_*_p.cgns"))
surface_files = list(Path(unzip_dir).glob("surface_design_*_p.cgns"))
volume_indices = [int(f.stem.split("_")[-2]) for f in volume_files]
surface_indices = [int(f.stem.split("_")[-2]) for f in surface_files]
volume_indices.sort()
surface_indices.sort()
if volume_indices != surface_indices:
print(f"Skipping {run_name} due to mismatched indices.")
continue
volume_files = sorted(
volume_files, key=lambda x: int(x.stem.split("_")[-2])
)
surface_files = sorted(
surface_files, key=lambda x: int(x.stem.split("_")[-2])
)
for v, s in zip(volume_files, surface_files):
# Copy the files to the new directory
shutil.copy(v, volume_dir)
shutil.copy(s, surface_dir)
moved_volume_path = volume_dir / v.name
moved_surface_path = surface_dir / s.name
samples.append(PyvistaSample(moved_volume_path, moved_surface_path))
if len(samples) >= num_samples:
shutil.rmtree(tmp_dir)
ds = cls(samples)
ds.add_metadata(metadata_df)
print(f"Loaded {len(ds)} samples from '{data_dir}'.")
return ds
# Clean up temporary directory
shutil.rmtree(tmp_dir)
ds = cls(samples)
ds.add_metadata(metadata_df)
print(f"Loaded {len(ds)} samples from '{data_dir}'.")
return ds
[docs]
def get_bounds(self):
"""
Returns the bounding box of the volume data.
The bounding box is a six-tuple (xmin, xmax, ymin, ymax, zmin, zmax).
"""
bounds = (np.inf, -np.inf, np.inf, -np.inf, np.inf, -np.inf)
self.unload()
with ProcessPoolExecutor() as executor:
sample_bounds_list = list(
tqdm(
executor.map(get_sample_bounds, self.samples),
total=len(self.samples),
desc="Computing bounds",
)
)
for sample_bounds in sample_bounds_list:
bounds = (
min(bounds[0], sample_bounds[0]),
max(bounds[1], sample_bounds[1]),
min(bounds[2], sample_bounds[2]),
max(bounds[3], sample_bounds[3]),
min(bounds[4], sample_bounds[4]),
max(bounds[5], sample_bounds[5]),
)
return bounds
[docs]
def load_to_memory(self):
"""
Load all samples into memory.
"""
for sample in self.samples:
sample.load()
[docs]
def unload(self):
"""
Unload all samples from memory.
"""
for sample in self.samples:
sample.unload()
[docs]
def get_sample_bounds(sample):
return sample.get_bounding_box()