arjunanand13 commited on
Commit
baab6b5
verified
1 Parent(s): 8036f6f

Upload 4 files

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. app (1).py +31 -0
  3. delay_tyre.mp4 +3 -0
  4. process.py +111 -0
  5. requirements (1).txt +26 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ delay_tyre.mp4 filter=lfs diff=lfs merge=lfs -text
app (1).py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from process import inference
4
+
5
+ def clickit(video, prompt):
6
+ return inference(
7
+ video,
8
+ prompt
9
+ )
10
+
11
+ with gr.Blocks() as blok:
12
+ with gr.Row():
13
+ with gr.Column():
14
+ video = gr.Video(
15
+ label="video input",
16
+ )
17
+ prompt = gr.Text(
18
+ label="Prompt",
19
+ value="Please describe this video in detail."
20
+ )
21
+ with gr.Column():
22
+ button = gr.Button("Caption it", variant="primary")
23
+ text = gr.Text(label="Output")
24
+
25
+ button.click(
26
+ fn=clickit,
27
+ inputs=[video, prompt],
28
+ outputs=[text]
29
+ )
30
+
31
+ blok.launch()
delay_tyre.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63a245902a9509f492fda6537c84ab53c3582f868503982b53419c01fee6e592
3
+ size 7352910
process.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import spaces
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ from decord import cpu, VideoReader, bridge
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from transformers import BitsAndBytesConfig
9
+
10
+ MODEL_PATH = "THUDM/cogvlm2-llama3-caption"
11
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+ TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[
13
+ 0] >= 8 else torch.float16
14
+
15
+ parser = argparse.ArgumentParser(description="CogVLM2-Video CLI Demo")
16
+ parser.add_argument('--quant', type=int, choices=[4, 8], help='Enable 4-bit or 8-bit precision loading', default=4)
17
+ args = parser.parse_args([])
18
+
19
+ def load_video(video_data, strategy='chat'):
20
+ bridge.set_bridge('torch')
21
+ mp4_stream = video_data
22
+ num_frames = 24
23
+ decord_vr = VideoReader(io.BytesIO(mp4_stream), ctx=cpu(0))
24
+ frame_id_list = None
25
+ total_frames = len(decord_vr)
26
+
27
+ if strategy == 'base':
28
+ clip_end_sec = 60
29
+ clip_start_sec = 0
30
+ start_frame = int(clip_start_sec * decord_vr.get_avg_fps())
31
+ end_frame = min(total_frames,
32
+ int(clip_end_sec * decord_vr.get_avg_fps())) if clip_end_sec is not None else total_frames
33
+ frame_id_list = np.linspace(start_frame, end_frame - 1, num_frames, dtype=int)
34
+ elif strategy == 'chat':
35
+ timestamps = decord_vr.get_frame_timestamp(np.arange(total_frames))
36
+ timestamps = [i[0] for i in timestamps]
37
+ max_second = round(max(timestamps)) + 1
38
+ frame_id_list = []
39
+ for second in range(max_second):
40
+ closest_num = min(timestamps, key=lambda x: abs(x - second))
41
+ index = timestamps.index(closest_num)
42
+ frame_id_list.append(index)
43
+ if len(frame_id_list) >= num_frames:
44
+ break
45
+
46
+ video_data = decord_vr.get_batch(frame_id_list)
47
+ video_data = video_data.permute(3, 0, 1, 2)
48
+ return video_data
49
+
50
+ # Configure quantization
51
+ quantization_config = BitsAndBytesConfig(
52
+ load_in_4bit=True,
53
+ bnb_4bit_compute_dtype=TORCH_TYPE,
54
+ bnb_4bit_use_double_quant=True,
55
+ bnb_4bit_quant_type="nf4"
56
+ )
57
+
58
+ tokenizer = AutoTokenizer.from_pretrained(
59
+ MODEL_PATH,
60
+ trust_remote_code=True,
61
+ )
62
+
63
+ model = AutoModelForCausalLM.from_pretrained(
64
+ MODEL_PATH,
65
+ torch_dtype=TORCH_TYPE,
66
+ trust_remote_code=True,
67
+ quantization_config=quantization_config,
68
+ device_map="auto"
69
+ ).eval()
70
+
71
+ @spaces.GPU
72
+ def predict(prompt, video_data, temperature):
73
+ strategy = 'chat'
74
+ video = load_video(video_data, strategy=strategy)
75
+ history = []
76
+ query = prompt
77
+ inputs = model.build_conversation_input_ids(
78
+ tokenizer=tokenizer,
79
+ query=query,
80
+ images=[video],
81
+ history=history,
82
+ template_version=strategy
83
+ )
84
+
85
+ inputs = {
86
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to(DEVICE),
87
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to(DEVICE),
88
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to(DEVICE),
89
+ 'images': [[inputs['images'][0].to(DEVICE).to(TORCH_TYPE)]],
90
+ }
91
+
92
+ gen_kwargs = {
93
+ "max_new_tokens": 2048,
94
+ "pad_token_id": 128002,
95
+ "top_k": 1,
96
+ "do_sample": False,
97
+ "top_p": 0.1,
98
+ "temperature": temperature,
99
+ }
100
+
101
+ with torch.no_grad():
102
+ outputs = model.generate(**inputs, **gen_kwargs)
103
+ outputs = outputs[:, inputs['input_ids'].shape[1]:]
104
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
105
+ return response
106
+
107
+ def inference(video, prompt):
108
+ temperature = 0.1
109
+ video_data = open(video, 'rb').read()
110
+ response = predict(prompt, video_data, temperature)
111
+ return response
requirements (1).txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ decord>=0.6.0
2
+ #鏍规嵁https://download.pytorch.org/whl/torch/锛宲ython鐗堟湰涓篬3.8,3.11]
3
+ torch==2.1.0
4
+ torchvision== 0.16.0
5
+ pytorchvideo==0.1.5
6
+ xformers
7
+ transformers==4.42.4
8
+ #git+https://github.com/huggingface/transformers.git
9
+ huggingface-hub>=0.23.0
10
+ pillow
11
+ chainlit>=1.0
12
+ pydantic>=2.7.1
13
+ timm>=0.9.16
14
+ openai>=1.30.1
15
+ loguru>=0.7.2
16
+ pydantic>=2.7.1
17
+ einops
18
+ sse-starlette>=2.1.0
19
+ flask
20
+ gunicorn
21
+ gevent
22
+ requests
23
+ gradio
24
+ accelerate
25
+ bitsandbytes>=0.39.0
26
+ spaces