summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--Lib/ntpath.py56
-rw-r--r--Lib/test/test_ntpath.py7
2 files changed, 47 insertions, 16 deletions
diff --git a/Lib/ntpath.py b/Lib/ntpath.py
index cf7c353..1be2961 100644
--- a/Lib/ntpath.py
+++ b/Lib/ntpath.py
@@ -42,22 +42,46 @@ def join(a, *p):
"""Join two or more pathname components, inserting "\\" as needed"""
path = a
for b in p:
- # If path starts with a raw drive letter (e.g. "C:"), and b doesn't
- # start with a drive letter, path+b is correct, and regardless of\
- # whether b is absolute on its own.
- if len(path) >= 2 and path[1] == ":" and splitdrive(b)[0] == "":
- if path[-1] in "/\\" and b[:1] in "/\\":
- b = b[1:]
-
- # In any other case, if b is absolute it wipes out the path so far.
- elif isabs(b) or path == "":
- path = ""
-
- # Else make sure a separator appears between the pieces.
- elif path[-1:] not in "/\\":
- b = "\\" + b
-
- path += b
+ b_wins = 0 # set to 1 iff b makes path irrelevant
+ if path == "":
+ b_wins = 1
+
+ elif isabs(b):
+ # This probably wipes out path so far. However, it's more
+ # complicated if path begins with a drive letter:
+ # 1. join('c:', '/a') == 'c:/a'
+ # 2. join('c:/', '/a') == 'c:/a'
+ # But
+ # 3. join('c:/a', '/b') == '/b'
+ # 4. join('c:', 'd:/') = 'd:/'
+ # 5. join('c:/', 'd:/') = 'd:/'
+ if path[1:2] != ":" or b[1:2] == ":":
+ # Path doesn't start with a drive letter, or cases 4 and 5.
+ b_wins = 1
+
+ # Else path has a drive letter, and b doesn't but is absolute.
+ elif len(path) > 3 or (len(path) == 3 and
+ path[-1] not in "/\\"):
+ # case 3
+ b_wins = 1
+
+ if b_wins:
+ path = b
+ else:
+ # Join, and ensure there's a separator.
+ assert len(path) > 0
+ if path[-1] in "/\\":
+ if b and b[0] in "/\\":
+ path += b[1:]
+ else:
+ path += b
+ elif path[-1] == ":":
+ path += b
+ elif b:
+ if b[0] in "/\\":
+ path += b
+ else:
+ path += "\\" + b
return path
diff --git a/Lib/test/test_ntpath.py b/Lib/test/test_ntpath.py
index 7386900..fe997b3 100644
--- a/Lib/test/test_ntpath.py
+++ b/Lib/test/test_ntpath.py
@@ -66,6 +66,13 @@ tester('ntpath.join("a\\", "b", "c")', 'a\\b\\c')
tester('ntpath.join("a", "b\\", "c")', 'a\\b\\c')
tester('ntpath.join("a", "b", "\\c")', '\\c')
tester('ntpath.join("d:\\", "\\pleep")', 'd:\\pleep')
+tester('ntpath.join("d:\\", "a", "b")', 'd:\\a\\b')
+tester("ntpath.join('c:', '/a')", 'c:/a')
+tester("ntpath.join('c:/', '/a')", 'c:/a')
+tester("ntpath.join('c:/a', '/b')", '/b')
+tester("ntpath.join('c:', 'd:/')", 'd:/')
+tester("ntpath.join('c:/', 'd:/')", 'd:/')
+tester("ntpath.join('c:/', 'd:/a/b')", 'd:/a/b')
if errors:
raise TestFailed(str(errors) + " errors.")