diff --git a/wren-ui/e2e/specs/connectTrino.spec.ts b/wren-ui/e2e/specs/connectTrino.spec.ts
index d4ef0ad25d..44c24c740d 100644
--- a/wren-ui/e2e/specs/connectTrino.spec.ts
+++ b/wren-ui/e2e/specs/connectTrino.spec.ts
@@ -34,6 +34,34 @@ test.describe('Test Trino data source', () => {
await expect(page).toHaveURL('/setup/models', { timeout: 60000 });
});
+ test('Connect Trino data source with OAuth2 successfully', async ({ page }) => {
+ await page.goto('/setup/connection');
+
+ await page.locator('button').filter({ hasText: 'Trino' }).click();
+
+ await page.getByLabel('Display name').click();
+ await page.getByLabel('Display name').fill('test-trino-oauth2');
+ await page.getByLabel('Host').click();
+ await page.getByLabel('Host').fill(testConfig.trino.host);
+ await page.getByLabel('Port').click();
+ await page.getByLabel('Port').fill(testConfig.trino.port);
+ await page.getByLabel('Catalog').click();
+ await page.getByLabel('Catalog').fill(testConfig.trino.catalog);
+ await page.getByLabel('Schema').click();
+ await page.getByLabel('Schema').fill(testConfig.trino.schema);
+ await page.getByLabel('Username').click();
+ await page.getByLabel('Username').fill(testConfig.trino.username);
+ await page.getByLabel('OAuth2 Client ID').click();
+ await page.getByLabel('OAuth2 Client ID').fill(testConfig.trino.clientId);
+ await page.getByLabel('OAuth2 Client Secret').click();
+ await page.getByLabel('OAuth2 Client Secret').fill(testConfig.trino.clientSecret);
+ await page.getByLabel('OAuth2 Token URL').click();
+ await page.getByLabel('OAuth2 Token URL').fill(testConfig.trino.tokenUrl);
+
+ await page.getByRole('button', { name: 'Next' }).click();
+ await expect(page).toHaveURL('/setup/models', { timeout: 60000 });
+ });
+
test('Setup all models', onboarding.setupModels);
test(
diff --git a/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts b/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts
index a36af9d8ea..dd3e363f64 100644
--- a/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts
+++ b/wren-ui/src/apollo/server/adaptors/ibisAdaptor.ts
@@ -53,6 +53,7 @@ export interface IbisTrinoConnectionInfo {
schema: string;
user: string;
password: string;
+ auth?: OAuth2Authentication;
}
export interface IbisSnowflakeConnectionInfo {
@@ -63,6 +64,12 @@ export interface IbisSnowflakeConnectionInfo {
schema: string;
}
+export interface OAuth2Authentication {
+ clientId: string;
+ clientSecret: string;
+ tokenUrl: string;
+}
+
export type IbisConnectionInfo =
| UrlBasedConnectionInfo
| HostBasedConnectionInfo
diff --git a/wren-ui/src/apollo/server/adaptors/tests/ibisAdaptor.test.ts b/wren-ui/src/apollo/server/adaptors/tests/ibisAdaptor.test.ts
index 5f3d30a6b2..325abb9cd0 100644
--- a/wren-ui/src/apollo/server/adaptors/tests/ibisAdaptor.test.ts
+++ b/wren-ui/src/apollo/server/adaptors/tests/ibisAdaptor.test.ts
@@ -90,6 +90,18 @@ describe('IbisAdaptor', () => {
schema: 'my-schema',
};
+ const mockOAuth2TrinoConnectionInfo = {
+ schemas: 'my-catalog.my-schema',
+ host: 'localhost',
+ port: 5450,
+ username: 'my-username',
+ auth: {
+ clientId: 'my-client-id',
+ clientSecret: 'my-client-secret',
+ tokenUrl: 'https://example.com/token',
+ },
+ };
+
const mockManifest: Manifest = {
catalog: 'wrenai', // eg: "test-catalog"
schema: 'wrenai', // eg: "test-schema"
@@ -274,6 +286,34 @@ describe('IbisAdaptor', () => {
);
});
+ it('should get trino constraints with OAuth2', async () => {
+ const mockResponse = { data: [] };
+ mockedAxios.post.mockResolvedValue(mockResponse);
+
+ const result = await ibisAdaptor.getConstraints(
+ DataSourceName.TRINO,
+ mockOAuth2TrinoConnectionInfo,
+ );
+
+ const { username, host, port, schemas, auth } = mockOAuth2TrinoConnectionInfo;
+ const schemasArray = schemas.split(',');
+ const [catalog, schema] = schemasArray[0].split('.');
+ const expectConnectionInfo = {
+ connectionUrl: `trino://${username}@${host}:${port}/${catalog}/${schema}`,
+ auth: {
+ clientId: auth.clientId,
+ clientSecret: auth.clientSecret,
+ tokenUrl: auth.tokenUrl,
+ },
+ };
+
+ expect(result).toEqual([]);
+ expect(mockedAxios.post).toHaveBeenCalledWith(
+ `${ibisServerEndpoint}/v2/connector/trino/metadata/constraints`,
+ { connectionInfo: expectConnectionInfo },
+ );
+ });
+
it('should get snowflake constraints', async () => {
const mockResponse = { data: [] };
mockedAxios.post.mockResolvedValue(mockResponse);
diff --git a/wren-ui/src/components/pages/setup/dataSources/TrinoProperties.tsx b/wren-ui/src/components/pages/setup/dataSources/TrinoProperties.tsx
index f0381a1573..549d02f091 100644
--- a/wren-ui/src/components/pages/setup/dataSources/TrinoProperties.tsx
+++ b/wren-ui/src/components/pages/setup/dataSources/TrinoProperties.tsx
@@ -89,6 +89,42 @@ export default function TrinoProperties({ mode }: Props) {
>