Source code for nasbenchapi.nasbench201_api


import pickle
import random
import hashlib
import json
import time
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, Optional, List, Iterator

from .utils import resolve_path, sizeof_fmt, display_path

try:
    # Optional import to minimize pip overhead
    from tqdm import tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False



[docs] class NASBench201: """NASBench-201 API""" # NB201 operations for the search space OPS = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3'] NUM_OPS = len(OPS) NUM_EDGES = 6 # 6 edges in the cell
[docs] def __init__(self, pickle_path: Optional[str] = None, verbose: bool = True): self.path = resolve_path('201', pickle_path) self.verbose = verbose self.data: Any = None self._arch_keys = [] # Optional mappings if available from loaded data self._idx_to_str: Dict[int, str] = {} self._str_to_idx: Dict[str, int] = {} self._load()
def _load(self) -> None: start_time = time.perf_counter() size = self.path.stat().st_size if self.verbose: print(f"Loading NB201 from {display_path(self.path)} ({sizeof_fmt(size)})") with open(self.path, 'rb') as f: if HAS_TQDM and self.verbose and size > 0: bar = tqdm(total=size, unit='B', unit_scale=True, desc='Reading') raw = bytearray() chunk = f.read(1024 * 1024) while chunk: raw.extend(chunk) bar.update(len(chunk)) chunk = f.read(1024 * 1024) bar.close() # Unpickling stage if self.verbose: print("Unpickling data...") self.data = pickle.loads(bytes(raw)) if self.verbose: print("Unpickling complete.") else: # Unpickling stage (no size info) if self.verbose: print("Unpickling data...") self.data = pickle.load(f) if self.verbose: print("Unpickling complete.") size_info = 'dict' if isinstance(self.data, dict) else type(self.data).__name__ # Cache architecture keys if data is a dict fallback_keys = None if isinstance(self.data, dict): # Check if this is the official NB201 format with nested structure # Official format has: meta_archs, arch2infos, evaluated_indexes, etc. if 'arch2infos' in self.data: # Use arch2infos for architecture data # Ensure indices are integers self._arch_keys = [int(k) for k in self.data['arch2infos'].keys()] # Build index<->arch_str mappings when available try: for k in self._arch_keys: entry = self.data['arch2infos'].get(k, {}) arch_str = None if isinstance(entry, dict) and 'full' in entry and isinstance(entry['full'], dict): arch_str = entry['full'].get('arch_str') if isinstance(arch_str, str): self._idx_to_str[int(k)] = arch_str # If duplicates exist, keep the first seen mapping self._str_to_idx.setdefault(arch_str, int(k)) except Exception as e: if self.verbose: print(f"Warning: failed to build NB201 arch_str mappings: {e}") else: # Fallback to top-level keys self._arch_keys = list(self.data.keys()) fallback_keys = self._arch_keys if len(self._arch_keys) < 100 else None arch_count = len(self._arch_keys) if self.verbose: elapsed = time.perf_counter() - start_time print(f"[NB201] Loaded {arch_count} architectures in {elapsed:.2f}s") if fallback_keys is not None: print(f"[NB201] Note: top-level keys: {fallback_keys}") print("[NB201] Note: These may be metadata keys, not architectures")
[docs] def get_statistics(self) -> Dict[str, Any]: if isinstance(self.data, dict) and 'arch2infos' in self.data: n = len(self.data['arch2infos']) else: n = len(self.data) if isinstance(self.data, dict) else None return { 'benchmark': 'nasbench201', 'entries': n, }
[docs] def random_sample(self, n: int = 1, seed: Optional[int] = None): """Sample random architectures from NB201 search space. Args: n: Number of samples to return. seed: Optional random seed. Returns: List of sampled architecture strings in NB201 canonical format. """ import random as rnd if seed is not None: rnd.seed(seed) # Sample indices first if self._arch_keys: # Sample from loaded architectures idxs = rnd.sample(self._arch_keys, min(n, len(self._arch_keys))) else: # Sample uniformly from the full index space (0..15624) max_idx = (self.NUM_OPS ** self.NUM_EDGES) - 1 # 15624 idxs = [rnd.randint(0, max_idx) for _ in range(n)] # Convert to strings return [self._idx_to_str.get(i, self._index_to_arch_str(i)) for i in idxs]
[docs] def iter_all(self): """Iterate over all architectures in the loaded data. Returns: Iterator over architecture strings. """ if self._arch_keys: # Convert indices to strings return (self._idx_to_str.get(i, self._index_to_arch_str(i)) for i in self._arch_keys) return iter(())
[docs] def get_index(self, arch: str) -> int: """Convert an architecture string to its canonical index. Args: arch: NB201 architecture string (e.g., '|none~0|+|skip_connect~0|nor_conv_1x1~1|+|...') Returns: Integer index (0..15624) corresponding to the architecture. Raises: ValueError: If the architecture string is invalid or cannot be parsed. """ # Try mapping first (faster if loaded) arch_idx = self._str_to_idx.get(arch) if arch_idx is not None: return arch_idx # Otherwise decode from string return self._arch_str_to_index(arch)
[docs] def query(self, arch: str, dataset: str = 'cifar10', split: str = 'val', seed: Optional[int] = None, budget: Optional[Any] = None) -> Dict[str, Any]: """Query performance metrics for an architecture from loaded data. Args: arch: NB201 architecture string. dataset: Dataset name ('cifar10', 'cifar100', 'ImageNet16-120'). split: Split name ('val', 'test', or 'train'). seed: Optional seed (default: 777 for official NB201). budget: Optional epoch number (0-199, default: 199 for final epoch). Returns: Dictionary with keys: metric, metric_name, cost, std, info. """ # Convert to index try: arch_idx = self.get_index(arch) except Exception: arch_idx = None if not isinstance(self.data, dict) or 'arch2infos' not in self.data: return { 'metric': None, 'metric_name': f'{split}_acc', 'cost': None, 'std': None, 'info': {'note': 'NB201 data not loaded', 'arch': arch}, } # Get architecture data arch_data = self.data['arch2infos'].get(arch_idx) if arch_idx is not None else None if arch_data is None: return { 'metric': None, 'metric_name': f'{split}_acc', 'cost': None, 'std': None, 'info': {'error': 'architecture not found', 'arch': arch}, } # Default values if seed is None: seed = 777 # Official NB201 seed if budget is None: budget = 199 # Final epoch (0-199) metric = None cost = None # Build clean info dict with only essential metadata info = { 'arch_index': arch_idx, 'dataset': dataset, 'split': split, 'seed': seed, 'epoch': budget, } # Add architecture string if available if 'full' in arch_data and 'arch_str' in arch_data['full']: info['arch_str'] = arch_data['full']['arch_str'] elif isinstance(arch, str): info['arch_str'] = arch # Navigate official NB201 structure # arch_data['full']['all_results'][(dataset, seed)] if 'full' in arch_data: full_data = arch_data['full'] # Add basic arch info if 'arch_config' in full_data: info['params'] = full_data.get('params') info['flop'] = full_data.get('flop') if 'all_results' in full_data: all_results = full_data['all_results'] # Map split names and construct result key result_key = None metric_key_prefix = None if split in ['val', 'valid']: result_key = (f'{dataset}-valid', seed) metric_key_prefix = 'x-valid@' elif split == 'test': result_key = (dataset, seed) metric_key_prefix = 'ori-test@' elif split == 'train': result_key = (dataset, seed) if result_key in all_results: result = all_results[result_key] if 'train_acc1es' in result and budget in result['train_acc1es']: metric = result['train_acc1es'][budget] if 'train_times' in result and result['train_times'] and budget in result['train_times']: cost = result['train_times'][budget] return { 'metric': float(metric) if metric is not None else None, 'metric_name': 'train_acc', 'cost': float(cost) if cost is not None else None, 'std': None, 'info': info, } def _try_eval(result_key_tuple, key_prefix, budgets_to_try): nonlocal metric, cost if result_key_tuple not in all_results: return False result_local = all_results[result_key_tuple] eval_acc = result_local.get('eval_acc1es', {}) eval_times = result_local.get('eval_times', {}) for b in budgets_to_try: mkey = f"{key_prefix}{b}" if mkey in eval_acc: metric = eval_acc[mkey] if mkey in eval_times: cost = eval_times[mkey] return True return False # For val/test, search across typical budgets and fallbacks if result_key and metric_key_prefix: budgets_try = [] # prioritize requested budget if isinstance(budget, int): budgets_try.append(budget) # common final epochs in NB201 for b in (199, 200): if b not in budgets_try: budgets_try.append(b) found = _try_eval(result_key, metric_key_prefix, budgets_try) # Fallback: if requesting test and not found, try validation if not found and split == 'test': val_key = (f'{dataset}-valid', seed) found = _try_eval(val_key, 'x-valid@', budgets_try) return { 'metric': float(metric) if metric is not None else None, 'metric_name': f'{split}_acc', 'cost': float(cost) if cost is not None else None, 'std': None, 'info': info, }
# --- Internal encoding helpers ------------------------------------------------- def _index_to_arch_str(self, arch_idx: int) -> str: """Convert an architecture index (0..15624) to NB201 arch string. Uses a fixed edge order: (1<-0), (2<-0), (2<-1), (3<-0), (3<-1), (3<-2). """ op_ids = self._index_to_ops(arch_idx) ops = [self.OPS[i] for i in op_ids] return '|{}~0|+|{}~0|{}~1|+|{}~0|{}~1|{}~2|'.format(*ops) def _arch_str_to_index(self, arch_str: str) -> int: """Convert a NB201 arch string to its canonical index (0..15624).""" op_ids = self._arch_str_to_ops(arch_str) return self._ops_to_index(op_ids) def _index_to_ops(self, idx: int) -> list: """Convert index to a list of 6 op IDs (0..4) in canonical order.""" if idx < 0 or idx >= (self.NUM_OPS ** self.NUM_EDGES): raise ValueError(f"NB201 index out of range: {idx}") base = self.NUM_OPS out = [] x = idx for _ in range(self.NUM_EDGES): out.append(x % base) x //= base # out[0] -> edge0, ... already in correct order return out def _ops_to_index(self, op_ids: list) -> int: """Convert a list of 6 op IDs (0..4) to canonical index.""" if len(op_ids) != self.NUM_EDGES: raise ValueError("NB201 requires 6 operation IDs") base = self.NUM_OPS idx = 0 mul = 1 for i in range(self.NUM_EDGES): oid = int(op_ids[i]) if oid < 0 or oid >= base: raise ValueError(f"Invalid op id {oid} for NB201") idx += oid * mul mul *= base return idx def _arch_str_to_ops(self, arch_str: str) -> list: """Parse a NB201 arch string into 6 operation IDs in canonical order.""" # Extract tokens like 'op~0', 'op~1', ... in order # Replace separators and split s = arch_str.replace('|', ' ').replace('+', ' ').strip() tokens = [t for t in s.split() if '~' in t] if len(tokens) != self.NUM_EDGES: # Some formats may include trailing separators; try a more robust parse tokens = [] buf = '' for ch in arch_str: if ch == '|': if buf: buf = buf.strip() if buf: tokens.extend([x for x in buf.split('|') if x]) buf = '' else: buf += ch tokens = [t.strip() for t in tokens if '~' in t] if len(tokens) != self.NUM_EDGES: raise ValueError(f"Cannot parse NB201 arch_str: {arch_str}") ops_only = [t.split('~')[0] for t in tokens] op_ids = [] for op in ops_only: if op not in self.OPS: raise ValueError(f"Unknown NB201 op '{op}' in arch_str") op_ids.append(self.OPS.index(op)) return op_ids