Source code for nasbenchapi.nb_api

"""
Unified API wrappers for NASBench-101/201/301.

This module provides wrappers around the benchmark-specific implementations 
to expose a consistent interface across all benchmarks.
"""

import random
from typing import Any, Dict, Iterator, List, Optional, Set

from .base import NASBenchBase
from .nasbench101_api import NASBench101 as _NB101, Arch101
from .nasbench201_api import NASBench201 as _NB201
from .nasbench301_api import NASBench301 as _NB301


[docs] class NASBench101(NASBenchBase): """Unified NB101 API wrapper. Delegates all operations to the underlying NASBench101 implementation. """
[docs] def __init__(self, data_path: Optional[str] = None, verbose: bool = True): self.api = _NB101(data_path, verbose=verbose)
[docs] def load(self, data_path: Optional[str] = None) -> "NASBench101": """Load NB101 from the provided path or environment.""" return NASBench101(data_path)
[docs] def bench_name(self) -> str: """Short benchmark name.""" return 'nb101'
[docs] def datasets(self) -> List[str]: """Available datasets.""" return ['cifar10']
[docs] def splits(self, dataset: str) -> List[str]: """Supported splits.""" return ['train', 'val', 'test']
[docs] def available_budgets(self, dataset: Optional[str] = None, split: Optional[str] = None) -> Optional[List[Any]]: """NB101 does not track explicit training budgets.""" return None
# Delegate to underlying API
[docs] def op_set(self) -> List[str]: """Operations available in NB101 cell.""" return self.api.op_set()
[docs] def decode(self, encoding: Dict[str, str]) -> Arch101: """Decode architecture from native encoding.""" return self.api.decode(encoding)
[docs] def encode(self, arch: Arch101) -> Dict[str, str]: """Encode architecture to native format.""" return self.api.encode(arch)
[docs] def id(self, arch: Arch101) -> str: """Stable identifier for architecture.""" return self.api.id(arch)
[docs] def random_sample(self, n: int = 1, seed: Optional[int] = None) -> List[Arch101]: """Random sample from loaded NB101 architectures.""" return self.api.random_sample(n, seed)
[docs] def iter_all(self) -> Iterator[Arch101]: """Iterate all NB101 architectures.""" return self.api.iter_all()
[docs] def mutate(self, arch: Arch101, rng: random.Random, kind: Optional[str] = None) -> Arch101: """Mutate architecture.""" return self.api.mutate(arch, rng, kind)
[docs] def query(self, arch: Arch101, dataset: str = 'cifar10', split: str = 'val', seed: Optional[int] = None, budget: Optional[Any] = None, average: bool = False, summary: bool = False) -> Any: """Query performance metrics for architecture.""" return self.api.query(arch, dataset, split, seed, budget, average=average, summary=summary)
[docs] def is_valid(self, arch: Arch101) -> bool: """Check if architecture is valid.""" return self.api.is_valid(arch)
[docs] def train_time(self, arch: Arch101, dataset: str = 'cifar10') -> Optional[float]: """Get training time for architecture.""" return self.api.train_time(arch, dataset)
def _supports_arch(self, arch: Arch101) -> bool: """Check if the architecture exists in the loaded dataset.""" if not self.api.is_valid(arch): return False latest = self.api.data.get('latest_by_arch', {}) arch_id = self.api.id(arch) if arch_id in latest: return True enc = self.api.encode(arch) for last in latest.values(): if (last.get('adjacency_str') == enc.get('adjacency_str') and last.get('operations_str') == enc.get('operations_str')): return True return False
[docs] class NASBench201(NASBenchBase): """Unified NB201 API wrapper. Delegates all operations to the underlying NASBench201 implementation. """
[docs] def __init__(self, data_path: Optional[str] = None, verbose: bool = True): self.api = _NB201(data_path, verbose=verbose) self._budget_cache: Optional[Dict[str, Dict[str, List[int]]]] = None
[docs] def load(self, data_path: Optional[str] = None) -> "NASBench201": """Load NB201 from the provided path or environment.""" return NASBench201(data_path)
[docs] def bench_name(self) -> str: """Short benchmark name.""" return 'nb201'
[docs] def datasets(self) -> List[str]: """Available datasets.""" return ['cifar10', 'cifar100', 'ImageNet16-120']
[docs] def splits(self, dataset: str) -> List[str]: """Supported splits.""" return ['train', 'val', 'test']
[docs] def available_budgets(self, dataset: Optional[str] = None, split: Optional[str] = None) -> Optional[List[int]]: """Return available training budgets for NB201.""" self._ensure_budget_cache() if not self._budget_cache: return None if dataset is None: combined: set = set() for ds_budgets in self._budget_cache.values(): if split is None: for values in ds_budgets.values(): combined.update(values) else: combined.update(ds_budgets.get(split, [])) return sorted(combined) if combined else None canonical = self._canonical_dataset_name(dataset) if canonical is None or canonical not in self._budget_cache: return None ds_budgets = self._budget_cache[canonical] if split is None: combined = set() for values in ds_budgets.values(): combined.update(values) return sorted(combined) if combined else None values = ds_budgets.get(split) return list(values) if values else None
# Delegate to underlying API
[docs] def decode(self, encoding: Any) -> Any: """Decode architecture (pass-through for now).""" return encoding
[docs] def encode(self, arch: Any) -> Any: """Encode architecture (pass-through for now).""" return arch
[docs] def id(self, arch: Any) -> str: """Stable identifier for architecture.""" import hashlib import json h = hashlib.sha256() h.update(json.dumps({'arch': arch}, sort_keys=True).encode("utf-8")) return h.hexdigest()
[docs] def random_sample(self, n: int = 1, seed: Optional[int] = None) -> List[Any]: """Random sample from NB201.""" return self.api.random_sample(n, seed)
[docs] def random_sample_str(self, n: int = 1, seed: Optional[int] = None) -> List[str]: """Random sample NB201 architectures as arch strings.""" return self.api.random_sample_str(n, seed)
[docs] def iter_all(self) -> Iterator[Any]: """Iterate all NB201 architectures.""" return self.api.iter_all()
[docs] def mutate(self, arch: Any, rng: random.Random, kind: Optional[str] = None) -> Any: """Mutate architecture (no-op for now).""" return arch
[docs] def query(self, arch: Any, dataset: str = 'cifar10', split: str = 'val', seed: Optional[int] = None, budget: Optional[Any] = None) -> Dict[str, Any]: """Query performance metrics for architecture.""" return self.api.query(arch, dataset, split, seed, budget)
# Conversion helpers (benchmark-specific)
[docs] def index_to_arch_str(self, idx: int) -> str: """Convert NB201 index (0..15624) to arch string.""" return self.api.index_to_arch_str(idx)
[docs] def arch_str_to_index(self, arch_str: str) -> int: """Convert NB201 arch string to canonical index (0..15624).""" return self.api.arch_str_to_index(arch_str)
def _supports_arch(self, arch: Any) -> bool: """Check whether an architecture exists in the loaded NB201 dataset.""" if isinstance(arch, int): if self.api._arch_keys: return arch in self.api._arch_keys max_idx = (self.api.NUM_OPS ** self.api.NUM_EDGES) - 1 return 0 <= arch <= max_idx if isinstance(arch, str): try: arch_idx = self.api.get_index(arch) except Exception: return False if self.api._arch_keys: return arch_idx in self.api._arch_keys return arch_idx is not None return False # --- Budget helpers ----------------------------------------------------- def _ensure_budget_cache(self) -> None: if self._budget_cache is not None: return cache: Dict[str, Dict[str, Set[int]]] = {} data = getattr(self.api, 'data', {}) if not isinstance(data, dict): self._budget_cache = {} return arch2infos = data.get('arch2infos', {}) if not isinstance(arch2infos, dict) or not arch2infos: self._budget_cache = {} return for arch_info in arch2infos.values(): if not isinstance(arch_info, dict): continue full = arch_info.get('full') if not isinstance(full, dict): continue all_results = full.get('all_results') if not isinstance(all_results, dict): continue for (dataset_key, _seed), result in all_results.items(): if not isinstance(result, dict): continue canonical = self._canonical_dataset_name(dataset_key) if canonical is None: continue store = cache.setdefault(canonical, {}) self._accumulate_budgets(store, result) # Budgets are consistent across architectures; first populated entry is enough. if cache: break # Convert sets to sorted lists for stability normalized_cache: Dict[str, Dict[str, List[int]]] = {} for ds, split_map in cache.items(): normalized_cache[ds] = {} for split, values in split_map.items(): if values: normalized_cache[ds][split] = sorted(values) self._budget_cache = normalized_cache def _canonical_dataset_name(self, dataset: Optional[str]) -> Optional[str]: if dataset is None: return None key = str(dataset).lower() if key.startswith('cifar10'): return 'cifar10' if key.startswith('cifar100'): return 'cifar100' if key.startswith('imagenet16-120'): return 'ImageNet16-120' return dataset if isinstance(dataset, str) else None def _accumulate_budgets(self, store: Dict[str, Set[int]], result: Dict[str, Any]) -> None: eval_acc = result.get('eval_acc1es', {}) if isinstance(eval_acc, dict): for metric_key in eval_acc.keys(): if not isinstance(metric_key, str): continue if metric_key.startswith('x-valid@'): budget = self._parse_budget_suffix(metric_key, 'x-valid@') if budget is not None: store.setdefault('val', set()).add(budget) elif metric_key.startswith('ori-test@'): budget = self._parse_budget_suffix(metric_key, 'ori-test@') if budget is not None: store.setdefault('test', set()).add(budget) train_acc = result.get('train_acc1es') if isinstance(train_acc, dict): budgets = [] for key, value in train_acc.items(): try: budgets.append(int(key)) except (TypeError, ValueError): continue if budgets: store.setdefault('train', set()).update(budgets) elif isinstance(train_acc, (list, tuple)): store.setdefault('train', set()).update(range(len(train_acc))) @staticmethod def _parse_budget_suffix(metric_key: str, prefix: str) -> Optional[int]: try: return int(metric_key.replace(prefix, '', 1)) except ValueError: return None
[docs] class NASBench301(NASBenchBase): """Unified NB301 API wrapper. Delegates all operations to the underlying NASBench301 implementation. """
[docs] def __init__(self, data_path: Optional[str] = None, verbose: bool = True): self.api = _NB301(data_path, verbose=verbose) self._budget_cache: Optional[Dict[str, Dict[str, List[int]]]] = None
[docs] def load(self, data_path: Optional[str] = None) -> "NASBench301": """Load NB301 from the provided path or environment.""" return NASBench301(data_path)
[docs] def bench_name(self) -> str: """Short benchmark name.""" return 'nb301'
[docs] def datasets(self) -> List[str]: """Available datasets.""" return ['cifar10', 'cifar100']
[docs] def splits(self, dataset: str) -> List[str]: """Supported splits.""" return ['val', 'test']
[docs] def available_budgets(self, dataset: Optional[str] = None, split: Optional[str] = None) -> Optional[List[int]]: """Return available budgets derived from NB301 learning curves.""" self._ensure_budget_cache() if not self._budget_cache: return None if dataset is None: combined: Set[int] = set() for ds_budgets in self._budget_cache.values(): if split is None: for values in ds_budgets.values(): combined.update(values) else: combined.update(ds_budgets.get(split, [])) return sorted(combined) if combined else None ds_budgets = self._budget_cache.get(dataset) if ds_budgets is None: return None if split is None: combined: Set[int] = set() for values in ds_budgets.values(): combined.update(values) return sorted(combined) if combined else None values = ds_budgets.get(split) return list(values) if values else None
# Delegate to underlying API
[docs] def decode(self, encoding: Any) -> Any: """Decode architecture (pass-through for now).""" return encoding
[docs] def encode(self, arch: Any) -> Any: """Encode architecture (pass-through for now).""" return arch
[docs] def id(self, arch: Any) -> str: """Stable identifier for architecture.""" import hashlib import json h = hashlib.sha256() h.update(json.dumps({'arch': arch}, sort_keys=True).encode("utf-8")) return h.hexdigest()
[docs] def random_sample(self, n: int = 1, seed: Optional[int] = None) -> List[Any]: """Random sample from NB301.""" return self.api.random_sample(n, seed)
[docs] def iter_all(self) -> Iterator[Any]: """Iterate all NB301 architectures.""" return self.api.iter_all()
[docs] def mutate(self, arch: Any, rng: random.Random, kind: Optional[str] = None) -> Any: """Mutate architecture (no-op for now).""" return arch
[docs] def query(self, arch: Any, dataset: str = 'cifar10', split: str = 'val', seed: Optional[int] = None, budget: Optional[Any] = None) -> Dict[str, Any]: """Query performance metrics for architecture.""" return self.api.query(arch, dataset, split, seed, budget)
def _supports_arch(self, arch: Any) -> bool: """Check if an architecture or index exists in the NB301 payload.""" if isinstance(arch, int): if self.api._arch_keys: return 0 <= arch < len(self.api._arch_keys) return False if isinstance(arch, dict): idx = self.api.get_index(arch) if idx is not None: return True return False def _ensure_budget_cache(self) -> None: if self._budget_cache is not None: return cache: Dict[str, Dict[str, Set[int]]] = {} data = getattr(self.api, 'data', {}) entries = data.get('entries') if isinstance(data, dict) else None if not isinstance(entries, list): self._budget_cache = {} return for entry in entries: if not isinstance(entry, dict): continue parsed = entry.get('parsed', {}) dataset_name = self.api._infer_dataset(parsed) if not dataset_name: continue store = cache.setdefault(dataset_name, {}) val_curve = self.api._get_learning_curve(parsed, 'Train/val_accuracy') if val_curve: store.setdefault('val', set()).update(range(1, len(val_curve) + 1)) budget = parsed.get('budget') try: if budget is not None: store.setdefault('test', set()).add(int(budget)) except (TypeError, ValueError): pass normalized: Dict[str, Dict[str, List[int]]] = {} for ds, split_map in cache.items(): normalized[ds] = {} for split, values in split_map.items(): if values: normalized[ds][split] = sorted(values) self._budget_cache = normalized