"""Define functions and tools used for conversion.
"""
from __future__ import annotations
import typing
#: Associates a compression method with an extension
compression_exts = {"gzip": "gz", "zstd": "zst"}
#: Associates a format with the available compressions
format_compressions = {
# `bz2` and `brotli` are very slow in writting
"csv": ["gzip", "lz4", "zstd", "bz2", "brotli"],
"feather": ["uncompressed", "lz4", "zstd"],
"parquet": ["none", "snappy", "gzip", "zstd", "lz4"],
"root": ["zlib", "lzma", "lz4", "zstd"],
"rroot": ["zlib", "lzma", "lz4", "zstd"],
}
#: Associates a format with a compression that actually means that there is
#: no compression
format_no_compressions = {"feather": "uncompressed", "parquet": "none"}
#: Compression levels of a ROOT file according to the chosen compression algorithm,
#: as recommended in https://root.cern/doc/master/Compression_8h_source.html
root_compression_levels = {"zlib": 1, "lzma": 8, "lz4": 4, "zstd": 5}
#: Possible formats
formats = list(format_compressions.keys())
[docs]
def get_io_function(
action: typing.Literal["r", "w", "read", "write"],
format: str,
) -> typing.Tuple[typing.Callable, dict]:
"""Get the function to read or write in a given format
Args:
action: What to return
* ``r`` or ``read`` for the reading function
* ``w`` or ``write`` for the writting function
format: ``csv``, ``feather`` or ``parquet``
Returns:
The function to read or write the given format using ``pyarrow``.
For the writting function, the compression can be provided using the
keyword argument ``compression`` of the function that is returned.
The table to write is the first argument and the path is the second.
For the reading function, the compression is figured out
from the extension of the file
"""
import pyarrow as pa
assert format in formats
assert action in ["r", "w", "read", "write"]
read = action in ["r", "read"]
if format == "csv":
import pyarrow as pa
import pyarrow.csv as pac
if read:
return pac.read_csv
else:
def write_csv(table: pa.Table, path: str, compression: str):
"""Write a CSV file for a given compression.
Args:
table: the pyarrow to write
path: the path where to write
compression: compression to use
"""
if compression is not None:
with pa.CompressedOutputStream(path, compression) as out:
pac.write_csv(table, out)
else:
pac.write_csv(table, path)
return write_csv
elif format == "feather":
import pyarrow.feather as pf
return pf.read_feather if read else pf.write_feather
elif format == "parquet":
import pyarrow.parquet as pq
return pq.read_table if read else pq.write_table
elif format == "root":
import uproot
if read:
def read_root(
path: str,
columns: typing.Optional[str | typing.List[str]] = None,
**kwargs,
) -> pa.Table:
rfile = uproot.open(
path,
# decompression_executor=uproot.ThreadPoolExecutor(),
interpretation_executor=uproot.ThreadPoolExecutor(),
**kwargs,
)
dataframe = rfile["tree"].arrays(library="np", expressions=columns)
return pa.Table.from_arrays(
list(dataframe.values()), names=list(dataframe.keys())
)
return read_root
else:
def write_root(
table: pa.Table, path: str, compression: typing.Optional[str], **kwargs
):
if compression is not None:
compression = getattr(uproot.compression, compression.upper())(
root_compression_levels.get(compression, 9)
)
else:
compression = uproot.compression.ZLIB(0)
with uproot.recreate(path, compression=compression, **kwargs) as rfile:
dataframe = {}
for column_name, column in zip(table.column_names, table.columns):
dataframe[column_name] = column
rfile["tree"] = dataframe
return write_root
elif format == "rroot":
import ROOT
ROOT.EnableImplicitMT()
if read:
def read_root(
path: str,
columns: typing.Optional[str | typing.List[str]] = None,
**kwargs,
) -> pa.Table:
rdataframe = ROOT.RDataFrame("tree", path, **kwargs)
dataframe = rdataframe.AsNumpy(columns=columns)
return pa.Table.from_arrays(
list(dataframe.values()), names=list(dataframe.keys())
)
return read_root
else:
def write_root(
table: pa.Table,
path: str,
compression: typing.Optional[str],
):
dataframe = {}
for column_name, column in zip(table.column_names, table.columns):
dataframe[column_name] = column.to_numpy()
rdataframe = ROOT.RDF.MakeNumpyDataFrame(dataframe)
rsnapshotoptions = ROOT.RDF.RSnapshotOptions()
if compression is None:
rsnapshotoptions.fCompressionLevel = 0
else:
rsnapshotoptions.fCompressionLevel = root_compression_levels.get(
compression, 9
)
rsnapshotoptions.fCompressionAlgorithm = getattr(
ROOT, "k" + compression.upper()
)
rdataframe.Snapshot(
"tree", path, "", rsnapshotoptions
)
return write_root
[docs]
def get_extension(format: str, compression: typing.Optional[str] = None) -> str:
"""Get compression from format and compression.
Args:
format: ``csv``, ``feather`` or ``parquet``
compression: compression algorithm to use. See :py:data:`format_compressions`
for the possible compression given a format.
Returns:
Extension of the file given its format and compression
"""
assert format in formats
assert (compression is None) or (compression in format_compressions[format])
if (compression is None) or (
(format in format_no_compressions)
and compression == format_no_compressions[format]
): # no compression
return "." + format
else:
compression_ext = compression_exts.get(compression, compression)
return "." + format + "." + compression_ext