blob: b06595ce2be9f1cf97ece4e2cb10bdf778cc67e1 [file] [log] [blame]
use clap::ArgMatches;
use diesel::backend::Backend;
use diesel::query_builder::QueryBuilder;
use diesel::QueryResult;
use diesel_table_macro_syntax::{ColumnDef, TableDecl};
use std::collections::HashMap;
use std::error::Error;
use std::fs::File;
use std::io::Read;
use std::path::Path;
use syn::visit::Visit;
use crate::config::Config;
use crate::database::InferConnection;
use crate::infer_schema_internals::{
filter_table_names, load_table_names, ColumnDefinition, ColumnType, ForeignKeyConstraint,
TableData, TableName,
};
use crate::print_schema::DocConfig;
fn compatible_type_list() -> HashMap<&'static str, Vec<&'static str>> {
let mut map = HashMap::new();
map.insert("integer", vec!["int4"]);
map.insert("bigint", vec!["int8"]);
map.insert("smallint", vec!["int2"]);
map.insert("text", vec!["varchar"]);
map
}
pub fn generate_sql_based_on_diff_schema(
config: Config,
matches: &ArgMatches,
schema_file_path: &Path,
) -> Result<(String, String), Box<dyn Error + Send + Sync>> {
let config = config.set_filter(matches)?;
let project_root = crate::find_project_root()?;
let schema_path = project_root.join(schema_file_path);
let mut schema_file = File::open(schema_path)?;
let mut content = String::new();
schema_file.read_to_string(&mut content)?;
let syn_file = syn::parse_file(&content)?;
let mut tables_from_schema = SchemaCollector::default();
tables_from_schema.visit_file(&syn_file);
let mut conn = InferConnection::from_matches(matches);
let tables_from_database = filter_table_names(
load_table_names(&mut conn, None)?,
&config.print_schema.filter,
);
let foreign_keys =
crate::infer_schema_internals::load_foreign_key_constraints(&mut conn, None)?;
let foreign_key_map =
foreign_keys
.into_iter()
.fold(HashMap::<_, Vec<_>>::new(), |mut acc, t| {
acc.entry(t.child_table.rust_name.clone())
.or_default()
.push(t);
acc
});
let mut expected_fk_map = tables_from_schema.joinable.into_iter().try_fold(
HashMap::<_, Vec<_>>::new(),
|mut acc, t| {
t.map(|t| {
acc.entry(t.child_table.to_string()).or_default().push(t);
acc
})
},
)?;
let table_pk_key_list = tables_from_schema
.table_decls
.iter()
.map(|t| {
let t = t.as_ref().unwrap();
Ok((
t.table_name.to_string(),
t.primary_keys.as_ref().map(|keys| {
keys.keys
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
}),
))
})
.collect::<Result<HashMap<_, _>, &syn::Error>>()?;
let mut expected_schema_map = tables_from_schema
.table_decls
.into_iter()
.map(|t| {
let t = t?;
Ok((t.table_name.to_string().to_lowercase(), t))
})
.collect::<Result<HashMap<_, _>, syn::Error>>()?;
let mut schema_diff = Vec::new();
for table in tables_from_database {
let columns = crate::infer_schema_internals::load_table_data(
&mut conn,
table.clone(),
&crate::print_schema::ColumnSorting::OrdinalPosition,
DocConfig::NoDocComments,
)?;
if let Some(t) = expected_schema_map.remove(&table.sql_name.to_lowercase()) {
let mut primary_keys_in_db =
crate::infer_schema_internals::get_primary_keys(&mut conn, &table)?;
primary_keys_in_db.sort();
let mut primary_keys_in_schema = t
.primary_keys
.map(|pk| pk.keys.iter().map(|k| k.to_string()).collect::<Vec<_>>())
.unwrap_or_else(|| vec!["id".into()]);
primary_keys_in_schema.sort();
if primary_keys_in_db != primary_keys_in_schema {
return Err("Cannot change primary keys with --diff-schema yet".into());
}
let mut expected_column_map = t
.column_defs
.into_iter()
.map(|c| (c.sql_name.to_lowercase(), c))
.collect::<HashMap<_, _>>();
let mut added_columns = Vec::new();
let mut removed_columns = Vec::new();
let mut changed_columns = Vec::new();
for c in columns.column_data {
if let Some(def) = expected_column_map.remove(&c.sql_name.to_lowercase()) {
let tpe = ColumnType::for_column_def(&def)?;
if !is_same_type(&c.ty, tpe) {
changed_columns.push((c, def));
}
} else {
removed_columns.push(c);
}
}
added_columns.extend(expected_column_map.into_values());
schema_diff.push(SchemaDiff::ChangeTable {
table: t.sql_name,
added_columns,
removed_columns,
changed_columns,
});
} else {
let foreign_keys = foreign_key_map
.get(&table.rust_name)
.cloned()
.unwrap_or_default();
if foreign_keys
.iter()
.any(|fk| fk.foreign_key_columns.len() != 1 || fk.primary_key_columns.len() != 1)
{
return Err(
"Tables with composite foreign keys are not supported by --diff-schema".into(),
);
}
schema_diff.push(SchemaDiff::DropTable {
table,
columns,
foreign_keys,
});
}
}
schema_diff.extend(expected_schema_map.into_values().map(|t| {
let foreign_keys = expected_fk_map
.remove(&t.table_name.to_string())
.unwrap_or_default()
.into_iter()
.filter_map(|j| {
let referenced_table = table_pk_key_list.get(&t.table_name.to_string())?;
match referenced_table {
None => Some((j, "id".into())),
Some(pks) if pks.len() == 1 => Some((j, pks.first()?.to_string())),
Some(_) => None,
}
})
.collect();
SchemaDiff::CreateTable {
to_create: t,
foreign_keys,
}
}));
let mut up_sql = String::new();
let mut down_sql = String::new();
for diff in schema_diff {
let up = match conn {
#[cfg(feature = "postgres")]
InferConnection::Pg(_) => {
let mut qb = diesel::pg::PgQueryBuilder::default();
diff.generate_up_sql(&mut qb)?;
qb.finish()
}
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(_) => {
let mut qb = diesel::sqlite::SqliteQueryBuilder::default();
diff.generate_up_sql(&mut qb)?;
qb.finish()
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(_) => {
let mut qb = diesel::mysql::MysqlQueryBuilder::default();
diff.generate_up_sql(&mut qb)?;
qb.finish()
}
};
let down = match conn {
#[cfg(feature = "postgres")]
InferConnection::Pg(_) => {
let mut qb = diesel::pg::PgQueryBuilder::default();
diff.generate_down_sql(&mut qb)?;
qb.finish()
}
#[cfg(feature = "sqlite")]
InferConnection::Sqlite(_) => {
let mut qb = diesel::sqlite::SqliteQueryBuilder::default();
diff.generate_down_sql(&mut qb)?;
qb.finish()
}
#[cfg(feature = "mysql")]
InferConnection::Mysql(_) => {
let mut qb = diesel::mysql::MysqlQueryBuilder::default();
diff.generate_down_sql(&mut qb)?;
qb.finish()
}
};
up_sql += &up;
up_sql += "\n";
down_sql += &down;
down_sql += "\n";
}
Ok((up_sql, down_sql))
}
fn is_same_type(ty: &ColumnType, tpe: ColumnType) -> bool {
if ty.is_array != tpe.is_array
|| ty.is_nullable != tpe.is_nullable
|| ty.is_unsigned != tpe.is_unsigned
|| ty.max_length != tpe.max_length
{
return false;
}
let mut is_same_schema = ty.schema == tpe.schema;
if !is_same_schema
&& ((ty.schema.as_deref() == Some("pg_catalog") && tpe.schema.is_none())
|| (tpe.schema.as_deref() == Some("pg_catalog") && ty.schema.is_none()))
{
is_same_schema = true;
}
if ty.sql_name.to_lowercase() == tpe.sql_name.to_lowercase() && is_same_schema {
return true;
}
if !is_same_schema {
return false;
}
let compatible_types = compatible_type_list();
if let Some(compatible) = compatible_types.get(&ty.sql_name.to_lowercase() as &str) {
return compatible.contains(&((&tpe.sql_name.to_lowercase()) as &str));
}
if let Some(compatible) = compatible_types.get(&tpe.sql_name.to_lowercase() as &str) {
return compatible.contains(&(&ty.sql_name.to_lowercase() as &str));
}
false
}
#[allow(clippy::enum_variant_names)]
enum SchemaDiff {
DropTable {
table: TableName,
columns: TableData,
foreign_keys: Vec<ForeignKeyConstraint>,
},
CreateTable {
to_create: TableDecl,
foreign_keys: Vec<(Joinable, String)>,
},
ChangeTable {
table: String,
added_columns: Vec<ColumnDef>,
removed_columns: Vec<ColumnDefinition>,
changed_columns: Vec<(ColumnDefinition, ColumnDef)>,
},
}
impl SchemaDiff {
fn generate_up_sql<DB>(&self, query_builder: &mut impl QueryBuilder<DB>) -> QueryResult<()>
where
DB: Backend,
{
match self {
SchemaDiff::DropTable { table, .. } => {
generate_drop_table(query_builder, &table.sql_name.to_lowercase())?;
}
SchemaDiff::CreateTable {
to_create,
foreign_keys,
} => {
let table = &to_create.sql_name.to_lowercase();
let primary_keys = to_create
.primary_keys
.as_ref()
.map(|keys| keys.keys.iter().map(|k| k.to_string()).collect())
.unwrap_or_else(|| vec![String::from("id")]);
let column_data = to_create
.column_defs
.iter()
.map(|c| {
let ty = ColumnType::for_column_def(c)
.map_err(diesel::result::Error::QueryBuilderError)?;
Ok(ColumnDefinition {
sql_name: c.sql_name.to_lowercase(),
rust_name: c.sql_name.clone(),
ty,
comment: None,
})
})
.collect::<QueryResult<Vec<_>>>()?;
let foreign_keys = foreign_keys
.iter()
.map(|(f, pk)| {
(
f.parent_table.to_string(),
f.ref_column.to_string(),
pk.clone(),
)
})
.collect::<Vec<_>>();
generate_create_table(
query_builder,
table,
&column_data,
&primary_keys,
&foreign_keys,
)?;
}
SchemaDiff::ChangeTable {
table,
added_columns,
removed_columns,
changed_columns,
} => {
for c in removed_columns
.iter()
.chain(changed_columns.iter().map(|(a, _)| a))
{
generate_drop_column(query_builder, &table.to_lowercase(), &c.sql_name)?;
query_builder.push_sql("\n");
}
for c in added_columns
.iter()
.chain(changed_columns.iter().map(|(_, b)| b))
{
generate_add_column(
query_builder,
&table.to_lowercase(),
&c.column_name.to_string().to_lowercase(),
&ColumnType::for_column_def(c)
.map_err(diesel::result::Error::QueryBuilderError)?,
)?;
query_builder.push_sql("\n");
}
}
}
Ok(())
}
fn generate_down_sql<DB>(&self, query_builder: &mut impl QueryBuilder<DB>) -> QueryResult<()>
where
DB: Backend,
{
match self {
SchemaDiff::DropTable {
table,
columns,
foreign_keys,
} => {
let fk = foreign_keys
.iter()
.map(|fk| {
(
fk.parent_table.rust_name.clone(),
fk.foreign_key_columns_rust[0].clone(),
fk.primary_key_columns[0].clone(),
)
})
.collect::<Vec<_>>();
generate_create_table(
query_builder,
&table.sql_name.to_lowercase(),
&columns.column_data,
&columns.primary_key,
&fk,
)?;
}
SchemaDiff::CreateTable { to_create, .. } => {
generate_drop_table(query_builder, &to_create.sql_name.to_lowercase())?;
}
SchemaDiff::ChangeTable {
table,
added_columns,
removed_columns,
changed_columns,
} => {
for c in added_columns
.iter()
.chain(changed_columns.iter().map(|(_, b)| b))
{
generate_drop_column(
query_builder,
&table.to_lowercase(),
&c.column_name.to_string().to_lowercase(),
)?;
query_builder.push_sql("\n");
}
for c in removed_columns
.iter()
.chain(changed_columns.iter().map(|(a, _)| a))
{
generate_add_column(
query_builder,
&table.to_lowercase(),
&c.sql_name.to_lowercase(),
&c.ty,
)?;
query_builder.push_sql("\n");
}
}
}
Ok(())
}
}
fn generate_add_column<DB>(
query_builder: &mut impl QueryBuilder<DB>,
table: &str,
column: &str,
ty: &ColumnType,
) -> QueryResult<()>
where
DB: Backend,
{
query_builder.push_sql("ALTER TABLE ");
query_builder.push_identifier(table)?;
query_builder.push_sql(" ADD COLUMN ");
query_builder.push_identifier(column)?;
generate_column_type_name(query_builder, ty);
query_builder.push_sql(";");
Ok(())
}
fn generate_drop_column<DB>(
query_builder: &mut impl QueryBuilder<DB>,
table: &str,
column: &str,
) -> QueryResult<()>
where
DB: Backend,
{
query_builder.push_sql("ALTER TABLE ");
query_builder.push_identifier(table)?;
query_builder.push_sql(" DROP COLUMN ");
query_builder.push_identifier(column)?;
query_builder.push_sql(";");
Ok(())
}
fn generate_create_table<DB>(
query_builder: &mut impl QueryBuilder<DB>,
table: &str,
column_data: &[ColumnDefinition],
primary_keys: &[String],
foreign_keys: &[(String, String, String)],
) -> QueryResult<()>
where
DB: Backend,
{
query_builder.push_sql("CREATE TABLE ");
query_builder.push_identifier(table)?;
query_builder.push_sql("(\n");
let mut first = true;
let mut foreign_key_list = Vec::with_capacity(foreign_keys.len());
for column in column_data {
if first {
first = false;
} else {
query_builder.push_sql(",\n");
}
query_builder.push_sql("\t");
query_builder.push_identifier(&column.sql_name)?;
generate_column_type_name(query_builder, &column.ty);
if primary_keys.contains(&column.rust_name) && primary_keys.len() == 1 {
query_builder.push_sql(" PRIMARY KEY");
}
if let Some((table, _, pk)) = foreign_keys.iter().find(|(_, k, _)| k == &column.rust_name) {
foreign_key_list.push((column, table, pk));
}
}
if primary_keys.len() > 1 {
query_builder.push_sql(",\n");
query_builder.push_sql("\tPRIMARY KEY(");
for (idx, key) in primary_keys.iter().enumerate() {
query_builder.push_identifier(key)?;
if idx != primary_keys.len() - 1 {
query_builder.push_sql(", ");
}
}
query_builder.push_sql(")");
}
// MySQL parses but ignores “inline REFERENCES specifications”
// (as defined in the SQL standard)
// where the references are defined as part of the column specification.
// MySQL accepts REFERENCES clauses only when specified as
// part of a separate FOREIGN KEY specification.
//
// https://dev.mysql.com/doc/refman/8.0/en/ansi-diff-foreign-keys.html
for (column, table, pk) in foreign_key_list {
query_builder.push_sql(",\n\t");
query_builder.push_sql("FOREIGN KEY (");
query_builder.push_identifier(&column.sql_name)?;
query_builder.push_sql(") REFERENCES ");
query_builder.push_identifier(table)?;
query_builder.push_sql("(");
query_builder.push_identifier(pk)?;
query_builder.push_sql(")");
}
query_builder.push_sql("\n);");
query_builder.push_sql("\n");
Ok(())
}
fn generate_column_type_name<DB>(query_builder: &mut impl QueryBuilder<DB>, ty: &ColumnType)
where
DB: Backend,
{
// TODO: handle schema
query_builder.push_sql(&format!(" {}", ty.sql_name.to_uppercase()));
if let Some(max_length) = ty.max_length {
query_builder.push_sql(&format!("({max_length})"));
}
if !ty.is_nullable {
query_builder.push_sql(" NOT NULL");
}
if ty.is_unsigned {
query_builder.push_sql(" UNSIGNED");
}
if ty.is_array {
query_builder.push_sql("[]");
}
}
fn generate_drop_table<DB>(
query_builder: &mut impl QueryBuilder<DB>,
table: &str,
) -> QueryResult<()>
where
DB: Backend,
{
// TODO: handle schema?
query_builder.push_sql("DROP TABLE IF EXISTS ");
query_builder.push_identifier(table)?;
query_builder.push_sql(";");
Ok(())
}
struct Joinable {
parent_table: syn::Ident,
child_table: syn::Ident,
ref_column: syn::Ident,
}
impl syn::parse::Parse for Joinable {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let child_table = input.parse()?;
let _arrow: syn::Token![->] = input.parse()?;
let parent_table = input.parse()?;
let content;
syn::parenthesized!(content in input);
let ref_column = content.parse()?;
Ok(Self {
child_table,
parent_table,
ref_column,
})
}
}
#[derive(Default)]
struct SchemaCollector {
table_decls: Vec<Result<TableDecl, syn::Error>>,
joinable: Vec<Result<Joinable, syn::Error>>,
}
impl<'ast> syn::visit::Visit<'ast> for SchemaCollector {
fn visit_macro(&mut self, i: &'ast syn::Macro) {
let last_segment = i.path.segments.last();
if last_segment.map(|s| s.ident == "table").unwrap_or(false) {
self.table_decls.push(i.parse_body());
} else if last_segment.map(|s| s.ident == "joinable").unwrap_or(false) {
self.joinable.push(i.parse_body());
}
syn::visit::visit_macro(self, i)
}
}