Source code for quick_vis.slice_plot

"""2D slice down the middle of the domain."""

import importlib
import os
import pickle as pl
import sys
from inspect import getmembers, isfunction

import numpy as np
import yt
from skimage.measure import find_contours
from yt.units.yt_array import YTArray

sys.path.append(os.path.abspath(os.path.join(sys.argv[0], "../../")))
import matplotlib.pyplot as plt  # noqa: E402

import ytscripts.utilities as utils  # noqa: E402
import ytscripts.ytargs as ytargs  # noqa: E402

if utils.is_latex_available():
    plt.rc("text", usetex=True)
else:
    print("LaTeX not available, using standard font.")


[docs] def get_parser(): """Get the parser.""" ytparse = ytargs.ytVisArgs() # Add in the arguments needed for SlicePlot ytparse.orientation_args() ytparse.vis_2d_args() ytparse.slice_args() return ytparse
[docs] def get_base_parser(): """Get the base level parser primarily for documentation.""" return get_parser().get_parser()
[docs] def get_args(parser): """Get the arguments from the parser.""" args = parser.parse_args() # Get the initial set of arguments init_args = parser.parse_args() # Override the command-line arguments with the input file if init_args.ifile: args = parser.override_args(init_args, init_args.ifile) else: args = vars(init_args) # Return the parsed arguments as a dict return args
[docs] def plot_contours(contour, ax, left_edge, dxy, color, linewidth): """Add contours to plot axes.""" for icnt in contour: ax.plot( icnt[:, 1] * dxy[0] + left_edge[0], icnt[:, 0] * dxy[1] + left_edge[1], alpha=1.0, color=color, zorder=10, linewidth=linewidth, )
[docs] def main(): # Parse the input arguments parser = get_parser() args = get_args(parser) # Parse the configuration options configs = utils.get_configs() # Import UDFs udf_funcs = {} if args["add_udf"]: udf_path = os.path.abspath(os.path.join(sys.argv[0], "../../", "udfs")) sys.path.append(udf_path) udf_mod = importlib.import_module("udfs" + f""".{args["add_udf"][:-3]}""") udf_tups = getmembers(udf_mod, isfunction) # Convert the tuples to dict for iname, ifunc in udf_tups: udf_funcs.update({iname: ifunc}) # Enable tex parsing for plots if args["use_tex"]: plt.rc("text", usetex=True) # Make the output directory for images if args["outpath"]: imgpath = args["outpath"] else: imgpath = os.path.abspath(os.path.join(sys.argv[0], "../../outdata/", "images")) os.makedirs(imgpath, exist_ok=True) # Override the units if needed if args["SI"]: units_override = { "length_unit": (1.0, "m"), "time_unit": (1.0, "s"), "mass_unit": (1.0, "kg"), "velocity_unit": (1.0, "m/s"), } axes_unit = "m" eb_var_name = "volFrac" else: units_override = None axes_unit = "cm" eb_var_name = "vfrac" # Load data files into dataset series ts, index_dict = utils.load_dataseries( datapath=args["datapath"], pname=args["pname"], units_override=units_override, nskip=args["nskip"], ) # Get base attributes base_attributes = utils.get_attributes(ds=ts[0]) # get number of cells in the level 0 non-EB grid if args["grid_info"]: dx0, dy0, dz0 = np.array(base_attributes["dxyz"]) data = ts[0].covering_grid( level=0, left_edge=base_attributes["left_edge"], dims=base_attributes["dimensions"], ds=ts[0], ) num_cells_0 = float(data[eb_var_name].sum()) domain_volume = dx0 * dy0 * dz0 * num_cells_0 # left_edge = np.array(base_attributes["left_edge"]) # right_edge = np.array(base_attributes["right_edge"]) # Dx, Dy, Dz = right_edge - left_edge # domain_volume = Dx * Dy * Dz if args["verbose"]: print(f"""The fields in this dataset are: {base_attributes["field_list"]}""") print( f"""The derived fields in this dataset are: """ f"""{base_attributes["derived_field_list"]}""" ) # Set the center of the plot for loading the data if args["center"] is not None: slc_center = args["center"] else: # Set the center based on the plt data slc_center = ( base_attributes["right_edge"] + base_attributes["left_edge"] ) / 2.0 # provide slight offset to avoid grid alignment vis issues slc_center += YTArray(args["grid_offset"], base_attributes["length_unit"]) # Compute the center of the image for plotting if args["pbox"]: # Set the center based on the pbox pbox_center = [ (args["pbox"][2] + args["pbox"][0]) / 2.0, (args["pbox"][3] + args["pbox"][1]) / 2.0, ] # Set the width based on the pbox pbox_width = ( (args["pbox"][2] - args["pbox"][0], axes_unit), (args["pbox"][3] - args["pbox"][1], axes_unit), ) # Set the left edge base on the pbox # pbox_left_edge = [args.pbox[0], args.pbox[1]] if args["contour"]: sys.exit("joint pbox and contour options are currently broken...") # Loop over all datasets in the time series yt.enable_parallelism() for ds in ts.piter(dynamic=True): if ( hasattr(ds.fields.boxlib, "velocityx") and hasattr(ds.fields.boxlib, "velocityy") and hasattr(ds.fields.boxlib, "velocityz") ): utils.define_velocity_fields(ds) # Visualize the gradient field, if requested if args["gradient"]: vis_field = utils.get_gradient_field(ds, args["field"], args["gradient"]) else: vis_field = args["field"] # Get updated attributes for each plt file ds_attributes = utils.get_attributes(ds=ds) # Get the image slice resolution slc_res = { "x": (ds_attributes["resolution"][1], ds_attributes["resolution"][2]), "y": (ds_attributes["resolution"][2], ds_attributes["resolution"][0]), "z": (ds_attributes["resolution"][0], ds_attributes["resolution"][1]), } if args["normal"] == "y": max_res = max( ds_attributes["resolution"][2], ds_attributes["resolution"][0] ) slc_res["y"] = (max_res, max_res) # Set index according to dict index = index_dict[str(ds)] # Plot the field slc = yt.SlicePlot( ds=ds, normal=args["normal"], fields=vis_field, center=slc_center, buff_size=( tuple(args["buff"]) if args["buff"] is not None else slc_res[args["normal"]] ), ) if args["normal"] == "y": slc.swap_axes() slc.set_axes_unit(axes_unit) slc.set_origin("native") if args["pbox"] is not None: slc.set_width(pbox_width) slc.set_center(pbox_center) if args["fbounds"] is not None: slc.set_zlim(vis_field, args["fbounds"][0], args["fbounds"][1]) if not args["no_time"]: slc.annotate_timestamp(draw_inset_box=True) if args["grids"]: if len(args["grids"]) > 0: slc.annotate_grids( alpha=args["grids"][0], min_level=args["grids"][1], max_level=args["grids"][2], linewidth=args["grids"][3], ) else: slc.annotate_grids() # annotate the cell edges of the mesh if args["cells"]: slc.annotate_cell_edges( line_width=float(args["cells"][0]), alpha=float(args["cells"][1]), color=args["cells"][2], ) slc.set_log(vis_field, args["plot_log"]) slc.set_cmap(field=vis_field, cmap=args["cmap"]) # Set the colorbar label for gradient fields (too long) if args["gradient"]: if args["gradient"] == "magnitude": new_label = rf"""|$\nabla$ {args["field"]}|""" else: new_label = rf"""$\nabla_{args["gradient"]}$ {args["field"]}""" slc.set_colorbar_label(field=vis_field, label=new_label) # Remove the units if args["no_units"]: norm_dict = {"x": ["y", "z"], "y": ["x", "z"], "z": ["x", "y"]} slc.set_colorbar_label( field=vis_field, label=( configs["vis_field_attrs"][vis_field]["label"] if vis_field in configs["vis_field_attrs"] else vis_field ), ) # if not configs["cbar_attrs"]["label"]["loc"] == "right": # slc.set_colorbar_label(field=vis_field, label="") slc.set_xlabel(f"""${norm_dict[args["normal"]][0]}$""") slc.set_ylabel(f"""${norm_dict[args["normal"]][1]}$""") # Override the colorbar label if vis_field in configs["vis_field_attrs"] and not args["no_units"]: slc.set_colorbar_label( field=vis_field, label=configs["vis_field_attrs"][vis_field]["label"], ) # Remove the colorbar label if plotting on top if configs["cbar_attrs"]["label"]["loc"] == "top": slc.set_colorbar_label(field=vis_field, label="") slc.set_font_size(configs["plot_attrs"]["base"]["fontsize"]) # Convert the slice to matplotlib figure fig = slc.export_to_mpl_figure( nrows_ncols=(1, 1), cbar_pad=configs["cbar_attrs"]["base"]["pad"], cbar_location=configs["cbar_attrs"]["base"]["loc"], ) # Get the axes from the figure handle ax = fig.axes[0] axlabel_size = configs["plot_attrs"]["axes"]["labelsize"] ax.tick_params(axis="x", labelsize=axlabel_size) ax.tick_params(axis="y", labelsize=axlabel_size) axc = fig.axes[1] axc.tick_params(axis="x", labelsize=axlabel_size) axc.tick_params(axis="y", labelsize=axlabel_size) if configs["cbar_attrs"]["label"]["loc"] == "top": axc.set_title( ( configs["vis_field_attrs"][vis_field]["label"] if vis_field in configs["vis_field_attrs"] else vis_field ), fontsize=configs["cbar_attrs"]["base"]["fontsize"], pad=configs["cbar_attrs"]["title"]["pad"], ) axc.set_xlabel("") if args["contour"] is not None: xres, yres, zres = np.array(ds_attributes["resolution"]) lx, ly, lz = np.array(ds_attributes["left_edge"]) rx, ry, rz = np.array(ds_attributes["right_edge"]) dx = (rx - lx) / xres dy = (ry - ly) / yres dz = (rz - lz) / zres # contour must be a multiple of three arguments if not len(args["contour"]) % 3 == 0: sys.exit( "Contour argument must be a multiple of 3! [FIELD, VALUE, COLOR]" ) else: num_contours = len(args["contour"]) // 3 # Compute and plot the contours for icnt in range(num_contours): if args["clw"] is None: linewidth = 1.0 else: linewidth = args["clw"][icnt] idx = icnt * 3 contour = find_contours( image=slc.frb[args["contour"][idx]], level=args["contour"][idx + 1] ) if args["normal"] == "x": plot_contours( contour=contour, ax=ax, left_edge=[ly, lz], dxy=[dy, dz], color=args["contour"][idx + 2], linewidth=linewidth, ) elif args["normal"] == "y": plot_contours( contour=contour, ax=ax, left_edge=[lz, lx], dxy=[dz, dx], color=args["contour"][idx + 2], linewidth=linewidth, ) elif args["normal"] == "z": plot_contours( contour=contour, ax=ax, left_edge=[lx, ly], dxy=[dx, dy], color=args["contour"][idx + 2], linewidth=linewidth, ) else: sys.exit(f"""Normal {args["normal"]} is not in [x, y, z]!""") plt_fname = f"""{vis_field}_{args["normal"]}_{str(index).zfill(5)}""" # Add grid information to the slice plot if args["grid_info"]: dx0, dy0, dz0 = np.array(ds_attributes["dxyz"]) level_data = ds.index.level_stats[0 : ds.index.max_level + 1] total_cells = 0 cell_vol_percents = np.zeros(np.shape(level_data)) for ilev, lev in enumerate(level_data): dx = dx0 / (2**ilev) dy = dy0 / (2**ilev) dz = dz0 / (2**ilev) total_cells += lev[1] cell_vol_percents[ilev] = np.minimum( lev[1] * dx * dy * dz / domain_volume * 100, 100 ) # Define text with grid info text_string = "" for ilev in np.arange(args["grid_info"][2], args["grid_info"][3] + 1): text_string += ( f"Level {int(ilev)} vol: {cell_vol_percents[int(ilev)]:.1f}%\n" ) text_string += f"{total_cells * 3 / 1e6:.0f}M DOF" # Add text ax.text( x=args["grid_info"][0], y=args["grid_info"][1], s=text_string, color="white", bbox=dict(facecolor="black", edgecolor="white", boxstyle="round"), ) # Remove the EB boundary defined by vfrac < 0.5 if args["rm_eb"]: if args["pbox"]: extent = [ args["pbox"][0], args["pbox"][2], args["pbox"][1], args["pbox"][3], ] else: lx, ly, lz = np.array(ds_attributes["left_edge"]) rx, ry, rz = np.array(ds_attributes["right_edge"]) if args["normal"] == "x": extent = [ly, ry, lz, rz] elif args["normal"] == "y": extent = [lz, rz, lx, rx] elif args["normal"] == "z": extent = [lx, rx, ly, ry] # TODO: move the default rm_eb function into utils of top of script if "rm_eb" in udf_funcs: ax = udf_funcs["rm_eb"](ax) else: vfrac = slc.frb[("boxlib", eb_var_name)].to_ndarray() m_vfrac = np.ma.array( args["rm_eb"] * np.ones(np.shape(vfrac)), mask=(vfrac > 0.5), fill_value=np.nan, ) ax.imshow( m_vfrac, origin="lower", extent=extent, aspect="equal", cmap="binary", vmin=0.0, vmax=1.0, ) fig.tight_layout() fig.savefig( os.path.join(imgpath, f"{plt_fname}.png"), dpi=args["dpi"], ) # Dump the figure handle as pickle for later modifications if args["pickle"]: with open(os.path.join(imgpath, f"{plt_fname}.pickle"), "wb") as f: pl.dump(fig, f)
if __name__ == "__main__": main()