summary refs log tree commit diff stats
path: root/ganarchy/git.py
diff options
context:
space:
mode:
Diffstat (limited to 'ganarchy/git.py')
-rw-r--r--ganarchy/git.py72
1 files changed, 53 insertions, 19 deletions
diff --git a/ganarchy/git.py b/ganarchy/git.py
index a658022..f8ccfcd 100644
--- a/ganarchy/git.py
+++ b/ganarchy/git.py
@@ -23,7 +23,7 @@
 
 import subprocess
 
-class GitError(LookupError):
+class GitError(Exception):
     """Raised when a git operation fails, generally due to a
     missing commit or branch, or network connection issues.
     """
@@ -54,13 +54,39 @@ class Git:
             GitError: If an error occurs.
         """
         try:
-            subprocess.check_call(
+            subprocess.run(
                 self.base + ("merge-base", "--is-ancestor", commit, local_head),
-                stdout=subprocess.DEVNULL,
-                stderr=subprocess.DEVNULL
+                check=True,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE
             )
         except subprocess.CalledProcessError as e:
-            raise GitError from e
+            raise GitError("check history") from e
+
+    def check_branchname(self, branchname):
+        """Checks if the given branchname is a valid branch name.
+        Raises if it isn't.
+
+        Args:
+            branchname (str): Name of branch.
+
+        Raises:
+            GitError: If an error occurs.
+        """
+        try:
+            # TODO check that this rstrip is safe
+            out = subprocess.run(
+                self.base + ("check-ref-format", "--branch", branchname),
+                check=True,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE
+            ).stdout.decode("utf-8").rstrip('\r\n')
+            # protect against @{-1}/@{-n} ("previous checkout operation")
+            # is also fairly future-proofed, I hope?
+            if out != branchname:
+                raise GitError("check branchname", out, branchname)
+        except subprocess.CalledProcessError as e:
+            raise GitError("check branchname") from e
 
     def force_fetch(self, url, remote_head, local_head):
         """Fetches a remote head into a local head.
@@ -76,12 +102,14 @@ class Git:
             GitError: If an error occurs.
         """
         try:
-            subprocess.check_output(
+            subprocess.run(
                 self.base + ("fetch", "-q", url, "+" + remote_head + ":" + local_head),
-                stderr=subprocess.STDOUT
+                check=True,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE
             )
         except subprocess.CalledProcessError as e:
-            raise GitError from e
+            raise GitError(e.output) from e
 
     def get_count(self, first_hash, last_hash):
         """Returns a count of the commits added since ``first_hash``
@@ -96,10 +124,12 @@ class Git:
             if an error occurs.
         """
         try:
-            res = subprocess.check_output(
+            res = subprocess.run(
                 self.base + ("rev-list", "--count", first_hash + ".." + last_hash, "--"),
-                stderr=subprocess.DEVNULL
-            ).decode("utf-8").strip()
+                check=True,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE
+            ).stdout.decode("utf-8").strip()
             return int(res)
         except subprocess.CalledProcessError as e:
             return 0
@@ -114,12 +144,14 @@ class Git:
             GitError: If an error occurs.
         """
         try:
-            return subprocess.check_output(
+            return subprocess.run(
                 self.base + ("show", target, "-s", "--format=format:%H", "--"),
-                stderr=subprocess.DEVNULL
-            ).decode("utf-8")
+                check=True,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE
+            ).stdout.decode("utf-8")
         except subprocess.CalledProcessError as e:
-            raise GitError from e
+            raise GitError("") from e
 
     def get_commit_message(self, target):
         """Returns the commit message for a given target.
@@ -131,9 +163,11 @@ class Git:
             GitError: If an error occurs.
         """
         try:
-            return subprocess.check_output(
+            return subprocess.run(
                 self.base + ("show", target, "-s", "--format=format:%B", "--"),
-                stderr=subprocess.DEVNULL
-            ).decode("utf-8", "replace")
+                check=True,
+                stdout=subprocess.PIPE,
+                stderr=subprocess.PIPE
+            ).stdout.decode("utf-8", "replace")
         except subprocess.CalledProcessError as e:
-            raise GitError from e
+            raise GitError("") from e