sd_comfy_api_v2.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #This is an example that uses the websockets api to know when a prompt execution is done
  2. #Once the prompt execution is done it downloads the images using the /history endpoint
  3. import websocket #NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
  4. import uuid
  5. import json
  6. import urllib.request
  7. import urllib.parse
  8. from PIL import Image
  9. import io
  10. import random
  11. import sys
  12. import base64
  13. server_address = "127.0.0.1:8188"
  14. client_id = str(uuid.uuid4())
  15. def convert_base64_string_to_object(base64_string):
  16. bytes = base64.b64decode(base64_string)
  17. string = bytes.decode("ascii")
  18. return json.loads(string)
  19. def set_filename(json_obj, title, new_prefix):
  20. for key, value in json_obj.items():
  21. if isinstance(value, dict):
  22. if value.get("_meta", {}).get("title") == title:
  23. if "inputs" in value and "filename_prefix" in value["inputs"]:
  24. value["inputs"]["filename_prefix"] = new_prefix
  25. else:
  26. result = set_filename(value, title, new_prefix)
  27. if result:
  28. return result
  29. return None
  30. def find_node(json_obj, title):
  31. for key, value in json_obj.items():
  32. if isinstance(value, dict):
  33. if value.get("_meta", {}).get("title") == title:
  34. return value
  35. else:
  36. result = find_node(value, title)
  37. if result:
  38. return result
  39. return None
  40. def queue_prompt(prompt):
  41. p = {"prompt": prompt, "client_id": client_id}
  42. data = json.dumps(p).encode('utf-8')
  43. req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
  44. return json.loads(urllib.request.urlopen(req).read())
  45. def get_prompt(ai_scene_info):
  46. with open(
  47. "D://Git//ap-canvas-creation-module//04_stable_diffusion//workflows//canvas_3d_to_img_standard_V1.json",
  48. "r",
  49. ) as f:
  50. prompt_text_json = f.read()
  51. prompt = json.loads(prompt_text_json)
  52. #set the text prompt for our positive CLIPTextEncode
  53. positive_text = ai_scene_info["ai_scene"]["prompt"]["positive_prompt"]
  54. negative_text = ai_scene_info["ai_scene"]["prompt"]["negative_prompt"]
  55. image_path = "D://Git//ap-canvas-creation-module//03_blender//sd_blender//sample_scene//Renders//15a314a1-8ba1-4e0e-ad0c-f605b06f89f8//"
  56. image_base_path = image_path + "base0001.jpg"
  57. image_alpha_products_path = image_path + "alpha_products0001.jpg"
  58. # image_depth_path = image_path + "depth0001.png"
  59. prompt = json.loads(prompt_text_json)
  60. set_filename(prompt, "Save Image", "{project_id}/basic_api_example".format(project_id=ai_scene_info["project_id"]))
  61. ksampler_main = find_node(prompt, "KSampler")
  62. ksampler_main["inputs"]["noise_seed"] = random.randint(0, 1000000)
  63. ksampler_main = find_node(prompt, "KSampler")
  64. ksampler_main["inputs"]["steps"] = ai_scene_info["ai_scene"]["settings"]["steps"]
  65. ksampler_main["inputs"]["cfg"] = ai_scene_info["ai_scene"]["settings"]["cfg"]
  66. prompt_positive = find_node(prompt, "positive_CLIPTextEncodeSDXL")
  67. prompt_positive["inputs"]["text_g"] = positive_text
  68. prompt_positive["inputs"]["text_l"] = positive_text
  69. prompt_negative = find_node(prompt, "negative_CLIPTextEncodeSDXL")
  70. prompt_negative["inputs"]["text_g"] = negative_text
  71. prompt_negative["inputs"]["text_l"] = negative_text
  72. image_base = find_node(prompt, "image_base")
  73. image_base["inputs"]["image"] = image_base_path
  74. image_base = find_node(prompt, "image_product_mask")
  75. image_base["inputs"]["image"] = image_alpha_products_path
  76. image_base = find_node(prompt, "image_depth")
  77. # image_base["inputs"]["image"] = image_depth_path
  78. return prompt
  79. def get_image(filename, subfolder, folder_type):
  80. data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
  81. url_values = urllib.parse.urlencode(data)
  82. with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
  83. return response.read()
  84. def get_history(prompt_id):
  85. with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
  86. return json.loads(response.read())
  87. def get_images(ws, prompt):
  88. prompt_id = queue_prompt(prompt)['prompt_id']
  89. output_images = {}
  90. while True:
  91. out = ws.recv()
  92. if isinstance(out, str):
  93. message = json.loads(out)
  94. if message['type'] == 'executing':
  95. data = message['data']
  96. if data['node'] is None and data['prompt_id'] == prompt_id:
  97. break #Execution is done
  98. else:
  99. continue #previews are binary data
  100. history = get_history(prompt_id)[prompt_id]
  101. for node_id in history['outputs']:
  102. node_output = history['outputs'][node_id]
  103. images_output = []
  104. if 'images' in node_output:
  105. for image in node_output['images']:
  106. image_data = get_image(image['filename'], image['subfolder'], image['type'])
  107. images_output.append(image_data)
  108. output_images[node_id] = images_output
  109. return output_images
  110. def main():
  111. argv = sys.argv
  112. try:
  113. argv = argv[argv.index("--") + 1 :]
  114. ai_scene_info = convert_base64_string_to_object(argv[0])
  115. print("loading scene data", ai_scene_info)
  116. except Exception as e:
  117. print("Error:", e)
  118. prompt = get_prompt(ai_scene_info)
  119. ws = websocket.WebSocket()
  120. ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
  121. images = get_images(ws, prompt)
  122. #Commented out code to display the output images:
  123. # for node_id in images:
  124. # for image_data in images[node_id]:
  125. # image = Image.open(io.BytesIO(image_data))
  126. # image.show()
  127. if __name__ == "__main__":
  128. main()