Nu Appleblossom commited on
Commit
ae9cccb
·
1 Parent(s): 9e22acb

switching from graphviz to networkx

Browse files
Files changed (2) hide show
  1. app.py +28 -29
  2. requirements.txt +1 -2
app.py CHANGED
@@ -15,7 +15,6 @@ from dotenv import load_dotenv
15
  from huggingface_hub import hf_hub_download
16
  import spaces
17
  import traceback
18
- from graphviz import Digraph
19
  from PIL import Image, ImageDraw, ImageFont
20
  from io import BytesIO
21
  import functools
@@ -320,39 +319,39 @@ def add_nodes_edges(dot, node, config, max_weight, min_weight, parent=None, is_r
320
  for child in node.get('children', []):
321
  add_nodes_edges(dot, child, config, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, trim_cutoff=trim_cutoff)
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  def create_tree_diagram(data, config, max_weight, min_weight, trim_cutoff=0):
324
- import os
325
- from graphviz import Digraph
326
 
327
- # Ensure the system can find the dot command
328
- os.environ["PATH"] += os.pathsep + "/usr/bin"
 
 
 
329
 
330
- # Initialize the Digraph object and explicitly set the dot command path
331
- dot = Digraph(comment='Definition Tree', format='png', engine='dot')
332
- dot.command = '/usr/bin/dot'
333
-
334
- dot.attr(rankdir='LR', size='5040,5000', margin='0.06', nodesep='0.06', ranksep='1', dpi='120', bgcolor='white')
335
 
336
- add_nodes_edges(dot, data, config, max_weight, min_weight, trim_cutoff=trim_cutoff)
337
 
338
- # Save to a temporary file first
339
- temp_filename = "temp_tree_diagram"
340
- dot.render(temp_filename, format='png', cleanup=True)
341
-
342
- # Read the file back into a BytesIO object
343
- with open(f"{temp_filename}.png", "rb") as f:
344
- output = BytesIO(f.read())
345
-
346
- # Add white background
347
- with Image.open(output) as img:
348
- bg = Image.new("RGB", (img.width, 5000), (255, 255, 255))
349
- y_offset = (5000 - img.height) // 2
350
- bg.paste(img, (0, y_offset))
351
- final_output = BytesIO()
352
- bg.save(final_output, 'PNG')
353
- final_output.seek(0)
354
-
355
- return final_output
356
 
357
 
358
 
 
15
  from huggingface_hub import hf_hub_download
16
  import spaces
17
  import traceback
 
18
  from PIL import Image, ImageDraw, ImageFont
19
  from io import BytesIO
20
  import functools
 
319
  for child in node.get('children', []):
320
  add_nodes_edges(dot, child, config, max_weight, min_weight, parent=node_id, is_root=False, depth=depth+1, trim_cutoff=trim_cutoff)
321
 
322
+ import networkx as nx
323
+ import matplotlib.pyplot as plt
324
+
325
+ def add_nodes_edges_nx(G, node, parent=None, is_root=True):
326
+ node_id = str(id(node))
327
+ token = node.get('token', '').strip()
328
+
329
+ if is_root or token:
330
+ G.add_node(node_id, label=token if not is_root else "*")
331
+ if parent:
332
+ G.add_edge(parent, node_id)
333
+
334
+ for child in node.get('children', []):
335
+ add_nodes_edges_nx(G, child, parent=node_id, is_root=False)
336
+
337
  def create_tree_diagram(data, config, max_weight, min_weight, trim_cutoff=0):
338
+ G = nx.DiGraph()
339
+ add_nodes_edges_nx(G, data)
340
 
341
+ # Draw the tree using matplotlib
342
+ plt.figure(figsize=(12, 12))
343
+ pos = nx.spring_layout(G, k=0.5, iterations=50)
344
+ labels = nx.get_node_attributes(G, 'label')
345
+ nx.draw(G, pos, labels=labels, with_labels=True, node_size=5000, node_color="lightblue", font_size=10, font_weight="bold", edge_color="gray", arrows=False)
346
 
347
+ # Save the image to a BytesIO object
348
+ output = BytesIO()
349
+ plt.savefig(output, format='png')
350
+ plt.close()
351
+ output.seek(0)
352
 
353
+ return output
354
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
 
requirements.txt CHANGED
@@ -12,6 +12,5 @@ spaces
12
  graphviz
13
  Pillow
14
  accelerate
 
15
  #fnord
16
- #fnord
17
- #fnord
 
12
  graphviz
13
  Pillow
14
  accelerate
15
+ networkx
16
  #fnord