import time import functools import random import math import traceback import torch from torch import nn import gpytorch from botorch.models import SingleTaskGP from botorch.models.gp_regression import MIN_INFERRED_NOISE_LEVEL from botorch.fit import fit_gpytorch_model from gpytorch.mlls import ExactMarginalLogLikelihood from gpytorch.likelihoods import GaussianLikelihood from gpytorch.priors.torch_priors import GammaPrior from gpytorch.constraints import GreaterThan from bar_distribution import BarDistribution from utils import default_device from .utils import get_batch_to_dataloader from . import fast_gp def get_model(x, y, hyperparameters: dict, sample=True): aug_batch_shape = SingleTaskGP(x,y.unsqueeze(-1))._aug_batch_shape noise_prior = GammaPrior(hyperparameters.get('noise_concentration',1.1), hyperparameters.get('noise_rate',0.05)) noise_prior_mode = (noise_prior.concentration - 1) / noise_prior.rate likelihood = GaussianLikelihood( noise_prior=noise_prior, batch_shape=aug_batch_shape, noise_constraint=GreaterThan( MIN_INFERRED_NOISE_LEVEL, transform=None, initial_value=noise_prior_mode, ), ) model = SingleTaskGP(x, y.unsqueeze(-1), covar_module=gpytorch.kernels.ScaleKernel( gpytorch.kernels.MaternKernel( nu=hyperparameters.get('nu',2.5), ard_num_dims=x.shape[-1], batch_shape=aug_batch_shape, lengthscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('lengthscale_concentration',3.0), hyperparameters.get('lengthscale_rate',6.0)), ), batch_shape=aug_batch_shape, outputscale_prior=gpytorch.priors.GammaPrior(hyperparameters.get('outputscale_concentration',.5), hyperparameters.get('outputscale_rate',0.15)), ), likelihood=likelihood) likelihood = model.likelihood if sample: sampled_model = model.pyro_sample_from_prior() return sampled_model, sampled_model.likelihood else: assert not(hyperparameters.get('sigmoid', False)) and not(hyperparameters.get('y_minmax_norm', False)), "Sigmoid and y_minmax_norm can only be used to sample models..." return model, likelihood @torch.no_grad() def get_batch(batch_size, seq_len, num_features, device=default_device, hyperparameters=None, batch_size_per_gp_sample=None, num_outputs=1, fix_to_range=None, equidistant_x=False): ''' This function is very similar to the equivalent in .fast_gp. The only difference is that this function operates over a mixture of GP priors. :param batch_size: :param seq_len: :param num_features: :param device: :param hyperparameters: :param for_regression: :return: ''' assert num_outputs == 1 hyperparameters = hyperparameters or {} with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))): batch_size_per_gp_sample = (batch_size_per_gp_sample or max(batch_size // 10,1)) assert batch_size % batch_size_per_gp_sample == 0 total_num_candidates = batch_size*(2**(fix_to_range is not None)) num_candidates = batch_size_per_gp_sample * (2**(fix_to_range is not None)) if equidistant_x: assert num_features == 1 x = torch.linspace(0,1.,seq_len).unsqueeze(0).repeat(total_num_candidates,1).unsqueeze(-1) else: x = torch.rand(total_num_candidates, seq_len, num_features, device=device) samples = [] for i in range(0,total_num_candidates,num_candidates): num_of_dims ~ uniform model, likelihood = get_model(x[i:i+num_candidates,...,:num_of_dims], torch.zeros(num_candidates,x.shape[1]), hyperparameters) x[i:i + num_candidates, ..., num_of_dims:] = 0 x[i:i + num_candidates, ..., :num_of_dims] *= total_dims/num_of_dims #print(model.covar_module.base_kernel.lengthscale) model.to(device) # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda() # trained_model.eval() successful_sample = 0 throwaway_share = 0. while successful_sample < 1: with gpytorch.settings.prior_mode(True): d = model(x[i:i+num_candidates]) d = likelihood(d) sample = d.sample() # bs_per_gp_s x T if hyperparameters.get('y_minmax_norm'): sample = ((sample - sample.min(1)[0]) / (sample.max(1)[0] - sample.min(1)[0])) if hyperparameters.get('sigmoid'): sample = sample.sigmoid() if fix_to_range is None: samples.append(sample.transpose(0, 1)) successful_sample = True continue smaller_mask = sample < fix_to_range[0] larger_mask = sample >= fix_to_range[1] in_range_mask = ~ (smaller_mask | larger_mask).any(1) throwaway_share += (~in_range_mask[:batch_size_per_gp_sample]).sum()/batch_size_per_gp_sample if in_range_mask.sum() < batch_size_per_gp_sample: successful_sample -= 1 if successful_sample < 100: print("Please change hyper-parameters (e.g. decrease outputscale_mean) it" "seems like the range is set to tight for your hyper-parameters.") continue x[i:i+batch_size_per_gp_sample] = x[i:i+num_candidates][in_range_mask][:batch_size_per_gp_sample] sample = sample[in_range_mask][:batch_size_per_gp_sample] samples.append(sample.transpose(0, 1)) successful_sample = True if random.random() < .01: print('throwaway share', throwaway_share/(batch_size//batch_size_per_gp_sample)) #print(f'took {time.time() - start}') sample = torch.cat(samples, 1) x = x.view(-1,batch_size,seq_len,num_features)[0] # TODO think about enabling the line below #sample = sample - sample[0, :].unsqueeze(0).expand(*sample.shape) x = x.transpose(0,1) assert x.shape[:2] == sample.shape[:2] target_sample = sample return x, sample, target_sample # x.shape = (T,B,H) class DataLoader(get_batch_to_dataloader(get_batch)): num_outputs = 1 @torch.no_grad() def validate(self, model, step_size=1, start_pos=0): if isinstance(model.criterion, BarDistribution): (x,y), target_y = self.gbm(**self.get_batch_kwargs, fuse_x_y=self.fuse_x_y) model.eval() losses = [] for eval_pos in range(start_pos, len(x), step_size): logits = model((x,y), single_eval_pos=eval_pos) means = model.criterion.mean(logits) # num_evals x batch_size mse = nn.MSELoss() losses.append(mse(means[0], target_y[eval_pos])) model.train() return torch.stack(losses) else: return 123. @torch.enable_grad() def get_fitted_model(x, y, hyperparameters, device): # fit the gaussian process model, likelihood = get_model(x,y,hyperparameters,sample=False) #print(model.covar_module.base_kernel.lengthscale) model.to(device) mll = ExactMarginalLogLikelihood(likelihood, model) model.train() fit_gpytorch_model(mll) #print(model.covar_module.base_kernel.lengthscale) return model, likelihood evaluate = functools.partial(fast_gp.evaluate, get_model_on_device=get_fitted_model) def get_mcmc_model(x, y, hyperparameters, device, num_samples, warmup_steps): from pyro.infer.mcmc import NUTS, MCMC import pyro x = x.to(device) y = y.to(device) model, likelihood = get_model(x, y, hyperparameters, sample=False) model.to(device) def pyro_model(x, y): sampled_model = model.pyro_sample_from_prior() _ = sampled_model.likelihood(sampled_model(x)) return y nuts_kernel = NUTS(pyro_model, adapt_step_size=True) mcmc_run = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps) #print(x.shape) mcmc_run.run(x, y) model.pyro_load_from_samples(mcmc_run.get_samples()) model.eval() # test_x = torch.linspace(0, 1, 101).unsqueeze(-1) # test_y = torch.sin(test_x * (2 * math.pi)) # expanded_test_x = test_x.unsqueeze(0).repeat(num_samples, 1, 1) # output = model(expanded_test_x) #print(x.shape) return model, likelihood # output = model(x[-1].unsqueeze(1).repeat(1, num_samples 1)) # return output.mean def get_mean_logdensity(dists, x: torch.Tensor, full_range=None): means = torch.cat([d.mean.squeeze() for d in dists], 0) vars = torch.cat([d.variance.squeeze() for d in dists], 0) assert len(means.shape) == 1 and len(vars.shape) == 1 dist = torch.distributions.Normal(means, vars.sqrt()) #logprobs = torch.cat([d.log_prob(x) for d in dists], 0) logprobs = dist.log_prob(x) if full_range is not None: used_weight = 1. - (dist.cdf(torch.tensor(full_range[0])) + (1.-dist.cdf(torch.tensor(full_range[1])))) if torch.isinf(-torch.log(used_weight)).any() or torch.isinf(torch.log(used_weight)).any(): print('factor is inf', -torch.log(used_weight)) logprobs -= torch.log(used_weight) assert len(logprobs.shape) == 1 #print(logprobs) return torch.logsumexp(logprobs, 0) - math.log(len(logprobs)) def evaluate_(x, y, y_non_noisy, hyperparameters=None, device=default_device, num_samples=100, warmup_steps=300, full_range=None, min_seq_len=0, use_likelihood=False): with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations',(True,True,True))), gpytorch.settings.fast_pred_var(False): x = x.to(device) y = y.to(device) start_time = time.time() losses_after_t = [.0] if min_seq_len == 0 else [] all_losses = [] for t in range(max(min_seq_len,1), len(x)): #print('Timestep', t) loss_sum = 0. step_losses = [] start_step = time.time() for b_i in range(x.shape[1]): done = 0 while done < 1: try: model, likelihood = get_mcmc_model(x[:t, b_i], y[:t, b_i], hyperparameters, device, num_samples=num_samples, warmup_steps=warmup_steps) model.eval() with torch.no_grad(): dists = model(x[t, b_i, :].unsqueeze( 0)) # TODO check what is going on here! Does the GP interpret the input wrong? if use_likelihood: dists = likelihood(dists) l = -get_mean_logdensity([dists], y[t, b_i], full_range) done = 1 except Exception as e: done -= 1 print('Trying again..') print(traceback.format_exc()) print(e) finally: if done < -10: print('Too many retries...') exit() step_losses.append(l.item()) #print('loss',l.item()) print(f'current average loss at step {t} is {sum(step_losses)/len(step_losses)} with {(time.time()-start_step)/len(step_losses)} s per eval.') loss_sum += l loss_sum /= x.shape[1] all_losses.append(step_losses) print(f'loss after step {t} is {loss_sum}') losses_after_t.append(loss_sum) print(f'losses so far {torch.tensor(losses_after_t)}') return torch.tensor(losses_after_t), time.time() - start_time, all_losses if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--batch_size', type=int) parser.add_argument('--seq_len', type=int) parser.add_argument('--min_seq_len', type=int, default=0) parser.add_argument('--warmup_steps', type=int) parser.add_argument('--num_samples', type=int) parser.add_argument('--min_y', type=int) parser.add_argument('--max_y', type=int) parser.add_argument('--dim', type=int, default=1) parser.add_argument('--use_likelihood', default=True, type=bool) parser.add_argument('--device', default='cpu') parser.add_argument('--outputscale_concentraion', default=2., type=float) parser.add_argument('--noise_concentration', default=1.1, type=float) parser.add_argument('--noise_rate', default=.05, type=float) args = parser.parse_args() print('min_y:', args.min_y) full_range = (None if args.min_y is None else (args.min_y,args.max_y)) hps = {'outputscale_concentration': args.outputscale_concentraion, 'noise_concentration': args.noise_concentration, 'noise_rate': args.noise_rate, 'fast_computations': (False,False,False)} x, y, _ = get_batch(args.batch_size, args.seq_len, args.dim, fix_to_range=full_range, hyperparameters=hps) print('RESULT:', evaluate_(x, y, y, device=args.device, warmup_steps=args.warmup_steps, num_samples=args.num_samples, full_range=full_range, min_seq_len=args.min_seq_len, hyperparameters=hps, use_likelihood=args.use_likelihood))