0
0
Fork 0
mirror of https://github.com/ohmyzsh/ohmyzsh.git synced 2024-09-19 04:01:21 +02:00

fix(dependencies): improve typing

This commit is contained in:
Carlo Sala 2024-05-09 18:27:01 +02:00
parent 13c8a10e39
commit a258eb4547

View file

@ -4,7 +4,7 @@ import subprocess
import sys
import timeit
from copy import deepcopy
from typing import Optional, TypedDict
from typing import Literal, NotRequired, TypedDict
import requests
import yaml
@ -49,20 +49,24 @@ class DependencyDict(TypedDict):
repo: str
branch: str
version: str
precopy: Optional[str]
postcopy: Optional[str]
precopy: NotRequired[str]
postcopy: NotRequired[str]
class DependencyYAML(TypedDict):
dependencies: dict[str, DependencyDict]
class UpdateStatus(TypedDict):
has_updates: bool
version: Optional[str]
compare_url: Optional[str]
head_ref: Optional[str]
head_url: Optional[str]
class UpdateStatusFalse(TypedDict):
has_updates: Literal[False]
class UpdateStatusTrue(TypedDict):
has_updates: Literal[True]
version: str
compare_url: str
head_ref: str
head_url: str
class CommandRunner:
@ -105,7 +109,9 @@ class DependencyStore:
with CodeTimer(f"store deepcopy: {path}"):
store_copy = deepcopy(DependencyStore.store)
dependency = store_copy["dependencies"].get(path, {})
dependency = store_copy["dependencies"].get(path)
if dependency is None:
raise ValueError(f"Dependency {path} {version} not found")
dependency["version"] = version
store_copy["dependencies"][path] = dependency
@ -171,7 +177,7 @@ class Dependency:
else:
status = GitHub.check_updates(repo, remote_branch, version)
if status["has_updates"]:
if status["has_updates"] is True:
short_sha = status["head_ref"][:8]
new_version = status["version"] if is_tag else short_sha
@ -212,10 +218,10 @@ Check out the [list of changes]({status['compare_url']}).
case CommandRunner.Exception:
# Print error message
print(
f"Error running {e.stage} command: {e.returncode}",
f"Error running {e.stage} command: {e.returncode}", # pyright: ignore[reportAttributeAccessIssue]
file=sys.stderr,
)
print(e.stderr, file=sys.stderr)
print(e.stderr, file=sys.stderr) # pyright: ignore[reportAttributeAccessIssue]
case shutil.Error:
print(f"Error copying files: {e}", file=sys.stderr)
@ -378,7 +384,7 @@ class Git:
class GitHub:
@staticmethod
def check_newer_tag(repo, current_tag) -> UpdateStatus:
def check_newer_tag(repo, current_tag) -> UpdateStatusFalse | UpdateStatusTrue:
# GET /repos/:owner/:repo/git/refs/tags
url = f"https://api.github.com/repos/{repo}/git/refs/tags"
@ -417,7 +423,7 @@ class GitHub:
)
@staticmethod
def check_updates(repo, branch, version) -> UpdateStatus:
def check_updates(repo, branch, version) -> UpdateStatusFalse | UpdateStatusTrue:
# TODO: add support for semver updating (based on tags)
# Check if upstream github repo has a new version
# GitHub API URL for comparing two commits