ctheodoris hchen725 commited on
Commit
69e6887
·
verified ·
1 Parent(s): 54b408b

Prepare data cell level pass cell state dict instead of genes (#483)

Browse files

- Update geneformer/classifier.py (107b2b84dd2a79cdba7470322d6c3637cc1d41ed)


Co-authored-by: Han Chen <[email protected]>

Files changed (1) hide show
  1. geneformer/classifier.py +9 -3
geneformer/classifier.py CHANGED
@@ -437,14 +437,20 @@ class Classifier:
437
  )
438
  # rename cell state column to "label"
439
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
 
 
 
 
 
440
 
 
441
  # convert classes to numerical labels and save as id_class_dict
442
  # of note, will label all genes in gene_class_dict
443
  # if (cross-)validating, genes will be relabeled in column "labels" for each split
444
  # at the time of training with Classifier.validate
445
- data, id_class_dict = cu.label_classes(
446
- self.classifier, data, self.gene_class_dict, self.nproc
447
- )
448
 
449
  # save id_class_dict for future reference
450
  id_class_output_path = (
 
437
  )
438
  # rename cell state column to "label"
439
  data = cu.rename_cols(data, self.cell_state_dict["state_key"])
440
+
441
+ # convert classes to numerical labels and save as id_class_dict
442
+ data, id_class_dict = cu.label_classes(
443
+ self.classifier, data, self.cell_state_dict, self.nproc
444
+ )
445
 
446
+ elif self.classifier == "gene":
447
  # convert classes to numerical labels and save as id_class_dict
448
  # of note, will label all genes in gene_class_dict
449
  # if (cross-)validating, genes will be relabeled in column "labels" for each split
450
  # at the time of training with Classifier.validate
451
+ data, id_class_dict = cu.label_classes(
452
+ self.classifier, data, self.gene_class_dict, self.nproc
453
+ )
454
 
455
  # save id_class_dict for future reference
456
  id_class_output_path = (