Skip to content
Snippets Groups Projects
Commit 5ceb285d authored by Loic Huder's avatar Loic Huder
Browse files

Add support of Path filepaths in `convert_graph`

parent 815c51ec
No related branches found
No related tags found
1 merge request!234Add support of Path filepaths in `convert_graph`
Pipeline #193687 passed
......@@ -25,16 +25,16 @@ GraphRepresentation = enum.Enum(
def dump(
graph: networkx.DiGraph,
destination=None,
destination: Optional[Union[str, Path]] = None,
representation: Optional[Union[GraphRepresentation, str]] = None,
**kw,
) -> Union[str, dict]:
) -> Union[str, Path, dict]:
"""From runtime to persistent representation"""
if isinstance(representation, str):
representation = GraphRepresentation.__members__[representation]
if representation is None:
if isinstance(destination, str):
filename = destination.lower()
if isinstance(destination, (str, Path)):
filename = str(destination).lower()
if filename.endswith(".json"):
representation = GraphRepresentation.json
elif filename.endswith((".yml", ".yaml")):
......@@ -46,6 +46,8 @@ def dump(
return _networkx_to_dict(graph)
if representation == GraphRepresentation.json:
if destination is None:
raise TypeError("Destination should be specified when dumping to json")
dictrepr = dump(graph)
makedirs_from_filename(destination)
kw.setdefault("indent", 2)
......@@ -58,6 +60,8 @@ def dump(
return json.dumps(dictrepr, **kw)
if representation == GraphRepresentation.yaml:
if destination is None:
raise TypeError("Destination should be specified when dumping to yaml")
dictrepr = dump(graph)
makedirs_from_filename(destination)
with open(destination, mode="w") as f:
......@@ -65,7 +69,9 @@ def dump(
return destination
if representation == GraphRepresentation.json_module:
package, _, file = destination.rpartition(".")
if destination is None:
raise TypeError("Destination should be specified when dumping to json")
package, _, file = str(destination).rpartition(".")
assert package, f"No package provided when saving graph to '{destination}'"
destination = os.path.join(_package_path(package), f"{file}.json")
return dump(graph, destination=destination, representation="json", **kw)
......
from pathlib import Path
from typing import Hashable, Optional, Union
import networkx
from ewoksutils.import_utils import qualname
......@@ -107,10 +108,10 @@ class TaskGraph:
def dump(
self,
destination=None,
destination: Optional[Union[str, Path]] = None,
representation: Optional[Union[serialize.GraphRepresentation, str]] = None,
**kw,
) -> Optional[Union[str, dict]]:
) -> Optional[Union[str, Path, dict]]:
return serialize.dump(
self.graph, destination=destination, representation=representation, **kw
)
......
from typing import Iterable, Optional, Tuple
from pathlib import Path
import pytest
from ewokscore import execute_graph
......@@ -68,13 +68,14 @@ def test_start_nodes():
@pytest.mark.parametrize(
"representation", (None, "json", "json_dict", "json_string", "yaml")
)
def test_serialize_graph(graph_name, representation, tmpdir):
@pytest.mark.parametrize("path_format", (str, Path))
def test_serialize_graph(graph_name, representation, path_format, tmpdir):
graph, _ = get_graph(graph_name)
ewoksgraph = load_graph(graph)
if representation == "yaml":
destination = str(tmpdir / "file.yml")
destination = path_format(tmpdir / "file.yml")
elif representation == "json":
destination = str(tmpdir / "file.json")
destination = path_format(tmpdir / "file.json")
else:
destination = None
inmemorydump = ewoksgraph.dump(destination, representation=representation)
......@@ -89,17 +90,18 @@ def test_serialize_graph(graph_name, representation, tmpdir):
@pytest.mark.parametrize("graph_name", graph_names())
def test_convert_graph(graph_name, tmpdir):
@pytest.mark.parametrize("path_format", (str, Path))
def test_convert_graph(graph_name, path_format, tmpdir):
graph, _ = get_graph(graph_name)
ewoksgraph = load_graph(graph)
assert_convert_graph(convert_graph, ewoksgraph, tmpdir)
assert_convert_graph(convert_graph, ewoksgraph, path_format, tmpdir)
def assert_convert_graph(
convert_graph,
ewoksgraph,
path_format,
tmpdir,
representations: Optional[Iterable[Tuple[dict, dict, Optional[str]]]] = None,
):
"""All graph `representations` need to be known by `convert_graph`. It will always
test the basic representations (e.g. json and yaml) in addition to the provided
......@@ -115,15 +117,13 @@ def assert_convert_graph(
(dict(), {"representation": "json_dict"}, None),
(dict(), {"representation": "json_string"}, None),
]
if representations:
conversion_chain.extend(representations)
conversion_chain.append(non_serialized_representation)
source = ewoksgraph
for convert_from, convert_to in zip(conversion_chain[:-1], conversion_chain[1:]):
load_options, _, _ = convert_from
_, save_options, fileext = convert_to
if fileext:
destination = str(tmpdir / f"file.{fileext}")
destination = path_format(tmpdir / f"file.{fileext}")
else:
destination = None
result = convert_graph(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment