triposr-s3 / main.py
ashh757's picture
Upload 18 files
8c93108 verified
import os
from utils import process_image, run_model
from boto3 import Session
import torch
import pickle
import datetime
import gzip
# Retrieve credentials from environment variables
session = Session(
aws_access_key_id=os.getenv('AWS_ACCESS_KEY_ID'),
aws_secret_access_key=os.getenv('AWS_SECRET_ACCESS_KEY'),
region_name=os.getenv('AWS_DEFAULT_REGION')
)
s3 = session.client('s3')
def load_model():
with gzip.open('model_quantized_compressed.pkl.gz', 'rb') as f_in:
model_data = f_in.read()
model = pickle.loads(model_data)
print("Model Loaded")
return model
def upload_to_s3(file_path, bucket_name, s3_key):
with open(file_path, 'rb') as f:
s3.upload_fileobj(f, bucket_name, s3_key)
s3_url = f's3://{bucket_name}/{s3_key}'
return s3_url
def generate_mesh(image_path, output_dir, model):
print('Process start')
# Process the image
image = process_image(image_path, output_dir)
print('Process end')
print('Run start')
output_file_path = run_model(model, image, output_dir)
print('Run end')
# Upload the input image and generated mesh file to S3
bucket_name = 'vasana-bkt1'
input_s3_key = f'input_images/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}-{os.path.basename(image_path)}'
output_s3_key = f'output_meshes/{datetime.datetime.now().strftime("%Y%m%d%H%M%S")}-{os.path.basename(output_file_path)}'
input_s3_url = upload_to_s3(image_path, bucket_name, input_s3_key)
output_s3_url = upload_to_s3(output_file_path, bucket_name, output_s3_key)
print(f'Files uploaded to S3:\nInput Image: {input_s3_url}\nOutput Mesh: {output_s3_url}')
return output_file_path