wanicca commited on
Commit
1a04d3c
·
1 Parent(s): b18ec20

lora filter 允许比例,例如0.2*0-4

Browse files
Files changed (1) hide show
  1. rwkv_lora.py +22 -7
rwkv_lora.py CHANGED
@@ -7,38 +7,53 @@ import types, gc, os, time, re
7
  import torch
8
  from torch.nn import functional as F
9
 
10
- def get_filter_keys(layer_filter):
11
  if layer_filter:
12
  layers = []
 
13
  for layer in layer_filter.split(' '):
 
 
 
 
 
14
  if layer.isdecimal():
15
  layers.append(int(layer))
 
16
  elif '-' in layer:
17
  start,_,end = layer.partition('-')
18
  start,end = int(start),int(end)
19
  layers.extend(range(start,end+1))
 
 
20
  else:
21
  raise NotImplementedError("layer_filter Not implemented:",layer_filter)
22
  layers = sorted(set(layers))
23
  layer_prefixes = tuple(f"blocks.{l}." for l in layers)
24
- def filter_keys(keys):
25
  new_keys = []
26
  for key in keys:
27
- if key.startswith("blocks."):
28
  if not key.startswith(layer_prefixes):
29
  continue
30
  new_keys.append(key)
31
  return new_keys
32
-
 
 
 
 
33
  else:
34
  def filter_keys(keys):
35
  return keys
36
- return filter_keys
 
 
37
 
38
  def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,):
39
  print(f"Loading LoRA: {lora}")
40
  print(f"LoRA alpha={lora_alpha}, layer_filter={layer_filter}")
41
- filter_keys = get_filter_keys(layer_filter)
42
  w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
43
  # merge LoRA-only slim checkpoint into the main weights
44
  w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
@@ -61,7 +76,7 @@ def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,):
61
  w[k] = w[k].to(device=device)
62
  w[lora_A] = w[lora_A].to(device=device)
63
  w[lora_B] = w[lora_B].to(device=device)
64
- w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r)
65
  output_w[k] = w[k].to(device='cpu', copy=True)
66
  del w[k]
67
  del w[lora_A]
 
7
  import torch
8
  from torch.nn import functional as F
9
 
10
+ def get_filter_keys_and_merge_coef(layer_filter):
11
  if layer_filter:
12
  layers = []
13
+ layer_coef = {}
14
  for layer in layer_filter.split(' '):
15
+ if '*' in layer:
16
+ coef,_,layer = layer.partition('*')
17
+ coef = float(coef)
18
+ else:
19
+ coef = 1
20
  if layer.isdecimal():
21
  layers.append(int(layer))
22
+ layer_coef[int(layer)]=coef
23
  elif '-' in layer:
24
  start,_,end = layer.partition('-')
25
  start,end = int(start),int(end)
26
  layers.extend(range(start,end+1))
27
+ for l in range(start,end+1):
28
+ layer_coef[l] = coef
29
  else:
30
  raise NotImplementedError("layer_filter Not implemented:",layer_filter)
31
  layers = sorted(set(layers))
32
  layer_prefixes = tuple(f"blocks.{l}." for l in layers)
33
+ def filter_keys(keys):
34
  new_keys = []
35
  for key in keys:
36
+ if key.startswith("blocks."): #过滤掉blocks开头,且不在允许范围内的权重
37
  if not key.startswith(layer_prefixes):
38
  continue
39
  new_keys.append(key)
40
  return new_keys
41
+ def merge_coef(key):
42
+ if key.startswith('blocks.') and int(key.split('.')[1]) in layer_coef:
43
+ return layer_coef[int(key.split('.')[1])]
44
+ else:
45
+ return 1
46
  else:
47
  def filter_keys(keys):
48
  return keys
49
+ def merge_coef(key):
50
+ return 1
51
+ return filter_keys,merge_coef
52
 
53
  def lora_merge(base_model,lora,lora_alpha,device="cuda",layer_filter=None,):
54
  print(f"Loading LoRA: {lora}")
55
  print(f"LoRA alpha={lora_alpha}, layer_filter={layer_filter}")
56
+ filter_keys,merge_coef = get_filter_keys_and_merge_coef(layer_filter)
57
  w: Dict[str, torch.Tensor] = torch.load(base_model, map_location='cpu')
58
  # merge LoRA-only slim checkpoint into the main weights
59
  w_lora: Dict[str, torch.Tensor] = torch.load(lora, map_location='cpu')
 
76
  w[k] = w[k].to(device=device)
77
  w[lora_A] = w[lora_A].to(device=device)
78
  w[lora_B] = w[lora_B].to(device=device)
79
+ w[k] += w[lora_B] @ w[lora_A] * (lora_alpha / lora_r) * merge_coef(k)
80
  output_w[k] = w[k].to(device='cpu', copy=True)
81
  del w[k]
82
  del w[lora_A]