Prepare data cell level pass cell state dict instead of genes (#483)
Browse files- Update geneformer/classifier.py (107b2b84dd2a79cdba7470322d6c3637cc1d41ed)
Co-authored-by: Han Chen <[email protected]>
- geneformer/classifier.py +9 -3
geneformer/classifier.py
CHANGED
@@ -437,14 +437,20 @@ class Classifier:
|
|
437 |
)
|
438 |
# rename cell state column to "label"
|
439 |
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
|
|
|
|
|
|
|
|
|
|
440 |
|
|
|
441 |
# convert classes to numerical labels and save as id_class_dict
|
442 |
# of note, will label all genes in gene_class_dict
|
443 |
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
444 |
# at the time of training with Classifier.validate
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
|
449 |
# save id_class_dict for future reference
|
450 |
id_class_output_path = (
|
|
|
437 |
)
|
438 |
# rename cell state column to "label"
|
439 |
data = cu.rename_cols(data, self.cell_state_dict["state_key"])
|
440 |
+
|
441 |
+
# convert classes to numerical labels and save as id_class_dict
|
442 |
+
data, id_class_dict = cu.label_classes(
|
443 |
+
self.classifier, data, self.cell_state_dict, self.nproc
|
444 |
+
)
|
445 |
|
446 |
+
elif self.classifier == "gene":
|
447 |
# convert classes to numerical labels and save as id_class_dict
|
448 |
# of note, will label all genes in gene_class_dict
|
449 |
# if (cross-)validating, genes will be relabeled in column "labels" for each split
|
450 |
# at the time of training with Classifier.validate
|
451 |
+
data, id_class_dict = cu.label_classes(
|
452 |
+
self.classifier, data, self.gene_class_dict, self.nproc
|
453 |
+
)
|
454 |
|
455 |
# save id_class_dict for future reference
|
456 |
id_class_output_path = (
|