CodeGemma Release
Collection
18 items
•
Updated
•
78
This repository is publicly accessible, but you have to accept the conditions to access its files and content.
To access CodeGemma on Hugging Face, you’re required to review and agree to Google’s usage license. To do this, please ensure you’re logged-in to Hugging Face and click below. Requests are processed immediately.
Log in or Sign Up to review the conditions and access this model content.
This repository corresponds to the CodeGemma 2B checkpoint for use with Gemma PyTorch. If you're looking for the
transformers
implementation, or more detailed model card, visit https://huggingface.co./google/codegemma-2b.
Model Page: CodeGemma
Resources and Technical Documentation:
Terms of Use: Terms
Authors: Google
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
VARIANT = "2b"
MACHINE_TYPE = "cpu"
weights_dir = 'codegemma-2b-pytorch'
@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
model = GemmaForCausalLM(model_config)
ckpt_path = os.path.join(weights_dir, f'codegemma-{VARIANT}.pt')
model.load_weights(ckpt_path)
model = model.to(device).eval()
FIM_PROMPT = """<|fim_prefix|>import <|fim_suffix|>if __name__ == "__main__":
sys.exit(0)<|fim_middle|>"""
model.generate(
FIM_PROMPT,
device=device,
output_len=100,
)