Spaces:
Build error
Build error
from train import train, Losses | |
import priors | |
import encoders | |
from collections import defaultdict | |
from priors.utils import trunc_norm_sampler_f, gamma_sampler_f | |
from utils import get_uniform_single_eval_pos_sampler | |
import torch | |
import math | |
def save_model(model, path, filename, config_sample): | |
config_sample = {**config_sample} | |
def make_serializable(config_sample): | |
if isinstance(config_sample, dict): | |
config_sample = {k: make_serializable(config_sample[k]) for k in config_sample} | |
if isinstance(config_sample, list): | |
config_sample = [make_serializable(v) for v in config_sample] | |
if callable(config_sample): | |
config_sample = str(config_sample) | |
return config_sample | |
#if 'num_features_used' in config_sample: | |
# del config_sample['num_features_used'] | |
#config_sample['num_classes_as_str'] = str(config_sample['num_classes']) | |
#del config_sample['num_classes'] | |
config_sample = make_serializable(config_sample) | |
torch.save((model.state_dict(), None, config_sample), os.path.join(path, filename)) | |
import subprocess as sp | |
import os | |
def get_gpu_memory(): | |
command = "nvidia-smi" | |
memory_free_info = sp.check_output(command.split()).decode('ascii') | |
return memory_free_info | |
def load_model(path, filename, device, eval_positions, verbose): | |
# TODO: This function only restores evaluation functionality but training canät be continued. It is also not flexible. | |
model_state, optimizer_state, config_sample = torch.load( | |
os.path.join(path, filename), map_location='cpu') | |
if ('differentiable_hyperparameters' in config_sample | |
and 'prior_mlp_activations' in config_sample['differentiable_hyperparameters']): | |
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values_used'] = config_sample[ | |
'differentiable_hyperparameters'][ | |
'prior_mlp_activations'][ | |
'choice_values'] | |
config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values'] = [ | |
torch.nn.Tanh for k in config_sample['differentiable_hyperparameters']['prior_mlp_activations']['choice_values']] | |
config_sample['categorical_features_sampler'] = lambda: lambda x: ([], [], []) | |
config_sample['num_features_used_in_training'] = config_sample['num_features_used'] | |
config_sample['num_features_used'] = lambda: config_sample['num_features'] | |
config_sample['num_classes_in_training'] = config_sample['num_classes'] | |
config_sample['num_classes'] = 2 | |
config_sample['batch_size_in_training'] = config_sample['batch_size'] | |
config_sample['batch_size'] = 1 | |
config_sample['bptt_in_training'] = config_sample['bptt'] | |
config_sample['bptt'] = 10 | |
config_sample['bptt_extra_samples_in_training'] = config_sample['bptt_extra_samples'] | |
config_sample['bptt_extra_samples'] = None | |
#print('Memory', str(get_gpu_memory())) | |
model = get_model(config_sample, device=device, should_train=False, verbose=verbose) | |
module_prefix = 'module.' | |
model_state = {k.replace(module_prefix, ''): v for k, v in model_state.items()} | |
model[2].load_state_dict(model_state) | |
model[2].to(device) | |
return model, config_sample | |
def fix_loaded_config_sample(loaded_config_sample, config): | |
def copy_to_sample(*k): | |
t,s = loaded_config_sample, config | |
for k_ in k[:-1]: | |
t = t[k_] | |
s = s[k_] | |
t[k[-1]] = s[k[-1]] | |
copy_to_sample('num_features_used') | |
copy_to_sample('num_classes') | |
copy_to_sample('differentiable_hyperparameters','prior_mlp_activations','choice_values') | |
def load_config_sample(path, template_config): | |
model_state, optimizer_state, loaded_config_sample = torch.load(path, map_location='cpu') | |
fix_loaded_config_sample(loaded_config_sample, template_config) | |
return loaded_config_sample | |
def get_default_spec(test_datasets, valid_datasets): | |
bptt = 10000 | |
eval_positions = [1000, 2000, 3000, 4000, 5000] # list(2 ** np.array([4, 5, 6, 7, 8, 9, 10, 11, 12])) | |
max_features = max([X.shape[1] for (_, X, _, _, _, _) in test_datasets] + [X.shape[1] for (_, X, _, _, _, _) in valid_datasets]) | |
max_splits = 5 | |
return bptt, eval_positions, max_features, max_splits | |
def get_mlp_prior_hyperparameters(config): | |
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config} | |
if "prior_sigma_gamma_k" in config: | |
sigma_sampler = gamma_sampler_f(config["prior_sigma_gamma_k"], config["prior_sigma_gamma_theta"]) | |
config['init_std'] = sigma_sampler | |
if "prior_noise_std_gamma_k" in config: | |
noise_std_sampler = gamma_sampler_f(config["prior_noise_std_gamma_k"], config["prior_noise_std_gamma_theta"]) | |
config['noise_std'] = noise_std_sampler | |
return config | |
def get_gp_mix_prior_hyperparameters(config): | |
return {'lengthscale_concentration': config["prior_lengthscale_concentration"], | |
'nu': config["prior_nu"], | |
'outputscale_concentration': config["prior_outputscale_concentration"], | |
'categorical_data': config["prior_y_minmax_norm"], | |
'y_minmax_norm': config["prior_lengthscale_concentration"], | |
'noise_concentration': config["prior_noise_concentration"], | |
'noise_rate': config["prior_noise_rate"]} | |
def get_gp_prior_hyperparameters(config): | |
return {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config} | |
def get_meta_gp_prior_hyperparameters(config): | |
config = {hp: (list(config[hp].values())[0]) if type(config[hp]) is dict else config[hp] for hp in config} | |
if "outputscale_mean" in config: | |
outputscale_sampler = trunc_norm_sampler_f(config["outputscale_mean"] | |
, config["outputscale_mean"] * config["outputscale_std_f"]) | |
config['outputscale'] = outputscale_sampler | |
if "lengthscale_mean" in config: | |
lengthscale_sampler = trunc_norm_sampler_f(config["lengthscale_mean"], | |
config["lengthscale_mean"] * config["lengthscale_std_f"]) | |
config['lengthscale'] = lengthscale_sampler | |
return config | |
def get_model(config, device, should_train=True, verbose=False, state_dict=None, epoch_callback=None): | |
extra_kwargs = {} | |
verbose_train, verbose_prior = verbose >= 1, verbose >= 2 | |
config['verbose'] = verbose_prior | |
if 'aggregate_k_gradients' not in config or config['aggregate_k_gradients'] is None: | |
config['aggregate_k_gradients'] = math.ceil(config['batch_size'] * ((config['nlayers'] * config['emsize'] * config['bptt'] * config['bptt']) / 10824640000)) | |
config['num_steps'] = math.ceil(config['num_steps'] * config['aggregate_k_gradients']) | |
config['batch_size'] = math.ceil(config['batch_size'] / config['aggregate_k_gradients']) | |
config['recompute_attn'] = config['recompute_attn'] if 'recompute_attn' in config else False | |
def make_get_batch(model_proto, **extra_kwargs): | |
extra_kwargs = defaultdict(lambda: None, **extra_kwargs) | |
return (lambda batch_size, seq_len, num_features, hyperparameters | |
, device, model_proto=model_proto, get_batch=extra_kwargs['get_batch'] | |
, prior_bag_priors=extra_kwargs['prior_bag_priors']: model_proto.get_batch( | |
batch_size=batch_size | |
, seq_len=seq_len | |
, device=device | |
, get_batch=get_batch | |
, hyperparameters=hyperparameters | |
, num_features=num_features)) | |
if config['prior_type'] == 'prior_bag': | |
# Prior bag combines priors | |
get_batch_gp = make_get_batch(priors.fast_gp) | |
get_batch_mlp = make_get_batch(priors.mlp) | |
if 'flexible' in config and config['flexible']: | |
get_batch_gp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_gp}) | |
get_batch_mlp = make_get_batch(priors.flexible_categorical, **{'get_batch': get_batch_mlp}) | |
prior_bag_hyperparameters = {'prior_bag_get_batch': (get_batch_gp, get_batch_mlp) | |
, 'prior_bag_exp_weights_1': 2.0} | |
prior_hyperparameters = {**get_mlp_prior_hyperparameters(config), **get_gp_prior_hyperparameters(config) | |
, **prior_bag_hyperparameters} | |
model_proto = priors.prior_bag | |
else: | |
if config['prior_type'] == 'mlp': | |
prior_hyperparameters = get_mlp_prior_hyperparameters(config) | |
model_proto = priors.mlp | |
elif config['prior_type'] == 'gp': | |
prior_hyperparameters = get_gp_prior_hyperparameters(config) | |
model_proto = priors.fast_gp | |
elif config['prior_type'] == 'gp_mix': | |
prior_hyperparameters = get_gp_mix_prior_hyperparameters(config) | |
model_proto = priors.fast_gp_mix | |
else: | |
raise Exception() | |
if 'flexible' in config and config['flexible']: | |
get_batch_base = make_get_batch(model_proto) | |
extra_kwargs['get_batch'] = get_batch_base | |
model_proto = priors.flexible_categorical | |
use_style = False | |
if 'differentiable' in config and config['differentiable']: | |
get_batch_base = make_get_batch(model_proto, **extra_kwargs) | |
extra_kwargs = {'get_batch': get_batch_base, 'differentiable_hyperparameters': config['differentiable_hyperparameters']} | |
model_proto = priors.differentiable_prior | |
use_style = True | |
print(f"Using style prior: {use_style}") | |
if (('nan_prob_no_reason' in config and config['nan_prob_no_reason'] > 0.0) or | |
('nan_prob_a_reason' in config and config['nan_prob_a_reason'] > 0.0) or | |
('nan_prob_unknown_reason' in config and config['nan_prob_unknown_reason'] > 0.0)): | |
encoder = encoders.NanHandlingEncoder | |
else: | |
encoder = encoders.Linear | |
num_outputs = config['num_outputs'] if 'num_outputs' in config else 1 | |
if config['max_num_classes'] == 2: | |
if 'joint_loss' in config and config['joint_loss']: | |
loss = JointBCELossWithLogits | |
else: | |
loss = Losses.bce | |
elif config['max_num_classes'] > 2: | |
loss = Losses.ce(torch.ones((config['max_num_classes']))) | |
else: | |
loss = BarDistribution(borders=get_bucket_limits(500, full_range=(-10, 10))) | |
aggregate_k_gradients = 1 if 'aggregate_k_gradients' not in config else config['aggregate_k_gradients'] | |
check_is_compatible = False if 'multiclass_loss_type' not in config else (config['multiclass_loss_type'] == 'compatible') | |
config['multiclass_type'] = config['multiclass_type'] if 'multiclass_type' in config else 'rank' | |
config['mix_activations'] = config['mix_activations'] if 'mix_activations' in config else False | |
config['bptt_extra_samples'] = config['bptt_extra_samples'] if 'bptt_extra_samples' in config else None | |
config['eval_positions'] = [int(config['bptt'] * 0.95)] if config['bptt_extra_samples'] is None else [int(config['bptt'])] | |
epochs = 0 if not should_train else config['epochs'] | |
model = train(model_proto.DataLoader | |
, loss | |
, encoder | |
, style_encoder_generator = encoders.StyleEncoder if use_style else None | |
, emsize=config['emsize'] | |
, nhead=config['nhead'] | |
, y_encoder_generator= encoders.get_Canonical(config['max_num_classes']) if config.get('canonical_y_encoder', False) else encoders.Linear | |
, pos_encoder_generator=None | |
, batch_size=config['batch_size'] | |
, nlayers=config['nlayers'] | |
, nhid=config['emsize'] * config['nhid_factor'] | |
, epochs=epochs | |
, total_available_time_in_s=config.get('total_available_time_in_s', None) | |
, warmup_epochs=20 | |
, bptt=config['bptt'] | |
, gpu_device=device | |
, dropout=config['dropout'] | |
, steps_per_epoch=config['num_steps'] | |
, single_eval_pos_gen=get_uniform_single_eval_pos_sampler(config['bptt']) | |
, load_weights_from_this_state_dict=state_dict | |
, aggregate_k_gradients=aggregate_k_gradients | |
, check_is_compatible=check_is_compatible | |
, recompute_attn=config['recompute_attn'] | |
, epoch_callback=epoch_callback | |
, bptt_extra_samples = config['bptt_extra_samples'] | |
, extra_prior_kwargs_dict={ | |
'num_features': config['num_features'] | |
, 'fuse_x_y': False | |
, 'hyperparameters': prior_hyperparameters | |
, 'num_outputs':num_outputs | |
, 'dynamic_batch_size': 1 if ('num_global_att_tokens' in config and config['num_global_att_tokens']) else 2 | |
, **extra_kwargs | |
} | |
, lr=config['lr'] | |
, verbose=verbose_train, | |
weight_decay=config.get('weight_decay', 0.0), | |
normalize_labels=True) | |
return model |