Uploaded model

  • Developed by: mnm373
  • License: Gemma Terms of Use
  • Finetuned from model : gemma-2-9b

This gemma2 model was trained 2x faster with Unsloth and Huggingface's TRL library.

Usage

松尾研大規模言語モデル講座2024コンペの推論方法を以下に記載します。

# 必要なライブラリをインストール
!pip uninstall unsloth -y
!pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --upgrade torch
!pip install --upgrade xformers


# 必要なライブラリを読み込み
from unsloth import FastLanguageModel
import json
from tqdm import tqdm
import re


# モデルをロード
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "mnm373/gemma-2-9b-it-v3_lora",
    load_in_4bit = True,
    trust_remote_code=True,
)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference

# データセットの読み込み
# 事前にデータをアップロードしてください
datasets = []
with open("./elyza-tasks-100-TV_0.jsonl", "r") as f:
    item = ""
    for line in f:
      line = line.strip()
      item += line
      if item.endswith("}"):
        datasets.append(json.loads(item))
        item = ""

# 推論の実行
results = []
for dt in tqdm(datasets):
    input_text = dt["input"]

    prompt = f"""### 指示\n{input_text}\n### 回答\n"""

    inputs = tokenizer([prompt], return_tensors="pt").to(model.device)

    outputs = model.generate(**inputs, max_new_tokens=1024, use_cache=True, do_sample=False, repetition_penalty=1.2)
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).split('\n### 回答\n')[-1]

    # 不要なフレーズを削除
    if prediction.startswith("こんにちは!"):
        prediction = prediction.lstrip("こんにちは!")
    if prediction.startswith("もちろんです!"):
        prediction = prediction.lstrip("もちろんです!")

    phrases_to_remove = [
        "ユーモアを交えてお答えしますね。",
        "ユーモアを交えつつお答えしますね。"
    ]
    for phrase in phrases_to_remove:
        prediction = prediction.replace(phrase, "")

    # 不要な空白や改行をトリミング
    prediction = prediction.strip()

    results.append({"task_id": dt["task_id"], "input": input_text, "output": prediction})


# 結果をjsonlで保存。
json_file_id = re.sub(".*/", "", adapter_id)
with open(f"/content/{json_file_id}_output.jsonl", 'w', encoding='utf-8') as f:
    for result in results:
        json.dump(result, f, ensure_ascii=False)
        f.write('\n')
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.