Source code for nasbenchapi.nasbench301_api


import pickle
import random
import hashlib
import json
import time
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, Any, Optional, List, Iterator, Tuple

from .utils import resolve_path, sizeof_fmt, display_path

try:
    # Optional import to minimize pip overhead
    from tqdm import tqdm
    HAS_TQDM = True
except ImportError:
    HAS_TQDM = False


[docs] class NASBench301: """NASBench-301 API. Supports directory-converted JSON payloads or reserialized pickle. """ # NB301 uses DARTS-style search space OPS = [ 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5', 'none' ] NUM_OPS = len(OPS) NUM_NODES = 4 # Intermediate nodes in the cell NUM_EDGES_PER_NODE = 2 # Each node has 2 input edges
[docs] def __init__(self, pickle_path: Optional[str] = None, verbose: bool = True): self.path = resolve_path('301', pickle_path) self.verbose = verbose self.data: Any = None self._arch_keys = [] self._path_to_index: Dict[str, int] = {} self._load()
def _load(self) -> None: start_time = time.perf_counter() size = self.path.stat().st_size if self.verbose: print(f"Loading NB301 from {display_path(self.path)} ({sizeof_fmt(size)})") with open(self.path, 'rb') as f: if HAS_TQDM and self.verbose and size > 0: bar = tqdm(total=size, unit='B', unit_scale=True, desc='Reading') raw = bytearray() chunk = f.read(1024 * 1024) while chunk: raw.extend(chunk) bar.update(len(chunk)) chunk = f.read(1024 * 1024) bar.close() # Unpickling stage if self.verbose: print("Unpickling data...") self.data = pickle.loads(bytes(raw)) if self.verbose: print("Unpickling complete.") else: if self.verbose: print("Unpickling data...") self.data = pickle.load(f) if self.verbose: print("Unpickling complete.") entry_count = None # Cache architecture keys if data is a dict if isinstance(self.data, dict): if 'entries' in self.data: entries = self.data.get('entries', []) self._arch_keys = list(range(len(entries))) entry_count = len(self._arch_keys) self._path_to_index = {} for idx, entry in enumerate(entries): if isinstance(entry, dict): path = entry.get('path') if isinstance(path, str): self._path_to_index[path] = idx else: self._arch_keys = list(self.data.keys()) entry_count = len(self._arch_keys) self._path_to_index = {} else: self._arch_keys = [] self._path_to_index = {} loaded_count = len(self._arch_keys) if entry_count is None: entry_count = loaded_count if loaded_count else None if self.verbose: elapsed = time.perf_counter() - start_time if loaded_count: print(f"[NB301] Loaded {loaded_count} entries in {elapsed:.2f}s") else: print(f"[NB301] Loaded NB301 payload in {elapsed:.2f}s")
[docs] def get_statistics(self) -> Dict[str, Any]: if isinstance(self.data, dict) and 'entries' in self.data: return { 'benchmark': 'nasbench301', 'files': self.data.get('num_files', len(self.data.get('entries', []))), } return { 'benchmark': 'nasbench301', 'files': None, }
[docs] def random_sample(self, n: int = 1, seed: Optional[int] = None): """Sample random architectures from NB301 search space. NB301 uses DARTS-style search space which is continuous, but we can sample discrete architectures from it. Args: n: Number of samples to return. seed: Optional random seed. Returns: List of sampled architecture representations. """ import random as rnd if seed is not None: rnd.seed(seed) if self._arch_keys: # Sample from loaded architectures return rnd.sample(self._arch_keys, min(n, len(self._arch_keys))) else: # Generate random DARTS-style architectures # Each cell has 4 nodes, each node selects 2 operations from previous nodes samples = [] for _ in range(n): # Normal cell and reduction cell normal = [] reduce = [] for node_idx in range(self.NUM_NODES): # For each node, select 2 edges (predecessors and operations) for _ in range(self.NUM_EDGES_PER_NODE): # Can connect to any previous node (0 to node_idx+1) pred = rnd.randint(0, node_idx + 1) op = rnd.choice(self.OPS) normal.append((op, pred)) reduce.append((op, pred)) samples.append({'normal': normal, 'reduce': reduce}) return samples
[docs] def iter_all(self): """Iterate over all architectures in the loaded data. Returns: Iterator over architecture keys/indices. """ if self._arch_keys: return iter(self._arch_keys) return iter(())
[docs] def get_index(self, arch: Any) -> Optional[int]: """Get the index for an architecture in the loaded data. Args: arch: Architecture representation (dict, index, or path string). Returns: Integer index if architecture is found in loaded data, None otherwise. """ return self._index_from_arch(arch)
[docs] def query(self, arch: Any, dataset: str = 'cifar10', split: str = 'val', seed: Optional[int] = None, budget: Optional[Any] = None) -> Dict[str, Any]: """Query performance metrics for an architecture from loaded data. Note: NB301 directory payloads include recorded learning curves and final metrics. Validation queries read from the stored curves; test queries report the declared final accuracy for each entry. Args: arch: Architecture representation (dict or index). dataset: Dataset name ('cifar10', 'cifar100'). split: Split name ('val', 'test'). seed: Optional seed for reproducibility (unused). budget: Optional budget specification (unused). Returns: Dictionary with keys: metric, metric_name, cost, std, info. """ metric_name = f'{split}_acc' info: Dict[str, Any] = { 'requested_dataset': dataset, 'split': split, } entries = self.data.get('entries') if isinstance(self.data, dict) else None if not isinstance(entries, list): info['error'] = 'NB301 data not loaded' return { 'metric': None, 'metric_name': metric_name, 'cost': None, 'std': None, 'info': info, } idx = self._index_from_arch(arch) if idx is None or idx < 0 or idx >= len(entries): info.update({'error': 'entry not found', 'arch': arch}) return { 'metric': None, 'metric_name': metric_name, 'cost': None, 'std': None, 'info': info, } entry = entries[idx] parsed = entry.get('parsed', {}) if isinstance(entry, dict) else {} dataset_actual = self._infer_dataset(parsed) if dataset_actual: info['dataset'] = dataset_actual path = entry.get('path') if isinstance(entry, dict) else None if isinstance(path, str): info['path'] = path optimizer = parsed.get('created_by_optimizer') if optimizer: info['optimizer'] = optimizer declared_budget = parsed.get('budget') if declared_budget is not None: info['declared_budget'] = declared_budget cost = parsed.get('runtime') try: cost_value = float(cost) if cost is not None else None except (TypeError, ValueError): cost_value = None metric = None epoch_used: Optional[int] = None epochs_available: Optional[int] = None if split == 'val': curve = self._get_learning_curve(parsed, 'Train/val_accuracy') metric, epoch_used, epochs_available = self._curve_metric(curve, budget) if metric is None: metric = self._final_metric(parsed, 'val') if epochs_available is not None: info['epochs_available'] = epochs_available if epoch_used is not None: info['epoch_used'] = epoch_used elif split == 'test': metric = self._final_metric(parsed, 'test') if declared_budget is not None: info['epoch_used'] = declared_budget else: info['error'] = f"split '{split}' not supported for NB301" return { 'metric': None, 'metric_name': metric_name, 'cost': cost_value, 'std': None, 'info': info, } if dataset_actual and dataset and dataset_actual != dataset: info['note'] = f"entry recorded on dataset '{dataset_actual}'" try: metric_value = float(metric) if metric is not None else None except (TypeError, ValueError): metric_value = None return { 'metric': metric_value, 'metric_name': metric_name, 'cost': cost_value, 'std': None, 'info': info, }
# --- Internal helpers ------------------------------------------------- def _index_from_arch(self, arch: Any) -> Optional[int]: if not isinstance(self.data, dict) or 'entries' not in self.data: return None entries = self.data.get('entries', []) if not isinstance(entries, list): return None if isinstance(arch, int): return arch if 0 <= arch < len(entries) else None if isinstance(arch, str): return self._path_to_index.get(arch) if isinstance(arch, dict): idx = arch.get('index') if isinstance(idx, int) and 0 <= idx < len(entries): return idx path = arch.get('path') if isinstance(path, str) and path in self._path_to_index: return self._path_to_index[path] # Direct match if arch in entries: return entries.index(arch) parsed = arch.get('parsed') if parsed is not None: for cand_idx, candidate in enumerate(entries): if isinstance(candidate, dict) and candidate.get('parsed') == parsed: return cand_idx return None def _infer_dataset(self, parsed: Dict[str, Any]) -> Optional[str]: info_list = parsed.get('info') if isinstance(info_list, list) and info_list: first = info_list[0] if isinstance(first, dict): path = first.get('dataset_path', '') if isinstance(path, str): lowered = path.lower() if 'cifar100' in lowered: return 'cifar100' if 'cifar10' in lowered: return 'cifar10' if 'imagenet16-120' in lowered or 'imagenet16' in lowered: return 'ImageNet16-120' return None def _get_learning_curve(self, parsed: Dict[str, Any], key: str) -> List[float]: curves = parsed.get('learning_curves') if isinstance(curves, dict): curve = curves.get(key) if isinstance(curve, list): numeric = [] for v in curve: try: numeric.append(float(v)) except (TypeError, ValueError): continue return numeric return [] def _curve_metric(self, curve: List[float], budget: Optional[Any]) -> Tuple[Optional[float], Optional[int], Optional[int]]: if not curve: return None, None, None total_epochs = len(curve) epoch_used = total_epochs if budget is not None: try: epoch = int(budget) if epoch < 1: return None, None, total_epochs if epoch > total_epochs: epoch_used = total_epochs else: epoch_used = epoch except (TypeError, ValueError): epoch_used = total_epochs value = curve[epoch_used - 1] try: metric = float(value) except (TypeError, ValueError): metric = None return metric, epoch_used, total_epochs def _final_metric(self, parsed: Dict[str, Any], split: str) -> Optional[float]: info_list = parsed.get('info') record = info_list[0] if isinstance(info_list, list) and info_list and isinstance(info_list[0], dict) else {} if split == 'val': for key in ('val_accuracy_final', 'val_accuracy'): value = record.get(key) if value is not None: try: return float(value) except (TypeError, ValueError): continue elif split == 'test': value = parsed.get('test_accuracy') if value is None and isinstance(record, dict): value = record.get('test_accuracy') if value is not None: try: return float(value) except (TypeError, ValueError): return None return None