diff --git a/0mq/comm_node.dir/concore2.py b/0mq/comm_node.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/comm_node.dir/concore2.py +++ b/0mq/comm_node.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funbody.dir/concore2.py b/0mq/funbody.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funbody.dir/concore2.py +++ b/0mq/funbody.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funbody2.dir/concore2.py b/0mq/funbody2.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funbody2.dir/concore2.py +++ b/0mq/funbody2.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funbody_distributed.dir/concore2.py b/0mq/funbody_distributed.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funbody_distributed.dir/concore2.py +++ b/0mq/funbody_distributed.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funbody_zmq.dir/concore2.py b/0mq/funbody_zmq.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funbody_zmq.dir/concore2.py +++ b/0mq/funbody_zmq.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funbody_zmq2.dir/concore2.py b/0mq/funbody_zmq2.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funbody_zmq2.dir/concore2.py +++ b/0mq/funbody_zmq2.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funcall.dir/concore2.py b/0mq/funcall.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funcall.dir/concore2.py +++ b/0mq/funcall.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funcall2.dir/concore2.py b/0mq/funcall2.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funcall2.dir/concore2.py +++ b/0mq/funcall2.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funcall_distributed.dir/concore2.py b/0mq/funcall_distributed.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funcall_distributed.dir/concore2.py +++ b/0mq/funcall_distributed.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funcall_zmq.dir/concore2.py b/0mq/funcall_zmq.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funcall_zmq.dir/concore2.py +++ b/0mq/funcall_zmq.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/0mq/funcall_zmq2.dir/concore2.py b/0mq/funcall_zmq2.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/0mq/funcall_zmq2.dir/concore2.py +++ b/0mq/funcall_zmq2.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/TestConcoredockerApi.java b/TestConcoredockerApi.java index 90472430..d62c8de4 100644 --- a/TestConcoredockerApi.java +++ b/TestConcoredockerApi.java @@ -27,6 +27,8 @@ public static void main(String[] args) { testReadFileNotFound(); testReadRetriesExceeded(); testReadParseError(); + testReadTraversalBlocked(); + testWriteTraversalBlocked(); System.out.println("\n=== Results: " + passed + " passed, " + failed + " failed out of " + (passed + failed) + " tests ==="); if (failed > 0) { @@ -230,4 +232,23 @@ static void testReadParseError() { check("read parse error: status", concoredocker.ReadStatus.PARSE_ERROR, result.status); check("read parse error: data is default", 1, result.data.size()); } + + static void testReadTraversalBlocked() { + Path tmp = makeTempDir(); + concoredocker.resetState(); + concoredocker.setInPath(tmp.toString()); + + concoredocker.ReadResult result = concoredocker.read(1, "../escape", "[0.0, 7.0]"); + check("read traversal blocked: status", concoredocker.ReadStatus.PARSE_ERROR, result.status); + check("read traversal blocked: returns default", 1, result.data.size()); + } + + static void testWriteTraversalBlocked() { + Path tmp = makeTempDir(1); + concoredocker.resetState(); + concoredocker.setOutPath(tmp.toString()); + + concoredocker.write(1, "../escape", Collections.singletonList((Object) 1.0), 0); + check("write traversal blocked: no escaped file", false, Files.exists(tmp.resolve("escape"))); + } } diff --git a/concore.java b/concore.java index 3cb7d021..52b96f41 100644 --- a/concore.java +++ b/concore.java @@ -226,6 +226,15 @@ private static String portPath(String base, int portNum) { return base + portNum; } + private static Path resolvePortFilePath(String base, int portNum, String name) throws IOException { + Path portDir = Paths.get(portPath(base, portNum)).toAbsolutePath().normalize(); + Path filePath = portDir.resolve(name).normalize(); + if (!filePath.startsWith(portDir)) { + throw new IOException("Invalid file name '" + name + "' for port " + portNum); + } + return filePath; + } + // package-level helpers for testing with temp directories static void setInPath(String path) { inpath = path; } static void setOutPath(String path) { outpath = path; } @@ -268,7 +277,14 @@ public static ReadResult read(int port, String name, String initstr) { // initstr not parseable as list; defaultVal stays empty } - String filePath = Paths.get(portPath(inpath, port), name).toString(); + Path filePathObj; + try { + filePathObj = resolvePortFilePath(inpath, port, name); + } catch (IOException | RuntimeException e) { + System.out.println("Invalid path for port " + port + " and name '" + name + "': " + e.getMessage()); + return new ReadResult(ReadStatus.PARSE_ERROR, defaultVal); + } + String filePath = filePathObj.toString(); try { Thread.sleep(delay); } catch (InterruptedException e) { @@ -279,7 +295,7 @@ public static ReadResult read(int port, String name, String initstr) { String ins; try { - ins = new String(Files.readAllBytes(Paths.get(filePath))); + ins = new String(Files.readAllBytes(filePathObj)); } catch (IOException e) { System.out.println("File " + filePath + " not found, using default value."); s += initstr; @@ -296,7 +312,7 @@ public static ReadResult read(int port, String name, String initstr) { return new ReadResult(ReadStatus.TIMEOUT, defaultVal); } try { - ins = new String(Files.readAllBytes(Paths.get(filePath))); + ins = new String(Files.readAllBytes(filePathObj)); } catch (IOException e) { System.out.println("Retry " + (attempts + 1) + ": Error reading " + filePath); } @@ -438,7 +454,8 @@ private static String toJsonLiteral(Object obj) { */ public static void write(int port, String name, Object val, int delta) { try { - String path = Paths.get(portPath(outpath, port), name).toString(); + Path pathObj = resolvePortFilePath(outpath, port, name); + String path = pathObj.toString(); StringBuilder content = new StringBuilder(); if (val instanceof String) { Thread.sleep(2 * delay); @@ -470,7 +487,7 @@ public static void write(int port, String name, Object val, int delta) { System.out.println("write must have list or str"); return; } - Files.write(Paths.get(path), content.toString().getBytes()); + Files.write(pathObj, content.toString().getBytes()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); System.out.println("skipping " + outpath + "/" + port + "/" + name); diff --git a/concore_base.py b/concore_base.py index 9173289b..5bac48bd 100644 --- a/concore_base.py +++ b/concore_base.py @@ -194,6 +194,15 @@ def parse_params(sparams): params[key] = value return params + +def _resolve_port_file_path(mod, base_path, port_num, name): + """Resolve a file path under a port directory and block path traversal.""" + port_dir = os.path.abspath(mod._port_path(base_path, port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def load_params(params_file): try: if os.path.exists(params_file): @@ -293,14 +302,16 @@ def read(mod, port_identifier, name, initstr_val): # Case 2: File-based port try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(mod, mod.inpath, file_port_num, name) except ValueError: - logger.error(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + logger.error( + f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' " + f"for file operation. Must stay inside the port directory." + ) last_read_status = "PARSE_ERROR" return default_return_val, False time.sleep(mod.delay) - port_dir = mod._port_path(mod.inpath, file_port_num) - file_path = os.path.join(port_dir, name) ins = "" file_not_found = False @@ -394,10 +405,12 @@ def write(mod, port_identifier, name, val, delta=0): # Case 2: File-based port try: file_port_num = int(port_identifier) - port_dir = mod._port_path(mod.outpath, file_port_num) - file_path = os.path.join(port_dir, name) + file_path = _resolve_port_file_path(mod, mod.outpath, file_port_num, name) except ValueError: - logger.error(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + logger.error( + f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' " + f"for file operation. Must stay inside the port directory." + ) return # File writing rules diff --git a/concoredocker.java b/concoredocker.java index d6c7c21f..65c61701 100644 --- a/concoredocker.java +++ b/concoredocker.java @@ -1,5 +1,6 @@ import java.io.IOException; import java.nio.file.Files; +import java.nio.file.Path; import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; @@ -137,6 +138,15 @@ public static void defaultMaxTime(double defaultValue) { static double getSimtime() { return simtime; } static void resetState() { s = ""; olds = ""; simtime = 0; } + private static Path resolvePortFilePath(String base, int portNum, String name) throws IOException { + Path portDir = Paths.get(base, String.valueOf(portNum)).toAbsolutePath().normalize(); + Path filePath = portDir.resolve(name).normalize(); + if (!filePath.startsWith(portDir)) { + throw new IOException("Invalid file name '" + name + "' for port " + portNum); + } + return filePath; + } + public static boolean unchanged() { if (olds.equals(s)) { s = ""; @@ -172,7 +182,14 @@ public static ReadResult read(int port, String name, String initstr) { // initstr not parseable as list; defaultVal stays empty } - String filePath = inpath + "/" + port + "/" + name; + Path filePathObj; + try { + filePathObj = resolvePortFilePath(inpath, port, name); + } catch (IOException | RuntimeException e) { + System.out.println("Invalid path for port " + port + " and name '" + name + "': " + e.getMessage()); + return new ReadResult(ReadStatus.PARSE_ERROR, defaultVal); + } + String filePath = filePathObj.toString(); try { Thread.sleep(delay); } catch (InterruptedException e) { @@ -183,7 +200,7 @@ public static ReadResult read(int port, String name, String initstr) { String ins; try { - ins = new String(Files.readAllBytes(Paths.get(filePath))); + ins = new String(Files.readAllBytes(filePathObj)); } catch (IOException e) { System.out.println("File " + filePath + " not found, using default value."); s += initstr; @@ -200,7 +217,7 @@ public static ReadResult read(int port, String name, String initstr) { return new ReadResult(ReadStatus.TIMEOUT, defaultVal); } try { - ins = new String(Files.readAllBytes(Paths.get(filePath))); + ins = new String(Files.readAllBytes(filePathObj)); } catch (IOException e) { System.out.println("Retry " + (attempts + 1) + ": Error reading " + filePath); } @@ -342,7 +359,8 @@ private static String toJsonLiteral(Object obj) { */ public static void write(int port, String name, Object val, int delta) { try { - String path = outpath + "/" + port + "/" + name; + Path pathObj = resolvePortFilePath(outpath, port, name); + String path = pathObj.toString(); StringBuilder content = new StringBuilder(); if (val instanceof String) { Thread.sleep(2 * delay); @@ -374,7 +392,7 @@ public static void write(int port, String name, Object val, int delta) { System.out.println("write must have list or str"); return; } - Files.write(Paths.get(path), content.toString().getBytes()); + Files.write(pathObj, content.toString().getBytes()); } catch (InterruptedException e) { Thread.currentThread().interrupt(); System.out.println("skipping " + outpath + "/" + port + "/" + name); diff --git a/measurements/comm_node_test.dir/concore2.py b/measurements/comm_node_test.dir/concore2.py index a018ddf6..f3baf181 100644 --- a/measurements/comm_node_test.dir/concore2.py +++ b/measurements/comm_node_test.dir/concore2.py @@ -130,6 +130,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -154,12 +164,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -220,9 +230,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/measurements/throughput_test/concore.py b/measurements/throughput_test/concore.py index ce3ab0a9..3c3bf89f 100644 --- a/measurements/throughput_test/concore.py +++ b/measurements/throughput_test/concore.py @@ -156,6 +156,16 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError(f"Invalid file name '{name}' for port {port_num}") + return file_path + def read(port_identifier, name, initstr_val): global s, simtime, retrycount @@ -180,12 +190,12 @@ def read(port_identifier, name, initstr_val): try: file_port_num = int(port_identifier) + file_path = _resolve_port_file_path(inpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return default_return_val time.sleep(delay) - file_path = os.path.join(inpath+str(file_port_num), name) ins = "" try: @@ -246,9 +256,9 @@ def write(port_identifier, name, val, delta=0): file_path = os.path.join("../"+port_identifier, name) else: file_port_num = int(port_identifier) - file_path = os.path.join(outpath+str(file_port_num), name) + file_path = _resolve_port_file_path(outpath, file_port_num, name) except ValueError: - print(f"Error: Invalid port identifier '{port_identifier}' for file operation. Must be integer or ZMQ name.") + print(f"Error: Invalid port identifier '{port_identifier}' or file name '{name}' for file operation.") return if isinstance(val, str): diff --git a/nintan/powermetermax.dir/concore2.py b/nintan/powermetermax.dir/concore2.py index 6471e0b7..9834ede5 100644 --- a/nintan/powermetermax.dir/concore2.py +++ b/nintan/powermetermax.dir/concore2.py @@ -64,11 +64,22 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError("bad file name") + return file_path + def read(port, name, initstr): global s,simtime,retrycount time.sleep(delay) try: - infile = open(inpath+str(port)+"/"+name); + file_path = _resolve_port_file_path(inpath, port, name) + infile = open(file_path); ins = infile.read() except: ins = initstr @@ -89,7 +100,8 @@ def write(port, name, val, delta=0): print("mywrite must have list or str") quit() try: - with open(outpath+str(port)+"/"+name,"w") as outfile: + file_path = _resolve_port_file_path(outpath, port, name) + with open(file_path,"w") as outfile: if isinstance(val,list): outfile.write(str([simtime+delta]+val)) simtime += delta diff --git a/ratc/learn3.dir/concore2.py b/ratc/learn3.dir/concore2.py index 6471e0b7..9834ede5 100644 --- a/ratc/learn3.dir/concore2.py +++ b/ratc/learn3.dir/concore2.py @@ -64,11 +64,22 @@ def unchanged(): olds = s return False +def _resolve_port_file_path(base_path, port_num, name): + port_dir = os.path.abspath(base_path + str(port_num)) + file_path = os.path.abspath(os.path.join(port_dir, name)) + try: + if os.path.commonpath([port_dir, file_path]) != port_dir: + raise ValueError + except ValueError: + raise ValueError("bad file name") + return file_path + def read(port, name, initstr): global s,simtime,retrycount time.sleep(delay) try: - infile = open(inpath+str(port)+"/"+name); + file_path = _resolve_port_file_path(inpath, port, name) + infile = open(file_path); ins = infile.read() except: ins = initstr @@ -89,7 +100,8 @@ def write(port, name, val, delta=0): print("mywrite must have list or str") quit() try: - with open(outpath+str(port)+"/"+name,"w") as outfile: + file_path = _resolve_port_file_path(outpath, port, name) + with open(file_path,"w") as outfile: if isinstance(val,list): outfile.write(str([simtime+delta]+val)) simtime += delta diff --git a/tests/test_concore.py b/tests/test_concore.py index dc98ced5..1ac684ef 100644 --- a/tests/test_concore.py +++ b/tests/test_concore.py @@ -582,6 +582,21 @@ def send_json_with_retry(self, message): delattr(concore, "simtime") +class TestFilePathTraversalGuard: + """File write should not allow escaping the target port directory.""" + + def test_file_write_blocks_traversal_name(self, temp_dir): + import concore + + concore.simtime = 0 + concore.outpath = os.path.join(temp_dir, "out") + os.makedirs(os.path.join(temp_dir, "out1"), exist_ok=True) + + concore.write(1, "../escape.txt", [1.0], delta=0) + + assert not os.path.exists(os.path.join(temp_dir, "escape.txt")) + + class TestPidRegistry: """Tests for the Windows PID registry mechanism (Issue #391).""" diff --git a/tests/test_read_status.py b/tests/test_read_status.py index b9fe33d9..21bddd5d 100644 --- a/tests/test_read_status.py +++ b/tests/test_read_status.py @@ -109,6 +109,32 @@ def test_last_read_status_is_parse_error(self): assert self.concore.last_read_status == "PARSE_ERROR" +class TestReadFileTraversalBlocked: + """read() rejects traversal names and returns PARSE_ERROR.""" + + @pytest.fixture(autouse=True) + def setup(self, temp_dir, monkeypatch): + import concore + + self.concore = concore + monkeypatch.setattr(concore, "delay", 0) + monkeypatch.setattr(concore, "inpath", os.path.join(temp_dir, "in")) + + in_dir = os.path.join(temp_dir, "in1") + os.makedirs(in_dir, exist_ok=True) + with open(os.path.join(in_dir, "ym"), "w") as f: + f.write("[10, 3.14]") + + def test_returns_default_and_false_on_traversal_name(self): + data, ok = self.concore.read(1, "../ym", "[0, 0.0]") + assert ok is False + assert data == [0, 0.0] + + def test_last_read_status_is_parse_error_for_traversal_name(self): + self.concore.read(1, "../ym", "[0, 0.0]") + assert self.concore.last_read_status == "PARSE_ERROR" + + class TestReadFileRetriesExceeded: """read() returns (default, False) with RETRIES_EXCEEDED when file is empty."""