This project is a simple REST API that allows to train dreambooth model, charge for usage and inference it. It works on top of Replicate
- Redis
- Postgresql
- Flyway (used for database migrations)
- gRpc
- Prometheus and Grafana
- Docker
-
User sends POST request to
/api/v1/payments/create
, with following json bodyplan_id
- Required. ID for plan, they can vary in amount of images that user can generate and other possible featurespromocode_id
- Optional. Promocode ID in Stripe, if it is incorrect, payment link is created without promocodeversion_id
- Required ifplan.is_init is true
. All plans have column is_init which is used to identify if user purchases an add-on, for example extra amount of images to generate. If this values is not set or is incorrect user gets and error, otherwise webhook receives payment confirmation and plan with extra features binds to providedversion_id
-
User pays using payment link and gets
payment_id
, then one prepares images and sends request to/api/v1/versions/train/{id}
where{id}
ispayment_id
that one received after payment- I have middleware that sets key in redis in order to block extra requests from same user (payment_id). It takes some time to prepare data and receive success response from replicate (meaning that they started the training process) so it is possible to abuse the system. After endpoint is done key will be deleted from redis, or it will be automatically deleted in 5 minutes
- Images are send to imager where they are check for being more than 512x512 pixels and that they have supported content-type, if one image does not meet requirements data about that is stored in response so user might receive all issues after first request. The process of checking involves concurrency
- If everything is alright imager forms zip archive and uploads it to bucket, after request is send to replicate and user gets
version_id
-
With
version_id
user can call/api/v1/versions/info/{id}
and get info about version status. If version is ready user gets extended info about version. At the same time cron is running task which gets running version from postgres and checks to see if it is ready. -
Once version has field
pushed_at is not null
user can perform request to trained model and inference prompts using endpoint/api/v1/prompts/create/{id}
- When user sends request this endpoint also freezes using same middleware so there can not be concurrent request for same model
- When you inference any model on replicate it returns you
prediction_id
which is then used for getting info about running prediction. In order to continuously check for prediction status I use this function - After the prediction is done I transfer images from replicate to cloudflare using similar solution that you have seen before, however I added
sync.WaitGroup
since I do not need to process images until each image is uploaded to cloudflare.
-
User can send request to
/api/v1/prompts/list/{id}
and get info about completed prompts, in order to achieve that in one query I implemented custom type