feat: restrict non credential provider interactions (#871)

* wip: add provider field to sqlite user table

* feat: disable invites when credentials provider is not used

* wip: add migration for provider field in user table with sqlite

* wip: remove fields that can not be modified by non credential users

* wip: make username, mail and avatar disabled instead of hidden

* wip: external users membership of group cannot be managed manually

* feat: add alerts to inform about disabled fields and managing group members

* wip: add mysql migration for provider on user table

* chore: fix format issues

* chore: address pull request feedback

* fix: build issue

* fix: deepsource issues

* fix: tests not working

* feat: restrict login to specific auth providers

* chore: address pull request feedback

* fix: deepsource issue
This commit is contained in:
Meier Lukas
2024-07-27 11:38:51 +02:00
committed by GitHub
parent eba4052522
commit 6f7327b774
36 changed files with 2989 additions and 116 deletions

View File

@@ -2,6 +2,7 @@ import { notFound } from "next/navigation";
import { Card, Center, Stack, Text, Title } from "@mantine/core";
import { auth } from "@homarr/auth/next";
import { isProviderEnabled } from "@homarr/auth/server";
import { and, db, eq } from "@homarr/db";
import { invites } from "@homarr/db/schema/sqlite";
import { getScopedI18n } from "@homarr/translation/server";
@@ -19,6 +20,8 @@ interface InviteUsagePageProps {
}
export default async function InviteUsagePage({ params, searchParams }: InviteUsagePageProps) {
if (!isProviderEnabled("credentials")) notFound();
const session = await auth();
if (session) notFound();

View File

@@ -22,6 +22,7 @@ import {
IconUsersGroup,
} from "@tabler/icons-react";
import { isProviderEnabled } from "@homarr/auth/server";
import { getScopedI18n } from "@homarr/translation/server";
import { MainHeader } from "~/components/layout/header";
@@ -65,6 +66,7 @@ export default async function ManageLayout({ children }: PropsWithChildren) {
label: t("items.users.items.invites"),
icon: IconMailForward,
href: "/manage/users/invites",
hidden: !isProviderEnabled("credentials"),
},
{
label: t("items.users.items.groups"),

View File

@@ -3,6 +3,7 @@ import { Card, Group, SimpleGrid, Space, Stack, Text } from "@mantine/core";
import { IconArrowRight } from "@tabler/icons-react";
import { api } from "@homarr/api/server";
import { isProviderEnabled } from "@homarr/auth/server";
import { getScopedI18n } from "@homarr/translation/server";
import { DynamicBreadcrumb } from "~/components/navigation/dynamic-breadcrumb";
@@ -14,6 +15,7 @@ interface LinkProps {
subtitle: string;
count: number;
href: string;
hidden?: boolean;
}
export async function generateMetadata() {
@@ -42,6 +44,7 @@ export default async function ManagementPage() {
title: t("statistic.createUser"),
},
{
hidden: !isProviderEnabled("credentials"),
count: statistics.countInvites,
href: "/manage/users/invites",
subtitle: t("statisticLabel.authentication"),
@@ -72,24 +75,27 @@ export default async function ManagementPage() {
<HeroBanner />
<Space h="md" />
<SimpleGrid cols={{ xs: 1, sm: 2, md: 3 }}>
{links.map((link, index) => (
<Card component={Link} href={link.href} key={`link-${index}`} withBorder>
<Group justify="space-between" wrap="nowrap">
<Group wrap="nowrap">
<Text size="2.4rem" fw="bolder">
{link.count}
</Text>
<Stack gap={0}>
<Text c="red" size="xs">
{link.subtitle}
</Text>
<Text fw="bold">{link.title}</Text>
</Stack>
</Group>
<IconArrowRight />
</Group>
</Card>
))}
{links.map(
(link) =>
!link.hidden && (
<Card component={Link} href={link.href} key={link.href} withBorder>
<Group justify="space-between" wrap="nowrap">
<Group wrap="nowrap">
<Text size="2.4rem" fw="bolder">
{link.count}
</Text>
<Stack gap={0}>
<Text c="red" size="xs">
{link.subtitle}
</Text>
<Text fw="bold">{link.title}</Text>
</Stack>
</Group>
<IconArrowRight />
</Group>
</Card>
),
)}
</SimpleGrid>
</>
);

View File

@@ -93,24 +93,38 @@ export const UserProfileAvatarForm = ({ user }: UserProfileAvatarForm) => {
});
}, [mutate, user.id, openConfirmModal, tManageAvatar]);
const isCredentialsUser = user.provider === "credentials";
return (
<Box pos="relative">
<Menu opened={opened} keepMounted onChange={toggle} position="bottom-start" withArrow>
<Menu
opened={opened}
keepMounted
onChange={isCredentialsUser ? toggle : undefined}
position="bottom-start"
withArrow
>
<Menu.Target>
<UnstyledButton onClick={toggle}>
<UnstyledButton
component={isCredentialsUser ? undefined : "div"}
style={{ cursor: !isCredentialsUser ? "default" : undefined }}
onClick={isCredentialsUser ? toggle : undefined}
>
<UserAvatar user={user} size={200} />
<Button
component="div"
pos="absolute"
bottom={0}
left={0}
size="compact-md"
fw="normal"
variant="default"
leftSection={<IconPencil size={18} stroke={1.5} />}
>
{t("common.action.edit")}
</Button>
{isCredentialsUser && (
<Button
component="div"
pos="absolute"
bottom={0}
left={0}
size="compact-md"
fw="normal"
variant="default"
leftSection={<IconPencil size={18} stroke={1.5} />}
>
{t("common.action.edit")}
</Button>
)}
</UnstyledButton>
</Menu.Target>
<Menu.Dropdown>

View File

@@ -51,8 +51,12 @@ export const UserProfileForm = ({ user }: UserProfileFormProps) => {
},
});
// Only credentials users can edit their profile
const isProviderCredentials = user.provider === "credentials";
const handleSubmit = useCallback(
(values: FormType) => {
if (!isProviderCredentials) return;
mutate({
...values,
id: user.id,
@@ -64,14 +68,25 @@ export const UserProfileForm = ({ user }: UserProfileFormProps) => {
return (
<form onSubmit={form.onSubmit(handleSubmit)}>
<Stack>
<TextInput label={t("user.field.username.label")} withAsterisk {...form.getInputProps("name")} />
<TextInput label={t("user.field.email.label")} {...form.getInputProps("email")} />
<TextInput
disabled={!isProviderCredentials}
label={t("user.field.username.label")}
withAsterisk
{...form.getInputProps("name")}
/>
<TextInput
disabled={!isProviderCredentials}
label={t("user.field.email.label")}
{...form.getInputProps("email")}
/>
<Group justify="end">
<Button type="submit" color="teal" disabled={!form.isDirty()} loading={isPending}>
{t("common.action.saveChanges")}
</Button>
</Group>
{isProviderCredentials && (
<Group justify="end">
<Button type="submit" color="teal" disabled={!form.isDirty()} loading={isPending}>
{t("common.action.saveChanges")}
</Button>
</Group>
)}
</Stack>
</form>
);

View File

@@ -1,5 +1,6 @@
import { notFound } from "next/navigation";
import { Box, Group, Stack, Title } from "@mantine/core";
import { Alert, Box, Group, Stack, Title } from "@mantine/core";
import { IconExclamationCircle } from "@tabler/icons-react";
import { api } from "@homarr/api/server";
import { auth } from "@homarr/auth/next";
@@ -53,8 +54,14 @@ export default async function EditUserPage({ params }: Props) {
notFound();
}
const isCredentialsUser = user.provider === "credentials";
return (
<Stack>
<Alert variant="light" color="yellow" icon={<IconExclamationCircle size="1rem" stroke={1.5} />}>
{t("management.page.user.fieldsDisabledExternalProvider")}
</Alert>
<Title>{tGeneral("title")}</Title>
<Group gap="xl">
<Box flex={1}>
@@ -67,13 +74,15 @@ export default async function EditUserPage({ params }: Props) {
<ProfileLanguageChange />
<DangerZoneRoot>
<DangerZoneItem
label={t("user.action.delete.label")}
description={t("user.action.delete.description")}
action={<DeleteUserButton user={user} />}
/>
</DangerZoneRoot>
{isCredentialsUser && (
<DangerZoneRoot>
<DangerZoneItem
label={t("user.action.delete.label")}
description={t("user.action.delete.description")}
action={<DeleteUserButton user={user} />}
/>
</DangerZoneRoot>
)}
</Stack>
);
}

View File

@@ -28,6 +28,8 @@ export default async function Layout({ children, params }: PropsWithChildren<Lay
notFound();
}
const isCredentialsUser = user.provider === "credentials";
return (
<ManageContainer size="xl">
<DynamicBreadcrumb
@@ -57,11 +59,13 @@ export default async function Layout({ children, params }: PropsWithChildren<Lay
label={tUser("setting.general.title")}
icon={<IconSettings size="1rem" stroke={1.5} />}
/>
<NavigationLink
href={`/manage/users/${params.userId}/security`}
label={tUser("setting.security.title")}
icon={<IconShieldLock size="1rem" stroke={1.5} />}
/>
{isCredentialsUser && (
<NavigationLink
href={`/manage/users/${params.userId}/security`}
label={tUser("setting.security.title")}
icon={<IconShieldLock size="1rem" stroke={1.5} />}
/>
)}
</Stack>
</Stack>
</GridCol>

View File

@@ -28,6 +28,10 @@ export default async function UserSecurityPage({ params }: Props) {
notFound();
}
if (user.provider !== "credentials") {
notFound();
}
return (
<Stack>
<Title>{tSecurity("title")}</Title>

View File

@@ -1,8 +1,11 @@
import Link from "next/link";
import { Anchor, Center, Group, Stack, Table, TableTbody, TableTd, TableTr, Text, Title } from "@mantine/core";
import { Alert, Anchor, Center, Group, Stack, Table, TableTbody, TableTd, TableTr, Text, Title } from "@mantine/core";
import { IconExclamationCircle } from "@tabler/icons-react";
import type { RouterOutputs } from "@homarr/api";
import { api } from "@homarr/api/server";
import { env } from "@homarr/auth/env.mjs";
import { isProviderEnabled } from "@homarr/auth/server";
import { getI18n, getScopedI18n } from "@homarr/translation/server";
import { SearchInput, UserAvatar } from "@homarr/ui";
@@ -28,9 +31,22 @@ export default async function GroupsDetailPage({ params, searchParams }: GroupsD
group.members.filter((member) => member.name?.toLowerCase().includes(searchParams.search!.trim().toLowerCase()))
: group.members;
const providerTypes = isProviderEnabled("credentials")
? env.AUTH_PROVIDERS.length > 1
? "mixed"
: "credentials"
: "external";
return (
<Stack>
<Title>{tMembers("title")}</Title>
{providerTypes !== "credentials" && (
<Alert variant="light" color="yellow" icon={<IconExclamationCircle size="1rem" stroke={1.5} />}>
{t(`group.memberNotice.${providerTypes}`)}
</Alert>
)}
<Group justify="space-between">
<SearchInput
placeholder={t("common.rtl", {
@@ -39,7 +55,9 @@ export default async function GroupsDetailPage({ params, searchParams }: GroupsD
})}
defaultValue={searchParams.search}
/>
<AddGroupMember groupId={group.id} presentUserIds={group.members.map((member) => member.id)} />
{isProviderEnabled("credentials") && (
<AddGroupMember groupId={group.id} presentUserIds={group.members.map((member) => member.id)} />
)}
</Group>
{filteredMembers.length === 0 && (
<Center py="sm">
@@ -60,7 +78,7 @@ export default async function GroupsDetailPage({ params, searchParams }: GroupsD
}
interface RowProps {
member: RouterOutputs["group"]["getPaginated"]["items"][number]["members"][number];
member: RouterOutputs["group"]["getById"]["members"][number];
groupId: string;
}
@@ -70,13 +88,13 @@ const Row = ({ member, groupId }: RowProps) => {
<TableTd>
<Group>
<UserAvatar size="sm" user={member} />
<Anchor component={Link} href={`/manage/users/${member.id}`}>
<Anchor component={Link} href={`/manage/users/${member.id}/general`}>
{member.name}
</Anchor>
</Group>
</TableTd>
<TableTd w={100}>
<RemoveGroupMember user={member} groupId={groupId} />
{member.provider === "credentials" && <RemoveGroupMember user={member} groupId={groupId} />}
</TableTd>
</TableTr>
);

View File

@@ -1,9 +1,16 @@
import { notFound } from "next/navigation";
import { api } from "@homarr/api/server";
import { isProviderEnabled } from "@homarr/auth/server";
import { DynamicBreadcrumb } from "~/components/navigation/dynamic-breadcrumb";
import { InviteListComponent } from "./_components/invite-list";
export default async function InvitesOverviewPage() {
if (!isProviderEnabled("credentials")) {
notFound();
}
const initialInvites = await api.invite.getAll();
return (
<>

View File

@@ -22,18 +22,24 @@ export const MainNavigation = ({ headerSection, footerSection, links }: MainNavi
component={ScrollArea}
>
{links.map((link, index) => {
if (link.hidden) {
return null;
}
const { icon: TablerIcon, ...props } = link;
const Icon = <TablerIcon size={20} stroke={1.5} />;
let clientLink: ClientNavigationLink;
if ("items" in props) {
clientLink = {
...props,
items: props.items.map((item) => {
return {
...item,
icon: <item.icon size={20} stroke={1.5} />,
};
}),
items: props.items
.filter((item) => !item.hidden)
.map((item) => {
return {
...item,
icon: <item.icon size={20} stroke={1.5} />,
};
}),
} as ClientNavigationLink;
} else {
clientLink = props as ClientNavigationLink;
@@ -49,6 +55,7 @@ export const MainNavigation = ({ headerSection, footerSection, links }: MainNavi
interface CommonNavigationLinkProps {
label: string;
icon: TablerIcon;
hidden?: boolean;
}
interface NavigationLinkHref extends CommonNavigationLinkProps {

View File

@@ -57,6 +57,7 @@ export const groupRouter = createTRPCRouter({
name: true,
email: true,
image: true,
provider: true,
},
},
},

View File

@@ -6,9 +6,11 @@ import { invites } from "@homarr/db/schema/sqlite";
import { z } from "@homarr/validation";
import { createTRPCRouter, protectedProcedure } from "../trpc";
import { throwIfCredentialsDisabled } from "./invite/checks";
export const inviteRouter = createTRPCRouter({
getAll: protectedProcedure.query(async ({ ctx }) => {
throwIfCredentialsDisabled();
const dbInvites = await ctx.db.query.invites.findMany({
orderBy: asc(invites.expirationDate),
columns: {
@@ -32,6 +34,7 @@ export const inviteRouter = createTRPCRouter({
}),
)
.mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const id = createId();
const token = randomBytes(20).toString("hex");
@@ -54,6 +57,7 @@ export const inviteRouter = createTRPCRouter({
}),
)
.mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const dbInvite = await ctx.db.query.invites.findFirst({
where: eq(invites.id, input.id),
});

View File

@@ -0,0 +1,12 @@
import { TRPCError } from "@trpc/server";
import { env } from "@homarr/auth/env.mjs";
export const throwIfCredentialsDisabled = () => {
if (!env.AUTH_PROVIDERS.includes("credentials")) {
throw new TRPCError({
code: "FORBIDDEN",
message: "Credentials provider is disabled",
});
}
};

View File

@@ -170,8 +170,8 @@ describe("byId should return group by id including members and permissions", ()
expect(result.members.length).toBe(1);
const userKeys = Object.keys(result.members[0] ?? {});
expect(userKeys.length).toBe(4);
expect(["id", "name", "email", "image"].some((key) => userKeys.includes(key)));
expect(userKeys.length).toBe(5);
expect(["id", "name", "email", "image", "provider"].some((key) => userKeys.includes(key)));
expect(result.permissions.length).toBe(1);
expect(result.permissions[0]).toBe("admin");
});

View File

@@ -22,6 +22,15 @@ vi.mock("@homarr/auth", async () => {
return { ...mod, auth: () => ({}) as Session };
});
// Mock the env module to return the credentials provider
vi.mock("@homarr/auth/env.mjs", () => {
return {
env: {
AUTH_PROVIDERS: ["credentials"],
},
};
});
describe("all should return all existing invites without sensitive informations", () => {
test("invites should not contain sensitive informations", async () => {
// Arrange

View File

@@ -13,6 +13,15 @@ vi.mock("@homarr/auth", async () => {
return { ...mod, auth: () => ({}) as Session };
});
// Mock the env module to return the credentials provider
vi.mock("@homarr/auth/env.mjs", () => {
return {
env: {
AUTH_PROVIDERS: ["credentials"],
},
};
});
describe("initUser should initialize the first user", () => {
it("should throw an error if a user already exists", async () => {
const db = createDb();
@@ -230,6 +239,7 @@ describe("editProfile shoud update user", () => {
password: null,
image: null,
homeBoardId: null,
provider: "credentials",
});
});
@@ -270,6 +280,7 @@ describe("editProfile shoud update user", () => {
password: null,
image: null,
homeBoardId: null,
provider: "credentials",
});
});
});
@@ -294,6 +305,7 @@ describe("delete should delete user", () => {
password: null,
salt: null,
homeBoardId: null,
provider: "ldap" as const,
},
{
id: userToDelete,
@@ -314,6 +326,7 @@ describe("delete should delete user", () => {
password: null,
salt: null,
homeBoardId: null,
provider: "oidc" as const,
},
];

View File

@@ -4,12 +4,17 @@ import { createSaltAsync, hashPasswordAsync } from "@homarr/auth";
import type { Database } from "@homarr/db";
import { and, createId, eq, schema } from "@homarr/db";
import { groupMembers, groupPermissions, groups, invites, users } from "@homarr/db/schema/sqlite";
import type { SupportedAuthProvider } from "@homarr/definitions";
import { logger } from "@homarr/log";
import { validation, z } from "@homarr/validation";
import { createTRPCRouter, protectedProcedure, publicProcedure } from "../trpc";
import { throwIfCredentialsDisabled } from "./invite/checks";
export const userRouter = createTRPCRouter({
initUser: publicProcedure.input(validation.user.init).mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const firstUser = await ctx.db.query.users.findFirst({
columns: {
id: true,
@@ -40,6 +45,7 @@ export const userRouter = createTRPCRouter({
});
}),
register: publicProcedure.input(validation.user.registrationApi).mutation(async ({ ctx, input }) => {
throwIfCredentialsDisabled();
const inviteWhere = and(eq(invites.id, input.inviteId), eq(invites.token, input.token));
const dbInvite = await ctx.db.query.invites.findFirst({
columns: {
@@ -56,7 +62,7 @@ export const userRouter = createTRPCRouter({
});
}
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, input.username);
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, "credentials", input.username);
await createUserAsync(ctx.db, input);
@@ -64,7 +70,8 @@ export const userRouter = createTRPCRouter({
await ctx.db.delete(invites).where(inviteWhere);
}),
create: publicProcedure.input(validation.user.create).mutation(async ({ ctx, input }) => {
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, input.username);
throwIfCredentialsDisabled();
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, "credentials", input.username);
await createUserAsync(ctx.db, input);
}),
@@ -93,6 +100,7 @@ export const userRouter = createTRPCRouter({
columns: {
id: true,
image: true,
provider: true,
},
where: eq(users.id, input.userId),
});
@@ -104,6 +112,13 @@ export const userRouter = createTRPCRouter({
});
}
if (user.provider !== "credentials") {
throw new TRPCError({
code: "FORBIDDEN",
message: "Profile image can not be changed for users with external providers",
});
}
await ctx.db
.update(users)
.set({
@@ -112,13 +127,14 @@ export const userRouter = createTRPCRouter({
.where(eq(users.id, input.userId));
}),
getAll: publicProcedure.query(async ({ ctx }) => {
return ctx.db.query.users.findMany({
return await ctx.db.query.users.findMany({
columns: {
id: true,
name: true,
email: true,
emailVerified: true,
image: true,
provider: true,
},
});
}),
@@ -139,6 +155,7 @@ export const userRouter = createTRPCRouter({
email: true,
emailVerified: true,
image: true,
provider: true,
},
where: eq(users.id, input.userId),
});
@@ -154,7 +171,7 @@ export const userRouter = createTRPCRouter({
}),
editProfile: publicProcedure.input(validation.user.editProfile).mutation(async ({ input, ctx }) => {
const user = await ctx.db.query.users.findFirst({
columns: { email: true },
columns: { email: true, provider: true },
where: eq(users.id, input.id),
});
@@ -165,7 +182,14 @@ export const userRouter = createTRPCRouter({
});
}
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, input.name, input.id);
if (user.provider !== "credentials") {
throw new TRPCError({
code: "FORBIDDEN",
message: "Username and email can not be changed for users with external providers",
});
}
await checkUsernameAlreadyTakenAndThrowAsync(ctx.db, "credentials", input.name, input.id);
const emailDirty = input.email && user.email !== input.email;
await ctx.db
@@ -190,26 +214,38 @@ export const userRouter = createTRPCRouter({
});
}
const dbUser = await ctx.db.query.users.findFirst({
columns: {
id: true,
password: true,
salt: true,
provider: true,
},
where: eq(users.id, input.userId),
});
if (!dbUser) {
throw new TRPCError({
code: "NOT_FOUND",
message: "User not found",
});
}
if (dbUser.provider !== "credentials") {
throw new TRPCError({
code: "FORBIDDEN",
message: "Password can not be changed for users with external providers",
});
}
// Admins can change the password of other users without providing the previous password
const isPreviousPasswordRequired = ctx.session.user.id === input.userId;
logger.info(
`User ${user.id} is changing password for user ${input.userId}, previous password is required: ${isPreviousPasswordRequired}`,
);
if (isPreviousPasswordRequired) {
const dbUser = await ctx.db.query.users.findFirst({
columns: {
id: true,
password: true,
salt: true,
},
where: eq(users.id, input.userId),
});
if (!dbUser) {
throw new TRPCError({
code: "NOT_FOUND",
message: "User not found",
});
}
const previousPasswordHash = await hashPasswordAsync(input.previousPassword, dbUser.salt ?? "");
const isValid = previousPasswordHash === dbUser.password;
@@ -249,9 +285,14 @@ const createUserAsync = async (db: Database, input: z.infer<typeof validation.us
return userId;
};
const checkUsernameAlreadyTakenAndThrowAsync = async (db: Database, username: string, ignoreId?: string) => {
const checkUsernameAlreadyTakenAndThrowAsync = async (
db: Database,
provider: SupportedAuthProvider,
username: string,
ignoreId?: string,
) => {
const user = await db.query.users.findFirst({
where: eq(users.name, username.toLowerCase()),
where: and(eq(users.name, username.toLowerCase()), eq(users.provider, provider)),
});
if (!user) return;

View File

@@ -0,0 +1,9 @@
import type { SupportedAuthProvider } from "@homarr/definitions";
import { env } from "../env.mjs";
export const isProviderEnabled = (provider: SupportedAuthProvider) => {
// The question mark is placed there because isProviderEnabled is called during static build of about page
// eslint-disable-next-line @typescript-eslint/no-unnecessary-condition
return env.AUTH_PROVIDERS?.includes(provider);
};

View File

@@ -1,7 +1,7 @@
import bcrypt from "bcrypt";
import type { Database } from "@homarr/db";
import { eq } from "@homarr/db";
import { and, eq } from "@homarr/db";
import { users } from "@homarr/db/schema/sqlite";
import { logger } from "@homarr/log";
import type { validation, z } from "@homarr/validation";
@@ -11,7 +11,7 @@ export const authorizeWithBasicCredentialsAsync = async (
credentials: z.infer<typeof validation.user.signIn>,
) => {
const user = await db.query.users.findFirst({
where: eq(users.name, credentials.name),
where: and(eq(users.name, credentials.name), eq(users.provider, "credentials")),
});
if (!user?.password) {

View File

@@ -1,7 +1,8 @@
import type { Adapter } from "@auth/core/adapters";
import { CredentialsSignin } from "@auth/core/errors";
import { createId } from "@homarr/db";
import type { Database } from "@homarr/db";
import { and, createId, eq } from "@homarr/db";
import { users } from "@homarr/db/schema/sqlite";
import { logger } from "@homarr/log";
import type { validation } from "@homarr/validation";
import { z } from "@homarr/validation";
@@ -10,7 +11,7 @@ import { env } from "../../../env.mjs";
import { LdapClient } from "../ldap-client";
export const authorizeWithLdapCredentialsAsync = async (
adapter: Adapter,
db: Database,
credentials: z.infer<typeof validation.user.signIn>,
) => {
logger.info(`user ${credentials.name} is trying to log in using LDAP. Connecting to LDAP server...`);
@@ -89,18 +90,30 @@ export const authorizeWithLdapCredentialsAsync = async (
await client.disconnectAsync();
// Create or update user in the database
let user = await adapter.getUserByEmail?.(mailResult.data);
let user = await db.query.users.findFirst({
columns: {
id: true,
name: true,
image: true,
email: true,
emailVerified: true,
provider: true,
},
where: and(eq(users.email, mailResult.data), eq(users.provider, "ldap")),
});
if (!user) {
logger.info(`User ${credentials.name} not found in the database. Creating...`);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
user = await adapter.createUser!({
user = {
id: createId(),
name: credentials.name,
email: mailResult.data,
emailVerified: new Date(), // assume email is verified
});
image: null,
provider: "ldap",
};
await db.insert(users).values(user);
logger.info(`User ${credentials.name} created successfully.`);
}
@@ -108,11 +121,9 @@ export const authorizeWithLdapCredentialsAsync = async (
if (user.name !== credentials.name) {
logger.warn(`User ${credentials.name} found in the database but with different name. Updating...`);
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
user = await adapter.updateUser!({
id: user.id,
name: credentials.name,
});
user.name = credentials.name;
await db.update(users).set({ name: user.name }).where(eq(users.id, user.id));
logger.info(`User ${credentials.name} updated successfully.`);
}

View File

@@ -3,7 +3,6 @@ import type Credentials from "@auth/core/providers/credentials";
import type { Database } from "@homarr/db";
import { validation } from "@homarr/validation";
import { adapter } from "../../adapter";
import { authorizeWithBasicCredentialsAsync } from "./authorization/basic-authorization";
import { authorizeWithLdapCredentialsAsync } from "./authorization/ldap-authorization";
@@ -32,7 +31,7 @@ export const createCredentialsConfiguration = (db: Database) =>
const data = await validation.user.signIn.parseAsync(credentials);
if (data.credentialType === "ldap") {
return await authorizeWithLdapCredentialsAsync(adapter, data).catch(() => null);
return await authorizeWithLdapCredentialsAsync(db, data).catch(() => null);
}
return await authorizeWithBasicCredentialsAsync(db, data);

View File

@@ -32,6 +32,7 @@ export const OidcProvider = (headers: ReadonlyHeaders | null): OIDCConfig<Profil
// Use the name as the username if the preferred_username is an email address
name: profile.preferred_username.includes("@") ? profile.name : profile.preferred_username,
email: profile.email,
provider: "oidc",
};
},
});

View File

@@ -1,9 +1,8 @@
import type { Adapter } from "@auth/core/adapters";
import { CredentialsSignin } from "@auth/core/errors";
import { DrizzleAdapter } from "@auth/drizzle-adapter";
import { describe, expect, test, vi } from "vitest";
import { createId, eq } from "@homarr/db";
import type { Database } from "@homarr/db";
import { and, createId, eq } from "@homarr/db";
import { users } from "@homarr/db/schema/sqlite";
import { createDb } from "@homarr/db/test";
@@ -32,7 +31,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act
const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
@@ -55,7 +54,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act
const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
@@ -85,7 +84,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act
const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
@@ -118,7 +117,7 @@ describe("authorizeWithLdapCredentials", () => {
// Act
const act = () =>
authorizeWithLdapCredentialsAsync(null as unknown as Adapter, {
authorizeWithLdapCredentialsAsync(null as unknown as Database, {
name: "test",
password: "test",
credentialType: "ldap",
@@ -132,7 +131,6 @@ describe("authorizeWithLdapCredentials", () => {
test("should authorize user with correct credentials and create user", async () => {
// Arrange
const db = createDb();
const adapter = DrizzleAdapter(db);
const spy = vi.spyOn(ldapClient, "LdapClient");
spy.mockImplementation(
() =>
@@ -151,7 +149,7 @@ describe("authorizeWithLdapCredentials", () => {
);
// Act
const result = await authorizeWithLdapCredentialsAsync(adapter, {
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
@@ -166,13 +164,68 @@ describe("authorizeWithLdapCredentials", () => {
expect(dbUser?.id).toBe(result.id);
expect(dbUser?.email).toBe("test@gmail.com");
expect(dbUser?.emailVerified).not.toBeNull();
expect(dbUser?.provider).toBe("ldap");
});
test("should authorize user with correct credentials and create user with same email when credentials user already exists", async () => {
// Arrange
const db = createDb();
const spy = vi.spyOn(ldapClient, "LdapClient");
const salt = await createSaltAsync();
spy.mockImplementation(
() =>
({
bindAsync: vi.fn(() => Promise.resolve()),
searchAsync: vi.fn(() =>
Promise.resolve([
{
dn: "test",
mail: "test@gmail.com",
},
]),
),
disconnectAsync: vi.fn(),
}) as unknown as ldapClient.LdapClient,
);
await db.insert(users).values({
id: createId(),
name: "test",
salt,
password: await hashPasswordAsync("test", salt),
email: "test@gmail.com",
provider: "credentials",
});
// Act
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
});
// Assert
expect(result.name).toBe("test");
const dbUser = await db.query.users.findFirst({
where: and(eq(users.name, "test"), eq(users.provider, "ldap")),
});
expect(dbUser).toBeDefined();
expect(dbUser?.id).toBe(result.id);
expect(dbUser?.email).toBe("test@gmail.com");
expect(dbUser?.emailVerified).not.toBeNull();
expect(dbUser?.provider).toBe("ldap");
const credentialsUser = await db.query.users.findFirst({
where: and(eq(users.name, "test"), eq(users.provider, "credentials")),
});
expect(credentialsUser).toBeDefined();
expect(credentialsUser?.id).not.toBe(result.id);
});
test("should authorize user with correct credentials and update name", async () => {
// Arrange
const userId = createId();
const db = createDb();
const adapter = DrizzleAdapter(db);
const salt = await createSaltAsync();
await db.insert(users).values({
id: userId,
@@ -180,10 +233,11 @@ describe("authorizeWithLdapCredentials", () => {
salt,
password: await hashPasswordAsync("test", salt),
email: "test@gmail.com",
provider: "ldap",
});
// Act
const result = await authorizeWithLdapCredentialsAsync(adapter, {
const result = await authorizeWithLdapCredentialsAsync(db, {
name: "test",
password: "test",
credentialType: "ldap",
@@ -200,5 +254,6 @@ describe("authorizeWithLdapCredentials", () => {
expect(dbUser?.id).toBe(userId);
expect(dbUser?.name).toBe("test");
expect(dbUser?.email).toBe("test@gmail.com");
expect(dbUser?.provider).toBe("ldap");
});
});

View File

@@ -1 +1,2 @@
export { hasQueryAccessToIntegrationsAsync } from "./permissions/integration-query-permissions";
export { isProviderEnabled } from "./providers/check-provider";

View File

@@ -0,0 +1 @@
ALTER TABLE `user` ADD `provider` varchar(64) DEFAULT 'credentials' NOT NULL;

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,13 @@
"when": 1720113913876,
"tag": "0004_noisy_giant_girl",
"breakpoints": true
},
{
"idx": 5,
"version": "5",
"when": 1722068832607,
"tag": "0005_soft_microbe",
"breakpoints": true
}
]
}

View File

@@ -0,0 +1 @@
ALTER TABLE `user` ADD `provider` text DEFAULT 'credentials' NOT NULL;

File diff suppressed because it is too large Load Diff

View File

@@ -36,6 +36,13 @@
"when": 1720036615408,
"tag": "0004_peaceful_red_ghost",
"breakpoints": true
},
{
"idx": 5,
"version": "6",
"when": 1722014142492,
"tag": "0005_lean_random",
"breakpoints": true
}
]
}

View File

@@ -13,6 +13,7 @@ import type {
IntegrationPermission,
IntegrationSecretKind,
SectionKind,
SupportedAuthProvider,
WidgetKind,
} from "@homarr/definitions";
import { backgroundImageAttachments, backgroundImageRepeats, backgroundImageSizes } from "@homarr/definitions";
@@ -25,6 +26,7 @@ export const users = mysqlTable("user", {
image: text("image"),
password: text("password"),
salt: text("salt"),
provider: varchar("provider", { length: 64 }).$type<SupportedAuthProvider>().default("credentials").notNull(),
homeBoardId: varchar("homeBoardId", { length: 64 }).references((): AnyMySqlColumn => boards.id, {
onDelete: "set null",
}),

View File

@@ -15,6 +15,7 @@ import type {
IntegrationPermission,
IntegrationSecretKind,
SectionKind,
SupportedAuthProvider,
WidgetKind,
} from "@homarr/definitions";
@@ -26,6 +27,7 @@ export const users = sqliteTable("user", {
image: text("image"),
password: text("password"),
salt: text("salt"),
provider: text("provider").$type<SupportedAuthProvider>().default("credentials").notNull(),
homeBoardId: text("homeBoardId").references((): AnySQLiteColumn => boards.id, {
onDelete: "set null",
}),

View File

@@ -0,0 +1,2 @@
export const supportedAuthProviders = ["credentials", "oidc", "ldap"] as const;
export type SupportedAuthProvider = (typeof supportedAuthProviders)[number];

View File

@@ -4,3 +4,4 @@ export * from "./section";
export * from "./widget";
export * from "./permissions";
export * from "./docker";
export * from "./auth";

View File

@@ -195,6 +195,10 @@ export default {
},
},
},
memberNotice: {
mixed: "Some members are from external providers and cannot be managed here",
external: "All members are from external providers and cannot be managed here",
},
action: {
create: {
label: "New group",
@@ -1334,6 +1338,8 @@ export default {
},
user: {
back: "Back to users",
fieldsDisabledExternalProvider:
"Certain fields are disabled because they are managed by an external authentication provider.",
setting: {
general: {
title: "General",
@@ -1379,7 +1385,7 @@ export default {
},
},
invite: {
title: "Manager user invites",
title: "Manage user invites",
action: {
new: {
title: "New invite",