diff --git a/src/hla_algorithm/hla_algorithm.py b/src/hla_algorithm/hla_algorithm.py index 07240f8..9e81e02 100644 --- a/src/hla_algorithm/hla_algorithm.py +++ b/src/hla_algorithm/hla_algorithm.py @@ -80,8 +80,8 @@ def __init__( @classmethod def use_config( cls, - standards_path: Optional[str] = None, - frequencies_path: Optional[str] = None, + standards_path: str | Path | None = None, + frequencies_path: str | Path | None = None, ) -> "HLAAlgorithm": """ An alternate constructor that accepts file paths for the configuration. @@ -90,11 +90,11 @@ def use_config( frequencies: Optional[dict[HLA_LOCUS, dict[HLAProteinPair, int]]] = None if standards_path is not None: - with open(standards_path) as f: + with Path(standards_path).open() as f: processed_stds = cls.read_hla_standards(f) if frequencies_path is not None: - with open(frequencies_path) as f: + with Path(frequencies_path).open() as f: frequencies = cls.read_hla_frequencies(f) return cls(processed_stds, frequencies) @@ -138,9 +138,9 @@ def load_default_hla_standards() -> LoadedStandards: :return: List of known HLA standards :rtype: list[HLAStandard] """ - with open( + with ( HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_standards.yaml" - ) as standards_file: + ).open() as standards_file: return HLAAlgorithm.read_hla_standards(standards_file) FREQUENCY_LOCUS_COLUMNS: dict[HLA_LOCUS, tuple[str, str]] = { @@ -202,7 +202,7 @@ def load_default_hla_frequencies() -> dict[HLA_LOCUS, dict[HLAProteinPair, int]] :rtype: dict[HLA_LOCUS, dict[HLAProteinPair, int]] """ hla_freqs: dict[HLA_LOCUS, dict[HLAProteinPair, int]] - with open(HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_frequencies.csv") as f: + with (HLAAlgorithm.DEFAULT_CONFIG_DIR / "hla_frequencies.csv").open() as f: hla_freqs = HLAAlgorithm.read_hla_frequencies(f) return hla_freqs diff --git a/src/hla_algorithm/interpret_from_json.py b/src/hla_algorithm/interpret_from_json.py index 554c68f..bc91616 100644 --- a/src/hla_algorithm/interpret_from_json.py +++ b/src/hla_algorithm/interpret_from_json.py @@ -3,6 +3,8 @@ import argparse import json import logging +import sys +from pathlib import Path from .hla_algorithm import HLAAlgorithm from .interpret_from_json_lib import HLAInput, HLAResult @@ -17,14 +19,18 @@ def main(): ) parser.add_argument( "infile", - type=argparse.FileType("r"), + type=str, help='Input file containing the JSON input (use "-" to read from stdin)', ) args: argparse.Namespace = parser.parse_args() hla_input_str: str = "" - with args.infile: - for line in args.infile: + if args.infile == "-": + input_file = sys.stdin + else: + input_file = Path(args.infile).open() + with input_file: + for line in input_file: hla_input_str += f"{line}\n" hla_input: HLAInput = HLAInput(**json.loads(hla_input_str)) diff --git a/src/hla_algorithm/interpret_from_json_lib.py b/src/hla_algorithm/interpret_from_json_lib.py index f58d046..1be5341 100644 --- a/src/hla_algorithm/interpret_from_json_lib.py +++ b/src/hla_algorithm/interpret_from_json_lib.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Optional +from pathlib import Path from pydantic import BaseModel, Field @@ -24,11 +24,11 @@ class HLAInput(BaseModel): seq1: str - seq2: Optional[str] + seq2: str | None locus: HLA_LOCUS - threshold: Optional[int] = None - hla_std_path: Optional[str] = None - hla_freq_path: Optional[str] = None + threshold: int | None = None + hla_std_path: Path | None = None + hla_freq_path: Path | None = None def check_sequences(self) -> list[str]: errors: list[str] = [] @@ -113,7 +113,7 @@ class HLAResult(BaseModel): alleles_version: str = "" alleles_last_updated: datetime = Field(default_factory=datetime.now) b5701: bool = False - dist_b5701: Optional[int] = None + dist_b5701: int | None = None errors: list[str] = Field(default_factory=list) all_mismatches: dict[str, HLAMatchAdaptor] = Field(default_factory=dict) @@ -144,7 +144,9 @@ def build_from_interpretation( return HLAResult( seqs=seqs, - alleles_all=[f"{x[0]} - {x[1]}" for x in sort_allele_pairs(aps.allele_pairs)], + alleles_all=[ + f"{x[0]} - {x[1]}" for x in sort_allele_pairs(aps.allele_pairs) + ], alleles_clean=alleles_clean, alleles_for_mismatches=f"{rep_ap[0]} - {rep_ap[1]}", mismatches=[str(x) for x in match_details.mismatches], diff --git a/src/hla_algorithm/reformat_old_alleles.py b/src/hla_algorithm/reformat_old_alleles.py index 84d25e3..5557f13 100644 --- a/src/hla_algorithm/reformat_old_alleles.py +++ b/src/hla_algorithm/reformat_old_alleles.py @@ -4,13 +4,14 @@ import csv import logging from datetime import datetime +from pathlib import Path from typing import cast import yaml from .utils import ( - GroupedAllele, HLA_LOCUS, + GroupedAllele, HLARawStandard, StoredHLAStandards, group_identical_alleles, @@ -28,22 +29,22 @@ def main(): parser.add_argument( "a_standards", help="CSV file containing all HLA-A alleles", - type=str, + type=Path, ) parser.add_argument( "b_standards", help="CSV file containing all HLA-B alleles", - type=str, + type=Path, ) parser.add_argument( "c_standards", help="CSV file containing all HLA-C alleles", - type=str, + type=Path, ) parser.add_argument( "--output", help="filename to store the reformatted standards in YAML", - type=str, + type=Path, default="reformatted_hla_standards.yaml", ) parser.add_argument( @@ -84,7 +85,7 @@ def main(): grouped_alleles: dict[HLA_LOCUS, list[GroupedAllele]] = {"A": [], "B": [], "C": []} for locus in ("A", "B", "C"): logger.info(f"Grouping HLA-{locus} alleles....") - with open(input_filenames_by_locus[locus]) as f: + with input_filenames_by_locus[locus].open() as f: standards_csv: csv.DictReader = csv.DictReader( f, fieldnames=("allele", "exon2", "exon3"), @@ -114,7 +115,7 @@ def main(): ) logger.info(f"Writing HLA standards to {args.output}....") - with open(args.output, "w") as f: + with args.output.open("w") as f: yaml.safe_dump(standards_for_saving.model_dump(), f) logger.info("Done.") diff --git a/src/hla_algorithm/update_alleles.py b/src/hla_algorithm/update_alleles.py index d62e18d..bfd20b1 100644 --- a/src/hla_algorithm/update_alleles.py +++ b/src/hla_algorithm/update_alleles.py @@ -8,6 +8,7 @@ import time from datetime import datetime from io import StringIO +from pathlib import Path from typing import Final, Optional, TypedDict, cast import Bio @@ -215,13 +216,13 @@ def main(): parser.add_argument( "--output", help="filename to store the unreduced standards (YAML format)", - type=str, + type=Path, default="hla_standards.yaml", ) parser.add_argument( "--checksum", help="filename to store the MD5 checksum of the retrieved data in", - type=str, + type=Path, default="hla_nuc.fasta.checksum.txt", ) parser.add_argument( @@ -242,8 +243,7 @@ def main(): parser.add_argument( "--dump_full_fasta_to", help="if specified, the full original FASTA file is dumped to the specified path", - type=str, - default="", + type=Path, ) parser.add_argument( "--standard_report_interval", @@ -278,16 +278,14 @@ def main(): f"{retrieval_datetime}." ) - if args.dump_full_fasta_to != "": + if args.dump_full_fasta_to is not None: logger.info(f"Dumping the full FASTA file to {args.dump_full_fasta_to}.") - with open(args.dump_full_fasta_to, "w") as f: - f.write(alleles_str) + args.dump_full_fasta_to.write_text(alleles_str) # Compute the checksum. md5_calc = hashlib.md5() md5_calc.update(alleles_str.encode()) - with open(args.checksum, "w") as f: - f.write(f"{md5_calc.hexdigest()} {HLA_ALLELES_FILENAME}\n") + args.checksum.write_text(f"{md5_calc.hexdigest()} {HLA_ALLELES_FILENAME}\n") raw_standards: dict[HLA_LOCUS, list[HLARawStandard]] = collate_standards( list(Bio.SeqIO.parse(StringIO(alleles_str), "fasta")), @@ -313,7 +311,7 @@ def main(): # First, prepare the unreduced YAML output. logger.info(f"Writing HLA standards to {args.output}....") - with open(args.output, "w") as f: + with args.output.open("w") as f: yaml.safe_dump(standards_for_saving.model_dump(), f) logger.info("Done.") diff --git a/src/scripts/measure_resources.py b/src/scripts/measure_resources.py index 904502a..3d0066d 100644 --- a/src/scripts/measure_resources.py +++ b/src/scripts/measure_resources.py @@ -90,7 +90,7 @@ def main(): } ) - with open(args.output_csv, "w") as f: + with args.output_csv.open("w") as f: resource_summary_writer = csv.DictWriter( f, fieldnames=("sample_name", "wall_clock_time", "max_memory_usage_kb"),